"""
Databricks Agent Datasets Python SDK. For more details see Databricks Agent Evaluation:
<https://docs.databricks.com/en/generative-ai/agent-evaluation/index.html>
The API docs can be found here:
<https://api-docs.databricks.com/python/databricks-agents/latest/databricks_agent_eval.html#datasets>
"""
import logging
import time
from typing import Any
from mlflow.genai.datasets.evaluation_dataset import EvaluationDataset
from mlflow.store.tracking import SEARCH_EVALUATION_DATASETS_MAX_RESULTS
from mlflow.telemetry.events import CreateDatasetEvent
from mlflow.telemetry.track import record_usage_event
from mlflow.tracking import get_tracking_uri
from mlflow.utils.annotations import deprecated_parameter, experimental
from mlflow.utils.databricks_utils import is_databricks_default_tracking_uri
_logger = logging.getLogger(__name__)
_ERROR_MSG = (
"The `databricks-agents` package is required to use `mlflow.genai.datasets`. "
"Please install it with `pip install databricks-agents`."
)
def _validate_databricks_params(
name: str | None,
dataset_id: str | None = None,
) -> None:
"""
Validate parameters for Databricks environment.
Args:
name: The dataset name parameter (required)
dataset_id: The dataset ID parameter (should not be provided)
Raises:
ValueError: If name is missing or dataset_id is provided
"""
if name is None:
raise ValueError("Parameter 'name' is required.")
if dataset_id is not None:
raise ValueError(
"Parameter 'dataset_id' is only supported outside of Databricks environments. "
"Use 'name' parameter instead."
)
def _validate_non_databricks_params(
name: str | None,
dataset_id: str | None = None,
) -> None:
"""
Validate parameters for non-Databricks environment.
Args:
name: The dataset name parameter (should not be provided)
dataset_id: The dataset ID parameter (required)
Raises:
ValueError: If dataset_id is missing or name is provided
"""
if name is not None:
raise ValueError(
"Parameter 'name' is only supported in Databricks environments. "
"Use 'dataset_id' parameter instead."
)
if dataset_id is None:
raise ValueError(
"Parameter 'dataset_id' is required. "
"Use search_datasets() to find the dataset ID by name if needed."
)
[docs]@deprecated_parameter("uc_table_name", "name")
@record_usage_event(CreateDatasetEvent)
def create_dataset(
name: str | None = None,
experiment_id: str | list[str] | None = None,
tags: dict[str, Any] | None = None,
) -> "EvaluationDataset":
"""
Create a dataset with the given name and associate it with the given experiment.
Args:
name: The name of the dataset. In Databricks, this is the UC table name.
experiment_id: The ID of the experiment(s) to associate the dataset with. If not provided,
the current experiment is inferred from the environment.
tags: Dictionary of tags to apply to the dataset. Not supported in Databricks.
Returns:
An EvaluationDataset object representing the created dataset.
Examples:
.. code-block:: python
from mlflow.genai.datasets import create_dataset
# Create a dataset with a single experiment
dataset = create_dataset(
name="customer_support_qa_v1",
experiment_id="0", # Default experiment
tags={
"version": "1.0",
"purpose": "regression_testing",
"model": "gpt-4",
"team": "ml-platform",
},
)
print(f"Created dataset: {dataset.dataset_id}")
# Output: Created dataset: d-1a2b3c4d5e6f7890abcdef1234567890
# Create a dataset linked to multiple experiments
multi_exp_dataset = create_dataset(
name="cross_team_eval_dataset",
experiment_id=["1", "2", "5"], # Multiple experiment IDs
tags={
"coverage": "comprehensive",
"status": "development",
},
)
# Create a dataset without tags (minimal example)
simple_dataset = create_dataset(
name="quick_test_dataset",
experiment_id="3", # Specific experiment
)
"""
if name is None:
raise ValueError("Parameter 'name' is required.")
experiment_ids = [experiment_id] if isinstance(experiment_id, str) else experiment_id
if is_databricks_default_tracking_uri(get_tracking_uri()):
if tags is not None:
raise NotImplementedError(
"Tags are not supported in Databricks environments. "
"Tags are managed through Unity Catalog."
)
try:
from databricks.agents.datasets import create_dataset as db_create
return EvaluationDataset(db_create(name, experiment_ids))
except ImportError as e:
raise ImportError(_ERROR_MSG) from e
else:
from mlflow.tracking.client import MlflowClient
if experiment_ids is None:
from mlflow.tracking.fluent import _get_experiment_id
current_exp_id = _get_experiment_id()
if current_exp_id:
experiment_ids = [current_exp_id]
mlflow_dataset = MlflowClient().create_dataset(
name=name,
experiment_id=experiment_ids,
tags=tags,
)
return EvaluationDataset(mlflow_dataset)
[docs]@deprecated_parameter("uc_table_name", "name")
def delete_dataset(
name: str | None = None,
dataset_id: str | None = None,
) -> None:
"""
Delete a dataset.
Args:
name: The name of the dataset (Databricks only). In Databricks, this is the UC table name.
dataset_id: The ID of the dataset.
Note:
- In Databricks environments: Use 'name' to specify the dataset.
- Outside of Databricks: Use 'dataset_id' to specify the dataset
Examples:
.. code-block:: python
from mlflow.genai.datasets import delete_dataset, search_datasets
# Delete a specific dataset by ID (non-Databricks)
delete_dataset(dataset_id="d-4b5c6d7e8f9a0b1c2d3e4f5a6b7c8d9e")
# Clean up old test datasets
test_datasets = search_datasets(
filter_string="name LIKE 'test_%' AND tags.environment = 'development'",
order_by=["created_time ASC"],
)
# Delete datasets older than the most recent 5
if len(test_datasets) > 5:
for dataset in test_datasets[:-5]: # Keep the 5 most recent
print(f"Deleting old test dataset: {dataset.name}")
delete_dataset(dataset_id=dataset.dataset_id)
# Delete datasets with specific criteria
deprecated_datasets = search_datasets(filter_string="tags.status = 'deprecated'")
for dataset in deprecated_datasets:
delete_dataset(dataset_id=dataset.dataset_id)
print(f"Deleted deprecated dataset: {dataset.name}")
.. warning::
Deleting a dataset is permanent and cannot be undone. All associated
records, tags, and metadata will be permanently removed.
"""
if is_databricks_default_tracking_uri(get_tracking_uri()):
_validate_databricks_params(name, dataset_id)
try:
from databricks.agents.datasets import delete_dataset as db_delete
return db_delete(name)
except ImportError as e:
raise ImportError(_ERROR_MSG) from e
else:
_validate_non_databricks_params(name, dataset_id)
from mlflow.tracking.client import MlflowClient
MlflowClient().delete_dataset(dataset_id)
[docs]@deprecated_parameter("uc_table_name", "name")
def get_dataset(
name: str | None = None,
dataset_id: str | None = None,
) -> "EvaluationDataset":
"""
Get the dataset with the given name or ID.
Args:
name: The name of the dataset (Databricks only). In Databricks, this is the UC table name.
dataset_id: The ID of the dataset.
Returns:
An EvaluationDataset object representing the retrieved dataset.
Note:
- In Databricks environments: Use 'name' to specify the dataset.
- Outside of Databricks: Use 'dataset_id' to specify the dataset
Examples:
.. code-block:: python
from mlflow.genai.datasets import get_dataset
# Get a dataset by ID (non-Databricks)
dataset = get_dataset(dataset_id="d-7f2e3a9b8c1d4e5f6a7b8c9d0e1f2a3b")
# Access dataset properties
print(f"Dataset name: {dataset.name}")
print(f"Number of records: {len(dataset.records)}")
print(f"Tags: {dataset.tags}")
print(f"Created by: {dataset.created_by}")
# Work with the dataset
df = dataset.to_df() # Convert to pandas DataFrame
schema = dataset.schema # Get auto-computed schema
profile = dataset.profile # Get dataset statistics
# Add new records to the dataset
new_test_cases = [
{
"inputs": {"question": "What is MLflow?"},
"expectations": {"accuracy": 0.95, "contains_tracking": True},
}
]
dataset.merge_records(new_test_cases)
"""
if is_databricks_default_tracking_uri(get_tracking_uri()):
_validate_databricks_params(name, dataset_id)
try:
from databricks.agents.datasets import get_dataset as db_get
return EvaluationDataset(db_get(name))
except ImportError as e:
raise ImportError(_ERROR_MSG) from e
else:
_validate_non_databricks_params(name, dataset_id)
from mlflow.tracking.client import MlflowClient
mlflow_dataset = MlflowClient().get_dataset(dataset_id)
return EvaluationDataset(mlflow_dataset)
[docs]@experimental(version="3.4.0")
def search_datasets(
experiment_ids: str | list[str] | None = None,
filter_string: str | None = None,
max_results: int | None = None,
order_by: list[str] | None = None,
) -> list[EvaluationDataset]:
"""
Search for datasets (non-Databricks only).
.. warning::
Calling ``search_datasets()`` without any parameters will return ALL datasets
in your tracking server. This can be slow or even crash your Python session if
you have many datasets. Always use filters or ``max_results`` to limit the results.
Args:
experiment_ids: Single experiment ID (str) or list of experiment IDs to filter by.
If None, searches across all experiments.
filter_string: SQL-like filter string for dataset attributes. If not specified,
defaults to filtering for datasets created in the last 7 days. Supports filtering by:
- name: Dataset name
- created_by: User who created the dataset
- last_updated_by: User who last updated the dataset
- created_time: Creation timestamp (milliseconds since epoch)
- tags.<key>: Tag values
max_results: Maximum number of results. If not specified, returns all datasets.
order_by: List of columns to order by. Each entry can include an optional
"DESC" or "ASC" suffix (default is "ASC"). If not specified, defaults to
["created_time DESC"]. Supported columns:
- name
- created_time
- last_update_time
Returns:
List of EvaluationDataset objects matching the search criteria
.. list-table:: Common Search Patterns
:widths: 40 60
:header-rows: 1
* - Search Pattern
- Example Code
* - **Find datasets by name**
- .. code-block:: python
# Exact match
datasets = search_datasets(
filter_string="name = 'production_qa_v2'"
)
# Pattern matching
datasets = search_datasets(
filter_string="name LIKE 'qa_%'"
)
* - **Find datasets by experiment**
- .. code-block:: python
# Single experiment
datasets = search_datasets(
experiment_ids="1"
)
# Multiple experiments
datasets = search_datasets(
experiment_ids=["0", "1", "2", "5"]
)
* - **Find datasets by tags**
- .. code-block:: python
# Single tag
datasets = search_datasets(
filter_string="tags.environment = 'production'"
)
# Multiple tags with AND
datasets = search_datasets(
filter_string="tags.status = 'validated' AND tags.version = '2.0'"
)
* - **Find datasets by creator**
- .. code-block:: python
datasets = search_datasets(
filter_string="created_by = 'alice@company.com'"
)
* - **Find recent datasets**
- .. code-block:: python
# Last 10 datasets created
datasets = search_datasets(
order_by=["created_time DESC"],
max_results=10
)
* - **Complex search**
- .. code-block:: python
# Production-ready datasets from specific team
datasets = search_datasets(
experiment_ids="1",
filter_string="tags.status = 'production' AND "
"tags.team = 'ml-platform' AND "
"name LIKE '%customer%'",
order_by=["last_update_time DESC"],
max_results=20
)
Examples:
.. code-block:: python
from mlflow.genai.datasets import search_datasets
# WARNING: This returns ALL datasets - use with caution!
# all_datasets = search_datasets() # May be slow or crash
# Better: Always use filters or limits
recent_datasets = search_datasets(max_results=100)
# Search in specific experiments
exp_datasets = search_datasets(experiment_ids=["1", "2", "3"])
# Find production datasets
prod_datasets = search_datasets(
filter_string="tags.environment = 'production'", order_by=["name ASC"]
)
# Iterate through results (pagination handled automatically)
for dataset in prod_datasets:
print(f"{dataset.name} (ID: {dataset.dataset_id})")
print(f" Records: {len(dataset.records)}")
print(f" Tags: {dataset.tags}")
Note:
This API is not available in Databricks environments. Use Unity Catalog
search capabilities in Databricks instead.
"""
if is_databricks_default_tracking_uri(get_tracking_uri()):
raise NotImplementedError(
"Dataset search is not available in Databricks. "
"Use Unity Catalog search capabilities instead."
)
if isinstance(experiment_ids, str):
experiment_ids = [experiment_ids]
# Set default filter to return datasets created in the last 7 days if no filter provided
# Also handle empty list/string cases where user might pass [] or ""
if not filter_string:
# 7 days ago in milliseconds
seven_days_ago = int((time.time() - 7 * 24 * 60 * 60) * 1000)
filter_string = f"created_time >= {seven_days_ago}"
# Set default order by creation time DESC if no order provided
if order_by is None:
order_by = ["created_time DESC"]
from mlflow.tracking.client import MlflowClient
from mlflow.utils import get_results_from_paginated_fn
def pagination_wrapper_func(number_to_get, next_page_token):
return MlflowClient().search_datasets(
experiment_ids=experiment_ids,
filter_string=filter_string,
max_results=number_to_get,
order_by=order_by,
page_token=next_page_token,
)
mlflow_datasets = get_results_from_paginated_fn(
pagination_wrapper_func,
SEARCH_EVALUATION_DATASETS_MAX_RESULTS,
max_results,
)
return [EvaluationDataset(dataset) for dataset in mlflow_datasets]
[docs]@experimental(version="3.4.0")
def delete_dataset_tag(
dataset_id: str,
key: str,
) -> None:
"""
Delete a tag from a dataset.
Args:
dataset_id: The ID of the dataset.
key: The tag key to delete.
Examples:
.. code-block:: python
from mlflow.genai.datasets import delete_dataset_tag, get_dataset
# Get your dataset
dataset = get_dataset(dataset_id="d-9e8f7c6b5a4d3e2f1a0b9c8d7e6f5a4b")
# Remove a single tag
delete_dataset_tag(dataset_id=dataset.dataset_id, key="deprecated")
# Remove outdated tags during cleanup
outdated_tags = ["old_version", "temp_flag", "development_only"]
for tag_key in outdated_tags:
delete_dataset_tag(dataset_id=dataset.dataset_id, key=tag_key)
# Check remaining tags
updated_dataset = get_dataset(dataset_id=dataset.dataset_id)
print(f"Remaining tags: {updated_dataset.tags}")
Note:
This API is not available in Databricks environments yet.
Tags in Databricks are managed through Unity Catalog.
"""
if is_databricks_default_tracking_uri(get_tracking_uri()):
raise NotImplementedError(
"Dataset tag operations are not available in Databricks yet. "
"Tags are managed through Unity Catalog."
)
from mlflow.tracking.client import MlflowClient
MlflowClient().delete_dataset_tag(dataset_id, key)
def _validate_association_operation():
"""Validate that dataset association operations can be performed."""
from mlflow.store.tracking.file_store import FileStore
from mlflow.tracking._tracking_service.utils import _get_store
if is_databricks_default_tracking_uri(get_tracking_uri()):
raise NotImplementedError(
"Dataset association operations are not available in Databricks yet. "
"Associations are managed through Unity Catalog."
)
store = _get_store()
if isinstance(store, FileStore):
raise NotImplementedError(
"Dataset association operations are not supported with FileStore backend. "
"Please use a database-backed tracking store."
)
[docs]def add_dataset_to_experiments(dataset_id: str, experiment_ids: list[str]) -> "EvaluationDataset":
"""
Add a dataset to additional experiments.
This allows reusing datasets across multiple experiments for evaluation purposes.
Args:
dataset_id: The ID of the dataset to update.
experiment_ids: List of experiment IDs to associate with the dataset.
Returns:
The updated EvaluationDataset with new experiment associations.
Example:
.. code-block:: python
import mlflow
from mlflow.genai.datasets import add_dataset_to_experiments
# Add dataset to new experiments
dataset = add_dataset_to_experiments(
dataset_id="d-abc123", experiment_ids=["1", "2", "3"]
)
print(f"Dataset now associated with {len(dataset.experiment_ids)} experiments")
"""
_validate_association_operation()
from mlflow.tracking.client import MlflowClient
client = MlflowClient()
mlflow_dataset = client.add_dataset_to_experiments(dataset_id, experiment_ids)
return EvaluationDataset(mlflow_dataset)
[docs]def remove_dataset_from_experiments(
dataset_id: str, experiment_ids: list[str]
) -> "EvaluationDataset":
"""
Remove a dataset from experiments.
This operation is idempotent - removing non-existent associations will not raise an error
but will issue a warning.
Args:
dataset_id: The ID of the dataset to update.
experiment_ids: List of experiment IDs to disassociate from the dataset.
Returns:
The updated EvaluationDataset after removing experiment associations.
Example:
.. code-block:: python
import mlflow
from mlflow.genai.datasets import remove_dataset_from_experiments
# Remove dataset from experiments
dataset = remove_dataset_from_experiments(
dataset_id="d-abc123", experiment_ids=["1", "2"]
)
print(f"Dataset now associated with {len(dataset.experiment_ids)} experiments")
"""
_validate_association_operation()
from mlflow.tracking.client import MlflowClient
client = MlflowClient()
mlflow_dataset = client.remove_dataset_from_experiments(dataset_id, experiment_ids)
return EvaluationDataset(mlflow_dataset)
__all__ = [
"EvaluationDataset",
"add_dataset_to_experiments",
"create_dataset",
"delete_dataset",
"delete_dataset_tag",
"get_dataset",
"remove_dataset_from_experiments",
"search_datasets",
"set_dataset_tags",
]