Source code for mlflow.xgboost

"""
The ``mlflow.xgboost`` module provides an API for logging and loading XGBoost models.
This module exports XGBoost models with the following flavors:

XGBoost (native) format
    This is the main flavor that can be loaded back into XGBoost.
:py:mod:`mlflow.pyfunc`
    Produced for use by generic pyfunc-based deployment tools and batch inference.

.. _xgboost.Booster:
    https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster
.. _xgboost.Booster.save_model:
    https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.save_model
.. _xgboost.train:
    https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.train
.. _scikit-learn API:
    https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn
"""
import os
import shutil
import json
import yaml
import tempfile
import logging
from copy import deepcopy

import mlflow
from mlflow import pyfunc
from mlflow.models import Model, ModelInputExample, infer_signature
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.signature import ModelSignature
from mlflow.models.utils import _save_example
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils import _get_fully_qualified_class_name
from mlflow.utils.environment import (
    _mlflow_conda_env,
    _validate_env_arguments,
    _process_pip_requirements,
    _process_conda_env,
    _CONDA_ENV_FILE_NAME,
    _REQUIREMENTS_FILE_NAME,
    _CONSTRAINTS_FILE_NAME,
    _PYTHON_ENV_FILE_NAME,
    _PythonEnv,
)
from mlflow.utils.class_utils import _get_class_from_string
from mlflow.utils.requirements_utils import _get_pinned_requirement
from mlflow.utils.file_utils import write_to
from mlflow.utils.model_utils import (
    _get_flavor_configuration,
    _validate_and_copy_code_paths,
    _add_code_from_conf_to_system_path,
    _validate_and_prepare_target_save_path,
)
from mlflow.utils.docstring_utils import format_docstring, LOG_MODEL_PARAM_DOCS
from mlflow.utils.arguments_utils import _get_arg_names
from mlflow.utils.autologging_utils import (
    autologging_integration,
    safe_patch,
    picklable_exception_safe_function,
    get_mlflow_run_params_for_fn_args,
    INPUT_EXAMPLE_SAMPLE_ROWS,
    resolve_input_example_and_signature,
    InputExampleInfo,
    ENSURE_AUTOLOGGING_ENABLED_TEXT,
    batch_metrics_logger,
    MlflowAutologgingQueueingClient,
    get_autologging_config,
)

from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS

FLAVOR_NAME = "xgboost"

_logger = logging.getLogger(__name__)


[docs]def get_default_pip_requirements(): """ :return: A list of default pip requirements for MLflow Models produced by this flavor. Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment that, at minimum, contains these requirements. """ return [_get_pinned_requirement("xgboost")]
[docs]def get_default_conda_env(): """ :return: The default Conda environment for MLflow Models produced by calls to :func:`save_model()` and :func:`log_model()`. """ return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
[docs]@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) def save_model( xgb_model, path, conda_env=None, code_paths=None, mlflow_model=None, signature: ModelSignature = None, input_example: ModelInputExample = None, pip_requirements=None, extra_pip_requirements=None, ): """ Save an XGBoost model to a path on the local file system. :param xgb_model: XGBoost model (an instance of `xgboost.Booster`_ or models that implement the `scikit-learn API`_) to be saved. :param path: Local path where the model is to be saved. :param conda_env: {{ conda_env }} :param code_paths: A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are *prepended* to the system path when the model is loaded. :param mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to. :param signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>` describes model input and output :py:class:`Schema <mlflow.types.Schema>`. The model signature can be :py:func:`inferred <mlflow.models.infer_signature>` from datasets with valid model input (e.g. the training dataset with target column omitted) and valid model output (e.g. model predictions generated on the training dataset), for example: .. code-block:: python from mlflow.models.signature import infer_signature train = df.drop_column("target_label") predictions = ... # compute model predictions signature = infer_signature(train, predictions) :param input_example: Input example provides one or several instances of valid model input. The example can be used as a hint of what data to feed the model. The given example will be converted to a Pandas DataFrame and then serialized to json using the Pandas split-oriented format. Bytes are base64-encoded. :param pip_requirements: {{ pip_requirements }} :param extra_pip_requirements: {{ extra_pip_requirements }} """ import xgboost as xgb _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements) path = os.path.abspath(path) _validate_and_prepare_target_save_path(path) code_dir_subpath = _validate_and_copy_code_paths(code_paths, path) if mlflow_model is None: mlflow_model = Model() if signature is not None: mlflow_model.signature = signature if input_example is not None: _save_example(mlflow_model, input_example, path) model_data_subpath = "model.xgb" model_data_path = os.path.join(path, model_data_subpath) # Save an XGBoost model xgb_model.save_model(model_data_path) xgb_model_class = _get_fully_qualified_class_name(xgb_model) pyfunc.add_to_model( mlflow_model, loader_module="mlflow.xgboost", data=model_data_subpath, env=_CONDA_ENV_FILE_NAME, code=code_dir_subpath, ) mlflow_model.add_flavor( FLAVOR_NAME, xgb_version=xgb.__version__, data=model_data_subpath, model_class=xgb_model_class, code=code_dir_subpath, ) mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) if conda_env is None: if pip_requirements is None: default_reqs = get_default_pip_requirements() # To ensure `_load_pyfunc` can successfully load the model during the dependency # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file. inferred_reqs = mlflow.models.infer_pip_requirements( path, FLAVOR_NAME, fallback=default_reqs, ) default_reqs = sorted(set(inferred_reqs).union(default_reqs)) else: default_reqs = None conda_env, pip_requirements, pip_constraints = _process_pip_requirements( default_reqs, pip_requirements, extra_pip_requirements, ) else: conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f: yaml.safe_dump(conda_env, stream=f, default_flow_style=False) # Save `constraints.txt` if necessary if pip_constraints: write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) # Save `requirements.txt` write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
[docs]@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) def log_model( xgb_model, artifact_path, conda_env=None, code_paths=None, registered_model_name=None, signature: ModelSignature = None, input_example: ModelInputExample = None, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, pip_requirements=None, extra_pip_requirements=None, **kwargs, ): """ Log an XGBoost model as an MLflow artifact for the current run. :param xgb_model: XGBoost model (an instance of `xgboost.Booster`_ or models that implement the `scikit-learn API`_) to be saved. :param artifact_path: Run-relative artifact path. :param conda_env: {{ conda_env }} :param code_paths: A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are *prepended* to the system path when the model is loaded. :param registered_model_name: If given, create a model version under ``registered_model_name``, also creating a registered model if one with the given name does not exist. :param signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>` describes model input and output :py:class:`Schema <mlflow.types.Schema>`. The model signature can be :py:func:`inferred <mlflow.models.infer_signature>` from datasets with valid model input (e.g. the training dataset with target column omitted) and valid model output (e.g. model predictions generated on the training dataset), for example: .. code-block:: python from mlflow.models.signature import infer_signature train = df.drop_column("target_label") predictions = ... # compute model predictions signature = infer_signature(train, predictions) :param input_example: Input example provides one or several instances of valid model input. The example can be used as a hint of what data to feed the model. The given example will be converted to a Pandas DataFrame and then serialized to json using the Pandas split-oriented format. Bytes are base64-encoded. :param await_registration_for: Number of seconds to wait for the model version to finish being created and is in ``READY`` status. By default, the function waits for five minutes. Specify 0 or None to skip waiting. :param pip_requirements: {{ pip_requirements }} :param extra_pip_requirements: {{ extra_pip_requirements }} :param kwargs: kwargs to pass to `xgboost.Booster.save_model`_ method. :return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the metadata of the logged model. """ return Model.log( artifact_path=artifact_path, flavor=mlflow.xgboost, registered_model_name=registered_model_name, xgb_model=xgb_model, conda_env=conda_env, code_paths=code_paths, signature=signature, input_example=input_example, await_registration_for=await_registration_for, pip_requirements=pip_requirements, extra_pip_requirements=extra_pip_requirements, **kwargs, )
def _load_model(path): """ Load Model Implementation. :param path: Local filesystem path to the MLflow Model with the ``xgboost`` flavor (MLflow < 1.22.0) or the top-level MLflow Model directory (MLflow >= 1.22.0). """ model_dir = os.path.dirname(path) if os.path.isfile(path) else path flavor_conf = _get_flavor_configuration(model_path=model_dir, flavor_name=FLAVOR_NAME) # XGBoost models saved in MLflow >=1.22.0 have `model_class` # in the XGBoost flavor configuration to specify its XGBoost model class. # When loading models, we first get the XGBoost model from # its flavor configuration and then create an instance based on its class. model_class = flavor_conf.get("model_class", "xgboost.core.Booster") xgb_model_path = os.path.join(model_dir, flavor_conf.get("data")) model = _get_class_from_string(model_class)() model.load_model(xgb_model_path) return model def _load_pyfunc(path): """ Load PyFunc implementation. Called by ``pyfunc.load_model``. :param path: Local filesystem path to the MLflow Model with the ``xgboost`` flavor. """ return _XGBModelWrapper(_load_model(path))
[docs]def load_model(model_uri, dst_path=None): """ Load an XGBoost model from a local file or a run. :param model_uri: The location, in URI format, of the MLflow model. For example: - ``/Users/me/path/to/local/model`` - ``relative/path/to/local/model`` - ``s3://my_bucket/path/to/model`` - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` For more information about supported URI schemes, see `Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html# artifact-locations>`_. :param dst_path: The local filesystem path to which to download the model artifact. This directory must already exist. If unspecified, a local output path will be created. :return: An XGBoost model. An instance of either `xgboost.Booster`_ or XGBoost scikit-learn models, depending on the saved model class specification. """ local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path) flavor_conf = _get_flavor_configuration(local_model_path, FLAVOR_NAME) _add_code_from_conf_to_system_path(local_model_path, flavor_conf) return _load_model(path=local_model_path)
class _XGBModelWrapper: def __init__(self, xgb_model): self.xgb_model = xgb_model def predict(self, dataframe): import xgboost as xgb if isinstance(self.xgb_model, xgb.Booster): return self.xgb_model.predict(xgb.DMatrix(dataframe)) else: return self.xgb_model.predict(dataframe)
[docs]@autologging_integration(FLAVOR_NAME) def autolog( importance_types=None, log_input_examples=False, log_model_signatures=True, log_models=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None, ): # pylint: disable=W0102,unused-argument """ Enables (or disables) and configures autologging from XGBoost to MLflow. Logs the following: - parameters specified in `xgboost.train`_. - metrics on each iteration (if ``evals`` specified). - metrics at the best iteration (if ``early_stopping_rounds`` specified). - feature importance as JSON files and plots. - trained model, including: - an example of valid input. - inferred signature of the inputs and outputs of the model. Note that the `scikit-learn API`_ is now supported. :param importance_types: Importance types to log. If unspecified, defaults to ``["weight"]``. :param log_input_examples: If ``True``, input examples from training datasets are collected and logged along with XGBoost model artifacts during training. If ``False``, input examples are not logged. Note: Input examples are MLflow model attributes and are only collected if ``log_models`` is also ``True``. :param log_model_signatures: If ``True``, :py:class:`ModelSignatures <mlflow.models.ModelSignature>` describing model inputs and outputs are collected and logged along with XGBoost model artifacts during training. If ``False``, signatures are not logged. Note: Model signatures are MLflow model attributes and are only collected if ``log_models`` is also ``True``. :param log_models: If ``True``, trained models are logged as MLflow model artifacts. If ``False``, trained models are not logged. Input examples and model signatures, which are attributes of MLflow models, are also omitted when ``log_models`` is ``False``. :param disable: If ``True``, disables the XGBoost autologging integration. If ``False``, enables the XGBoost autologging integration. :param exclusive: If ``True``, autologged content is not logged to user-created fluent runs. If ``False``, autologged content is logged to the active fluent run, which may be user-created. :param disable_for_unsupported_versions: If ``True``, disable autologging for versions of xgboost that have not been tested against this version of the MLflow client or are incompatible. :param silent: If ``True``, suppress all event logs and warnings from MLflow during XGBoost autologging. If ``False``, show all events and warnings during XGBoost autologging. :param registered_model_name: If given, each time a model is trained, it is registered as a new model version of the registered model with this name. The registered model is created if it does not already exist. """ import functools import xgboost import numpy as np if importance_types is None: importance_types = ["weight"] # Patching this function so we can get a copy of the data given to DMatrix.__init__ # to use as an input example and for inferring the model signature. # (there is no way to get the data back from a DMatrix object) # We store it on the DMatrix object so the train function is able to read it. def __init__(original, self, *args, **kwargs): data = args[0] if len(args) > 0 else kwargs.get("data") if data is not None: try: if isinstance(data, str): raise Exception( "cannot gather example input when dataset is loaded from a file." ) input_example_info = InputExampleInfo( input_example=deepcopy(data[:INPUT_EXAMPLE_SAMPLE_ROWS]) ) except Exception as e: input_example_info = InputExampleInfo(error_msg=str(e)) setattr(self, "input_example_info", input_example_info) original(self, *args, **kwargs) def train(_log_models, original, *args, **kwargs): def record_eval_results(eval_results, metrics_logger): """ Create a callback function that records evaluation results. """ # TODO: Remove `replace("SNAPSHOT", "dev")` once the following issue is addressed: # https://github.com/dmlc/xgboost/issues/6984 from mlflow.xgboost._autolog import IS_TRAINING_CALLBACK_SUPPORTED if IS_TRAINING_CALLBACK_SUPPORTED: from mlflow.xgboost._autolog import AutologCallback # In xgboost >= 1.3.0, user-defined callbacks should inherit # `xgboost.callback.TrainingCallback`: # https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback return AutologCallback(metrics_logger, eval_results) else: from mlflow.xgboost._autolog import autolog_callback return picklable_exception_safe_function( functools.partial( autolog_callback, metrics_logger=metrics_logger, eval_results=eval_results ) ) def log_feature_importance_plot(features, importance, importance_type): """ Log feature importance plot. """ import matplotlib.pyplot as plt from cycler import cycler features = np.array(features) # Structure the supplied `importance` values as a `num_features`-by-`num_classes` matrix importances_per_class_by_feature = np.array(importance) if importances_per_class_by_feature.ndim <= 1: # In this case, the supplied `importance` values are not given per class. Rather, # one importance value is given per feature. For consistency with the assumed # `num_features`-by-`num_classes` matrix structure, we coerce the importance # values to a `num_features`-by-1 matrix indices = np.argsort(importance) # Sort features and importance values by magnitude during transformation to a # `num_features`-by-`num_classes` matrix features = features[indices] importances_per_class_by_feature = np.array( [[importance] for importance in importances_per_class_by_feature[indices]] ) # In this case, do not include class labels on the feature importance plot because # only one importance value has been provided per feature, rather than an # one importance value for each class per feature label_classes_on_plot = False else: importance_value_magnitudes = np.abs(importances_per_class_by_feature).sum(axis=1) indices = np.argsort(importance_value_magnitudes) features = features[indices] importances_per_class_by_feature = importances_per_class_by_feature[indices] label_classes_on_plot = True num_classes = importances_per_class_by_feature.shape[1] num_features = len(features) # If num_features > 10, increase the figure height to prevent the plot # from being too dense. w, h = [6.4, 4.8] # matplotlib's default figure size h = h + 0.1 * num_features if num_features > 10 else h h = h + 0.1 * num_classes if num_classes > 1 else h fig, ax = plt.subplots(figsize=(w, h)) # When importance values are provided for each class per feature, we want to ensure # that the same color is used for all bars in the bar chart that have the same class colors_to_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"][:num_classes] color_cycler = cycler(color=colors_to_cycle) ax.set_prop_cycle(color_cycler) # The following logic operates on one feature at a time, adding a bar to the bar chart # for each class that reflects the importance of the feature to predictions of that # class feature_ylocs = np.arange(num_features) # Define offsets on the y-axis that are used to evenly space the bars for each class # around the y-axis position of each feature offsets_per_yloc = np.linspace(-0.5, 0.5, num_classes) / 2 if num_classes > 1 else [0] for feature_idx, (feature_yloc, importances_per_class) in enumerate( zip(feature_ylocs, importances_per_class_by_feature) ): for class_idx, (offset, class_importance) in enumerate( zip(offsets_per_yloc, importances_per_class) ): (bar,) = ax.barh( feature_yloc + offset, class_importance, align="center", # Set the bar height such that importance value bars for a particular # feature are spaced properly relative to each other (no overlap or gaps) # and relative to importance value bars for other features height=(0.5 / max(num_classes - 1, 1)), ) if label_classes_on_plot and feature_idx == 0: # Only set a label the first time a bar for a particular class is plotted to # avoid duplicate legend entries. If we were to set a label for every bar, # the legend would contain `num_features` labels for each class. bar.set_label("Class {}".format(class_idx)) ax.set_yticks(feature_ylocs) ax.set_yticklabels(features) ax.set_xlabel("Importance") ax.set_title("Feature Importance ({})".format(importance_type)) if label_classes_on_plot: ax.legend() fig.tight_layout() tmpdir = tempfile.mkdtemp() try: # pylint: disable=undefined-loop-variable filepath = os.path.join(tmpdir, "feature_importance_{}.png".format(imp_type)) fig.savefig(filepath) mlflow.log_artifact(filepath) finally: plt.close(fig) shutil.rmtree(tmpdir) autologging_client = MlflowAutologgingQueueingClient() # logging booster params separately to extract key/value pairs and make it easier to # compare them across runs. booster_params = args[0] if len(args) > 0 else kwargs["params"] autologging_client.log_params(run_id=mlflow.active_run().info.run_id, params=booster_params) unlogged_params = [ "params", "dtrain", "evals", "obj", "feval", "evals_result", "xgb_model", "callbacks", "learning_rates", ] params_to_log_for_fn = get_mlflow_run_params_for_fn_args( original, args, kwargs, unlogged_params ) autologging_client.log_params( run_id=mlflow.active_run().info.run_id, params=params_to_log_for_fn ) param_logging_operations = autologging_client.flush(synchronous=False) all_arg_names = _get_arg_names(original) num_pos_args = len(args) # adding a callback that records evaluation results. eval_results = [] callbacks_index = all_arg_names.index("callbacks") run_id = mlflow.active_run().info.run_id with batch_metrics_logger(run_id) as metrics_logger: callback = record_eval_results(eval_results, metrics_logger) if num_pos_args >= callbacks_index + 1: tmp_list = list(args) tmp_list[callbacks_index] += [callback] args = tuple(tmp_list) elif "callbacks" in kwargs and kwargs["callbacks"] is not None: kwargs["callbacks"] += [callback] else: kwargs["callbacks"] = [callback] # training model model = original(*args, **kwargs) # If early_stopping_rounds is present, logging metrics at the best iteration # as extra metrics with the max step + 1. early_stopping_index = all_arg_names.index("early_stopping_rounds") early_stopping = num_pos_args >= early_stopping_index + 1 or kwargs.get( "early_stopping_rounds" ) if early_stopping: extra_step = len(eval_results) autologging_client.log_metrics( run_id=mlflow.active_run().info.run_id, metrics={ "stopped_iteration": extra_step - 1, "best_iteration": model.best_iteration, }, ) autologging_client.log_metrics( run_id=mlflow.active_run().info.run_id, metrics=eval_results[model.best_iteration], step=extra_step, ) early_stopping_logging_operations = autologging_client.flush(synchronous=False) # logging feature importance as artifacts. for imp_type in importance_types: imp = None try: imp = model.get_score(importance_type=imp_type) features, importance = zip(*imp.items()) log_feature_importance_plot(features, importance, imp_type) except Exception: _logger.exception( "Failed to log feature importance plot. XGBoost autologging " "will ignore the failure and continue. Exception: " ) if imp is not None: tmpdir = tempfile.mkdtemp() try: filepath = os.path.join(tmpdir, "feature_importance_{}.json".format(imp_type)) with open(filepath, "w") as f: json.dump(imp, f) mlflow.log_artifact(filepath) finally: shutil.rmtree(tmpdir) # dtrain must exist as the original train function already ran successfully dtrain = args[1] if len(args) > 1 else kwargs.get("dtrain") # it is possible that the dataset was constructed before the patched # constructor was applied, so we cannot assume the input_example_info exists input_example_info = getattr(dtrain, "input_example_info", None) def get_input_example(): if input_example_info is None: raise Exception(ENSURE_AUTOLOGGING_ENABLED_TEXT) if input_example_info.error_msg is not None: raise Exception(input_example_info.error_msg) return input_example_info.input_example def infer_model_signature(input_example): model_output = model.predict(xgboost.DMatrix(input_example)) model_signature = infer_signature(input_example, model_output) return model_signature # Only log the model if the autolog() param log_models is set to True. if _log_models: # Will only resolve `input_example` and `signature` if `log_models` is `True`. input_example, signature = resolve_input_example_and_signature( get_input_example, infer_model_signature, log_input_examples, log_model_signatures, _logger, ) registered_model_name = get_autologging_config( FLAVOR_NAME, "registered_model_name", None ) log_model( model, artifact_path="model", signature=signature, input_example=input_example, registered_model_name=registered_model_name, ) param_logging_operations.await_completion() if early_stopping: early_stopping_logging_operations.await_completion() return model safe_patch(FLAVOR_NAME, xgboost, "train", functools.partial(train, log_models), manage_run=True) # The `train()` method logs XGBoost models as Booster objects. When using XGBoost # scikit-learn models, we want to save / log models as their model classes. So we turn # off the log_models functionality in the `train()` method patched to `xgboost.sklearn`. # Instead the model logging is handled in `fit_mlflow_sklearn()` in `mlflow.sklearn._autolog()`, # where models are logged as XGBoost scikit-learn models after the `fit()` method returns. safe_patch( FLAVOR_NAME, xgboost.sklearn, "train", functools.partial(train, False), manage_run=True ) safe_patch(FLAVOR_NAME, xgboost.DMatrix, "__init__", __init__) # enable xgboost scikit-learn estimators autologging import mlflow.sklearn mlflow.sklearn._autolog( flavor_name=FLAVOR_NAME, log_input_examples=log_input_examples, log_model_signatures=log_model_signatures, log_models=log_models, disable=disable, exclusive=exclusive, disable_for_unsupported_versions=disable_for_unsupported_versions, silent=silent, max_tuning_runs=None, log_post_training_metrics=True, )