"""
APIs for interacting with artifacts in MLflow
"""
import json
import pathlib
import posixpath
import tempfile
from typing import Any, Optional
from mlflow.entities.file_info import FileInfo
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE
from mlflow.tracking import _get_store
from mlflow.tracking.artifact_utils import (
    _download_artifact_from_uri,
    _get_root_uri_and_artifact_path,
    add_databricks_profile_info_to_artifact_uri,
    get_artifact_repository,
)
[docs]def download_artifacts(
    artifact_uri: Optional[str] = None,
    run_id: Optional[str] = None,
    artifact_path: Optional[str] = None,
    dst_path: Optional[str] = None,
    tracking_uri: Optional[str] = None,
) -> str:
    """Download an artifact file or directory to a local directory.
    Args:
        artifact_uri: URI pointing to the artifacts. Supported formats include:
            * ``runs:/<run_id>/<artifact_path>``
              Example: ``runs:/500cf58bee2b40a4a82861cc31a617b1/my_model.pkl``
            * ``models:/<model_name>/<stage>``
              Example: ``models:/my_model/Production``
            * ``models:/<model_name>/<version>/path/to/model``
              Example: ``models:/my_model/2/path/to/model``
            * ``models:/<model_name>@<alias>/path/to/model``
              Example: ``models:/my_model@staging/path/to/model``
            * Cloud storage URIs: ``s3://<bucket>/<path>`` or ``gs://<bucket>/<path>``
            * Tracking server artifact URIs: ``http://<host>/mlartifacts`` or
              ``mlflow-artifacts://<host>/mlartifacts``
            Exactly one of ``artifact_uri`` or ``run_id`` must be specified.
        run_id: ID of the MLflow Run containing the artifacts. Exactly one of ``run_id`` or
            ``artifact_uri`` must be specified.
        artifact_path: (For use with ``run_id``) If specified, a path relative to the MLflow
            Run's root directory containing the artifacts to download.
        dst_path: Path of the local filesystem destination directory to which to download the
            specified artifacts. If the directory does not exist, it is created. If
            unspecified, the artifacts are downloaded to a new uniquely-named directory on
            the local filesystem, unless the artifacts already exist on the local
            filesystem, in which case their local path is returned directly.
        tracking_uri: The tracking URI to be used when downloading artifacts.
    Returns:
        The location of the artifact file or directory on the local filesystem.
    """
    if (run_id, artifact_uri).count(None) != 1:
        raise MlflowException(
            message="Exactly one of `run_id` or `artifact_uri` must be specified",
            error_code=INVALID_PARAMETER_VALUE,
        )
    elif artifact_uri is not None and artifact_path is not None:
        raise MlflowException(
            message="`artifact_path` cannot be specified if `artifact_uri` is specified",
            error_code=INVALID_PARAMETER_VALUE,
        )
    if dst_path is not None:
        pathlib.Path(dst_path).mkdir(exist_ok=True, parents=True)
    if artifact_uri is not None:
        return _download_artifact_from_uri(
            artifact_uri, output_path=dst_path, tracking_uri=tracking_uri
        )
    # Use `runs:/<run_id>/<artifact_path>` to download both run and model (if exists) artifacts
    if run_id and artifact_path:
        return _download_artifact_from_uri(
            f"runs:/{posixpath.join(run_id, artifact_path)}",
            output_path=dst_path,
            tracking_uri=tracking_uri,
        )
    artifact_path = artifact_path if artifact_path is not None else ""
    store = _get_store(store_uri=tracking_uri)
    artifact_uri = store.get_run(run_id).info.artifact_uri
    artifact_repo = get_artifact_repository(
        add_databricks_profile_info_to_artifact_uri(artifact_uri, tracking_uri),
        tracking_uri=tracking_uri,
    )
    return artifact_repo.download_artifacts(artifact_path, dst_path=dst_path) 
[docs]def list_artifacts(
    artifact_uri: Optional[str] = None,
    run_id: Optional[str] = None,
    artifact_path: Optional[str] = None,
    tracking_uri: Optional[str] = None,
) -> list[FileInfo]:
    """List artifacts at the specified URI.
    Args:
        artifact_uri: URI pointing to the artifacts, such as
            ``"runs:/500cf58bee2b40a4a82861cc31a617b1/my_model.pkl"``,
            ``"models:/my_model/Production"``, or ``"s3://my_bucket/my/file.txt"``.
            Exactly one of ``artifact_uri`` or ``run_id`` must be specified.
        run_id: ID of the MLflow Run containing the artifacts. Exactly one of ``run_id`` or
            ``artifact_uri`` must be specified.
        artifact_path: (For use with ``run_id``) If specified, a path relative to the MLflow
            Run's root directory containing the artifacts to list.
        tracking_uri: The tracking URI to be used when list artifacts.
    Returns:
        List of artifacts as FileInfo listed directly under path.
    """
    if (run_id, artifact_uri).count(None) != 1:
        raise MlflowException.invalid_parameter_value(
            message="Exactly one of `run_id` or `artifact_uri` must be specified",
        )
    elif artifact_uri is not None and artifact_path is not None:
        raise MlflowException.invalid_parameter_value(
            message="`artifact_path` cannot be specified if `artifact_uri` is specified",
        )
    if artifact_uri is not None:
        root_uri, artifact_path = _get_root_uri_and_artifact_path(artifact_uri)
        return get_artifact_repository(
            artifact_uri=root_uri, tracking_uri=tracking_uri
        ).list_artifacts(artifact_path)
    # Use `runs:/<run_id>/<artifact_path>` to list both run and model (if exists) artifacts
    if run_id and artifact_path:
        return get_artifact_repository(
            artifact_uri=f"runs:/{run_id}", tracking_uri=tracking_uri
        ).list_artifacts(artifact_path)
    store = _get_store(store_uri=tracking_uri)
    artifact_uri = store.get_run(run_id).info.artifact_uri
    artifact_repo = get_artifact_repository(
        add_databricks_profile_info_to_artifact_uri(artifact_uri, tracking_uri),
        tracking_uri=tracking_uri,
    )
    return artifact_repo.list_artifacts(artifact_path) 
[docs]def load_text(artifact_uri: str) -> str:
    """Loads the artifact contents as a string.
    Args:
        artifact_uri: Artifact location.
    Returns:
        The contents of the artifact as a string.
    .. code-block:: python
        :caption: Example
        import mlflow
        with mlflow.start_run() as run:
            artifact_uri = run.info.artifact_uri
            mlflow.log_text("This is a sentence", "file.txt")
            file_content = mlflow.artifacts.load_text(artifact_uri + "/file.txt")
            print(file_content)
    .. code-block:: text
        :caption: Output
        This is a sentence
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        local_artifact = download_artifacts(artifact_uri, dst_path=tmpdir)
        with open(local_artifact) as local_artifact_fd:
            try:
                return str(local_artifact_fd.read())
            except Exception:
                raise MlflowException("Unable to form a str object from file content", BAD_REQUEST) 
[docs]def load_dict(artifact_uri: str) -> dict[str, Any]:
    """Loads the artifact contents as a dictionary.
    Args:
        artifact_uri: artifact location.
    Returns:
        A dictionary.
    .. code-block:: python
      :caption: Example
      import mlflow
      with mlflow.start_run() as run:
          artifact_uri = run.info.artifact_uri
          mlflow.log_dict({"mlflow-version": "0.28", "n_cores": "10"}, "config.json")
          config_json = mlflow.artifacts.load_dict(artifact_uri + "/config.json")
          print(config_json)
    .. code-block:: text
      :caption: Output
      {'mlflow-version': '0.28', 'n_cores': '10'}
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        local_artifact = download_artifacts(artifact_uri, dst_path=tmpdir)
        with open(local_artifact) as local_artifact_fd:
            try:
                return json.load(local_artifact_fd)
            except json.JSONDecodeError:
                raise MlflowException("Unable to form a JSON object from file content", BAD_REQUEST) 
[docs]def load_image(artifact_uri: str):
    """Loads artifact contents as a ``PIL.Image.Image`` object
    Args:
        artifact_uri: Artifact location.
    Returns:
        A PIL.Image object.
    .. code-block:: python
        :caption: Example
        import mlflow
        from PIL import Image
        with mlflow.start_run() as run:
            image = Image.new("RGB", (100, 100))
            artifact_uri = run.info.artifact_uri
            mlflow.log_image(image, "image.png")
            image = mlflow.artifacts.load_image(artifact_uri + "/image.png")
            print(image)
    .. code-block:: text
        :caption: Output
        <PIL.PngImagePlugin.PngImageFile image mode=RGB size=100x100 at 0x11D2FA3D0>
    """
    try:
        from PIL import Image
    except ImportError as exc:
        raise ImportError(
            "`load_image` requires Pillow. Please install it via: pip install Pillow"
        ) from exc
    with tempfile.TemporaryDirectory() as tmpdir:
        local_artifact = download_artifacts(artifact_uri, dst_path=tmpdir)
        try:
            image_obj = Image.open(local_artifact)
            image_obj.load()
            return image_obj
        except Exception:
            raise MlflowException(
                "Unable to form a PIL Image object from file content", BAD_REQUEST
            )