import json
import logging
from functools import cached_property
from typing import Any, Optional, Union
import numpy as np
from mlflow.data.dataset import Dataset
from mlflow.data.dataset_source import DatasetSource
from mlflow.data.digest_utils import (
    MAX_ROWS,
    compute_numpy_digest,
    get_normalized_md5_digest,
)
from mlflow.data.evaluation_dataset import EvaluationDataset
from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin, PyFuncInputsOutputs
from mlflow.data.schema import TensorDatasetSchema
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE
from mlflow.types.schema import Schema
from mlflow.types.utils import _infer_schema
_logger = logging.getLogger(__name__)
[docs]class TensorFlowDataset(Dataset, PyFuncConvertibleDatasetMixin):
    """
    Represents a TensorFlow dataset for use with MLflow Tracking.
    """
    def __init__(
        self,
        features,
        source: DatasetSource,
        targets=None,
        name: Optional[str] = None,
        digest: Optional[str] = None,
    ):
        """
        Args:
            features: A TensorFlow dataset or tensor of features.
            source: The source of the TensorFlow dataset.
            targets: A TensorFlow dataset or tensor of targets. Optional.
            name: The name of the dataset. E.g. "wiki_train". If unspecified, a name is
                automatically generated.
            digest: The digest (hash, fingerprint) of the dataset. If unspecified, a digest
                is automatically computed.
        """
        import tensorflow as tf
        if not isinstance(features, tf.data.Dataset) and not tf.is_tensor(features):
            raise MlflowException(
                f"'features' must be an instance of tf.data.Dataset or a TensorFlow Tensor."
                f" Found: {type(features)}.",
                INVALID_PARAMETER_VALUE,
            )
        if tf.is_tensor(features) and targets is not None and not tf.is_tensor(targets):
            raise MlflowException(
                f"If 'features' is a TensorFlow Tensor, then 'targets' must also be a TensorFlow"
                f" Tensor. Found: {type(targets)}.",
                INVALID_PARAMETER_VALUE,
            )
        if (
            isinstance(features, tf.data.Dataset)
            and targets is not None
            and not isinstance(targets, tf.data.Dataset)
        ):
            raise MlflowException(
                "If 'features' is an instance of tf.data.Dataset, then 'targets' must also be an"
                f" instance of tf.data.Dataset. Found: {type(targets)}.",
                INVALID_PARAMETER_VALUE,
            )
        self._features = features
        self._targets = targets
        super().__init__(source=source, name=name, digest=digest)
    def _compute_tensorflow_dataset_digest(
        self,
        dataset,
        targets=None,
    ) -> str:
        """Computes a digest for the given Tensorflow dataset.
        Args:
            dataset: A Tensorflow dataset.
        Returns:
            A string digest.
        """
        import pandas as pd
        import tensorflow as tf
        hashable_elements = []
        def hash_tf_dataset_iterator_element(element):
            if element is None:
                return
            flat_element = tf.nest.flatten(element)
            flattened_array = np.concatenate([x.flatten() for x in flat_element])
            trimmed_array = flattened_array[0:MAX_ROWS]
            try:
                hashable_elements.append(pd.util.hash_array(trimmed_array))
            except TypeError:
                hashable_elements.append(np.int64(trimmed_array.size))
        for element in dataset.as_numpy_iterator():
            hash_tf_dataset_iterator_element(element)
        if targets is not None:
            for element in targets.as_numpy_iterator():
                hash_tf_dataset_iterator_element(element)
        return get_normalized_md5_digest(hashable_elements)
    def _compute_tensor_digest(
        self,
        tensor_data,
        tensor_targets,
    ) -> str:
        """Computes a digest for the given Tensorflow tensor.
        Args:
            tensor_data: A Tensorflow tensor, representing the features.
            tensor_targets: A Tensorflow tensor, representing the targets. Optional.
        Returns:
            A string digest.
        """
        if tensor_targets is None:
            return compute_numpy_digest(tensor_data.numpy())
        else:
            return compute_numpy_digest(tensor_data.numpy(), tensor_targets.numpy())
    def _compute_digest(self) -> str:
        """
        Computes a digest for the dataset. Called if the user doesn't supply
        a digest when constructing the dataset.
        """
        import tensorflow as tf
        if isinstance(self._features, tf.data.Dataset):
            return self._compute_tensorflow_dataset_digest(self._features, self._targets)
        return self._compute_tensor_digest(self._features, self._targets)
[docs]    def to_dict(self) -> dict[str, str]:
        """Create config dictionary for the dataset.
        Returns a string dictionary containing the following fields: name, digest, source, source
        type, schema, and profile.
        """
        schema = json.dumps(self.schema.to_dict()) if self.schema else None
        config = super().to_dict()
        config.update(
            {
                "schema": schema,
                "profile": json.dumps(self.profile),
            }
        )
        return config 
    @property
    def data(self):
        """
        The underlying TensorFlow data.
        """
        return self._features
    @property
    def source(self) -> DatasetSource:
        """
        The source of the dataset.
        """
        return self._source
    @property
    def targets(self):
        """
        The targets of the dataset.
        """
        return self._targets
    @property
    def profile(self) -> Optional[Any]:
        """
        A profile of the dataset. May be None if no profile is available.
        """
        import tensorflow as tf
        profile = {
            "features_cardinality": int(self._features.cardinality().numpy())
            if isinstance(self._features, tf.data.Dataset)
            else int(tf.size(self._features).numpy()),
        }
        if self._targets is not None:
            profile.update(
                {
                    "targets_cardinality": int(self._targets.cardinality().numpy())
                    if isinstance(self._targets, tf.data.Dataset)
                    else int(tf.size(self._targets).numpy()),
                }
            )
        return profile
    @cached_property
    def schema(self) -> Optional[TensorDatasetSchema]:
        """
        An MLflow TensorSpec schema representing the tensor dataset
        """
        try:
            features_schema = TensorFlowDataset._get_tf_object_schema(self._features)
            targets_schema = None
            if self._targets is not None:
                targets_schema = TensorFlowDataset._get_tf_object_schema(self._targets)
            return TensorDatasetSchema(features=features_schema, targets=targets_schema)
        except Exception as e:
            _logger.warning("Failed to infer schema for TensorFlow dataset. Exception: %s", e)
            return None
    @staticmethod
    def _get_tf_object_schema(tf_object) -> Schema:
        import tensorflow as tf
        if isinstance(tf_object, tf.data.Dataset):
            numpy_data = next(tf_object.as_numpy_iterator())
            if isinstance(numpy_data, np.ndarray):
                return _infer_schema(numpy_data)
            elif isinstance(numpy_data, dict):
                return TensorFlowDataset._get_schema_from_tf_dataset_dict_numpy_data(numpy_data)
            elif isinstance(numpy_data, tuple):
                return TensorFlowDataset._get_schema_from_tf_dataset_tuple_numpy_data(numpy_data)
            else:
                raise MlflowException(
                    f"Failed to infer schema for tf.data.Dataset due to unrecognized numpy iterator"
                    f" data type. Numpy iterator data types 'np.ndarray', 'dict', and 'tuple' are"
                    f" supported. Found: {type(numpy_data)}.",
                    INVALID_PARAMETER_VALUE,
                )
        elif tf.is_tensor(tf_object):
            return _infer_schema(tf_object.numpy())
        else:
            raise MlflowException(
                f"Cannot infer schema of an object that is not an instance of tf.data.Dataset or"
                f" a TensorFlow Tensor. Found: {type(tf_object)}",
                INTERNAL_ERROR,
            )
    @staticmethod
    def _get_schema_from_tf_dataset_dict_numpy_data(numpy_data: dict[Any, Any]) -> Schema:
        if not all(isinstance(data_element, np.ndarray) for data_element in numpy_data.values()):
            raise MlflowException(
                "Failed to infer schema for tf.data.Dataset. Schemas can only be inferred"
                " if the dataset consists of tensors. Ragged tensors, tensor arrays, and"
                " other types are not supported. Additionally, datasets with nested tensors"
                " are not supported.",
                INVALID_PARAMETER_VALUE,
            )
        return _infer_schema(numpy_data)
    @staticmethod
    def _get_schema_from_tf_dataset_tuple_numpy_data(numpy_data: tuple[Any]) -> Schema:
        if not all(isinstance(data_element, np.ndarray) for data_element in numpy_data):
            raise MlflowException(
                "Failed to infer schema for tf.data.Dataset. Schemas can only be inferred"
                " if the dataset consists of tensors. Ragged tensors, tensor arrays, and"
                " other types are not supported. Additionally, datasets with nested tensors"
                " are not supported.",
                INVALID_PARAMETER_VALUE,
            )
        return _infer_schema(
            {
                # MLflow Schemas currently require each tensor to have a name, if more than
                # one tensor is defined. Accordingly, use the index as the name
                str(i): data_element
                for i, data_element in enumerate(numpy_data)
            }
        )
    def to_pyfunc(self) -> PyFuncInputsOutputs:
        """
        Converts the dataset to a collection of pyfunc inputs and outputs for model
        evaluation. Required for use with mlflow.evaluate().
        """
        return PyFuncInputsOutputs(self._features, self._targets)
[docs]    def to_evaluation_dataset(self, path=None, feature_names=None) -> EvaluationDataset:
        """
        Converts the dataset to an EvaluationDataset for model evaluation. Only supported if the
        dataset is a Tensor. Required for use with mlflow.evaluate().
        """
        import tensorflow as tf
        # check that data and targets are Tensors
        if not tf.is_tensor(self._features):
            raise MlflowException("Data must be a Tensor to convert to an EvaluationDataset.")
        if self._targets is not None and not tf.is_tensor(self._targets):
            raise MlflowException("Targets must be a Tensor to convert to an EvaluationDataset.")
        return EvaluationDataset(
            data=self._features.numpy(),
            targets=self._targets.numpy() if self._targets is not None else None,
            path=path,
            feature_names=feature_names,
            name=self.name,
            digest=self.digest,
        )  
[docs]def from_tensorflow(
    features,
    source: Optional[Union[str, DatasetSource]] = None,
    targets=None,
    name: Optional[str] = None,
    digest: Optional[str] = None,
) -> TensorFlowDataset:
    """Constructs a TensorFlowDataset object from TensorFlow data, optional targets, and source.
    If the source is path like, then this will construct a DatasetSource object from the source
    path. Otherwise, the source is assumed to be a DatasetSource object.
    Args:
        features: A TensorFlow dataset or tensor of features.
        source: The source from which the data was derived, e.g. a filesystem
            path, an S3 URI, an HTTPS URL, a delta table name with version, or
            spark table etc. If source is not a path like string,
            pass in a DatasetSource object directly. If no source is specified,
            a CodeDatasetSource is used, which will source information from the run
            context.
        targets: A TensorFlow dataset or tensor of targets. Optional.
        name: The name of the dataset. If unspecified, a name is generated.
        digest: A dataset digest (hash). If unspecified, a digest is computed
            automatically.
    """
    from mlflow.data.code_dataset_source import CodeDatasetSource
    from mlflow.data.dataset_source_registry import resolve_dataset_source
    from mlflow.tracking.context import registry
    if source is not None:
        if isinstance(source, DatasetSource):
            resolved_source = source
        else:
            resolved_source = resolve_dataset_source(
                source,
            )
    else:
        context_tags = registry.resolve_tags()
        resolved_source = CodeDatasetSource(tags=context_tags)
    return TensorFlowDataset(
        features=features, source=resolved_source, targets=targets, name=name, digest=digest
    )