Source code for mlflow.utils.environment

import yaml
import os
import logging


from mlflow.utils import PYTHON_VERSION
from mlflow.utils.requirements_utils import _parse_requirements, _infer_requirements
from packaging.requirements import Requirement, InvalidRequirement

_logger = logging.getLogger(__name__)

_conda_header = """\
name: mlflow-env
channels:
  - conda-forge
"""

_CONDA_ENV_FILE_NAME = "conda.yaml"
_REQUIREMENTS_FILE_NAME = "requirements.txt"
_CONSTRAINTS_FILE_NAME = "constraints.txt"


def _mlflow_conda_env(
    path=None,
    additional_conda_deps=None,
    additional_pip_deps=None,
    additional_conda_channels=None,
    install_mlflow=True,
):
    """
    Creates a Conda environment with the specified package channels and dependencies. If there are
    any pip dependencies, including from the install_mlflow parameter, then pip will be added to
    the conda dependencies. This is done to ensure that the pip inside the conda environment is
    used to install the pip dependencies.

    :param path: Local filesystem path where the conda env file is to be written. If unspecified,
                 the conda env will not be written to the filesystem; it will still be returned
                 in dictionary format.
    :param additional_conda_deps: List of additional conda dependencies passed as strings.
    :param additional_pip_deps: List of additional pip dependencies passed as strings.
    :param additional_conda_channels: List of additional conda channels to search when resolving
                                      packages.
    :return: ``None`` if ``path`` is specified. Otherwise, the a dictionary representation of the
             Conda environment.
    """
    pip_deps = (["mlflow"] if install_mlflow else []) + (
        additional_pip_deps if additional_pip_deps else []
    )
    conda_deps = (additional_conda_deps if additional_conda_deps else []) + (
        ["pip"] if pip_deps else []
    )

    env = yaml.safe_load(_conda_header)
    env["dependencies"] = ["python={}".format(PYTHON_VERSION)]
    if conda_deps is not None:
        env["dependencies"] += conda_deps
    env["dependencies"].append({"pip": pip_deps})
    if additional_conda_channels is not None:
        env["channels"] += additional_conda_channels

    if path is not None:
        with open(path, "w") as out:
            yaml.safe_dump(env, stream=out, default_flow_style=False)
        return None
    else:
        return env


def _mlflow_additional_pip_env(pip_deps, path=None):
    requirements = "\n".join(pip_deps)
    if path is not None:
        with open(path, "w") as out:
            out.write(requirements)
        return None
    else:
        return requirements


def _is_pip_deps(dep):
    """
    Returns True if `dep` is a dict representing pip dependencies
    """
    return isinstance(dep, dict) and "pip" in dep


def _get_pip_deps(conda_env):
    """
    :return: The pip dependencies from the conda env
    """
    if conda_env is not None:
        for dep in conda_env["dependencies"]:
            if _is_pip_deps(dep):
                return dep["pip"]
    return []


def _overwrite_pip_deps(conda_env, new_pip_deps):
    """
    Overwrites the pip dependencies section in the given conda env dictionary.

    {
        "name": "env",
        "channels": [...],
        "dependencies": [
            ...,
            "pip",
            {"pip": [...]},  <- Overwrite this
        ],
    }
    """
    deps = conda_env.get("dependencies", [])
    new_deps = []
    contains_pip_deps = False
    for dep in deps:
        if _is_pip_deps(dep):
            contains_pip_deps = True
            new_deps.append({"pip": new_pip_deps})
        else:
            new_deps.append(dep)

    if not contains_pip_deps:
        new_deps.append({"pip": new_pip_deps})

    return {**conda_env, "dependencies": new_deps}


def _log_pip_requirements(conda_env, path, requirements_file=_REQUIREMENTS_FILE_NAME):
    pip_deps = _get_pip_deps(conda_env)
    _mlflow_additional_pip_env(pip_deps, path=os.path.join(path, requirements_file))


def _parse_pip_requirements(pip_requirements):
    """
    Parses an iterable of pip requirement strings or a pip requirements file.

    :param pip_requirements: Either an iterable of pip requirement strings
        (e.g. ``["scikit-learn", "-r requirements.txt"]``) or the string path to a pip requirements
        file on the local filesystem (e.g. ``"requirements.txt"``). If ``None``, an empty list will
        be returned.
    :return: A tuple of parsed requirements and constraints.
    """
    if pip_requirements is None:
        return [], []

    def _is_string(x):
        return isinstance(x, str)

    def _is_iterable(x):
        try:
            iter(x)
            return True
        except Exception:
            return False

    if _is_string(pip_requirements):
        with open(pip_requirements) as f:
            return _parse_pip_requirements(f.read().splitlines())
    elif _is_iterable(pip_requirements) and all(map(_is_string, pip_requirements)):
        requirements = []
        constraints = []
        for req_or_con in _parse_requirements(pip_requirements, is_constraint=False):
            if req_or_con.is_constraint:
                constraints.append(req_or_con.req_str)
            else:
                requirements.append(req_or_con.req_str)

        return requirements, constraints
    else:
        raise TypeError(
            "`pip_requirements` must be either a string path to a pip requirements file on the "
            "local filesystem or an iterable of pip requirement strings, but got `{}`".format(
                type(pip_requirements)
            )
        )


_INFER_PIP_REQUIREMENTS_FALLBACK_MESSAGE = (
    "Encountered an unexpected error while inferring pip requirements (model URI: %s, flavor: %s),"
    " fall back to return %s. Set logging level to DEBUG to see the full traceback."
)


[docs]def infer_pip_requirements(model_uri, flavor, fallback=None): """ Infers the pip requirements of the specified model by creating a subprocess and loading the model in it to determine which packages are imported. :param model_uri: The URI of the model. :param flavor: The flavor name of the model. :param fallback: If provided, an unexpected error during the inference procedure is swallowed and the value of ``fallback`` is returned. Otherwise, the error is raised. :return: A list of inferred pip requirements (e.g. ``["scikit-learn==0.24.2", ...]``). """ try: return _infer_requirements(model_uri, flavor) except Exception: if fallback is not None: _logger.warning(_INFER_PIP_REQUIREMENTS_FALLBACK_MESSAGE, model_uri, flavor, fallback) _logger.debug("", exc_info=True) return fallback raise
def _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements): """ Validates that only one or none of `conda_env`, `pip_requirements`, and `extra_pip_requirements` is specified. """ args = [ conda_env, pip_requirements, extra_pip_requirements, ] specified = [arg for arg in args if arg is not None] if len(specified) > 1: raise ValueError( "Only one of `conda_env`, `pip_requirements`, and " "`extra_pip_requirements` can be specified" ) # PIP requirement parser inspired from https://github.com/pypa/pip/blob/b392833a0f1cff1bbee1ac6dbe0270cccdd0c11f/src/pip/_internal/req/req_file.py#L400 def _get_pip_requirement_specifier(requirement_string): tokens = requirement_string.split(" ") for idx, token in enumerate(tokens): if token.startswith("-"): return " ".join(tokens[:idx]) return requirement_string def _is_mlflow_requirement(requirement_string): """ Returns True if `requirement_string` represents a requirement for mlflow (e.g. 'mlflow==1.2.3'). """ try: # `Requirement` throws an `InvalidRequirement` exception if `requirement_string` doesn't # conform to PEP 508 (https://www.python.org/dev/peps/pep-0508). return Requirement(requirement_string).name.lower() == "mlflow" except InvalidRequirement: # A local file path or URL falls into this branch. # `Requirement` throws an `InvalidRequirement` exception if `requirement_string` contains # per-requirement options (ex: package hashes) # GitHub issue: https://github.com/pypa/packaging/issues/488 # Per-requirement-option spec: https://pip.pypa.io/en/stable/reference/requirements-file-format/#per-requirement-options requirement_specifier = _get_pip_requirement_specifier(requirement_string) try: # Try again with the per-requirement options removed return Requirement(requirement_specifier).name.lower() == "mlflow" except InvalidRequirement: return False # TODO: Return True if `requirement_string` represents a project directory for MLflow # (e.g. '/path/to/mlflow') or git repository URL (e.g. 'https://github.com/mlflow/mlflow'). return False def _contains_mlflow_requirement(requirements): """ Returns True if `requirements` contains a requirement for mlflow (e.g. 'mlflow==1.2.3'). """ return any(map(_is_mlflow_requirement, requirements)) def _process_pip_requirements( default_pip_requirements, pip_requirements=None, extra_pip_requirements=None ): """ Processes `pip_requirements` and `extra_pip_requirements` passed to `mlflow.*.save_model` or `mlflow.*.log_model`, and returns a tuple of (conda_env, pip_requirements, pip_constraints). """ constraints = [] if pip_requirements is not None: pip_reqs, constraints = _parse_pip_requirements(pip_requirements) elif extra_pip_requirements is not None: extra_pip_requirements, constraints = _parse_pip_requirements(extra_pip_requirements) pip_reqs = default_pip_requirements + extra_pip_requirements else: pip_reqs = default_pip_requirements if not _contains_mlflow_requirement(pip_reqs): pip_reqs.insert(0, "mlflow") if constraints: pip_reqs.append(f"-c {_CONSTRAINTS_FILE_NAME}") # Set `install_mlflow` to False because `pip_reqs` already contains `mlflow` conda_env = _mlflow_conda_env(additional_pip_deps=pip_reqs, install_mlflow=False) return conda_env, pip_reqs, constraints def _process_conda_env(conda_env): """ Processes `conda_env` passed to `mlflow.*.save_model` or `mlflow.*.log_model`, and returns a tuple of (conda_env, pip_requirements, pip_constraints). """ if isinstance(conda_env, str): with open(conda_env, "r") as f: conda_env = yaml.safe_load(f) elif not isinstance(conda_env, dict): raise TypeError( "Expected a string path to a conda env yaml file or a `dict` representing a conda env, " "but got `{}`".format(type(conda_env).__name__) ) # User-specified `conda_env` may contain requirements/constraints file references pip_reqs = _get_pip_deps(conda_env) pip_reqs, constraints = _parse_pip_requirements(pip_reqs) if not _contains_mlflow_requirement(pip_reqs): pip_reqs.insert(0, "mlflow") if constraints: pip_reqs.append(f"-c {_CONSTRAINTS_FILE_NAME}") conda_env = _overwrite_pip_deps(conda_env, pip_reqs) return conda_env, pip_reqs, constraints