I'm working on decision tree (classifier) on sklearn, and it works well, I can visualize the tree, and predict my class. But I'd like to create one column (in my pandas dataframe) which is the path to get my result in the tree. I mean, I'd like a concatenation of all the rules to get my result like: - White=False,Black=False,Weight=1,price=5. Have you got any idea,please ?
            Asked
            
        
        
            Active
            
        
            Viewed 489 times
        
    0
            
            
        - 
                    Possible duplicate of [How to extract the decision rules from scikit-learn decision-tree?](https://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree) – PV8 Jun 19 '19 at 11:05
- 
                    https://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree – PV8 Jun 19 '19 at 11:05
- 
                    thanks, this link is helpful, but it's not exactly a duplicate because I don't want to get all rules of my decision tree, but i'd like to get the rules for one example. I mean a function which takes an index of my dataframe and returns the different rules of my decision tree to reach one leaf. – Corentin Moreau Jun 19 '19 at 11:25
1 Answers
1
            
            
        Based on the example here you can create your explanation of the applied rules.
- estimator.decision_pathgives you the nodes which are followed to get to the result
- is_leavesis an array which stores for each node if it is a leaf, i.e. terminal, (- True) or a branch/decision (- False)
- You can then iterate over node_indicatorto get nodes which were visited
- For each node you can get the thresholdand the relevantfeature
- Finally - applythe function to your dataframe and you are done.- def get_decision_path(estimator, feature_names, sample, precision=2, is_leaves=None): if is_leaves is None: is_leaves = get_leaves(estimator) feature = estimator.tree_.feature threshold = estimator.tree_.threshold text = [] node_indicator = estimator.decision_path([sample]) node_index = node_indicator.indices[node_indicator.indptr[0]: node_indicator.indptr[1]] for node_id in node_index: if is_leaves[node_id]: break if sample[feature[node_id]] <= threshold[node_id]: threshold_sign = "<=" else: threshold_sign = ">" text.append('{}: {} {} {}'.format(feature_names[feature[node_id]], sample[feature[node_id]], threshold_sign, round(threshold[node_id], precision))) return '; '.join(text) def get_leaves(estimator): n_nodes = estimator.tree_.node_count children_left = estimator.tree_.children_left children_right = estimator.tree_.children_right is_leaves = np.zeros(shape=n_nodes, dtype=bool) stack = [(0, -1)] while len(stack) > 0: node_id, parent_depth = stack.pop() if children_left[node_id] != children_right[node_id]: stack.append((children_left[node_id], parent_depth + 1)) stack.append((children_right[node_id], parent_depth + 1)) else: is_leaves[node_id] = True return is_leaves
Example
print(get_decision_path(estimator, 
                        ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 
                        [6.6, 3.0 , 4.4, 1.4]))
'petal width (cm): 1.4 > 0.8; petal length (cm): 4.4 <= 4.95; petal width (cm): 1.4 <= 1.65'
Full code
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
from sklearn import tree
import pydotplus
from IPython.core.display import HTML, display
def get_decision_path(estimator, feature_names, sample, precision=2, is_leaves=None):
    if is_leaves is None:
        is_leaves = get_leaves(estimator)
    feature = estimator.tree_.feature
    threshold = estimator.tree_.threshold
    text = []
    node_indicator = estimator.decision_path([sample])
    node_index = node_indicator.indices[node_indicator.indptr[0]:
                                        node_indicator.indptr[1]]
    for node_id in node_index:
        if is_leaves[node_id]:
            break
        if sample[feature[node_id]] <= threshold[node_id]:
            threshold_sign = "<="
        else:
            threshold_sign = ">"
        text.append('{}: {} {} {}'.format(feature_names[feature[node_id]],
                                          sample[feature[node_id]],
                                          threshold_sign,
                                          round(threshold[node_id], precision)))
    return '; '.join(text)
def get_leaves(estimator):
    n_nodes = estimator.tree_.node_count
    children_left = estimator.tree_.children_left
    children_right = estimator.tree_.children_right
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, -1)]
    while len(stack) > 0:
        node_id, parent_depth = stack.pop()
        if children_left[node_id] != children_right[node_id]:
            stack.append((children_left[node_id], parent_depth + 1))
            stack.append((children_right[node_id], parent_depth + 1))
        else:
            is_leaves[node_id] = True
    return is_leaves
# prepare data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
X = df.iloc[:, 0:4].to_numpy()
y = df.iloc[:, 4].to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# create decision tree
estimator = DecisionTreeClassifier(max_leaf_nodes=5, random_state=0)
estimator.fit(X_train, y_train)
# visualize decision tree
dot_data = tree.export_graphviz(estimator, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
svg = graph.create_svg()
display(HTML(svg.decode('utf-8')))
# add explanation to data frame
is_leaves = get_leaves(estimator)
df['explanation'] = df.apply(lambda row: get_decision_path(estimator, df.columns[0:4], row[0:4], is_leaves=is_leaves), axis=1)
df.sample(5, axis=0, random_state=42)
 
    
    
        Maximilian Peters
        
- 30,348
- 12
- 86
- 99
 
    
