mlflow.shap.log_explanation(predict_function, features, artifact_path=None)[source]


Experimental: This method may change or be removed in a future release without warning.

Given a predict_function capable of computing ML model output on the provided features, computes and logs explanations of an ML model’s output. Explanations are logged as a directory of artifacts containing the following items generated by SHAP (SHapley Additive exPlanations).

  • Base values

  • SHAP values (computed using shap.KernelExplainer)

  • Summary bar plot (shows the average impact of each feature on model output)

  • predict_function

    A function to compute the output of a model (e.g. predict_proba method of scikit-learn classifiers). Must have the following signature:

    def predict_function(X) -> pred:
    • X: An array-like object whose shape should be (# samples, # features).

    • pred: An array-like object whose shape should be (# samples) for a regressor or (# classes, # samples) for a classifier. For a classifier, the values in pred should correspond to the predicted probability of each class.

    Acceptable array-like object types:

    • numpy.array

    • pandas.DataFrame

    • shap.common.DenseData

    • scipy.sparse matrix

  • features

    A matrix of features to compute SHAP values with. The provided features should have shape (# samples, # features), and can be either of the array-like object types listed above.


    Background data for shap.KernelExplainer is generated by subsampling features with shap.kmeans. The background data size is limited to 100 rows for performance reasons.

  • artifact_path – The run-relative artifact path to which the explanation is saved. If unspecified, defaults to “model_explanations_shap”.


Artifact URI of the logged explanations.

import os

import numpy as np
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.linear_model import LinearRegression

import mlflow

# prepare training data
dataset = load_boston()
X = pd.DataFrame([:50, :8], columns=dataset.feature_names[:8])
y =[:50]

# train a model
model = LinearRegression(), y)

# log an explanation
with mlflow.start_run() as run:
    mlflow.shap.log_explanation(model.predict, X)

# list artifacts
client = mlflow.tracking.MlflowClient()
artifact_path = "model_explanations_shap"
artifacts = [x.path for x in client.list_artifacts(, artifact_path)]
print("# artifacts:")

# load back the logged explanation
dst_path = client.download_artifacts(, artifact_path)
base_values = np.load(os.path.join(dst_path, "base_values.npy"))
shap_values = np.load(os.path.join(dst_path, "shap_values.npy"))

print("\n# base_values:")
print("\n# shap_values:")
# artifacts:

# base_values:

# shap_values:
[[ 2.09975523  0.4746513   7.63759026  0.        ]
 [ 2.00883109 -0.18816665 -0.14419184  0.        ]
 [ 2.00891772 -0.18816665 -0.14419184  0.        ]]

Logged artifacts