Package rdkit :: Package ML :: Package DecTree :: Module PruneTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.PruneTree

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ Contains functionality for doing tree pruning 
  5   
  6  """ 
  7  from __future__ import print_function 
  8   
  9  import copy 
 10   
 11  import numpy 
 12   
 13  from rdkit.ML.DecTree import CrossValidate, DecTree 
 14  from rdkit.six.moves import range 
 15   
 16  _verbose = 0 
 17   
 18   
19 -def MaxCount(examples):
20 """ given a set of examples, returns the most common result code 21 22 **Arguments** 23 24 examples: a list of examples to be counted 25 26 **Returns** 27 28 the most common result code 29 30 """ 31 resList = [x[-1] for x in examples] 32 maxVal = max(resList) 33 counts = [None] * (maxVal + 1) 34 for i in range(maxVal + 1): 35 counts[i] = sum([x == i for x in resList]) 36 37 return numpy.argmax(counts)
38 39
40 -def _GetLocalError(node):
41 nWrong = 0 42 for example in node.GetExamples(): 43 pred = node.ClassifyExample(example, appendExamples=0) 44 if pred != example[-1]: 45 nWrong += 1 46 # if _verbose: print('------------------>MISS:',example,pred) 47 return nWrong
48 49
50 -def _Pruner(node, level=0):
51 """Recursively finds and removes the nodes whose removals improve classification 52 53 **Arguments** 54 55 - node: the tree to be pruned. The pruning data should already be contained 56 within node (i.e. node.GetExamples() should return the pruning data) 57 58 - level: (optional) the level of recursion, used only in _verbose printing 59 60 61 **Returns** 62 63 the pruned version of node 64 65 66 **Notes** 67 68 - This uses a greedy algorithm which basically does a DFS traversal of the tree, 69 removing nodes whenever possible. 70 71 - If removing a node does not affect the accuracy, it *will be* removed. We 72 favor smaller trees. 73 74 """ 75 if _verbose: 76 print(' ' * level, '<%d> ' % level, '>>> Pruner') 77 children = node.GetChildren()[:] 78 79 bestTree = copy.deepcopy(node) 80 bestErr = 1e6 81 # 82 # Loop over the children of this node, removing them when doing so 83 # either improves the local error or leaves it unchanged (we're 84 # introducing a bias for simpler trees). 85 # 86 for i in range(len(children)): 87 child = children[i] 88 examples = child.GetExamples() 89 if _verbose: 90 print(' ' * level, '<%d> ' % level, ' Child:', i, child.GetLabel()) 91 bestTree.Print() 92 print() 93 if len(examples): 94 if _verbose: 95 print(' ' * level, '<%d> ' % level, ' Examples', len(examples)) 96 if child.GetTerminal(): 97 if _verbose: 98 print(' ' * level, '<%d> ' % level, ' Terminal') 99 continue 100 101 if _verbose: 102 print(' ' * level, '<%d> ' % level, ' Nonterminal') 103 104 workTree = copy.deepcopy(bestTree) 105 # 106 # First recurse on the child (try removing things below it) 107 # 108 newNode = _Pruner(child, level=level + 1) 109 workTree.ReplaceChildIndex(i, newNode) 110 tempErr = _GetLocalError(workTree) 111 if tempErr <= bestErr: 112 bestErr = tempErr 113 bestTree = copy.deepcopy(workTree) 114 if _verbose: 115 print(' ' * level, '<%d> ' % level, '>->->->->->') 116 print(' ' * level, '<%d> ' % level, 'replacing:', i, child.GetLabel()) 117 child.Print() 118 print(' ' * level, '<%d> ' % level, 'with:') 119 newNode.Print() 120 print(' ' * level, '<%d> ' % level, '<-<-<-<-<-<') 121 else: 122 workTree.ReplaceChildIndex(i, child) 123 # 124 # Now try replacing the child entirely 125 # 126 bestGuess = MaxCount(child.GetExamples()) 127 newNode = DecTree.DecTreeNode(workTree, 'L:%d' % (bestGuess), label=bestGuess, isTerminal=1) 128 newNode.SetExamples(child.GetExamples()) 129 workTree.ReplaceChildIndex(i, newNode) 130 if _verbose: 131 print(' ' * level, '<%d> ' % level, 'ATTEMPT:') 132 workTree.Print() 133 newErr = _GetLocalError(workTree) 134 if _verbose: 135 print(' ' * level, '<%d> ' % level, '---> ', newErr, bestErr) 136 if newErr <= bestErr: 137 bestErr = newErr 138 bestTree = copy.deepcopy(workTree) 139 if _verbose: 140 print(' ' * level, '<%d> ' % level, 'PRUNING:') 141 workTree.Print() 142 else: 143 if _verbose: 144 print(' ' * level, '<%d> ' % level, 'FAIL') 145 # whoops... put the child back in: 146 workTree.ReplaceChildIndex(i, child) 147 else: 148 if _verbose: 149 print(' ' * level, '<%d> ' % level, ' No Examples', len(examples)) 150 # 151 # FIX: we need to figure out what to do here (nodes that contain 152 # no examples in the testing set). I can concoct arguments for 153 # leaving them in and for removing them. At the moment they are 154 # left intact. 155 # 156 pass 157 158 if _verbose: 159 print(' ' * level, '<%d> ' % level, '<<< out') 160 return bestTree
161 162
163 -def PruneTree(tree, trainExamples, testExamples, minimizeTestErrorOnly=1):
164 """ implements a reduced-error pruning of decision trees 165 166 This algorithm is described on page 69 of Mitchell's book. 167 168 Pruning can be done using just the set of testExamples (the validation set) 169 or both the testExamples and the trainExamples by setting minimizeTestErrorOnly 170 to 0. 171 172 **Arguments** 173 174 - tree: the initial tree to be pruned 175 176 - trainExamples: the examples used to train the tree 177 178 - testExamples: the examples held out for testing the tree 179 180 - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e. 181 _trainExamples_ + _testExamples_ will be used to evaluate the error. 182 183 **Returns** 184 185 a 2-tuple containing: 186 187 1) the best tree 188 189 2) the best error (the one which corresponds to that tree) 190 191 """ 192 if minimizeTestErrorOnly: 193 testSet = testExamples 194 else: 195 testSet = trainExamples + testExamples 196 197 # remove any stored examples the tree may have 198 tree.ClearExamples() 199 200 # 201 # screen the test data through the tree so that we end up with the 202 # appropriate points stored at each node of the tree. Results are ignored 203 # 204 totErr, badEx = CrossValidate.CrossValidate(tree, testSet, appendExamples=1) 205 206 # 207 # Prune 208 # 209 newTree = _Pruner(tree) 210 211 # 212 # And recalculate the errors 213 # 214 totErr, badEx = CrossValidate.CrossValidate(newTree, testSet) 215 newTree.SetBadExamples(badEx) 216 217 return newTree, totErr
218 219 220 # ------- 221 # testing code 222 # -------
223 -def _testRandom():
224 from rdkit.ML.DecTree import randomtest 225 # examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=20, randScale=0.25, 226 # nExamples=200) 227 examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=10, randScale=0.5, 228 nExamples=200) 229 tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals) 230 tree.Print() 231 tree.Pickle('orig.pkl') 232 print('original error is:', frac) 233 234 print('----Pruning') 235 newTree, frac2 = PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples()) 236 newTree.Print() 237 print('pruned error is:', frac2) 238 newTree.Pickle('prune.pkl')
239 240
241 -def _testSpecific():
242 from rdkit.ML.DecTree import ID3 243 oPts = [ 244 [0, 0, 1, 0], 245 [0, 1, 1, 1], 246 [1, 0, 1, 1], 247 [1, 1, 0, 0], 248 [1, 1, 1, 1], 249 ] 250 tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]] 251 252 tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4) 253 tree.Print() 254 err, _ = CrossValidate.CrossValidate(tree, oPts) 255 print('original error:', err) 256 257 err, _ = CrossValidate.CrossValidate(tree, tPts) 258 print('original holdout error:', err) 259 newTree, frac2 = PruneTree(tree, oPts, tPts) 260 newTree.Print() 261 print('best error of pruned tree:', frac2) 262 err, badEx = CrossValidate.CrossValidate(newTree, tPts) 263 print('pruned holdout error is:', err) 264 print(badEx)
265 266 # print(len(tree), len(newTree)) 267 268
269 -def _testChain():
270 from rdkit.ML.DecTree import ID3 271 oPts = [ 272 [1, 0, 0, 0, 1], 273 [1, 0, 0, 0, 1], 274 [1, 0, 0, 0, 1], 275 [1, 0, 0, 0, 1], 276 [1, 0, 0, 0, 1], 277 [1, 0, 0, 0, 1], 278 [1, 0, 0, 0, 1], 279 [0, 0, 1, 1, 0], 280 [0, 0, 1, 1, 0], 281 [0, 0, 1, 1, 1], 282 [0, 1, 0, 1, 0], 283 [0, 1, 0, 1, 0], 284 [0, 1, 0, 0, 1], 285 ] 286 tPts = oPts 287 288 tree = ID3.ID3Boot(oPts, attrs=range(len(oPts[0]) - 1), nPossibleVals=[2] * len(oPts[0])) 289 tree.Print() 290 err, _ = CrossValidate.CrossValidate(tree, oPts) 291 print('original error:', err) 292 293 err, _ = CrossValidate.CrossValidate(tree, tPts) 294 print('original holdout error:', err) 295 newTree, frac2 = PruneTree(tree, oPts, tPts) 296 newTree.Print() 297 print('best error of pruned tree:', frac2) 298 err, badEx = CrossValidate.CrossValidate(newTree, tPts) 299 print('pruned holdout error is:', err) 300 print(badEx)
301 302 303 if __name__ == '__main__': # pragma: nocover 304 _verbose = 1 305 # _testRandom() 306 _testChain() 307