"""
Global Surrogate Model
----------------------
A global surrogate model is an interpretable model that is trained to approximate the
predictions of a black box model. We can draw conclusions about the black box model
by interpreting the surrogate model [1].
Characteristics
===============
- global
- contrastive
Source
======
[1] Molnar, Christoph. "Interpretable machine learning. A Guide for Making Black Box Models Explainable", 2019.
https://christophm.github.io/interpretable-ml-book/
"""
import os
import subprocess
import warnings
from typing import Dict, Union
import graphviz
import numpy as np
import pandas as pd
from IPython.display import display
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, export_text
from explainy.core.explanation import Explanation
from explainy.core.explanation_base import ExplanationBase
from explainy.utils.surrogate_plot import SurrogatePlot
from explainy.utils.surrogate_text import SurrogateText
from explainy.utils.typing import ModelType
[docs]class SurrogateModelExplanation(ExplanationBase):
"""
Contrastive, global Explanation
"""
def __init__(
self,
X: Union[pd.DataFrame, np.array],
y: Union[pd.DataFrame, np.array],
model: ModelType,
number_of_features: int = 4,
config: Dict = None,
kind: str = "tree",
**kwargs: dict,
):
super(SurrogateModelExplanation, self).__init__(model, config)
"""Init the specific explanation class, the base class is "Explanation"
Args:
X (df): (Test) samples and features to calculate the importance for (sample, features)
y (np.array): (Test) target values of the samples (samples, 1)
model (object): trained (sckit-learn) model object
sparse (bool): boolean value to generate sparse or non sparse explanation
show_rating (bool):
save (bool, optional): boolean value to save the plots. Defaults to True.
Returns:
None.
"""
self.X = X
self.y = y
self.feature_names = self.get_feature_names(self.X)
self.number_of_features = np.log2(number_of_features)
self.number_of_groups = number_of_features
self.kind = kind
self.kwargs = kwargs
kinds = ["tree", "linear"]
assert (
self.kind in kinds
), f"'{self.kind}' is not a valid option, select from {kinds}"
(
natural_language_text_empty,
method_text_empty,
sentence_text_empty,
) = self.set_defaults()
self.define_explanation_placeholder(
natural_language_text_empty, method_text_empty, sentence_text_empty
)
self.explanation_name = "surrogate"
self.logger = self.setup_logger(self.explanation_name)
self._setup()
[docs] def set_defaults(self):
natural_language_text_empty = (
"The following thresholds were important for the predictions: {}"
)
method_text_empty = (
"The feature importance was calculated using a {} surrogate model."
" {} tree nodes are shown."
)
if self.is_classifier:
sentence_text_empty = "\nThe sample is assigned class {} if {}"
else:
sentence_text_empty = "\nThe sample has a value of {:.2f} if {}"
return (
natural_language_text_empty,
method_text_empty,
sentence_text_empty,
)
def _calculate_importance(self) -> None:
"""Train a surrogate model on the predicted values from the original model
Raises:
Exception: if the kind is not known
"""
if self.kind == "tree" and not self.is_classifier:
estimator = DecisionTreeRegressor
elif self.kind == "tree" and self.is_classifier:
estimator = DecisionTreeClassifier
elif self.kind == "linear" and not self.is_classifier:
estimator = LinearRegression
elif self.kind == "linear" and self.is_classifier:
estimator = LogisticRegression
else:
raise Exception(f'Value of "kind" is not supported: {self.kind}!')
y_hat = self.model.predict(self.X.values)
self.surrogate_model = self.get_surrogate_model(estimator)
self.surrogate_model.fit(self.X, y_hat)
self.logger.info(
"Surrogate Model score: {:.2f}".format(
self.surrogate_model.score(self.X, y_hat)
)
)
[docs] def get_surrogate_model(self, estimator: ModelType) -> ModelType:
"""Get the surrogate model per kind with the defined hyperparamters
Args:
estimator (ModelType): surrogate estimator
Returns:
ModelType: surrogate estimator with hyperparamters
"""
if self.kind == 'tree':
surrogate_model = estimator(
max_depth=self.number_of_features, **self.kwargs
)
elif self.kind == 'linear':
surrogate_model = estimator(**self.kwargs)
return surrogate_model
[docs] def get_feature_values(self):
pass
[docs] def importance(self) -> str:
"""Return the importance of the surrogate model
Returns:
str: importance of the surrogate model
"""
if isinstance(
self.surrogate_model,
(DecisionTreeClassifier, DecisionTreeRegressor),
):
tree_rules = export_text(
self.surrogate_model, feature_names=self.feature_names
)
return tree_rules
[docs] def plot(self, index_sample: int = None) -> None:
"""Plot the surrogate model
Args:
index_sample (int, optional): index of the sample in scope. Defaults to None.
Raises:
Exception: if the type of kind is not supported
"""
if self.kind == "tree":
self._plot_tree(index_sample)
elif self.kind == "linear":
self._plot_bar(index_sample)
else:
raise Exception(f'Value of "kind" is not supported: {self.kind}!')
def _plot_bar(self, sample_index: int) -> None:
raise NotImplementedError("to be done")
def _plot_tree(
self, sample_index: int = None, precision: int = 2, **kwargs: dict
) -> None:
"""
use garphix to plot the decision tree
"""
surrogatePlot = SurrogatePlot(precision=precision, **kwargs)
self.dot_file = surrogatePlot(
model=self.surrogate_model,
feature_names=self.feature_names,
)
name, extension = os.path.splitext(self.plot_name)
graphviz_source = graphviz.Source(
self.dot_file,
filename=os.path.join(self.path_plot, name),
format=extension.replace(".", ""),
)
try:
# graphviz_source.view()
display(graphviz_source)
except subprocess.CalledProcessError:
warnings.warn("plot already open!")
[docs] def save(self, sample_index: int, sample_name: str = None) -> None:
"""
Save the explanations to a csv file, save the plots
Args:
sample_index ([type]): [description]
sample_name ([type], optional): [description]. Defaults to None.
Returns:
None.
"""
if not sample_name:
sample_name = sample_index
self.save_csv(sample_name)
with open(
os.path.join(self.path_plot, f"{self.plot_name}.dot"),
"w",
) as file:
file.write(self.dot_file)
[docs] def get_method_text(self) -> str:
"""Define the method introduction text of the explanation type.
Returns:
str: method_text explanation
"""
return self.method_text_empty.format(
self.surrogate_model.__class__.__name__,
self.num_to_str[self.number_of_groups].capitalize(),
)
[docs] def get_natural_language_text(self) -> str:
"""
Define the natural language output using the feature names and its
values for this explanation type
Returns:
str: natural_language_text explanation
"""
surrogateText = SurrogateText(
text=self.sentence_text_empty,
model=self.surrogate_model,
X=self.X,
feature_names=self.feature_names,
)
sentences = surrogateText.get_text()
return self.natural_language_text_empty.format(sentences)
def _setup(self) -> None:
"""
Calculate the feature importance and create the text once
Returns:
None.
"""
self._calculate_importance()
self.natural_language_text = self.get_natural_language_text()
self.method_text = self.get_method_text()
[docs] def explain(
self, sample_index: int, sample_name: str = None, separator: str = '\n'
) -> Explanation:
"""main function to create the explanation of the given sample.
The method_text, natural_language_text and the plots are create per sample.
Args:
sample_index (int): number of the sample to create the explanation for
sample_name (str, optional): name of the sample. Defaults to None.
separator (str, optional): seprator for the string concatenation. Defaults to '\n'.
Returns:
Explanation: explantion object containg the explainations
"""
sample_name = self.get_sample_name(sample_index, sample_name)
self.plot_name = self.get_plot_name(sample_name)
self.prediction = self.get_prediction(sample_index)
self.score_text = self.get_score_text()
self.explanation = Explanation(
self.score_text, self.method_text, self.natural_language_text
)
return self.explanation