Source code for explainy.core.explanation_base

import csv
import logging
import os
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import numpy as np
import pandas as pd
from sklearn.base import is_classifier

from explainy.utils.typing import Config, ModelType
from explainy.utils.utils import create_folder, join_text_with_comma_and_and, num_to_str

logger = logging.getLogger(__name__)


[docs] class ExplanationBase(ABC): def __init__( self, model: ModelType, config: Optional[Config] = None, ) -> None: """Initialize the explanation base class Args: model (ModelType): trained model that should be explained config (Dict, optional): config file that contains explanation settings. Defaults to None. """ self.model = model self.config = config if config else {} self.is_classifier: bool = is_classifier(self.model) self.folder_name: str = self.config.get("folder_name", "output") self.file_name: str = self.config.get("file_name", "explanations.csv") self.explanation_name: str self.number_of_features: int self.set_paths() if self.is_classifier: score_text_empty = ( "The {} used {} features to produce the predictions. The class" " of this sample was {:.0f}." ) else: score_text_empty = ( "The {} used {} features to produce the predictions. The" " prediction of this sample was {:.1f}." ) description_text_empty: str = ( "This is a {} explanation, it creates {} and {} explanations." ) attribute_names = [ ("description_text_empty", description_text_empty), ("score_text_empty", score_text_empty), ] for attr_name, default_value in attribute_names: setattr(self, attr_name, self.config.get(attr_name, default_value))
[docs] def define_explanation_placeholder( self, natural_language_text_empty: str, method_text_empty: str, sentence_text_empty: str, ) -> None: """Set the explanation text, if defined else load it from defaults Args: natural_language_text_empty (str): natural language explanation placeholder method_text_empty (str): method placeholder sentence_text_empty (str): sentence text placeholder """ attribute_names = [ ("natural_language_text_empty", natural_language_text_empty), ("method_text_empty", method_text_empty), ("sentence_text_empty", sentence_text_empty), ] for attr_name, default_value in attribute_names: setattr(self, attr_name, self.config.get(attr_name, default_value))
[docs] def get_number_of_features(self, number_of_features: int) -> int: """Set the number of features based on the defined number and the max number of features Args: number_of_features (int): number_of_features as input Returns: int: number_of_features considering the max number of dataset features """ if number_of_features > self.X.shape[1]: warnings.warn( 'The "number_of_features" is larger than the number of dataset' f" features. The value is set to {self.X.shape[1]}" ) return min(number_of_features, self.X.shape[1])
[docs] def get_feature_names(self, X: Union[pd.DataFrame, np.ndarray]) -> List[str]: """Get the feature names based on the given dataset Args: X (Union[pd.DataFrame, np.array]): features dataset Returns: List[str]: list of feature names """ if isinstance(X, pd.DataFrame): feature_names = list(X) else: feature_names = [f"feature_{index}" for index in range(X.shape[1])] return feature_names
[docs] def set_paths(self) -> None: """Set the paths where the output should be saved""" self.path = os.path.join(os.getcwd(), "reports", self.folder_name) self.path_plot = create_folder(os.path.join(self.path, "plot")) self.path_result = create_folder(os.path.join(self.path, "results")) self.path_log = create_folder(os.path.join(self.path, "logs")) paths_dict = { "path": self.path, "path_plot": self.path_plot, "path_result": self.path_result, "path_log": self.path_log, } logger.debug(f"paths_dict:", extra=paths_dict)
@abstractmethod def _calculate_importance(self): """Calculate the feature importance""" raise NotImplementedError("Subclasses should implement this!")
[docs] @abstractmethod def plot(self, sample_index: int, kind: str) -> None: """Plot the feature importance""" raise NotImplementedError("Subclasses should implement this!")
[docs] @abstractmethod def get_feature_values(self): """Get the feature values""" raise NotImplementedError("Subclasses should implement this!")
[docs] def importance(self) -> pd.DataFrame: """Get the feature importance""" df = pd.DataFrame(self.feature_values, columns=["Feature", "Importance"]) return df.round(2)
[docs] def get_prediction(self, sample_index: int) -> float: """Get the model prediction Args: sample_index (int): sample_index for a which a predction shall be made Returns: float: predction of the model for that sample """ with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) prediction: float = self.model.predict(self.X)[sample_index] return prediction
[docs] def get_method_text(self) -> str: """Generate the output of the method explanation.""" return self.method_text_empty.format(num_to_str[self.number_of_features])
[docs] def get_sentences(self) -> str: """Generate the output sentences of the explanation.""" values = [] for feature_name, feature_value in self.feature_values[ : self.number_of_features ]: values.append(self.sentence_text_empty.format(feature_name, feature_value)) sentences = join_text_with_comma_and_and(values) return sentences
[docs] def get_natural_language_text(self) -> str: """Generate the output of the explanation in natural language. Returns: str: return the natural_language_text explanation """ return self.natural_language_text_empty.format( num_to_str[self.number_of_features], self.sentences )
[docs] def get_description_text(self) -> str: """Generate the description of the explanation method. Example: This is a SHAP explanation, it creates local and non-contrastive explanations. Returns: str: return the explanation method description """ return self.description_text_empty.format( self.explanation_name, self.explanation_type, self.explanation_style )
[docs] def get_score_text(self) -> str: """Generate the text explaining the prediction score of the sample Returns: str: return the score_text for the sample. """ self.number_of_dataset_features = self.X.shape[1] return self.score_text_empty.format( self.model.__class__.__name__, self.number_of_dataset_features, self.prediction, )
[docs] def get_model_text(self) -> str: """Generate text the explains the used machine learning model (wip) Returns: str: return the description of the machine learning model """ return str(self.model)
[docs] def get_plot_name(self, sample_name: Optional[str] = None) -> str: """Get the name of the plot Args: sample_name (str, optional): name of the sample. Defaults to None. Returns: str: return the name of the plot """ prefix = f"{self.explanation_name}_features_{self.number_of_features}" if sample_name: plot_name = f"{prefix}_sample_{sample_name}.png" else: plot_name = f"{prefix}.png" return plot_name
[docs] def get_sample_name( self, sample_index: int, sample_name: Optional[str] = None ) -> str: """Determine the name of the sample, if no sample_name provide, use the sample_index Args: sample_index (int): index of the sample sample_name (str, optional): name of the sample. Defaults to None. Returns: str: name of the sample """ if not sample_name: sample_name = str(sample_index) return sample_name
[docs] def save(self, sample_index: int, sample_name: Optional[str] = None) -> None: """Save the explanations to a csv file, save the plots Args: sample_index (int): [description] sample_name (str, optional): name of the sample. Defaults to None. """ assert hasattr(self, "fig"), "missing the figure object, call `plot()` first" sample_name = self.get_sample_name(sample_index, sample_name) self.save_csv(sample_name) self.fig.savefig( os.path.join(self.path_plot, self.plot_name), bbox_inches="tight", ) print(f"Saved the plot to {os.path.join(self.path_plot, self.plot_name)}")
[docs] def save_csv(self, sample_index: int) -> None: """Save the explanation to a csv. The columns contain the method_text, the natural_language_text, the name of the plot and the predicted value. The index is the Entry ID. Args: sample_index (int): index of the sample Returns: None. """ assert hasattr( self, "natural_language_text" ), "missing the `natural_language_text`, call `explain()` first" assert hasattr( self, "prediction" ), "missing the prediction, call `explain()` first" assert hasattr( self, "plot_name" ), "instance lacks plot_name, call `plot()` first" output = { "score_text": self.score_text, "method_text": self.method_text, "natural_language_text": self.natural_language_text, "plot": self.plot_name, "number_of_features": self.number_of_features, "prediction": self.prediction, } df = pd.DataFrame(output, index=[sample_index]) for column in ["method_text", "natural_language_text", "score_text"]: df[column] = df[column].astype(str) df[column] = df[column].str.replace("\n", "\\n") # check if the file is already there, if not, create it is_file = os.path.isfile(os.path.join(self.path_result, self.file_name)) if is_file: # append to the file header = False mode = "a" else: # create the file header = True mode = "w" df.to_csv( os.path.join(self.path_result, self.file_name), sep=";", encoding="utf-8-sig", index_label=["Entry ID"], mode=mode, header=header, escapechar="\\", quotechar='"', quoting=csv.QUOTE_NONNUMERIC, ) print(f"Saved the csv to {os.path.join(self.path_result, self.file_name)}")