Source code for mlflow.genai.datasets.evaluation_dataset

from typing import TYPE_CHECKING, Any

from mlflow.data import Dataset
from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin
from mlflow.entities.evaluation_dataset import EvaluationDataset as _EntityEvaluationDataset
from mlflow.genai.datasets.databricks_evaluation_dataset_source import (
    DatabricksEvaluationDatasetSource,
)

if TYPE_CHECKING:
    import pandas as pd
    import pyspark.sql


[docs]class EvaluationDataset(Dataset, PyFuncConvertibleDatasetMixin): """ The public API for evaluation datasets in MLflow's GenAI module. This class provides a unified interface for evaluation datasets, supporting both: - Standard MLflow evaluation datasets (backed by MLflow's tracking store) - Databricks managed datasets (backed by Unity Catalog tables) through the databricks-agents library """ def __init__(self, dataset): """ Initialize the wrapper with either a managed dataset or an MLflow dataset. Args: dataset: Either a Databricks managed dataset (databricks.agents.datasets.Dataset) or an MLflow EvaluationDataset entity (mlflow.entities.evaluation_dataset.EvaluationDataset). The type is determined at runtime. """ if isinstance(dataset, _EntityEvaluationDataset): self._databricks_dataset = None self._mlflow_dataset = dataset else: self._databricks_dataset = dataset self._mlflow_dataset = None self._df = None def __eq__(self, other): """Check equality with another dataset.""" if isinstance(other, _EntityEvaluationDataset) and self._mlflow_dataset: return self._mlflow_dataset == other if isinstance(other, EvaluationDataset): if self._mlflow_dataset and other._mlflow_dataset: return self._mlflow_dataset == other._mlflow_dataset if self._databricks_dataset and other._databricks_dataset: return self._databricks_dataset == other._databricks_dataset return False def __setattr__(self, name, value): """Allow setting internal attributes on the wrapped dataset.""" object.__setattr__(self, name, value) if name == "_records" and hasattr(self, "_mlflow_dataset") and self._mlflow_dataset: self._mlflow_dataset._records = value def __getattr__(self, name): """ Dynamic attribute delegation for simple pass-through properties. This handles attributes that don't require special logic and can be directly delegated to the underlying dataset implementation. """ if name.startswith("_"): raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") if self._mlflow_dataset and hasattr(self._mlflow_dataset, name): return getattr(self._mlflow_dataset, name) elif self._databricks_dataset and hasattr(self._databricks_dataset, name): return getattr(self._databricks_dataset, name) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") @property def digest(self) -> str | None: """String digest (hash) of the dataset provided by the caller that uniquely identifies""" if self._mlflow_dataset: return self._mlflow_dataset.digest if self._databricks_dataset.digest is None: from mlflow.data.digest_utils import compute_pandas_digest return compute_pandas_digest(self.to_df()) return self._databricks_dataset.digest @property def name(self) -> str: """The name of the dataset.""" if self._mlflow_dataset: return self._mlflow_dataset.name return self._databricks_dataset.name if self._databricks_dataset else None @property def dataset_id(self) -> str: """The unique identifier of the dataset.""" if self._mlflow_dataset: return self._mlflow_dataset.dataset_id return self._databricks_dataset.dataset_id if self._databricks_dataset else None @property def source(self): """Source information for the dataset.""" if self._mlflow_dataset: return self._mlflow_dataset.source return DatabricksEvaluationDatasetSource(table_name=self.name, dataset_id=self.dataset_id) @property def source_type(self) -> str | None: """The type of the dataset source.""" if self._mlflow_dataset: return self._mlflow_dataset.source._get_source_type() return self._databricks_dataset.source_type @property def created_time(self) -> int | str | None: """The time the dataset was created.""" if self._mlflow_dataset: return self._mlflow_dataset.created_time return self._databricks_dataset.create_time @property def create_time(self) -> int | str | None: """Alias for created_time (for backward compatibility with managed datasets).""" return self.created_time @property def tags(self) -> dict[str, Any] | None: """The tags for the dataset (MLflow only).""" if self._mlflow_dataset: return self._mlflow_dataset.tags raise NotImplementedError( "Tags are not available for Databricks managed datasets. " "Tags are managed through Unity Catalog. Use Unity Catalog APIs to manage dataset tags." ) @property def experiment_ids(self) -> list[str]: """The experiment IDs associated with the dataset (MLflow only).""" if self._mlflow_dataset: return self._mlflow_dataset.experiment_ids raise NotImplementedError( "Experiment associations are not available for Databricks managed datasets. " "Dataset associations are managed through Unity Catalog." ) @property def records(self): """The records in the dataset (MLflow only).""" if self._mlflow_dataset: return self._mlflow_dataset.records raise NotImplementedError("Records access is not supported for Databricks managed datasets") @property def schema(self) -> str | None: """The schema of the dataset.""" if self._mlflow_dataset: return self._mlflow_dataset.schema return self._databricks_dataset.schema if self._databricks_dataset else None @property def profile(self) -> str | None: """The profile of the dataset.""" if self._mlflow_dataset: return self._mlflow_dataset.profile return self._databricks_dataset.profile if self._databricks_dataset else None
[docs] def set_profile(self, profile: str) -> "EvaluationDataset": """Set the profile of the dataset.""" if self._mlflow_dataset: self._mlflow_dataset._profile = profile return self dataset = self._databricks_dataset.set_profile(profile) return EvaluationDataset(dataset)
[docs] def merge_records( self, records: "list[dict[str, Any]] | pd.DataFrame | pyspark.sql.DataFrame", ) -> "EvaluationDataset": """Merge records into the dataset.""" if self._mlflow_dataset: self._mlflow_dataset.merge_records(records) return self dataset = self._databricks_dataset.merge_records(records) return EvaluationDataset(dataset)
[docs] def to_df(self) -> "pd.DataFrame": """Convert the dataset to a pandas DataFrame.""" if self._mlflow_dataset: return self._mlflow_dataset.to_df() if self._df is None: self._df = self._databricks_dataset.to_df() return self._df
[docs] def has_records(self) -> bool: """Check if dataset records are loaded without triggering a load.""" if self._mlflow_dataset: return self._mlflow_dataset.has_records() return self._df is not None
[docs] def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" if self._mlflow_dataset: return self._mlflow_dataset.to_dict() raise NotImplementedError( "Serialization to dict is not supported for Databricks managed datasets. " "Databricks datasets are persisted in Unity Catalog tables and don't " "require serialization." )
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> "EvaluationDataset": """ Create instance from dictionary representation. Note: This creates an MLflow dataset from serialized data. Databricks managed datasets are loaded directly from Unity Catalog, not from dict. """ mlflow_dataset = _EntityEvaluationDataset.from_dict(data) return cls(mlflow_dataset)
[docs] def to_proto(self): """Convert to protobuf representation.""" if self._mlflow_dataset: return self._mlflow_dataset.to_proto() raise NotImplementedError( "Protobuf serialization is not supported for Databricks managed datasets. " "Databricks datasets are persisted in Unity Catalog tables and don't " "require serialization." )
[docs] @classmethod def from_proto(cls, proto): """ Create instance from protobuf representation. Note: This creates an MLflow dataset from serialized protobuf data. Databricks managed datasets are loaded directly from Unity Catalog, not from protobuf. """ mlflow_dataset = _EntityEvaluationDataset.from_proto(proto) return cls(mlflow_dataset)
def _to_pyfunc_dataset(self): """Support for PyFuncConvertibleDatasetMixin.""" return self.to_evaluation_dataset()
[docs] def to_evaluation_dataset(self, path=None, feature_names=None): """ Converts the dataset to the legacy EvaluationDataset for model evaluation. Required for use with mlflow.evaluate(). """ from mlflow.data.evaluation_dataset import EvaluationDataset as LegacyEvaluationDataset return LegacyEvaluationDataset( data=self.to_df(), path=path, feature_names=feature_names, name=self.name, digest=self.digest, )
def _to_mlflow_entity(self): """Convert to MLflow Dataset entity for logging.""" from mlflow.entities import Dataset as DatasetEntity return DatasetEntity( name=self.name, digest=self.digest, source_type=self.source_type, source=self.source.to_json(), schema=self.schema, profile=self.profile, )