"""SHAP Explanation
----------------
A prediction can be explained by assuming that each feature value of the instance is a "player" in a game where
the prediction is the payout. Shapley values (a method from coalitional game theory) tells us how to fairly
distribute the "payout" among the features. The Shapley value is the average marginal contribution of a feature
value across all possible coalitions [1].
Characteristics
===============
- local
- non-contrastive
Source
======
[1] Molnar, Christoph. "Interpretable machine learning. A Guide for Making Black Box Models Explainable", 2019.
https://christophm.github.io/interpretable-ml-book/
"""
import textwrap
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from explainy.core.explanation import Explanation
from explainy.core.explanation_base import ExplanationBase
from explainy.utils.logger import Logger
from explainy.utils.typing import Config, ModelType
[docs]
class ShapExplanation(ExplanationBase):
"""Non-contrastive, local Explanation"""
explanation_type: str = "local"
explanation_style: str = "non-contrastive"
explanation_name: str = "shap"
def __init__(
self,
X: pd.DataFrame,
y: np.ndarray,
model: ModelType,
number_of_features: int = 4,
config: Optional[Config] = None,
**kwargs,
) -> None:
super(ShapExplanation, self).__init__(model, config)
"""
This implementation is a thin wrapper around `shap.TreeExplainer
<https://shap-lrjball.readthedocs.io/en/docs_update/generated/shap.TreeExplainer.html>`
Args:
X (df): (Test) samples and features to calculate the importance for (sample, features)
y (np.array): (Test) target values of the samples (samples, 1)
model (object): trained (sckit-learn) model object
number_of_features (int): number of features to be displayed in the explanation. Defaults to 4.
config (Dict): configuration dictionary
Returns:
None.
"""
self.X = X
self.y = y
self.feature_names = self.get_feature_names(self.X)
self.number_of_features = self.get_number_of_features(number_of_features)
self.kwargs = kwargs
self.sample_index: int = None
natural_language_text_empty = (
"The {} features which contributed most to the prediction of this"
" particular sample were: {}."
)
method_text_empty = (
"The feature importance was calculated using the SHAP method."
)
sentence_text_empty = "'{}' ({:.2f})"
self.define_explanation_placeholder(
natural_language_text_empty, method_text_empty, sentence_text_empty
)
self.logger = Logger(self.explanation_name, self.path_log).get_logger()
self._calculate_importance()
def _calculate_importance(self) -> None:
"""Explain model predictions using SHAP library"""
self.explainer = shap.TreeExplainer(self.model, **self.kwargs)
self.shap_values = self.explainer.shap_values(self.X)
[docs]
def get_feature_values(self, sample_index: int = 0) -> List[Tuple[str, float]]:
"""Extract the feature name and its importance per sample
- get absolute values to get the strongst positive and negative contribution
- sort by importance -> highst to lowest
Args:
sample_index (int, optional): sample for which the explanation should
be returned. Defaults to 0.
Returns:
feature_values (list(tuple(str, float))): list of tuples for each
feature and its importance of a sample.
"""
if not self.is_classifier:
indexes = np.argsort(abs(self.shap_values[sample_index, :]))
sample_shap_value = self.shap_values
else:
indexes = np.argsort(
abs(self.shap_values[self.prediction][sample_index, :])
)
sample_shap_value = self.shap_values[self.prediction]
self.logger.info(
f"SHAP values are taken from predicted class '{self.prediction}'"
)
feature_values = []
for index in indexes.tolist()[::-1]:
feature_values.append((
self.feature_names[index],
sample_shap_value[sample_index, index],
))
return feature_values
[docs]
def plot(self, sample_index: int, kind: str = "bar") -> None:
"""Plot the shap values
Args:
sample_index (int, optional): sample for which the explanation should
be returned.
kind (str, optional): set the type of plot to be created. Defaults to "bar".
Returns:
None:
"""
if sample_index != self.sample_index:
raise ValueError(
"the provided index sample does not match the index the importance is"
" calculated for. re-run .explain(sample_index) to plot the correct"
" sample"
)
if kind == "bar":
self.fig = self._bar_plot(sample_index)
elif kind == "shap":
self.fig = self._shap_plot(sample_index)
else:
raise Exception(f'Value of "kind = {kind}" is not supported!')
def _bar_plot(self, sample_index: int) -> plt.Figure:
"""Create a bar plot of the shape values for a selected sample
Args:
sample_index (int, optional): sample for which the explanation should
be returned.
Returns:
plt.figure
"""
if not self.is_classifier:
shap_value = self.shap_values
else:
shap_value = self.shap_values[self.prediction]
indexes = np.argsort(abs(shap_value[sample_index, :]))
sorted_idx = indexes.tolist()[::-1][: self.number_of_features]
width = shap_value[sample_index, sorted_idx]
# wrap labels longer than 40 characters
labels = [textwrap.fill(self.feature_names[i], width=40) for i in sorted_idx]
y = np.arange(self.number_of_features, 0, -1)
fig = plt.figure(figsize=(6, max(2, int(0.6 * self.number_of_features))))
plt.barh(y=y, width=width, height=0.5)
plt.yticks(y, labels)
plt.xlabel("Shap Values")
plt.tight_layout()
plt.show()
return fig
def _shap_plot(self, sample_index: int) -> plt.Figure:
"""Visualize the first prediction's explanation
Args:
sample_index (int, optional): sample for which the explanation should
be returned.
Returns:
plt.figure: return a matplotlib figure
"""
if not self.is_classifier:
base_value = self.explainer.expected_value
shap_value = np.around(self.shap_values[sample_index, :], decimals=2)
else:
base_value = self.explainer.expected_value[self.prediction]
shap_value = np.around(
self.shap_values[self.prediction][sample_index, :], decimals=2
)
shap.force_plot(
base_value=base_value,
shap_values=shap_value,
features=self.X.iloc[sample_index, :],
matplotlib=True,
show=False,
)
fig = plt.gcf()
fig.set_figheight(4)
fig.set_figwidth(8)
plt.show()
return fig
def _log_output(self, sample_index: int) -> None:
"""Log the prediction values of the sample
Args:
sample_index (int): number of the sample to create the explanation for
Returns:
None.
"""
if not self.is_classifier:
message = f"The expected_value was: {self.explainer.expected_value}"
else:
message = (
"The expected_value was:"
f" {self.explainer.expected_value[self.prediction]}"
)
self.logger.debug(message)
self.logger.debug(f"The y_value was: {self.y.values[sample_index]}")
self.logger.debug(f"The predicted value was: {self.prediction}")
def _setup(self, sample_index: int, sample_name: str) -> None:
"""Helper function to call all methods to create the explanations
Args:
sample_index (int): number of the sample to create the explanation for
sample_name (str, optional): name of the sample. Defaults to None.
Returns:
None.
"""
self.sample_index = sample_index
self._log_output(sample_index)
self.feature_values = self.get_feature_values(sample_index)
self.sentences = self.get_sentences()
self.natural_language_text = self.get_natural_language_text()
self.method_text = self.get_method_text()
self.plot_name = self.get_plot_name(sample_name)
[docs]
def explain(
self,
sample_index: int,
sample_name: Optional[str] = None,
separator: str = "\n",
) -> Explanation:
"""Main function to create the explanation of the given sample. The
method_text, natural_language_text and the plots are create per sample.
Args:
sample_index (int): number of the sample to create the explanation for
sample_name (str, optional): name of the sample. Defaults to None.
separator (str, optional): separator for the explanations. Defaults to "\n".
Returns:
Explanation: Explanation object
"""
sample_name = self.get_sample_name(sample_index, sample_name)
self.prediction = self.get_prediction(sample_index)
self.score_text = self.get_score_text()
self._setup(sample_index, sample_name)
self.explanation = Explanation(
self.score_text, self.method_text, self.natural_language_text, separator
)
return self.explanation