"""
SHAP Explanation
----------------
A prediction can be explained by assuming that each feature value of the instance is a "player" in a game where
the prediction is the payout. Shapley values (a method from coalitional game theory) tells us how to fairly
distribute the "payout" among the features. The Shapley value is the average marginal contribution of a feature
value across all possible coalitions [1].
Characteristics
===============
- local
- non-contrastive
Source
======
[1] Molnar, Christoph. "Interpretable machine learning. A Guide for Making Black Box Models Explainable", 2019.
https://christophm.github.io/interpretable-ml-book/
"""
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
import sklearn
from explainy.core.explanation import Explanation
from explainy.core.explanation_base import ExplanationBase
[docs]class ShapExplanation(ExplanationBase):
"""
Non-contrastive, local Explanation
"""
def __init__(
self,
X: pd.DataFrame,
y: np.array,
model: sklearn.base.BaseEstimator,
number_of_features: int = 4,
config: Dict = None,
**kwargs,
) -> None:
super(ShapExplanation, self).__init__(model, config)
"""
This implementation is a thin wrapper around `shap.TreeExplainer
<https://shap-lrjball.readthedocs.io/en/docs_update/generated/shap.TreeExplainer.html>`
Args:
X (df): (Test) samples and features to calculate the importance for (sample, features)
y (np.array): (Test) target values of the samples (samples, 1)
model (object): trained (sckit-learn) model object
number_of_features (int):
config
Returns:
None.
"""
self.X = X
self.y = y
self.feature_names = self.get_feature_names(self.X)
self.number_of_features = self.get_number_of_features(number_of_features)
self.kwargs = kwargs
natural_language_text_empty = (
"The {} features which contributed most to the prediction of this"
" particular sample were: {}."
)
method_text_empty = (
"The feature importance was calculated using the SHAP method."
)
sentence_text_empty = "'{}' ({:.2f})"
self.define_explanation_placeholder(
natural_language_text_empty, method_text_empty, sentence_text_empty
)
self.explanation_type = 'local'
self.explanation_style = 'non-contrastive'
self.explanation_name = "shap"
self.logger = self.setup_logger(self.explanation_name)
self._calculate_importance()
def _calculate_importance(self) -> None:
"""
Explain model predictions using SHAP library
Returns:
None.
"""
self.explainer = shap.TreeExplainer(self.model, **self.kwargs)
self.shap_values = self.explainer.shap_values(self.X)
# if isinstance(self.explainer.expected_value, np.ndarray):
# self.explainer.expected_value = self.explainer.expected_value[0]
# assert isinstance(
# self.explainer.expected_value, float
# ), "self.explainer.expected_value has wrong type"
[docs] def get_feature_values(self, sample_index: int = 0) -> List[Tuple[str, float]]:
"""
extract the feature name and its importance per sample
- get absolute values to get the strongst postive and negative contribution
- sort by importance -> highst to lowest
Args:
sample_index (int, optional): sample for which the explanation should
be returned. Defaults to 0.
Returns:
feature_values (list(tuple(str, float))): list of tuples for each
feature and its importance of a sample.
"""
if not self.is_classifier:
indexes = np.argsort(abs(self.shap_values[sample_index, :]))
sample_shap_value = self.shap_values
else:
indexes = np.argsort(
abs(self.shap_values[self.prediction][sample_index, :])
)
sample_shap_value = self.shap_values[self.prediction]
self.logger.info(
f'SHAP values are taken from predicted class: {self.prediction}'
)
feature_values = []
for index in indexes.tolist()[::-1]:
feature_values.append(
(
self.feature_names[index],
sample_shap_value[sample_index, index],
)
)
return feature_values
[docs] def plot(self, sample_index: int, kind="bar") -> None:
"""
Plot the shap values
Args:
sample_index (int, optional): DESCRIPTION. Defaults to 0.
kind (TYPE, optional): DESCRIPTION. Defaults to "bar".
Returns:
None: DESCRIPTION.
"""
if kind == "bar":
self.fig = self._bar_plot(sample_index)
elif kind == "shap":
self.fig = self._shap_plot(sample_index)
else:
raise Exception(f'Value of "kind = {kind}" is not supported!')
def _bar_plot(self, sample_index: int) -> plt.figure:
"""
Create a bar plot of the shape values for a selected sample
Args:
sample_index (int, optional): sample for which the explanation should
be returned. Defaults to 0.
Returns:
None
"""
if not self.is_classifier:
shap_value = self.shap_values
else:
shap_value = self.shap_values[self.prediction]
indexes = np.argsort(abs(shap_value[sample_index, :]))
sorted_idx = indexes.tolist()[::-1][: self.number_of_features]
width = shap_value[sample_index, sorted_idx]
labels = [self.feature_names[i] for i in sorted_idx]
y = np.arange(self.number_of_features, 0, -1)
fig = plt.figure(figsize=(6, max(2, int(0.5 * self.number_of_features))))
plt.barh(y=y, width=width, height=0.5)
plt.yticks(y, labels)
plt.xlabel("Shap Values")
plt.tight_layout()
plt.show()
return fig
def _shap_plot(self, sample_index: int) -> plt.figure:
"""
visualize the first prediction's explanation
Args:
sample_index (int, optional): sample for which the explanation should
be returned. Defaults to 0.
Returns:
plt.figure: return a matplotlib figure containg the plot
"""
if not self.is_classifier:
base_value = self.explainer.expected_value
shap_value = np.around(self.shap_values[sample_index, :], decimals=2)
else:
base_value = self.explainer.expected_value[self.prediction]
shap_value = np.around(
self.shap_values[self.prediction][sample_index, :], decimals=2
)
shap.force_plot(
base_value=base_value,
shap_values=shap_value,
features=self.X.iloc[sample_index, :],
matplotlib=True,
show=False,
)
fig = plt.gcf()
fig.set_figheight(4)
fig.set_figwidth(8)
plt.show()
return fig
def _log_output(self, sample_index: int) -> None:
"""
Log the prediction values of the sample
Args:
sample (int): DESCRIPTION.
Returns:
None.
"""
if not self.is_classifier:
message = f"The expected_value was: {self.explainer.expected_value}"
else:
message = (
"The expected_value was:"
f" {self.explainer.expected_value[self.prediction]}"
)
self.logger.debug(message)
self.logger.debug(f"The y_value was: {self.y.values[sample_index][0]}")
self.logger.debug(f"The predicted value was: {self.prediction}")
def _setup(self, sample_index: int, sample_name: str):
"""
Helper function to call all methods to create the explanations
Args:
sample_index (TYPE): DESCRIPTION.
sample_name (TYPE): DESCRIPTION.
Returns:
None.
"""
self._log_output(sample_index)
self.feature_values = self.get_feature_values(sample_index)
self.sentences = self.get_sentences()
self.natural_language_text = self.get_natural_language_text()
self.method_text = self.get_method_text()
self.plot_name = self.get_plot_name(sample_name)
[docs] def explain(self, sample_index, sample_name=None, separator="\n") -> None:
"""
main function to create the explanation of the given sample. The
method_text, natural_language_text and the plots are create per sample.
Args:
sample_index (int): number of the sample to create the explanation for
Returns:
None.
"""
sample_name = self.get_sample_name(sample_index, sample_name)
self.prediction = self.get_prediction(sample_index)
self.score_text = self.get_score_text()
self._setup(sample_index, sample_name)
self.explanation = Explanation(
self.score_text, self.method_text, self.natural_language_text
)
return self.explanation