import sys
from contextlib import suppress
from typing import Union
from mlflow.data import dataset_registry
from mlflow.data import sources as mlflow_data_sources
from mlflow.data.dataset import Dataset
from mlflow.data.dataset_source import DatasetSource
from mlflow.data.dataset_source_registry import get_dataset_source_from_json, get_registered_sources
from mlflow.entities import Dataset as DatasetEntity
from mlflow.entities import DatasetInput
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
with suppress(ImportError):
    # Suppressing ImportError to pass mlflow-skinny testing.
    from mlflow.data import meta_dataset  # noqa: F401
[docs]def get_source(dataset: Union[DatasetEntity, DatasetInput, Dataset]) -> DatasetSource:
    """Obtains the source of the specified dataset or dataset input.
    Args:
        dataset:
            An instance of :py:class:`mlflow.data.dataset.Dataset <mlflow.data.dataset.Dataset>`,
            :py:class:`mlflow.entities.Dataset`, or :py:class:`mlflow.entities.DatasetInput`.
    Returns:
        An instance of :py:class:`DatasetSource <mlflow.data.dataset_source.DatasetSource>`.
    """
    if isinstance(dataset, DatasetInput):
        dataset: DatasetEntity = dataset.dataset
    if isinstance(dataset, DatasetEntity):
        dataset_source: DatasetSource = get_dataset_source_from_json(
            source_json=dataset.source,
            source_type=dataset.source_type,
        )
    elif isinstance(dataset, Dataset):
        dataset_source: DatasetSource = dataset.source
    else:
        raise MlflowException(
            f"Unrecognized dataset type {type(dataset)}. Expected one of: "
            f"`mlflow.data.dataset.Dataset`,"
            f" `mlflow.entities.Dataset`, `mlflow.entities.DatasetInput`.",
            INVALID_PARAMETER_VALUE,
        )
    return dataset_source 
__all__ = ["get_source"]
def _define_dataset_constructors_in_current_module():
    data_module = sys.modules[__name__]
    for (
        constructor_name,
        constructor_fn,
    ) in dataset_registry.get_registered_constructors().items():
        setattr(data_module, constructor_name, constructor_fn)
        __all__.append(constructor_name)
_define_dataset_constructors_in_current_module()
def _define_dataset_sources_in_sources_module():
    for source in get_registered_sources():
        setattr(mlflow_data_sources, source.__name__, source)
        mlflow_data_sources.__all__.append(source.__name__)
_define_dataset_sources_in_sources_module()