Source code for explainy.utils.surrogate_plot

import re
from typing import List

from sklearn import tree

from explainy.utils.typing import ModelType


[docs] class GraphvizNotFoundError(Exception): pass
[docs] class SurrogatePlot: """This class create the graphviz based surrogate plot using the trained sklearn DecisionTree""" def __init__( self, precision: int = 2, impurity: bool = False, rounded: bool = True, class_names: bool = True, ): self.precision = precision self.impurity = impurity self.rounded = rounded self.class_names = class_names
[docs] def get_plot(self, model: ModelType, feature_names: List[str]): """Update the dot file as desired, simplify the text in the boxes""" tree_dot_format = tree.export_graphviz( model, feature_names=feature_names, impurity=self.impurity, rounded=self.rounded, precision=self.precision, class_names=self.class_names, ) return self.one_hot_encoding_text(tree_dot_format)
[docs] @staticmethod def one_hot_encoding_text(tree: str) -> str: """Customize the labels text for one-hot encoded features""" values = re.findall(r'\[label="(.*?)"\]', tree, re.DOTALL) for value in values: if " - " in value: text = value.split("<=")[0].strip() feature_name = text.split(" - ")[0] feature_value = text.split(" - ")[1] node = value.split("\\n") new_text = f"{feature_name} is not {feature_value}" node[0] = new_text tree = tree.replace(value, "\n".join(node)) return tree
def __call__(self, model: ModelType, feature_names: List[str]): return self.get_plot(model, feature_names)