Source code for explainy.utils.surrogate_text
import numpy as np
from sklearn.base import is_classifier
from explainy.utils.utils import join_text_with_comma_and_and
[docs]
class SurrogateText:
""""""
def __init__(self, text: str, model: object, X: np.array, feature_names: list):
"""Class to generate text explanation from Decision Trees
Args:
text (TYPE): DESCRIPTION.
model (TYPE): DESCRIPTION.
X (TYPE): DESCRIPTION.
feature_names (TYPE): DESCRIPTION.
Returns:
None.
"""
self.text = text
self.model = model
self.X = X
self.feature_names = feature_names
self.children_left = self.model.tree_.children_left
self.children_right = self.model.tree_.children_right
self.feature = self.model.tree_.feature
self.threshold = self.model.tree_.threshold
if is_classifier(self.model):
self.values = np.argmax(self.model.tree_.value, axis=2).reshape(
self.model.tree_.value.shape[0], 1
)
else:
self.values = self.model.tree_.value.reshape(
self.model.tree_.value.shape[0], 1
)
[docs]
def get_text(self):
"""Returns:
TYPE: DESCRIPTION.
"""
paths = self.get_paths()
texts = []
for key in paths:
sentences = self.get_rule(paths[key])
sentences = join_text_with_comma_and_and(sentences)
score = self.values[key][0]
texts.append(self.text.format(score, sentences))
return " ".join([text + "." for text in texts])
[docs]
def get_paths(self) -> dict:
leave_id = self.model.apply(self.X)
paths = {}
for leaf in np.unique(leave_id):
path_leaf = []
self.find_path(0, path_leaf, leaf)
paths[leaf] = np.unique(np.sort(path_leaf))
return paths
[docs]
def find_path(self, node_numb, path, x):
"""Args:
node_numb (TYPE): DESCRIPTION.
path (TYPE): DESCRIPTION.
x (TYPE): DESCRIPTION.
Returns:
bool: DESCRIPTION.
"""
path.append(node_numb)
if node_numb == x:
return True
left = False
right = False
if self.children_left[node_numb] != -1:
left = self.find_path(self.children_left[node_numb], path, x)
if self.children_right[node_numb] != -1:
right = self.find_path(self.children_right[node_numb], path, x)
if left or right:
return True
path.remove(node_numb)
return False
[docs]
def get_rule(self, path):
"""Args:
path (TYPE): DESCRIPTION.
Returns:
TYPE: DESCRIPTION.
"""
mask = []
for index, node in enumerate(path):
# check if we are not in the leaf
if index != len(path) - 1:
feature_name_per_node = self.feature_names[self.feature[node]]
one_hot_feature_bool = " - " in feature_name_per_node
if one_hot_feature_bool:
feature_name, feature_value = feature_name_per_node.split(" - ")
# if under the threshold
if self.children_left[node] == path[index + 1]:
if one_hot_feature_bool:
text = f"{feature_name}' was not '{feature_value}'"
else:
text = (
f"'{feature_name_per_node}' was less or equal than"
f" {self.threshold[node]:.2f}"
)
else:
if one_hot_feature_bool:
text = f"'{feature_name}' was '{feature_value}'"
else:
text = (
f"'{feature_name_per_node}' was greater than"
f" {self.threshold[node]:.2f}"
)
mask.append(text)
sentences = [text for text in mask if text]
return sentences