import json
import logging
from functools import cached_property
from inspect import isclass
from typing import Any, Final, TypedDict
import polars as pl
from packaging.version import Version
if Version(pl.__version__).major < 1:
raise ImportError(f"mlflow.data.polars_dataset requires polars>=1.0.0, found {pl.__version__}")
from polars.datatypes.classes import DataType as PolarsDataType
from polars.datatypes.classes import DataTypeClass as PolarsDataTypeClass
from mlflow.data.dataset import Dataset
from mlflow.data.dataset_source import DatasetSource
from mlflow.data.evaluation_dataset import EvaluationDataset
from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin, PyFuncInputsOutputs
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.types.schema import Array, ColSpec, DataType, Object, Property, Schema
_logger = logging.getLogger(__name__)
def hash_polars_df(df: pl.DataFrame) -> str:
# probably not the best way to hash, also see:
# https://github.com/pola-rs/polars/issues/9743
# https://stackoverflow.com/q/76678160
return str(df.hash_rows().sum())
ColSpecType = DataType | Array | Object | str
TYPE_MAP: Final[dict[PolarsDataTypeClass, DataType]] = {
pl.Binary: DataType.binary,
pl.Boolean: DataType.boolean,
pl.Datetime: DataType.datetime,
pl.Float32: DataType.float,
pl.Float64: DataType.double,
pl.Int8: DataType.integer,
pl.Int16: DataType.integer,
pl.Int32: DataType.integer,
pl.Int64: DataType.long,
pl.String: DataType.string,
pl.Utf8: DataType.string,
}
CLOSE_MAP: Final[dict[PolarsDataTypeClass, DataType]] = {
pl.Categorical: DataType.string,
pl.Enum: DataType.string,
pl.Date: DataType.datetime,
pl.UInt8: DataType.integer,
pl.UInt16: DataType.integer,
pl.UInt32: DataType.long,
}
# Remaining types:
# pl.Decimal
# pl.UInt64
# pl.Duration
# pl.Time
# pl.Null
# pl.Object
# pl.Unknown
def infer_schema(df: pl.DataFrame) -> Schema:
return Schema([infer_colspec(df[col]) for col in df.columns])
def infer_colspec(col: pl.Series, *, allow_unknown: bool = True) -> ColSpec:
return ColSpec(
type=infer_dtype(col.dtype, col.name, allow_unknown=allow_unknown),
name=col.name,
required=col.count() > 0,
)
def infer_dtype(
dtype: PolarsDataType | PolarsDataTypeClass, col_name: str, *, allow_unknown: bool
) -> ColSpecType:
cls: PolarsDataTypeClass = dtype if isinstance(dtype, PolarsDataTypeClass) else type(dtype)
mapped = TYPE_MAP.get(cls)
if mapped is not None:
return mapped
mapped = CLOSE_MAP.get(cls)
if mapped is not None:
logging.warning(
"Data type of Column '%s' contains dtype=%s which will be mapped to %s."
" This is not an exact match but is close enough",
col_name,
dtype,
mapped,
)
return mapped
if not isinstance(dtype, PolarsDataType):
return _handle_unknown_dtype(dtype=dtype, col_name=col_name, allow_unknown=allow_unknown)
if isinstance(dtype, (pl.Array, pl.List)):
# cannot check inner if not instantiated
if isclass(dtype):
if not allow_unknown:
_raise_unknown_type(dtype)
return Array("Unknown")
inner = (
"Unknown"
if dtype.inner is None
else infer_dtype(dtype.inner, f"{col_name}.[]", allow_unknown=allow_unknown)
)
return Array(inner)
if isinstance(dtype, pl.Struct):
# cannot check fields if not instantiated
if isclass(dtype):
if not allow_unknown:
_raise_unknown_type(dtype)
return Object([])
return Object([
Property(
name=field.name,
dtype=infer_dtype(
field.dtype, f"{col_name}.{field.name}", allow_unknown=allow_unknown
),
)
for field in dtype.fields
])
return _handle_unknown_dtype(dtype=dtype, col_name=col_name, allow_unknown=allow_unknown)
def _handle_unknown_dtype(dtype: Any, col_name: str, *, allow_unknown: bool) -> str:
if not allow_unknown:
_raise_unknown_type(dtype)
logging.warning(
"Data type of Columns '%s' contains dtype=%s, which cannot be mapped to any DataType",
col_name,
dtype,
)
return str(dtype)
def _raise_unknown_type(dtype: Any) -> None:
msg = f"Unknown type: {dtype!r}"
raise ValueError(msg)
[docs]class PolarsDataset(Dataset, PyFuncConvertibleDatasetMixin):
"""A polars DataFrame for use with MLflow Tracking."""
def __init__(
self,
df: pl.DataFrame,
source: DatasetSource,
targets: str | None = None,
name: str | None = None,
digest: str | None = None,
predictions: str | None = None,
) -> None:
"""
Args:
df: A polars DataFrame.
source: Source of the DataFrame.
targets: Name of the target column. Optional.
name: Name of the dataset. E.g. "wiki_train". If unspecified, a name is automatically
generated.
digest: Digest (hash, fingerprint) of the dataset. If unspecified, a digest is
automatically computed.
predictions: Name of the column containing model predictions, if the dataset contains
model predictions. Optional. If specified, this column must be present in ``df``.
"""
if targets is not None and targets not in df.columns:
raise MlflowException(
f"DataFrame does not contain specified targets column: '{targets}'",
INVALID_PARAMETER_VALUE,
)
if predictions is not None and predictions not in df.columns:
raise MlflowException(
f"DataFrame does not contain specified predictions column: '{predictions}'",
INVALID_PARAMETER_VALUE,
)
# _df needs to be set before super init, as it is used in _compute_digest
# see Dataset.__init__()
self._df = df
super().__init__(source=source, name=name, digest=digest)
self._targets = targets
self._predictions = predictions
def _compute_digest(self) -> str:
"""Compute a digest for the dataset.
Called if the user doesn't supply a digest when constructing the dataset.
"""
return hash_polars_df(self._df)
[docs] class PolarsDatasetConfig(TypedDict):
name: str
digest: str
source: str
source_type: str
schema: str
profile: str
[docs] def to_dict(self) -> PolarsDatasetConfig:
"""Create config dictionary for the dataset.
Return a string dictionary containing the following fields: name, digest, source,
source type, schema, and profile.
"""
schema = json.dumps({"mlflow_colspec": self.schema.to_dict()} if self.schema else None)
return {
"name": self.name,
"digest": self.digest,
"source": self.source.to_json(),
"source_type": self.source._get_source_type(),
"schema": schema,
"profile": json.dumps(self.profile),
}
@property
def df(self) -> pl.DataFrame:
"""Underlying DataFrame."""
return self._df
@property
def source(self) -> DatasetSource:
"""Source of the dataset."""
return self._source
@property
def targets(self) -> str | None:
"""Name of the target column.
May be ``None`` if no target column is available.
"""
return self._targets
@property
def predictions(self) -> str | None:
"""Name of the predictions column.
May be ``None`` if no predictions column is available.
"""
return self._predictions
[docs] class PolarsDatasetProfile(TypedDict):
num_rows: int
num_elements: int
@property
def profile(self) -> PolarsDatasetProfile:
"""Profile of the dataset."""
return {
"num_rows": self._df.height,
"num_elements": self._df.height * self._df.width,
}
@cached_property
def schema(self) -> Schema | None:
"""Instance of :py:class:`mlflow.types.Schema` representing the tabular dataset.
May be ``None`` if the schema cannot be inferred from the dataset.
"""
try:
return infer_schema(self._df)
except Exception as e:
_logger.warning("Failed to infer schema for PolarsDataset. Exception: %s", e)
return None
def to_pyfunc(self) -> PyFuncInputsOutputs:
"""Convert dataset to a collection of pyfunc inputs and outputs for model evaluation."""
if self._targets:
inputs = self._df.drop(*self._targets)
outputs = self._df.select(self._targets).to_series()
return PyFuncInputsOutputs([inputs.to_pandas()], [outputs.to_pandas()])
else:
return PyFuncInputsOutputs([self._df.to_pandas()])
def to_evaluation_dataset(self, path=None, feature_names=None) -> EvaluationDataset:
"""Convert dataset to an EvaluationDataset for model evaluation."""
return EvaluationDataset(
data=self._df.to_pandas(),
targets=self._targets,
path=path,
feature_names=feature_names,
predictions=self._predictions,
name=self.name,
digest=self.digest,
)
[docs]def from_polars(
df: pl.DataFrame,
source: str | DatasetSource | None = None,
targets: str | None = None,
name: str | None = None,
digest: str | None = None,
predictions: str | None = None,
) -> PolarsDataset:
"""Construct a :py:class:`PolarsDataset <mlflow.data.polars_dataset.PolarsDataset>` instance.
Args:
df: A polars DataFrame.
source: Source from which the DataFrame was derived, e.g. a filesystem
path, an S3 URI, an HTTPS URL, a delta table name with version, or
spark table etc. ``source`` may be specified as a URI, a path-like string,
or an instance of
:py:class:`DatasetSource <mlflow.data.dataset_source.DatasetSource>`.
If unspecified, the source is assumed to be the code location
(e.g. notebook cell, script, etc.) where
:py:func:`from_polars <mlflow.data.from_polars>` is being called.
targets: An optional target column name for supervised training. This column
must be present in ``df``.
name: Name of the dataset. If unspecified, a name is generated.
digest: Dataset digest (hash). If unspecified, a digest is computed
automatically.
predictions: An optional predictions column name for model evaluation. This column
must be present in ``df``.
.. code-block:: python
:test:
:caption: Example
import mlflow
import polars as pl
x = pl.DataFrame(
[["tom", 10, 1, 1], ["nick", 15, 0, 1], ["julie", 14, 1, 1]],
schema=["Name", "Age", "Label", "ModelOutput"],
)
dataset = mlflow.data.from_polars(x, targets="Label", predictions="ModelOutput")
"""
from mlflow.data.code_dataset_source import CodeDatasetSource
from mlflow.data.dataset_source_registry import resolve_dataset_source
from mlflow.tracking.context import registry
if source is not None:
if isinstance(source, DatasetSource):
resolved_source = source
else:
resolved_source = resolve_dataset_source(source)
else:
context_tags = registry.resolve_tags()
resolved_source = CodeDatasetSource(tags=context_tags)
return PolarsDataset(
df=df,
source=resolved_source,
targets=targets,
name=name,
digest=digest,
predictions=predictions,
)