forked from motefly/DeepGBM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree_model_interpreter.py
111 lines (101 loc) · 3.97 KB
/
tree_model_interpreter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import lightgbm as lgb
import random
import math
def countSplitNodes(tree):
root = tree['tree_structure']
def counter(root):
if 'split_index' not in root:
return 0
return 1 + counter(root['left_child']) + counter(root['right_child'])
ans = counter(root)
return ans
def getItemByTree(tree, item='split_feature'):
root = tree.raw['tree_structure']
split_nodes = tree.split_nodes
res = np.zeros(split_nodes+tree.raw['num_leaves'], dtype=np.int32)
if 'value' in item or 'threshold' in item or 'split_gain' in item:
res = res.astype(np.float64)
def getFeature(root, res):
if 'child' in item:
if 'split_index' in root:
node = root[item]
if 'split_index' in node:
res[root['split_index']] = node['split_index']
else:
res[root['split_index']] = node['leaf_index'] + split_nodes # need to check
else:
res[root['leaf_index'] + split_nodes] = -1
elif 'value' in item:
if 'split_index' in root:
res[root['split_index']] = root['internal_'+item]
else:
res[root['leaf_index'] + split_nodes] = root['leaf_'+item]
else:
if 'split_index' in root:
res[root['split_index']] = root[item]
else:
res[root['leaf_index'] + split_nodes] = -2
if 'left_child' in root:
getFeature(root['left_child'], res)
if 'right_child' in root:
getFeature(root['right_child'], res)
getFeature(root, res)
return res
def getTreeSplits(model):
featurelist = []
threhlist = []
trees = []
for idx, tree in enumerate(model['tree_info']):
trees.append(TreeInterpreter(tree))
featurelist.append(trees[-1].feature)
threhlist.append(getItemByTree(trees[-1], 'threshold'))
return (trees, featurelist, threhlist)
def getChildren(trees):
listcl = []
listcr = []
for idx, tree in enumerate(trees):
listcl.append(getItemByTree(tree, 'left_child'))
listcr.append(getItemByTree(tree, 'right_child'))
return(listcl, listcr)
class TreeInterpreter(object):
def __init__(self, tree):
self.raw = tree
self.split_nodes = countSplitNodes(tree)
self.node_count = self.split_nodes# + tree['num_leaves']
self.value = getItemByTree(self, item='value')
self.feature = getItemByTree(self)
self.gain = getItemByTree(self, 'split_gain')
# self.leaf_value = getLeafValue(tree)
class ModelInterpreter(object):
def __init__(self, model, tree_model='lightgbm'):
print("Model Interpreting...")
self.tree_model = tree_model
model = model.dump_model()
self.n_features_ = model['max_feature_idx'] + 1
self.trees, self.featurelist, self.threshlist = getTreeSplits(model)
self.listcl, self.listcr = getChildren(self.trees)
def GetTreeSplits(self):
return (self.trees, self.featurelist, self.threshlist)
def GetChildren(self):
return (self.listcl, self.listcr)
def EqualGroup(self, n_clusters, args):
vectors = {}
# n_feature = 256
for idx,features in enumerate(self.featurelist):
vectors[idx] = set(features[np.where(features>0)])
keys = random.sample(vectors.keys(), len(vectors))
clusterIdx = np.zeros(len(vectors))
groups = [[] for i in range(n_clusters)]
trees_per_cluster = len(vectors)//n_clusters
mod_per_cluster = len(vectors) % n_clusters
begin = 0
for idx in range(n_clusters):
for jdx in range(trees_per_cluster):
clusterIdx[keys[begin]] = idx
begin += 1
if idx < mod_per_cluster:
clusterIdx[keys[begin]] = idx
begin += 1
print([np.where(clusterIdx==i)[0].shape for i in range(n_clusters)])
return clusterIdx