Source code for mlflow.shap

from contextlib import contextmanager
import os
import tempfile

import numpy as np

import mlflow
from mlflow.utils.annotations import experimental
from mlflow.utils.uri import append_to_uri_path


_MAXIMUM_BACKGROUND_DATA_SIZE = 100
_DEFAULT_ARTIFACT_PATH = "model_explanations_shap"
_SUMMARY_BAR_PLOT_FILE_NAME = "summary_bar_plot.png"
_BASE_VALUES_FILE_NAME = "base_values.npy"
_SHAP_VALUES_FILE_NAME = "shap_values.npy"


@contextmanager
def _log_artifact_contextmanager(out_file, artifact_path=None):
    """
    A context manager to make it easier to log an artifact.
    """
    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_path = os.path.join(tmp_dir, out_file)
        yield tmp_path
        mlflow.log_artifact(tmp_path, artifact_path)


def _log_numpy(numpy_obj, out_file, artifact_path=None):
    """
    Log a numpy object.
    """
    with _log_artifact_contextmanager(out_file, artifact_path) as tmp_path:
        np.save(tmp_path, numpy_obj)


def _log_matplotlib_figure(fig, out_file, artifact_path=None):
    """
    Log a matplotlib figure.
    """
    with _log_artifact_contextmanager(out_file, artifact_path) as tmp_path:
        fig.savefig(tmp_path)


[docs]@experimental def log_explanation(predict_function, features, artifact_path=None): r""" Given a ``predict_function`` capable of computing ML model output on the provided ``features``, computes and logs explanations of an ML model's output. Explanations are logged as a directory of artifacts containing the following items generated by `SHAP`_ (SHapley Additive exPlanations). - Base values - SHAP values (computed using `shap.KernelExplainer`_) - Summary bar plot (shows the average impact of each feature on model output) :param predict_function: A function to compute the output of a model (e.g. ``predict_proba`` method of scikit-learn classifiers). Must have the following signature: .. code-block:: python def predict_function(X) -> pred: ... - ``X``: An array-like object whose shape should be (# samples, # features). - ``pred``: An array-like object whose shape should be (# samples) for a regressor or (# classes, # samples) for a classifier. For a classifier, the values in ``pred`` should correspond to the predicted probability of each class. Acceptable array-like object types: - ``numpy.array`` - ``pandas.DataFrame`` - ``shap.common.DenseData`` - ``scipy.sparse matrix`` :param features: A matrix of features to compute SHAP values with. The provided features should have shape (# samples, # features), and can be either of the array-like object types listed above. .. note:: Background data for `shap.KernelExplainer`_ is generated by subsampling ``features`` with `shap.kmeans`_. The background data size is limited to 100 rows for performance reasons. :param artifact_path: The run-relative artifact path to which the explanation is saved. If unspecified, defaults to "model_explanations_shap". :return: Artifact URI of the logged explanations. .. _SHAP: https://github.com/slundberg/shap .. _shap.KernelExplainer: https://shap.readthedocs.io/en/latest/generated /shap.KernelExplainer.html#shap.KernelExplainer .. _shap.kmeans: https://github.com/slundberg/shap/blob/v0.36.0/shap/utils/_legacy.py#L9 .. code-block:: python :caption: Example import os import numpy as np import pandas as pd from sklearn.datasets import load_boston from sklearn.linear_model import LinearRegression import mlflow # prepare training data dataset = load_boston() X = pd.DataFrame(dataset.data[:50, :8], columns=dataset.feature_names[:8]) y = dataset.target[:50] # train a model model = LinearRegression() model.fit(X, y) # log an explanation with mlflow.start_run() as run: mlflow.shap.log_explanation(model.predict, X) # list artifacts client = mlflow.tracking.MlflowClient() artifact_path = "model_explanations_shap" artifacts = [x.path for x in client.list_artifacts(run.info.run_id, artifact_path)] print("# artifacts:") print(artifacts) # load back the logged explanation dst_path = client.download_artifacts(run.info.run_id, artifact_path) base_values = np.load(os.path.join(dst_path, "base_values.npy")) shap_values = np.load(os.path.join(dst_path, "shap_values.npy")) print("\n# base_values:") print(base_values) print("\n# shap_values:") print(shap_values[:3]) .. code-block:: text :caption: Output # artifacts: ['model_explanations_shap/base_values.npy', 'model_explanations_shap/shap_values.npy', 'model_explanations_shap/summary_bar_plot.png'] # base_values: 20.502000000000002 # shap_values: [[ 2.09975523 0.4746513 7.63759026 0. ] [ 2.00883109 -0.18816665 -0.14419184 0. ] [ 2.00891772 -0.18816665 -0.14419184 0. ]] .. figure:: ../_static/images/shap-ui-screenshot.png Logged artifacts """ import matplotlib.pyplot as plt import shap artifact_path = _DEFAULT_ARTIFACT_PATH if artifact_path is None else artifact_path background_data = shap.kmeans(features, min(_MAXIMUM_BACKGROUND_DATA_SIZE, len(features))) explainer = shap.KernelExplainer(predict_function, background_data) shap_values = explainer.shap_values(features) _log_numpy(explainer.expected_value, _BASE_VALUES_FILE_NAME, artifact_path) _log_numpy(shap_values, _SHAP_VALUES_FILE_NAME, artifact_path) shap.summary_plot(shap_values, features, plot_type="bar", show=False) fig = plt.gcf() fig.tight_layout() _log_matplotlib_figure(fig, _SUMMARY_BAR_PLOT_FILE_NAME, artifact_path) plt.close(fig) return append_to_uri_path(mlflow.active_run().info.artifact_uri, artifact_path)