Source code for explainy.explanations.counterfactual_explanation

"""Counterfactual Explanation
--------------------------
Counterfactual explanations tell us how the values of an instance have to change to
significantly change its prediction. A counterfactual explanation of a prediction
describes the smallest change to the feature values that changes the prediction
to a predefined output. By creating counterfactual instances, we learn about how the
model makes its predictions and can explain individual predictions [1].

Characteristics
===============
- local
- contrastive

Source
======
[1] Molnar, Christoph. "Interpretable machine learning. A Guide for Making Black Box Models Explainable", 2019.
https://christophm.github.io/interpretable-ml-book/
"""

import warnings
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.font_manager import FontProperties
from mlxtend.evaluate import create_counterfactual
from sklearn.base import is_regressor

from explainy.core.explanation import Explanation
from explainy.core.explanation_base import ExplanationBase
from explainy.utils.logger import Logger
from explainy.utils.typing import Config, ModelType
from explainy.utils.utils import (
    NonConvergenceError,
    join_text_with_comma_and_and,
    num_to_str,
)

np.seterr(divide="ignore", invalid="ignore")

COLUMN_REFERENCE = "Reference Values"
COLUMN_COUNTERFACTUAL = "Counterfactual Values"
COLUMN_DIFFERENCE = "Prediction Difference"


[docs] class CounterfactualExplanation(ExplanationBase): """Contrastive, local Explanation""" explanation_type: str = "local" explanation_style: str = "contrastive" explanation_name: str = "counterfactual" def __init__( self, X: pd.DataFrame, y: np.ndarray, model: ModelType, y_desired: float, number_of_features: int = 4, config: Optional[Config] = None, delta: Optional[float] = None, ) -> None: super(CounterfactualExplanation, self).__init__(model, config) """ This implementation is a thin wrapper around `smlxtend.evaluate.create_counterfactual <http://rasbt.github.io/mlxtend/user_guide/evaluate/create_counterfactual>` Args: X (df): (Test) samples and features to calculate the importance for (sample, features) y (np.array): (Test) target values of the samples (samples, 1) model (sklearn.base.BaseEstimator): trained (sckit-learn) model object number_of_features (int): number of features to consider in the explanation config (dict): configuration dictionary y_desired (float, optional): desired target value for the counter factual example. Defaults to max(y). delta (float, optional): maximum allowed difference between the desired target value and the predicted value. Defaults to prediction * 0.05. random_state (int, optional): random state for the counter factual example. Defaults to 0. Returns: None. """ self.X = X self.y = y self.y_desired = y_desired self.delta = delta self.feature_names = self.get_feature_names(self.X) self.number_of_features = self.get_number_of_features(number_of_features) self.sample_index: int = None self.is_regressor = is_regressor(self.model) natural_language_text_empty = ( "The sample would have had the desired prediction of '{}', {}." ) method_text_empty = ( "The feature importance is shown using a counterfactual example." ) sentence_text_empty = "the '{}' was {}" self.define_explanation_placeholder( natural_language_text_empty, method_text_empty, sentence_text_empty ) self.logger = Logger( name=self.explanation_name, path_log=self.path_log ).get_logger() def _calculate_importance( self, sample_index: int = 0 ) -> Tuple[np.ndarray, np.ndarray]: """Create the counter factual explanation for the given sample. Args: sample_index (int, optional): sample index. Defaults to 0. Returns: x_ref (np.ndarray): reference feature values x_counter_factual (np.ndarray): counter factual feature values """ x_ref = self.X.values[sample_index, :] if self.prediction == self.y_desired: warnings.warn( "The prediction is already equals to the desired value (y_desired), no" " counterfactual explanation needed.Are you sure you don't want to" " choose a different sample or desired value (y_desired)?" ) if not self.delta: if self.is_regressor: self.delta = self.prediction * 0.05 else: # in case of a classifier, we are trying the find the right class self.delta = 0 self.logger.info( f"No delta value set, therefore using the value '{self.delta}'" ) start = -3 stop = 2 num = stop - start + 1 self.logger.info( "Start to calculate the counterfactual example. This may take a while..." ) is_value_found = False # try different seed values for random_seed in range(14): # use exponential increase to search for the right lammbda value for lammbda in np.logspace( start=start, stop=stop, num=num, base=10, dtype="float", ): # catch the warning "Maximum number of function evaluations has been exceeded." warning with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) x_counter_factual = create_counterfactual( x_reference=x_ref, y_desired=self.y_desired, model=self.model, X_dataset=self.X.values, lammbda=lammbda, random_seed=random_seed, ) self.y_counter_factual = self.model.predict( x_counter_factual.reshape(1, -1) )[0] local_delta = np.abs(self.y_counter_factual - self.y_desired) self.logger.info( "hyperparameters:" f" y_pred: {self.prediction:.2f}," f" y_counter_factual: {self.y_counter_factual:.2f}, lambda:" f" {lammbda}, local_delta: {local_delta}, random_seed:" f" {random_seed}" ) self.logger.debug( f" y_desired: {self.y_desired:.2f}, label:" f" {self.y.values[sample_index]}, delta:" f" {self.delta}, " ) if self.is_regressor: if local_delta < self.delta: self.logger.debug("found value below delta!") is_value_found = True break else: if self.y_counter_factual == self.y_desired: self.logger.debug("found the right class!") is_value_found = True break if is_value_found: break else: raise NonConvergenceError( "No counterfactual value found, try to decrease the 'delta'" " value or adjust the desired prediction 'y_desired'" ) self.logger.debug(f"Features of the sample: {x_ref}") self.logger.debug(f"Features of the countefactual: {x_counter_factual}") return x_ref, x_counter_factual
[docs] def get_prediction_from_new_feature_value( self, feature_index: int, x_ref: np.ndarray, x_counter_factual: np.ndarray, ) -> float: """Replace the value of the feature at the position of the feature_index and predict a new value for this new set of features Args: feature_index (int): The index of the feature to replace with the counterfactual value. x_ref (np.ndarray): reference features. x_counter_factual (np.ndarray): counter factual features. Returns: prediction (float): predicted value with the updated features values. """ x_created = x_ref.reshape(1, -1).copy() old_value = x_created[0, feature_index] new_value = x_counter_factual.reshape(1, -1)[0, feature_index] self.logger.debug(f"old_value: {old_value:.4f}, new_value: {new_value:.4f}") # assign new value x_created[0, feature_index] = x_counter_factual.reshape(1, -1)[0, feature_index] with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) pred_new = self.model.predict(x_created)[0] return pred_new
[docs] def get_feature_importance( self, x_ref: np.ndarray, x_counter_factual: np.ndarray ) -> list: """Calculate the importance of each feature. Take the reference features and replace every feature with the new counter_factual value. Calculate the absulte difference that this feature adds to the prediction. A larger absolute value, means a larger contribution and therefore a more important feature. Args: x_ref (np.ndarray): reference features. x_counter_factual (np.ndarray): counter factual features. Returns: list: list of the feature sorted by importance """ with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) pred_ref = self.model.predict(x_ref.reshape(1, -1))[0] self.differences = [] for feature_index in range(x_ref.shape[0]): pred_new = self.get_prediction_from_new_feature_value( feature_index, x_ref, x_counter_factual ) difference = pred_new - pred_ref self.differences.append(difference) self.logger.debug( f"name: {self.feature_names[feature_index]}, difference:" f" {self.differences[feature_index]:.2f}" ) # get the sorted feature_names self.feature_sort = np.array(self.feature_names)[ np.array(self.differences).argsort()[::-1] ].tolist() return self.feature_sort
[docs] def get_feature_values( self, x_ref: np.ndarray, x_counter_factual: np.ndarray, decimal: int = 2, debug: bool = False, ): """Arrange the reference and the counter factual features in a dataframe Args: x_ref (np.array): features of the sample x_counter_factual (np.array): features of the counter factual sample to achieve y_desired decimal (int): decimal number to round the values to in the plot debug (bool): if True, plot the dataframe Returns: None. """ index = [ COLUMN_REFERENCE, COLUMN_COUNTERFACTUAL, COLUMN_DIFFERENCE, ] self.df = ( pd.DataFrame( [x_ref, x_counter_factual, self.differences], index=index, columns=self.feature_names, ) .round(decimal) .T ) # reorder dataframe according the the feature importance self.df = self.df.loc[self.feature_sort, :] try: self.df[COLUMN_DIFFERENCE][self.df[COLUMN_DIFFERENCE] != 0] except IndexError as e: print(e) if debug: self.df.plot(kind="barh", figsize=(3, 5))
[docs] def importance(self) -> pd.DataFrame: """Return the feature importance Returns: pd.DataFrame: dataframe with the feature importance """ return self.df.round(2)
[docs] def plot(self, sample_index: int, kind: str = "table") -> None: """Create the plot of the counterfactual table Args: kind (str, optional): kind of plot. Defaults to 'table'. Raises: Exception: raise Exception if the "kind" of plot is not supported """ if sample_index != self.sample_index: raise ValueError( "sample_index is not the same as the index used to calculate the" " counterfactual explanation, re-run .explain(sample_index) to plot the" " correct sample" ) if kind == "table": self.fig = self._plot_table() else: raise Exception(f'Value of kind "{kind}" is not supported!')
def _plot_table(self) -> plt.figure: """Plot the table comparing the reference and the counterfactual values Returns: plt.figure: figure object """ colLabels = ["Sample", "Counterfactual Sample"] columns = [COLUMN_REFERENCE, COLUMN_COUNTERFACTUAL] array_subset = self.df[columns].values[: self.number_of_features] rowLabels = list(self.df.index)[: self.number_of_features] # if show_rating: score_row = np.array( [f"{self.prediction:.1f}", f"{self.y_counter_factual:.1f}"] ).reshape(1, -1) array_subset = np.append(array_subset, score_row, axis=0) rowLabels = rowLabels + ["Prediction"] fig, ax = plt.subplots() fig.patch.set_visible(False) ax.axis("off") ax.axis("tight") table = ax.table( cellText=array_subset, colLabels=colLabels, rowLabels=rowLabels, loc="center", cellLoc="center", ) table.auto_set_font_size(False) table.set_fontsize(12) table.scale(1.25, 2) # if show_rating: # make the last row bold for (row, _), cell in table.get_celld().items(): if row == array_subset.shape[0]: cell.set_text_props(fontproperties=FontProperties(weight="bold")) plt.axis("off") plt.grid("off") plt.tight_layout() plt.show() return fig
[docs] def get_method_text(self) -> str: """Define the method introduction text of the explanation type. Returns: str: method text explanation """ return self.method_text_empty.format( num_to_str[self.number_of_features], self.y_counter_factual )
[docs] def get_natural_language_text(self) -> str: """Define the natural language output using the feature names and its values for this explanation type Returns: str: natural language explanation """ feature_values = self.df[COLUMN_COUNTERFACTUAL].tolist()[ : self.number_of_features ] feature_names = list(self.df.index)[: self.number_of_features] sentences = [] for feature_name, feature_value in zip(feature_names, feature_values): sentence_filled = self.sentence_text_empty.format( feature_name, f"'{feature_value}'" ) sentences.append(sentence_filled) sentences = "if " + join_text_with_comma_and_and(sentences) natural_language_text = self.natural_language_text_empty.format( self.y_counter_factual, sentences ) return natural_language_text
def _setup(self, sample_index: int, sample_name: str) -> None: """Helper function to setup the counterfactual explanation Args: sample_index (int): index of sample in scope sample_name (str): name of the sample in scope Returns: None """ x_ref, x_counter_factual = self._calculate_importance(sample_index) self.get_feature_importance(x_ref, x_counter_factual) self.get_feature_values(x_ref, x_counter_factual) self.natural_language_text = self.get_natural_language_text() self.method_text = self.get_method_text() self.plot_name = self.get_plot_name(sample_name)
[docs] def explain( self, sample_index: int, sample_name: Optional[str] = None, separator: str = "\n", ) -> Explanation: """Main function to create the explanation of the given sample. The method_text, natural_language_text and the plots are create per sample. Args: sample_index (int): number of the sample to create the explanation for sample_name (str, optional): name of the sample. Defaults to None. separator (str, optional): separator for the natural language text. Defaults to "\n". Returns: Explanation: Explanation object containing the explanations """ self.sample_index = sample_index sample_name = self.get_sample_name(sample_index, sample_name) self.prediction = self.get_prediction(sample_index) self.score_text = self.get_score_text() self._setup(sample_index, sample_name) self.explanation = Explanation( self.score_text, self.method_text, self.natural_language_text, separator=separator, ) return self.explanation