from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any
from google.protobuf.json_format import MessageToDict
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.dataset_record_source import DatasetRecordSource, DatasetRecordSourceType
from mlflow.protos.datasets_pb2 import DatasetRecord as ProtoDatasetRecord
from mlflow.protos.datasets_pb2 import DatasetRecordSource as ProtoDatasetRecordSource
[docs]@dataclass
class DatasetRecord(_MlflowObject):
    """Represents a single record in an evaluation dataset.
    A DatasetRecord contains the input data, expected outputs (ground truth),
    and metadata for a single evaluation example. Records are immutable once
    created and are uniquely identified by their dataset_record_id.
    """
    dataset_id: str
    inputs: dict[str, Any]
    dataset_record_id: str
    created_time: int
    last_update_time: int
    expectations: dict[str, Any] | None = None
    tags: dict[str, str] | None = None
    source: DatasetRecordSource | None = None
    source_id: str | None = None
    source_type: str | None = None
    created_by: str | None = None
    last_updated_by: str | None = None
    def __post_init__(self):
        if self.inputs is None:
            raise ValueError("inputs must be provided")
        if self.tags is None:
            self.tags = {}
        if self.source and isinstance(self.source, DatasetRecordSource):
            if not self.source_id:
                if self.source.source_type == DatasetRecordSourceType.TRACE:
                    self.source_id = self.source.source_data.get("trace_id")
                else:
                    self.source_id = self.source.source_data.get("source_id")
            if not self.source_type:
                self.source_type = self.source.source_type.value
[docs]    def to_proto(self) -> ProtoDatasetRecord:
        proto = ProtoDatasetRecord()
        proto.dataset_record_id = self.dataset_record_id
        proto.dataset_id = self.dataset_id
        proto.inputs = json.dumps(self.inputs)
        proto.created_time = self.created_time
        proto.last_update_time = self.last_update_time
        if self.expectations is not None:
            proto.expectations = json.dumps(self.expectations)
        if self.tags is not None:
            proto.tags = json.dumps(self.tags)
        if self.source is not None:
            proto.source = json.dumps(self.source.to_dict())
        if self.source_id is not None:
            proto.source_id = self.source_id
        if self.source_type is not None:
            proto.source_type = ProtoDatasetRecordSource.SourceType.Value(self.source_type)
        if self.created_by is not None:
            proto.created_by = self.created_by
        if self.last_updated_by is not None:
            proto.last_updated_by = self.last_updated_by
        return proto 
[docs]    @classmethod
    def from_proto(cls, proto: ProtoDatasetRecord) -> "DatasetRecord":
        inputs = json.loads(proto.inputs) if proto.HasField("inputs") else {}
        expectations = json.loads(proto.expectations) if proto.HasField("expectations") else None
        tags = json.loads(proto.tags) if proto.HasField("tags") else None
        source = None
        if proto.HasField("source"):
            source_dict = json.loads(proto.source)
            source = DatasetRecordSource.from_dict(source_dict)
        return cls(
            dataset_id=proto.dataset_id,
            inputs=inputs,
            dataset_record_id=proto.dataset_record_id,
            created_time=proto.created_time,
            last_update_time=proto.last_update_time,
            expectations=expectations,
            tags=tags,
            source=source,
            source_id=proto.source_id if proto.HasField("source_id") else None,
            source_type=DatasetRecordSourceType.from_proto(proto.source_type)
            if proto.HasField("source_type")
            else None,
            created_by=proto.created_by if proto.HasField("created_by") else None,
            last_updated_by=proto.last_updated_by if proto.HasField("last_updated_by") else None,
        ) 
[docs]    def to_dict(self) -> dict[str, Any]:
        d = MessageToDict(
            self.to_proto(),
            preserving_proto_field_name=True,
        )
        d["inputs"] = json.loads(d["inputs"])
        if "expectations" in d:
            d["expectations"] = json.loads(d["expectations"])
        if "tags" in d:
            d["tags"] = json.loads(d["tags"])
        if "source" in d:
            d["source"] = json.loads(d["source"])
        d["created_time"] = self.created_time
        d["last_update_time"] = self.last_update_time
        return d 
[docs]    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "DatasetRecord":
        # Validate required fields
        if "dataset_id" not in data:
            raise ValueError("dataset_id is required")
        if "dataset_record_id" not in data:
            raise ValueError("dataset_record_id is required")
        if "inputs" not in data:
            raise ValueError("inputs is required")
        if "created_time" not in data:
            raise ValueError("created_time is required")
        if "last_update_time" not in data:
            raise ValueError("last_update_time is required")
        source = None
        if data.get("source"):
            source = DatasetRecordSource.from_dict(data["source"])
        return cls(
            dataset_id=data["dataset_id"],
            inputs=data["inputs"],
            dataset_record_id=data["dataset_record_id"],
            created_time=data["created_time"],
            last_update_time=data["last_update_time"],
            expectations=data.get("expectations"),
            tags=data.get("tags"),
            source=source,
            source_id=data.get("source_id"),
            source_type=data.get("source_type"),
            created_by=data.get("created_by"),
            last_updated_by=data.get("last_updated_by"),
        ) 
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, DatasetRecord):
            return False
        return (
            self.dataset_record_id == other.dataset_record_id
            and self.dataset_id == other.dataset_id
            and self.inputs == other.inputs
            and self.expectations == other.expectations
            and self.tags == other.tags
            and self.source == other.source
            and self.source_id == other.source_id
            and self.source_type == other.source_type
        )