from typing import Optional
from mlflow.entities.logged_model_parameter import LoggedModelParameter as ModelParam
from mlflow.entities.metric import Metric
from mlflow.entities.model_registry._model_registry_entity import _ModelRegistryEntity
from mlflow.entities.model_registry.model_version_deployment_job_state import (
    ModelVersionDeploymentJobState,
)
from mlflow.entities.model_registry.model_version_status import ModelVersionStatus
from mlflow.entities.model_registry.model_version_tag import ModelVersionTag
from mlflow.protos.model_registry_pb2 import ModelVersion as ProtoModelVersion
from mlflow.protos.model_registry_pb2 import ModelVersionTag as ProtoModelVersionTag
[docs]class ModelVersion(_ModelRegistryEntity):
    """
    MLflow entity for Model Version.
    """
    def __init__(
        self,
        name: str,
        version: str,
        creation_timestamp: int,
        last_updated_timestamp: Optional[int] = None,
        description: Optional[str] = None,
        user_id: Optional[str] = None,
        current_stage: Optional[str] = None,
        source: Optional[str] = None,
        run_id: Optional[str] = None,
        status: str = ModelVersionStatus.to_string(ModelVersionStatus.READY),
        status_message: Optional[str] = None,
        tags: Optional[list[ModelVersionTag]] = None,
        run_link: Optional[str] = None,
        aliases: Optional[list[str]] = None,
        # TODO: Make model_id a required field
        # (currently optional to minimize breakages during prototype development)
        model_id: Optional[str] = None,
        params: Optional[list[ModelParam]] = None,
        metrics: Optional[list[Metric]] = None,
        deployment_job_state: Optional[ModelVersionDeploymentJobState] = None,
    ):
        super().__init__()
        self._name: str = name
        self._version: str = version
        self._creation_time: int = creation_timestamp
        self._last_updated_timestamp: Optional[int] = last_updated_timestamp
        self._description: Optional[str] = description
        self._user_id: Optional[str] = user_id
        self._current_stage: Optional[str] = current_stage
        self._source: Optional[str] = source
        self._run_id: Optional[str] = run_id
        self._run_link: Optional[str] = run_link
        self._status: str = status
        self._status_message: Optional[str] = status_message
        self._tags: dict[str, str] = {tag.key: tag.value for tag in (tags or [])}
        self._aliases: list[str] = aliases or []
        self._model_id: Optional[str] = model_id
        self._params: Optional[list[ModelParam]] = params
        self._metrics: Optional[list[Metric]] = metrics
        self._deployment_job_state: Optional[ModelVersionDeploymentJobState] = deployment_job_state
    @property
    def name(self) -> str:
        """String. Unique name within Model Registry."""
        return self._name
    @name.setter
    def name(self, new_name: str):
        self._name = new_name
    @property
    def version(self) -> str:
        """Version"""
        return self._version
    @property
    def creation_timestamp(self) -> int:
        """Integer. Model version creation timestamp (milliseconds since the Unix epoch)."""
        return self._creation_time
    @property
    def last_updated_timestamp(self) -> Optional[int]:
        """Integer. Timestamp of last update for this model version (milliseconds since the Unix
        epoch).
        """
        return self._last_updated_timestamp
    @last_updated_timestamp.setter
    def last_updated_timestamp(self, updated_timestamp: int):
        self._last_updated_timestamp = updated_timestamp
    @property
    def description(self) -> Optional[str]:
        """String. Description"""
        return self._description
    @description.setter
    def description(self, description: str):
        self._description = description
    @property
    def user_id(self) -> Optional[str]:
        """String. User ID that created this model version."""
        return self._user_id
    @property
    def current_stage(self) -> Optional[str]:
        """String. Current stage of this model version."""
        return self._current_stage
    @current_stage.setter
    def current_stage(self, stage: str):
        self._current_stage = stage
    @property
    def source(self) -> Optional[str]:
        """String. Source path for the model."""
        return self._source
    @property
    def run_id(self) -> Optional[str]:
        """String. MLflow run ID that generated this model."""
        return self._run_id
    @property
    def run_link(self) -> Optional[str]:
        """String. MLflow run link referring to the exact run that generated this model version."""
        return self._run_link
    @property
    def status(self) -> str:
        """String. Current Model Registry status for this model."""
        return self._status
    @property
    def status_message(self) -> Optional[str]:
        """String. Descriptive message for error status conditions."""
        return self._status_message
    @property
    def tags(self) -> dict[str, str]:
        """Dictionary of tag key (string) -> tag value for the current model version."""
        return self._tags
    @property
    def aliases(self) -> list[str]:
        """List of aliases (string) for the current model version."""
        return self._aliases
    @aliases.setter
    def aliases(self, aliases: list[str]):
        self._aliases = aliases
    @property
    def model_id(self) -> Optional[str]:
        """String. ID of the model associated with this version."""
        return self._model_id
    @property
    def params(self) -> Optional[list[ModelParam]]:
        """List of parameters associated with this model version."""
        return self._params
    @property
    def metrics(self) -> Optional[list[Metric]]:
        """List of metrics associated with this model version."""
        return self._metrics
    @property
    def deployment_job_state(self) -> Optional[ModelVersionDeploymentJobState]:
        """Deployment job state for the current model version."""
        return self._deployment_job_state
    @classmethod
    def _properties(cls) -> list[str]:
        # aggregate with base class properties since cls.__dict__ does not do it automatically
        return sorted(cls._get_properties_helper())
    def _add_tag(self, tag: ModelVersionTag):
        self._tags[tag.key] = tag.value
    # proto mappers
[docs]    @classmethod
    def from_proto(cls, proto) -> "ModelVersion":
        # input: mlflow.protos.model_registry_pb2.ModelVersion
        # returns: ModelVersion entity
        model_version = cls(
            proto.name,
            proto.version,
            proto.creation_timestamp,
            proto.last_updated_timestamp,
            proto.description if proto.HasField("description") else None,
            proto.user_id,
            proto.current_stage,
            proto.source,
            proto.run_id if proto.HasField("run_id") else None,
            ModelVersionStatus.to_string(proto.status),
            proto.status_message if proto.HasField("status_message") else None,
            run_link=proto.run_link,
            aliases=proto.aliases,
            deployment_job_state=ModelVersionDeploymentJobState.from_proto(
                proto.deployment_job_state
            ),
        )
        for tag in proto.tags:
            model_version._add_tag(ModelVersionTag.from_proto(tag))
        # TODO: Include params, metrics, and model ID in proto
        return model_version 
[docs]    def to_proto(self):
        # input: ModelVersion entity
        # returns mlflow.protos.model_registry_pb2.ModelVersion
        model_version = ProtoModelVersion()
        model_version.name = self.name
        model_version.version = str(self.version)
        model_version.creation_timestamp = self.creation_timestamp
        if self.last_updated_timestamp is not None:
            model_version.last_updated_timestamp = self.last_updated_timestamp
        if self.description is not None:
            model_version.description = self.description
        if self.user_id is not None:
            model_version.user_id = self.user_id
        if self.current_stage is not None:
            model_version.current_stage = self.current_stage
        if self.source is not None:
            model_version.source = str(self.source)
        if self.run_id is not None:
            model_version.run_id = str(self.run_id)
        if self.run_link is not None:
            model_version.run_link = str(self.run_link)
        if self.status is not None:
            model_version.status = ModelVersionStatus.from_string(self.status)
        if self.status_message:
            model_version.status_message = self.status_message
        model_version.tags.extend(
            [ProtoModelVersionTag(key=key, value=value) for key, value in self._tags.items()]
        )
        model_version.aliases.extend(self.aliases)
        if self.deployment_job_state is not None:
            ModelVersionDeploymentJobState.to_proto(self.deployment_job_state)
        # TODO: Include params, metrics, and model ID in proto
        return model_version