Source code for explainy.core.explanation_base

# -*- coding: utf-8 -*-
"""
Created on Tue Nov 24 21:15:30 2020

@author: mauro
"""

import csv
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Union

import numpy as np
import pandas as pd
from sklearn.base import is_classifier

from explainy.core.explanation_mixin import ExplanationMixin
from explainy.logger import Logger
from explainy.utils.typing import ModelType
from explainy.utils.utils import create_folder


[docs]class ExplanationBase(ABC, ExplanationMixin): def __init__( self, model: ModelType, config: Dict = None, ) -> None: """Initialize the explanation base class Args: model (ModelType): trained model that should be explained config (Dict, optional): config file that contains explanation settings. Defaults to None. """ self.model = model if not config: self.config = {} else: self.config = config self.is_classifier = is_classifier(self.model) self.folder = self.config.get("folder", "explanation") self.file_name = self.config.get("file_name", "explanations.csv") self.set_paths() self.get_number_to_string_dict() if self.is_classifier: score_text_empty = ( "The {} used {} features to produce the predictions. The class" " of this sample was {:.0f}." ) else: score_text_empty = ( "The {} used {} features to produce the predictions. The" " prediction of this sample was {:.1f}." ) description_text_empty = ( "This is a {} explanation, it creates {} and {} explanations." ) self.description_text_empty = self.config.get( "description_text_empty", description_text_empty ) self.score_text_empty = self.config.get("score_text_empty", score_text_empty)
[docs] def define_explanation_placeholder( self, natural_language_text_empty: str, method_text_empty: str, sentence_text_empty: str, ) -> None: """Set the explanation text, if defined else load it from defaults Args: natural_language_text_empty (str): natural language explanation placeholder method_text_empty (str): method placeholder sentence_text_empty (str): sentence text placeholder """ self.natural_language_text_empty = self.config.get( "natural_language_text_empty", natural_language_text_empty ) self.method_text_empty = self.config.get("method_text_empty", method_text_empty) self.sentence_text_empty = self.config.get( "sentence_text_empty", sentence_text_empty )
[docs] def get_number_of_features(self, number_of_features: int) -> int: """Set the number of features based on the defined number and the max number of features Args: number_of_features (int): number_of_features as input Returns: int: number_of_features considering the max number of dataset features """ if number_of_features > self.X.shape[1]: warnings.warn( 'The "number_of_features" is larger than the number of dataset' f' features. The value is set to {self.X.shape[1]}' ) return min(number_of_features, self.X.shape[1])
[docs] def get_feature_names(self, X: Union[pd.DataFrame, np.array]) -> List[str]: """Get the feature names based on the given dataset Args: X (Union[pd.DataFrame, np.array]): features dataset Returns: List[str]: list of feature names """ if isinstance(X, pd.DataFrame): feature_names = list(X) else: feature_names = [f'feature_{index}' for index in range(X.shape[1])] return feature_names
[docs] def set_paths(self) -> None: """Set the paths where the output should be saved Returns: None. """ self.path = os.path.join(os.path.dirname(os.getcwd()), "reports", self.folder) self.path_plot = create_folder(os.path.join(self.path, "plot")) self.path_result = create_folder(os.path.join(self.path, "results")) self.path_log = create_folder(os.path.join(self.path, "logs"))
[docs] def setup_logger(self, logger_name: str) -> object: logger = Logger(logger_name, self.path_log) return logger.get_logger()
@abstractmethod def _calculate_importance(self): raise NotImplementedError("Subclasses should implement this!")
[docs] @abstractmethod def plot(self): raise NotImplementedError("Subclasses should implement this!")
[docs] @abstractmethod def get_feature_values(self): raise NotImplementedError("Subclasses should implement this!")
[docs] def importance(self): return pd.DataFrame( self.feature_values, columns=['Feature', 'Importance'] ).round(2)
[docs] def get_prediction(self, sample_index: int) -> float: """Get the model prediction Args: sample_index (int): sample_index for a which a predction shall be made Returns: float: predction of the model for that sample """ return self.model.predict(self.X.values)[sample_index]
[docs] def get_method_text(self) -> None: """Generate the output of the method explanation. Returns: None """ return self.method_text_empty.format(self.num_to_str[self.number_of_features])
[docs] def get_sentences(self) -> None: """Generate the output sentences Returns: None """ values = [] for feature_name, feature_value in self.feature_values[ : self.number_of_features ]: values.append(self.sentence_text_empty.format(feature_name, feature_value)) sentences = self.join_text_with_comma_and_and(values) return sentences
[docs] def get_natural_language_text(self) -> str: """Generate the output of the explanation in natural language. Returns: str: return the natural_language_text explanation """ return self.natural_language_text_empty.format( self.num_to_str[self.number_of_features], self.sentences )
[docs] def get_description_text(self) -> str: """WIP Example: This is a SHAP explanation, it creates local and non-contrastive explanations. Returns: str: return the explanation method description """ return self.description_text_empty.format( self.explanation_name, self.explanation_type, self.explanation_style )
[docs] def get_score_text(self) -> str: """Generate the text explaining the prediction score of the sample Returns: str: return the score_text for the sample. """ self.number_of_dataset_features = self.X.shape[1] return self.score_text_empty.format( self.model.__class__.__name__, self.number_of_dataset_features, self.prediction, )
[docs] def get_model_text(self) -> str: """ WIP Generate text the explains the used machine learning model Returns: str: return the description of the machine learning model """ return str(self.model)
[docs] def get_plot_name(self, sample_name: str = None) -> str: """ Get the name of the plot Args: sample_name (str, optional): [description]. Defaults to None. Returns: str: return the name of the plot """ prefix = f"{self.explanation_name}_features_{self.number_of_features}" if sample_name: plot_name = f"{prefix}_sample_{sample_name}.png" else: plot_name = f"{prefix}.png" return plot_name
[docs] def get_sample_name(self, sample_index: int, sample_name: str = None) -> str: """ Determine the name of the sample, if no sample_name provide, use the sample_index Args: sample_index (int): index of the sample sample_name (str, optional): name of the sample. Defaults to None. Returns: str: name of the sample """ if not sample_name: sample_name = str(sample_index) return sample_name
[docs] def save(self, sample_index: int, sample_name: str = None) -> None: """ Save the explanations to a csv file, save the plots Args: sample_index (int): [description] sample_name (str, optional): name of the sample. Defaults to None. """ sample_name = self.get_sample_name(sample_index, sample_name) self.save_csv(sample_name) self.fig.savefig( os.path.join(self.path_plot, self.plot_name), bbox_inches="tight", )
[docs] def save_csv(self, sample: int) -> None: """ Save the explanation to a csv. The columns contain the method_text, the natural_language_text, the name of the plot and the predicted value. The index is the Entry ID. Args: sample (TYPE, optional): DESCRIPTION. Returns: None. """ assert hasattr(self, "plot_name"), "instance lacks plot_name" output = { "score_text": self.score_text, "method_text": self.method_text, "natural_language_text": self.natural_language_text, "plot": self.plot_name, "number_of_features": self.number_of_features, "prediction": self.prediction, } df = pd.DataFrame(output, index=[sample]) for column in ["method_text", "natural_language_text", "score_text"]: df[column] = df[column].astype(str) df[column] = df[column].str.replace("\n", "\\n") # check if the file is already there, if not, create it if not os.path.isfile(os.path.join(self.path_result, self.file_name)): df.to_csv( os.path.join(self.path_result, self.file_name), sep=";", encoding="utf-8-sig", index_label=["Entry ID"], escapechar="\\", quotechar='"', quoting=csv.QUOTE_NONNUMERIC, ) else: # append row to the file df.to_csv( os.path.join(self.path_result, self.file_name), sep=";", encoding="utf-8-sig", index_label=["Entry ID"], mode="a", header=False, escapechar="\\", quotechar='"', quoting=csv.QUOTE_NONNUMERIC, )