from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any
from google.protobuf.json_format import MessageToDict
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.dataset_record_source import DatasetRecordSource, DatasetRecordSourceType
from mlflow.protos.datasets_pb2 import DatasetRecord as ProtoDatasetRecord
from mlflow.protos.datasets_pb2 import DatasetRecordSource as ProtoDatasetRecordSource
[docs]@dataclass
class DatasetRecord(_MlflowObject):
"""Represents a single record in an evaluation dataset.
A DatasetRecord contains the input data, expected outputs (ground truth),
and metadata for a single evaluation example. Records are immutable once
created and are uniquely identified by their dataset_record_id.
"""
dataset_id: str
inputs: dict[str, Any]
dataset_record_id: str
created_time: int
last_update_time: int
expectations: dict[str, Any] | None = None
tags: dict[str, str] | None = None
source: DatasetRecordSource | None = None
source_id: str | None = None
source_type: str | None = None
created_by: str | None = None
last_updated_by: str | None = None
def __post_init__(self):
if self.inputs is None:
raise ValueError("inputs must be provided")
if self.tags is None:
self.tags = {}
if self.source and isinstance(self.source, DatasetRecordSource):
if not self.source_id:
if self.source.source_type == DatasetRecordSourceType.TRACE:
self.source_id = self.source.source_data.get("trace_id")
else:
self.source_id = self.source.source_data.get("source_id")
if not self.source_type:
self.source_type = self.source.source_type.value
[docs] def to_proto(self) -> ProtoDatasetRecord:
proto = ProtoDatasetRecord()
proto.dataset_record_id = self.dataset_record_id
proto.dataset_id = self.dataset_id
proto.inputs = json.dumps(self.inputs)
proto.created_time = self.created_time
proto.last_update_time = self.last_update_time
if self.expectations is not None:
proto.expectations = json.dumps(self.expectations)
if self.tags is not None:
proto.tags = json.dumps(self.tags)
if self.source is not None:
proto.source = json.dumps(self.source.to_dict())
if self.source_id is not None:
proto.source_id = self.source_id
if self.source_type is not None:
proto.source_type = ProtoDatasetRecordSource.SourceType.Value(self.source_type)
if self.created_by is not None:
proto.created_by = self.created_by
if self.last_updated_by is not None:
proto.last_updated_by = self.last_updated_by
return proto
[docs] @classmethod
def from_proto(cls, proto: ProtoDatasetRecord) -> "DatasetRecord":
inputs = json.loads(proto.inputs) if proto.HasField("inputs") else {}
expectations = json.loads(proto.expectations) if proto.HasField("expectations") else None
tags = json.loads(proto.tags) if proto.HasField("tags") else None
source = None
if proto.HasField("source"):
source_dict = json.loads(proto.source)
source = DatasetRecordSource.from_dict(source_dict)
return cls(
dataset_id=proto.dataset_id,
inputs=inputs,
dataset_record_id=proto.dataset_record_id,
created_time=proto.created_time,
last_update_time=proto.last_update_time,
expectations=expectations,
tags=tags,
source=source,
source_id=proto.source_id if proto.HasField("source_id") else None,
source_type=DatasetRecordSourceType.from_proto(proto.source_type)
if proto.HasField("source_type")
else None,
created_by=proto.created_by if proto.HasField("created_by") else None,
last_updated_by=proto.last_updated_by if proto.HasField("last_updated_by") else None,
)
[docs] def to_dict(self) -> dict[str, Any]:
d = MessageToDict(
self.to_proto(),
preserving_proto_field_name=True,
)
d["inputs"] = json.loads(d["inputs"])
if "expectations" in d:
d["expectations"] = json.loads(d["expectations"])
if "tags" in d:
d["tags"] = json.loads(d["tags"])
if "source" in d:
d["source"] = json.loads(d["source"])
d["created_time"] = self.created_time
d["last_update_time"] = self.last_update_time
return d
[docs] @classmethod
def from_dict(cls, data: dict[str, Any]) -> "DatasetRecord":
# Validate required fields
if "dataset_id" not in data:
raise ValueError("dataset_id is required")
if "dataset_record_id" not in data:
raise ValueError("dataset_record_id is required")
if "inputs" not in data:
raise ValueError("inputs is required")
if "created_time" not in data:
raise ValueError("created_time is required")
if "last_update_time" not in data:
raise ValueError("last_update_time is required")
source = None
if data.get("source"):
source = DatasetRecordSource.from_dict(data["source"])
return cls(
dataset_id=data["dataset_id"],
inputs=data["inputs"],
dataset_record_id=data["dataset_record_id"],
created_time=data["created_time"],
last_update_time=data["last_update_time"],
expectations=data.get("expectations"),
tags=data.get("tags"),
source=source,
source_id=data.get("source_id"),
source_type=data.get("source_type"),
created_by=data.get("created_by"),
last_updated_by=data.get("last_updated_by"),
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DatasetRecord):
return False
return (
self.dataset_record_id == other.dataset_record_id
and self.dataset_id == other.dataset_id
and self.inputs == other.inputs
and self.expectations == other.expectations
and self.tags == other.tags
and self.source == other.source
and self.source_id == other.source_id
and self.source_type == other.source_type
)