1
2
3
4 """ Contains functionality for doing tree pruning
5
6 """
7 from __future__ import print_function
8
9 import copy
10
11 import numpy
12
13 from rdkit.ML.DecTree import CrossValidate, DecTree
14 from rdkit.six.moves import range
15
16 _verbose = 0
17
18
20 """ given a set of examples, returns the most common result code
21
22 **Arguments**
23
24 examples: a list of examples to be counted
25
26 **Returns**
27
28 the most common result code
29
30 """
31 resList = [x[-1] for x in examples]
32 maxVal = max(resList)
33 counts = [None] * (maxVal + 1)
34 for i in range(maxVal + 1):
35 counts[i] = sum([x == i for x in resList])
36
37 return numpy.argmax(counts)
38
39
41 nWrong = 0
42 for example in node.GetExamples():
43 pred = node.ClassifyExample(example, appendExamples=0)
44 if pred != example[-1]:
45 nWrong += 1
46
47 return nWrong
48
49
51 """Recursively finds and removes the nodes whose removals improve classification
52
53 **Arguments**
54
55 - node: the tree to be pruned. The pruning data should already be contained
56 within node (i.e. node.GetExamples() should return the pruning data)
57
58 - level: (optional) the level of recursion, used only in _verbose printing
59
60
61 **Returns**
62
63 the pruned version of node
64
65
66 **Notes**
67
68 - This uses a greedy algorithm which basically does a DFS traversal of the tree,
69 removing nodes whenever possible.
70
71 - If removing a node does not affect the accuracy, it *will be* removed. We
72 favor smaller trees.
73
74 """
75 if _verbose:
76 print(' ' * level, '<%d> ' % level, '>>> Pruner')
77 children = node.GetChildren()[:]
78
79 bestTree = copy.deepcopy(node)
80 bestErr = 1e6
81
82
83
84
85
86 for i in range(len(children)):
87 child = children[i]
88 examples = child.GetExamples()
89 if _verbose:
90 print(' ' * level, '<%d> ' % level, ' Child:', i, child.GetLabel())
91 bestTree.Print()
92 print()
93 if len(examples):
94 if _verbose:
95 print(' ' * level, '<%d> ' % level, ' Examples', len(examples))
96 if child.GetTerminal():
97 if _verbose:
98 print(' ' * level, '<%d> ' % level, ' Terminal')
99 continue
100
101 if _verbose:
102 print(' ' * level, '<%d> ' % level, ' Nonterminal')
103
104 workTree = copy.deepcopy(bestTree)
105
106
107
108 newNode = _Pruner(child, level=level + 1)
109 workTree.ReplaceChildIndex(i, newNode)
110 tempErr = _GetLocalError(workTree)
111 if tempErr <= bestErr:
112 bestErr = tempErr
113 bestTree = copy.deepcopy(workTree)
114 if _verbose:
115 print(' ' * level, '<%d> ' % level, '>->->->->->')
116 print(' ' * level, '<%d> ' % level, 'replacing:', i, child.GetLabel())
117 child.Print()
118 print(' ' * level, '<%d> ' % level, 'with:')
119 newNode.Print()
120 print(' ' * level, '<%d> ' % level, '<-<-<-<-<-<')
121 else:
122 workTree.ReplaceChildIndex(i, child)
123
124
125
126 bestGuess = MaxCount(child.GetExamples())
127 newNode = DecTree.DecTreeNode(workTree, 'L:%d' % (bestGuess), label=bestGuess, isTerminal=1)
128 newNode.SetExamples(child.GetExamples())
129 workTree.ReplaceChildIndex(i, newNode)
130 if _verbose:
131 print(' ' * level, '<%d> ' % level, 'ATTEMPT:')
132 workTree.Print()
133 newErr = _GetLocalError(workTree)
134 if _verbose:
135 print(' ' * level, '<%d> ' % level, '---> ', newErr, bestErr)
136 if newErr <= bestErr:
137 bestErr = newErr
138 bestTree = copy.deepcopy(workTree)
139 if _verbose:
140 print(' ' * level, '<%d> ' % level, 'PRUNING:')
141 workTree.Print()
142 else:
143 if _verbose:
144 print(' ' * level, '<%d> ' % level, 'FAIL')
145
146 workTree.ReplaceChildIndex(i, child)
147 else:
148 if _verbose:
149 print(' ' * level, '<%d> ' % level, ' No Examples', len(examples))
150
151
152
153
154
155
156 pass
157
158 if _verbose:
159 print(' ' * level, '<%d> ' % level, '<<< out')
160 return bestTree
161
162
163 -def PruneTree(tree, trainExamples, testExamples, minimizeTestErrorOnly=1):
164 """ implements a reduced-error pruning of decision trees
165
166 This algorithm is described on page 69 of Mitchell's book.
167
168 Pruning can be done using just the set of testExamples (the validation set)
169 or both the testExamples and the trainExamples by setting minimizeTestErrorOnly
170 to 0.
171
172 **Arguments**
173
174 - tree: the initial tree to be pruned
175
176 - trainExamples: the examples used to train the tree
177
178 - testExamples: the examples held out for testing the tree
179
180 - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e.
181 _trainExamples_ + _testExamples_ will be used to evaluate the error.
182
183 **Returns**
184
185 a 2-tuple containing:
186
187 1) the best tree
188
189 2) the best error (the one which corresponds to that tree)
190
191 """
192 if minimizeTestErrorOnly:
193 testSet = testExamples
194 else:
195 testSet = trainExamples + testExamples
196
197
198 tree.ClearExamples()
199
200
201
202
203
204 totErr, badEx = CrossValidate.CrossValidate(tree, testSet, appendExamples=1)
205
206
207
208
209 newTree = _Pruner(tree)
210
211
212
213
214 totErr, badEx = CrossValidate.CrossValidate(newTree, testSet)
215 newTree.SetBadExamples(badEx)
216
217 return newTree, totErr
218
219
220
221
222
224 from rdkit.ML.DecTree import randomtest
225
226
227 examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=10, randScale=0.5,
228 nExamples=200)
229 tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals)
230 tree.Print()
231 tree.Pickle('orig.pkl')
232 print('original error is:', frac)
233
234 print('----Pruning')
235 newTree, frac2 = PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples())
236 newTree.Print()
237 print('pruned error is:', frac2)
238 newTree.Pickle('prune.pkl')
239
240
242 from rdkit.ML.DecTree import ID3
243 oPts = [
244 [0, 0, 1, 0],
245 [0, 1, 1, 1],
246 [1, 0, 1, 1],
247 [1, 1, 0, 0],
248 [1, 1, 1, 1],
249 ]
250 tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]
251
252 tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
253 tree.Print()
254 err, _ = CrossValidate.CrossValidate(tree, oPts)
255 print('original error:', err)
256
257 err, _ = CrossValidate.CrossValidate(tree, tPts)
258 print('original holdout error:', err)
259 newTree, frac2 = PruneTree(tree, oPts, tPts)
260 newTree.Print()
261 print('best error of pruned tree:', frac2)
262 err, badEx = CrossValidate.CrossValidate(newTree, tPts)
263 print('pruned holdout error is:', err)
264 print(badEx)
265
266
267
268
270 from rdkit.ML.DecTree import ID3
271 oPts = [
272 [1, 0, 0, 0, 1],
273 [1, 0, 0, 0, 1],
274 [1, 0, 0, 0, 1],
275 [1, 0, 0, 0, 1],
276 [1, 0, 0, 0, 1],
277 [1, 0, 0, 0, 1],
278 [1, 0, 0, 0, 1],
279 [0, 0, 1, 1, 0],
280 [0, 0, 1, 1, 0],
281 [0, 0, 1, 1, 1],
282 [0, 1, 0, 1, 0],
283 [0, 1, 0, 1, 0],
284 [0, 1, 0, 0, 1],
285 ]
286 tPts = oPts
287
288 tree = ID3.ID3Boot(oPts, attrs=range(len(oPts[0]) - 1), nPossibleVals=[2] * len(oPts[0]))
289 tree.Print()
290 err, _ = CrossValidate.CrossValidate(tree, oPts)
291 print('original error:', err)
292
293 err, _ = CrossValidate.CrossValidate(tree, tPts)
294 print('original holdout error:', err)
295 newTree, frac2 = PruneTree(tree, oPts, tPts)
296 newTree.Print()
297 print('best error of pruned tree:', frac2)
298 err, badEx = CrossValidate.CrossValidate(newTree, tPts)
299 print('pruned holdout error is:', err)
300 print(badEx)
301
302
303 if __name__ == '__main__':
304 _verbose = 1
305
306 _testChain()
307