import json
import logging
from functools import cached_property
from typing import TYPE_CHECKING, Any, Optional, Union
from packaging.version import Version
from mlflow.data.dataset import Dataset
from mlflow.data.dataset_source import DatasetSource
from mlflow.data.delta_dataset_source import DeltaDatasetSource
from mlflow.data.digest_utils import get_normalized_md5_digest
from mlflow.data.evaluation_dataset import EvaluationDataset
from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin, PyFuncInputsOutputs
from mlflow.data.spark_dataset_source import SparkDatasetSource
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE
from mlflow.types import Schema
from mlflow.types.utils import _infer_schema
if TYPE_CHECKING:
    import pyspark
_logger = logging.getLogger(__name__)
[docs]class SparkDataset(Dataset, PyFuncConvertibleDatasetMixin):
    """
    Represents a Spark dataset (e.g. data derived from a Spark Table / file directory or Delta
    Table) for use with MLflow Tracking.
    """
    def __init__(
        self,
        df: "pyspark.sql.DataFrame",
        source: DatasetSource,
        targets: Optional[str] = None,
        name: Optional[str] = None,
        digest: Optional[str] = None,
        predictions: Optional[str] = None,
    ):
        if targets is not None and targets not in df.columns:
            raise MlflowException(
                f"The specified Spark dataset does not contain the specified targets column"
                f" '{targets}'.",
                INVALID_PARAMETER_VALUE,
            )
        if predictions is not None and predictions not in df.columns:
            raise MlflowException(
                f"The specified Spark dataset does not contain the specified predictions column"
                f" '{predictions}'.",
                INVALID_PARAMETER_VALUE,
            )
        self._df = df
        self._targets = targets
        self._predictions = predictions
        super().__init__(source=source, name=name, digest=digest)
    def _compute_digest(self) -> str:
        """
        Computes a digest for the dataset. Called if the user doesn't supply
        a digest when constructing the dataset.
        """
        # Retrieve a semantic hash of the DataFrame's logical plan, which is much more efficient
        # and deterministic than hashing DataFrame records
        import numpy as np
        import pyspark
        # Spark 3.1.0+ has a semanticHash() method on DataFrame
        if Version(pyspark.__version__) >= Version("3.1.0"):
            semantic_hash = self._df.semanticHash()
        else:
            semantic_hash = self._df._jdf.queryExecution().analyzed().semanticHash()
        return get_normalized_md5_digest([np.int64(semantic_hash)])
[docs]    def to_dict(self) -> dict[str, str]:
        """Create config dictionary for the dataset.
        Returns a string dictionary containing the following fields: name, digest, source, source
        type, schema, and profile.
        """
        schema = json.dumps({"mlflow_colspec": self.schema.to_dict()}) if self.schema else None
        config = super().to_dict()
        config.update(
            {
                "schema": schema,
                "profile": json.dumps(self.profile),
            }
        )
        return config 
    @property
    def df(self):
        """The Spark DataFrame instance.
        Returns:
            The Spark DataFrame instance.
        """
        return self._df
    @property
    def targets(self) -> Optional[str]:
        """The name of the Spark DataFrame column containing targets (labels) for supervised
        learning.
        Returns:
            The string name of the Spark DataFrame column containing targets.
        """
        return self._targets
    @property
    def predictions(self) -> Optional[str]:
        """
        The name of the predictions column. May be ``None`` if no predictions column
        was specified when the dataset was created.
        """
        return self._predictions
    @property
    def source(self) -> Union[SparkDatasetSource, DeltaDatasetSource]:
        """
        Spark dataset source information.
        Returns:
            An instance of
            :py:class:`SparkDatasetSource <mlflow.data.spark_dataset_source.SparkDatasetSource>` or
            :py:class:`DeltaDatasetSource <mlflow.data.delta_dataset_source.DeltaDatasetSource>`.
        """
        return self._source
    @property
    def profile(self) -> Optional[Any]:
        """
        A profile of the dataset. May be None if no profile is available.
        """
        try:
            from pyspark.rdd import BoundedFloat
            # Use Spark RDD countApprox to get approximate count since count() may be expensive.
            # Note that we call the Scala RDD API because the PySpark API does not respect the
            # specified timeout. Reference code:
            # https://spark.apache.org/docs/3.4.0/api/python/_modules/pyspark/rdd.html
            # #RDD.countApprox. This is confirmed to work in all Spark 3.x versions
            py_rdd = self.df.rdd
            drdd = py_rdd.mapPartitions(lambda it: [float(sum(1 for i in it))])
            jrdd = drdd.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
            jdrdd = drdd.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
            timeout_millis = 5000
            confidence = 0.9
            approx_count_operation = jdrdd.sumApprox(timeout_millis, confidence)
            approx_count_result = approx_count_operation.initialValue()
            approx_count_float = BoundedFloat(
                mean=approx_count_result.mean(),
                confidence=approx_count_result.confidence(),
                low=approx_count_result.low(),
                high=approx_count_result.high(),
            )
            approx_count = int(approx_count_float)
            if approx_count <= 0:
                # An approximate count of zero likely indicates that the count timed
                # out before an estimate could be made. In this case, we use the value
                # "unknown" so that users don't think the dataset is empty
                approx_count = "unknown"
            return {
                "approx_count": approx_count,
            }
        except Exception as e:
            _logger.warning(
                "Encountered an unexpected exception while computing Spark dataset profile."
                " Exception: %s",
                e,
            )
    @cached_property
    def schema(self) -> Optional[Schema]:
        """
        The MLflow ColSpec schema of the Spark dataset.
        """
        try:
            return _infer_schema(self._df)
        except Exception as e:
            _logger.warning("Failed to infer schema for Spark dataset. Exception: %s", e)
            return None
    def to_pyfunc(self) -> PyFuncInputsOutputs:
        """
        Converts the Spark DataFrame to pandas and splits the resulting
        :py:class:`pandas.DataFrame` into: 1. a :py:class:`pandas.DataFrame` of features and
        2. a :py:class:`pandas.Series` of targets.
        To avoid overuse of driver memory, only the first 10,000 DataFrame rows are selected.
        """
        df = self._df.limit(10000).toPandas()
        if self._targets is not None:
            if self._targets not in df.columns:
                raise MlflowException(
                    f"Failed to convert Spark dataset to pyfunc inputs and outputs because"
                    f" the pandas representation of the Spark dataset does not contain the"
                    f" specified targets column '{self._targets}'.",
                    # This is an internal error because we should have validated the presence of
                    # the target column in the Hugging Face dataset at construction time
                    INTERNAL_ERROR,
                )
            inputs = df.drop(columns=self._targets)
            outputs = df[self._targets]
            return PyFuncInputsOutputs(inputs=inputs, outputs=outputs)
        else:
            return PyFuncInputsOutputs(inputs=df, outputs=None)
    def to_evaluation_dataset(self, path=None, feature_names=None) -> EvaluationDataset:
        """
        Converts the dataset to an EvaluationDataset for model evaluation. Required
        for use with mlflow.evaluate().
        """
        return EvaluationDataset(
            data=self._df.limit(10000).toPandas(),
            targets=self._targets,
            path=path,
            feature_names=feature_names,
            predictions=self._predictions,
            name=self.name,
            digest=self.digest,
        ) 
[docs]def load_delta(
    path: Optional[str] = None,
    table_name: Optional[str] = None,
    version: Optional[str] = None,
    targets: Optional[str] = None,
    name: Optional[str] = None,
    digest: Optional[str] = None,
) -> SparkDataset:
    """
    Loads a :py:class:`SparkDataset <mlflow.data.spark_dataset.SparkDataset>` from a Delta table
    for use with MLflow Tracking.
    Args:
        path: The path to the Delta table. Either ``path`` or ``table_name`` must be specified.
        table_name: The name of the Delta table. Either ``path`` or ``table_name`` must be
            specified.
        version: The Delta table version. If not specified, the version will be inferred.
        targets: Optional. The name of the Delta table column containing targets (labels) for
            supervised learning.
        name: The name of the dataset. E.g. "wiki_train". If unspecified, a name is
            automatically generated.
        digest: The digest (hash, fingerprint) of the dataset. If unspecified, a digest
            is automatically computed.
    Returns:
        An instance of :py:class:`SparkDataset <mlflow.data.spark_dataset.SparkDataset>`.
    """
    from mlflow.data.spark_delta_utils import (
        _try_get_delta_table_latest_version_from_path,
        _try_get_delta_table_latest_version_from_table_name,
    )
    if (path, table_name).count(None) != 1:
        raise MlflowException(
            "Must specify exactly one of `table_name` or `path`.",
            INVALID_PARAMETER_VALUE,
        )
    if version is None:
        if path is not None:
            version = _try_get_delta_table_latest_version_from_path(path)
        else:
            version = _try_get_delta_table_latest_version_from_table_name(table_name)
    if name is None and table_name is not None:
        name = table_name + (f"@v{version}" if version is not None else "")
    source = DeltaDatasetSource(path=path, delta_table_name=table_name, delta_table_version=version)
    df = source.load()
    return SparkDataset(
        df=df,
        source=source,
        targets=targets,
        name=name,
        digest=digest,
    ) 
[docs]def from_spark(
    df: "pyspark.sql.DataFrame",
    path: Optional[str] = None,
    table_name: Optional[str] = None,
    version: Optional[str] = None,
    sql: Optional[str] = None,
    targets: Optional[str] = None,
    name: Optional[str] = None,
    digest: Optional[str] = None,
    predictions: Optional[str] = None,
) -> SparkDataset:
    """
    Given a Spark DataFrame, constructs a
    :py:class:`SparkDataset <mlflow.data.spark_dataset.SparkDataset>` object for use with
    MLflow Tracking.
    Args:
        df: The Spark DataFrame from which to construct a SparkDataset.
        path: The path of the Spark or Delta source that the DataFrame originally came from. Note
            that the path does not have to match the DataFrame exactly, since the DataFrame may have
            been modified by Spark operations. This is used to reload the dataset upon request via
            :py:func:`SparkDataset.source.load()
            <mlflow.data.spark_dataset_source.SparkDatasetSource.load>`. If none of ``path``,
            ``table_name``, or ``sql`` are specified, a CodeDatasetSource is used, which will source
            information from the run context.
        table_name: The name of the Spark or Delta table that the DataFrame originally came from.
            Note that the table does not have to match the DataFrame exactly, since the DataFrame
            may have been modified by Spark operations. This is used to reload the dataset upon
            request via :py:func:`SparkDataset.source.load()
            <mlflow.data.spark_dataset_source.SparkDatasetSource.load>`. If none of ``path``,
            ``table_name``, or ``sql`` are specified, a CodeDatasetSource is used, which will source
            information from the run context.
        version: If the DataFrame originally came from a Delta table, specifies the version of the
            Delta table. This is used to reload the dataset upon request via
            :py:func:`SparkDataset.source.load()
            <mlflow.data.spark_dataset_source.SparkDatasetSource.load>`. ``version`` cannot be
            specified if ``sql`` is specified.
        sql: The Spark SQL statement that was originally used to construct the DataFrame. Note that
            the Spark SQL statement does not have to match the DataFrame exactly, since the
            DataFrame may have been modified by Spark operations. This is used to reload the dataset
            upon request via :py:func:`SparkDataset.source.load()
            <mlflow.data.spark_dataset_source.SparkDatasetSource.load>`. If none of ``path``,
            ``table_name``, or ``sql`` are specified, a CodeDatasetSource is used, which will source
            information from the run context.
        targets: Optional. The name of the Data Frame column containing targets (labels) for
            supervised learning.
        name: The name of the dataset. E.g. "wiki_train". If unspecified, a name is automatically
            generated.
        digest: The digest (hash, fingerprint) of the dataset. If unspecified, a digest is
            automatically computed.
        predictions: Optional. The name of the column containing model predictions,
            if the dataset contains model predictions. If specified, this column
            must be present in the dataframe (``df``).
    Returns:
        An instance of :py:class:`SparkDataset <mlflow.data.spark_dataset.SparkDataset>`.
    """
    from mlflow.data.code_dataset_source import CodeDatasetSource
    from mlflow.data.spark_delta_utils import (
        _is_delta_table,
        _is_delta_table_path,
        _try_get_delta_table_latest_version_from_path,
        _try_get_delta_table_latest_version_from_table_name,
    )
    from mlflow.tracking.context import registry
    if (path, table_name, sql).count(None) < 2:
        raise MlflowException(
            "Must specify at most one of `path`, `table_name`, or `sql`.",
            INVALID_PARAMETER_VALUE,
        )
    if (sql, version).count(None) == 0:
        raise MlflowException(
            "`version` may not be specified when `sql` is specified. `version` may only be"
            " specified when `table_name` or `path` is specified.",
            INVALID_PARAMETER_VALUE,
        )
    if sql is not None:
        source = SparkDatasetSource(sql=sql)
    elif path is not None:
        if _is_delta_table_path(path):
            version = version or _try_get_delta_table_latest_version_from_path(path)
            source = DeltaDatasetSource(path=path, delta_table_version=version)
        elif version is None:
            source = SparkDatasetSource(path=path)
        else:
            raise MlflowException(
                f"Version '{version}' was specified, but the path '{path}' does not refer"
                f" to a Delta table.",
                INVALID_PARAMETER_VALUE,
            )
    elif table_name is not None:
        if _is_delta_table(table_name):
            version = version or _try_get_delta_table_latest_version_from_table_name(table_name)
            source = DeltaDatasetSource(
                delta_table_name=table_name,
                delta_table_version=version,
            )
        elif version is None:
            source = SparkDatasetSource(table_name=table_name)
        else:
            raise MlflowException(
                f"Version '{version}' was specified, but could not find a Delta table with name"
                f" '{table_name}'.",
                INVALID_PARAMETER_VALUE,
            )
    else:
        context_tags = registry.resolve_tags()
        source = CodeDatasetSource(tags=context_tags)
    return SparkDataset(
        df=df,
        source=source,
        targets=targets,
        name=name,
        digest=digest,
        predictions=predictions,
    )