from typing import Optional
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.protos.service_pb2 import Metric as ProtoMetric
from mlflow.protos.service_pb2 import MetricWithRunId as ProtoMetricWithRunId
[docs]class Metric(_MlflowObject):
    """
    Metric object.
    """
    def __init__(
        self,
        key,
        value,
        timestamp,
        step,
        model_id: Optional[str] = None,
        dataset_name: Optional[str] = None,
        dataset_digest: Optional[str] = None,
        run_id: Optional[str] = None,
    ):
        if (dataset_name, dataset_digest).count(None) == 1:
            raise MlflowException(
                "Both dataset_name and dataset_digest must be provided if one is provided",
                INVALID_PARAMETER_VALUE,
            )
        self._key = key
        self._value = value
        self._timestamp = timestamp
        self._step = step
        self._model_id = model_id
        self._dataset_name = dataset_name
        self._dataset_digest = dataset_digest
        self._run_id = run_id
    @property
    def key(self):
        """String key corresponding to the metric name."""
        return self._key
    @property
    def value(self):
        """Float value of the metric."""
        return self._value
    @property
    def timestamp(self):
        """Metric timestamp as an integer (milliseconds since the Unix epoch)."""
        return self._timestamp
    @property
    def step(self):
        """Integer metric step (x-coordinate)."""
        return self._step
    @property
    def model_id(self):
        """ID of the Model associated with the metric."""
        return self._model_id
    @property
    def dataset_name(self) -> Optional[str]:
        """String. Name of the dataset associated with the metric."""
        return self._dataset_name
    @property
    def dataset_digest(self) -> Optional[str]:
        """String. Digest of the dataset associated with the metric."""
        return self._dataset_digest
    @property
    def run_id(self) -> Optional[str]:
        """String. Run ID associated with the metric."""
        return self._run_id
[docs]    def to_proto(self):
        metric = ProtoMetric()
        metric.key = self.key
        metric.value = self.value
        metric.timestamp = self.timestamp
        metric.step = self.step
        if self.model_id:
            metric.model_id = self.model_id
        if self.dataset_name:
            metric.dataset_name = self.dataset_name
        if self.dataset_digest:
            metric.dataset_digest = self.dataset_digest
        if self.run_id:
            metric.run_id = self.run_id
        return metric 
[docs]    @classmethod
    def from_proto(cls, proto):
        return cls(
            proto.key,
            proto.value,
            proto.timestamp,
            proto.step,
            model_id=proto.model_id or None,
            dataset_name=proto.dataset_name or None,
            dataset_digest=proto.dataset_digest or None,
            run_id=proto.run_id or None,
        ) 
    def __eq__(self, __o):
        if isinstance(__o, self.__class__):
            return self.__dict__ == __o.__dict__
        return False
    def __hash__(self):
        return hash(
            (
                self._key,
                self._value,
                self._timestamp,
                self._step,
                self._model_id,
                self._dataset_name,
                self._dataset_digest,
                self._run_id,
            )
        )
[docs]    def to_dictionary(self):
        """
        Convert the Metric object to a dictionary.
        Returns:
            dict: The Metric object represented as a dictionary.
        """
        return {
            "key": self.key,
            "value": self.value,
            "timestamp": self.timestamp,
            "step": self.step,
            "model_id": self.model_id,
            "dataset_name": self.dataset_name,
            "dataset_digest": self.dataset_digest,
            "run_id": self._run_id,
        } 
[docs]    @classmethod
    def from_dictionary(cls, metric_dict):
        """
        Create a Metric object from a dictionary.
        Args:
            metric_dict (dict): Dictionary containing metric information.
        Returns:
            Metric: The Metric object created from the dictionary.
        """
        required_keys = ["key", "value", "timestamp", "step"]
        missing_keys = [key for key in required_keys if key not in metric_dict]
        if missing_keys:
            raise MlflowException(
                f"Missing required keys {missing_keys} in metric dictionary",
                INVALID_PARAMETER_VALUE,
            )
        return cls(**metric_dict)  
class MetricWithRunId(Metric):
    def __init__(self, metric: Metric, run_id):
        super().__init__(
            key=metric.key,
            value=metric.value,
            timestamp=metric.timestamp,
            step=metric.step,
        )
        self._run_id = run_id
    @property
    def run_id(self):
        return self._run_id
    def to_dict(self):
        return {
            "key": self.key,
            "value": self.value,
            "timestamp": self.timestamp,
            "step": self.step,
            "run_id": self.run_id,
        }
    def to_proto(self):
        metric = ProtoMetricWithRunId()
        metric.key = self.key
        metric.value = self.value
        metric.timestamp = self.timestamp
        metric.step = self.step
        metric.run_id = self.run_id
        return metric