mlflow.shap
-
mlflow.shap.
log_explanation
(predict_function, features, artifact_path=None)[source] Note
Experimental: This method may change or be removed in a future release without warning.
Given a
predict_function
capable of computing ML model output on the providedfeatures
, computes and logs explanations of an ML model’s output. Explanations are logged as a directory of artifacts containing the following items generated by SHAP (SHapley Additive exPlanations).Base values
SHAP values (computed using shap.KernelExplainer)
Summary bar plot (shows the average impact of each feature on model output)
- Parameters
predict_function –
A function to compute the output of a model (e.g.
predict_proba
method of scikit-learn classifiers). Must have the following signature:def predict_function(X) -> pred: ...
X
: An array-like object whose shape should be (# samples, # features).pred
: An array-like object whose shape should be (# samples) for a regressor or (# classes, # samples) for a classifier. For a classifier, the values inpred
should correspond to the predicted probability of each class.
Acceptable array-like object types:
numpy.array
pandas.DataFrame
shap.common.DenseData
scipy.sparse matrix
features –
A matrix of features to compute SHAP values with. The provided features should have shape (# samples, # features), and can be either of the array-like object types listed above.
Note
Background data for shap.KernelExplainer is generated by subsampling
features
with shap.kmeans. The background data size is limited to 100 rows for performance reasons.artifact_path – The run-relative artifact path to which the explanation is saved. If unspecified, defaults to “model_explanations_shap”.
- Returns
Artifact URI of the logged explanations.
import os import numpy as np import pandas as pd from sklearn.datasets import load_boston from sklearn.linear_model import LinearRegression import mlflow # prepare training data dataset = load_boston() X = pd.DataFrame(dataset.data[:50, :8], columns=dataset.feature_names[:8]) y = dataset.target[:50] # train a model model = LinearRegression() model.fit(X, y) # log an explanation with mlflow.start_run() as run: mlflow.shap.log_explanation(model.predict, X) # list artifacts client = mlflow.tracking.MlflowClient() artifact_path = "model_explanations_shap" artifacts = [x.path for x in client.list_artifacts(run.info.run_id, artifact_path)] print("# artifacts:") print(artifacts) # load back the logged explanation dst_path = client.download_artifacts(run.info.run_id, artifact_path) base_values = np.load(os.path.join(dst_path, "base_values.npy")) shap_values = np.load(os.path.join(dst_path, "shap_values.npy")) print("\n# base_values:") print(base_values) print("\n# shap_values:") print(shap_values[:3])
# artifacts: ['model_explanations_shap/base_values.npy', 'model_explanations_shap/shap_values.npy', 'model_explanations_shap/summary_bar_plot.png'] # base_values: 20.502000000000002 # shap_values: [[ 2.09975523 0.4746513 7.63759026 0. ] [ 2.00883109 -0.18816665 -0.14419184 0. ] [ 2.00891772 -0.18816665 -0.14419184 0. ]]