Package rdkit :: Package ML :: Package DecTree :: Module Forest
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.Forest

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum 
  3  # 
  4  """ code for dealing with forests (collections) of decision trees 
  5   
  6  **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running. 
  7   
  8  """ 
  9  from __future__ import print_function 
 10   
 11  import numpy 
 12   
 13  from rdkit.ML.DecTree import CrossValidate, PruneTree 
 14  from rdkit.six.moves import cPickle 
 15   
 16   
17 -class Forest(object):
18 """a forest of unique decision trees. 19 20 adding an existing tree just results in its count field being incremented 21 and the errors being averaged. 22 23 typical usage: 24 25 1) grow the forest with AddTree until happy with it 26 27 2) call AverageErrors to calculate the average error values 28 29 3) call SortTrees to put things in order by either error or count 30 31 """ 32
33 - def MakeHistogram(self):
34 """ creates a histogram of error/count pairs 35 36 """ 37 nExamples = len(self.treeList) 38 histo = [] 39 i = 1 40 lastErr = self.errList[0] 41 countHere = self.countList[0] 42 eps = 0.001 43 while i < nExamples: 44 if self.errList[i] - lastErr > eps: 45 histo.append((lastErr, countHere)) 46 lastErr = self.errList[i] 47 countHere = self.countList[i] 48 else: 49 countHere = countHere + self.countList[i] 50 i = i + 1 51 52 return histo
53
54 - def CollectVotes(self, example):
55 """ collects votes across every member of the forest for the given example 56 57 **Returns** 58 59 a list of the results 60 61 """ 62 nTrees = len(self.treeList) 63 votes = [0] * nTrees 64 for i in range(nTrees): 65 votes[i] = self.treeList[i].ClassifyExample(example) 66 return votes
67
68 - def ClassifyExample(self, example):
69 """ classifies the given example using the entire forest 70 71 **returns** a result and a measure of confidence in it. 72 73 **FIX:** statistics sucks... I'm not seeing an obvious way to get 74 the confidence intervals. For that matter, I'm not seeing 75 an unobvious way. 76 77 For now, this is just treated as a voting problem with the confidence 78 measure being the percent of trees which voted for the winning result. 79 """ 80 self.treeVotes = self.CollectVotes(example) 81 votes = [0] * len(self._nPossible) 82 for i in range(len(self.treeList)): 83 res = self.treeVotes[i] 84 votes[res] = votes[res] + self.countList[i] 85 86 totVotes = sum(votes) 87 res = numpy.argmax(votes) 88 # print 'v:',res,votes,totVotes 89 return res, float(votes[res]) / float(totVotes)
90
91 - def GetVoteDetails(self):
92 """ Returns the details of the last vote the forest conducted 93 94 this will be an empty list if no voting has yet been done 95 96 """ 97 return self.treeVotes
98
99 - def Grow(self, examples, attrs, nPossibleVals, nTries=10, pruneIt=0, lessGreedy=0):
100 """ Grows the forest by adding trees 101 102 **Arguments** 103 104 - examples: the examples to be used for training 105 106 - attrs: a list of the attributes to be used in training 107 108 - nPossibleVals: a list with the number of possible values each variable 109 (as well as the result) can take on 110 111 - nTries: the number of new trees to add 112 113 - pruneIt: a toggle for whether or not the tree should be pruned 114 115 - lessGreedy: toggles the use of a less greedy construction algorithm where 116 each possible tree root is used. The best tree from each step is actually 117 added to the forest. 118 119 """ 120 self._nPossible = nPossibleVals 121 for i in range(nTries): 122 tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, silent=1, 123 calcTotalError=1, lessGreedy=lessGreedy) 124 if pruneIt: 125 tree, frac2 = PruneTree.PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples(), 126 minimizeTestErrorOnly=0) 127 print('prune: ', frac, frac2) 128 frac = frac2 129 self.AddTree(tree, frac) 130 if i % (nTries / 10) == 0: 131 print('Cycle: % 4d' % (i))
132
133 - def Pickle(self, fileName='foo.pkl'):
134 """ Writes this forest off to a file so that it can be easily loaded later 135 136 **Arguments** 137 138 fileName is the name of the file to be written 139 140 """ 141 pFile = open(fileName, 'wb+') 142 cPickle.dump(self, pFile, 1) 143 pFile.close()
144
145 - def AddTree(self, tree, error):
146 """ Adds a tree to the forest 147 148 If an identical tree is already present, its count is incremented 149 150 **Arguments** 151 152 - tree: the new tree 153 154 - error: its error value 155 156 **NOTE:** the errList is run as an accumulator, 157 you probably want to call AverageErrors after finishing the forest 158 159 """ 160 if tree in self.treeList: 161 idx = self.treeList.index(tree) 162 self.errList[idx] = self.errList[idx] + error 163 self.countList[idx] = self.countList[idx] + 1 164 else: 165 self.treeList.append(tree) 166 self.errList.append(error) 167 self.countList.append(1)
168
169 - def AverageErrors(self):
170 """ convert summed error to average error 171 172 This does the conversion in place 173 """ 174 self.errList = [x / y for x, y in zip(self.errList, self.countList)]
175
176 - def SortTrees(self, sortOnError=1):
177 """ sorts the list of trees 178 179 **Arguments** 180 181 sortOnError: toggles sorting on the trees' errors rather than their counts 182 183 """ 184 if sortOnError: 185 order = numpy.argsort(self.errList) 186 else: 187 order = numpy.argsort(self.countList) 188 189 # these elaborate contortions are required because, at the time this 190 # code was written, Numeric arrays didn't unpickle so well... 191 self.treeList = [self.treeList[x] for x in order] 192 self.countList = [self.countList[x] for x in order] 193 self.errList = [self.errList[x] for x in order]
194
195 - def GetTree(self, i):
196 return self.treeList[i]
197
198 - def SetTree(self, i, val):
199 self.treeList[i] = val
200
201 - def GetCount(self, i):
202 return self.countList[i]
203
204 - def SetCount(self, i, val):
205 self.countList[i] = val
206
207 - def GetError(self, i):
208 return self.errList[i]
209
210 - def SetError(self, i, val):
211 self.errList[i] = val
212
213 - def GetDataTuple(self, i):
214 """ returns all relevant data about a particular tree in the forest 215 216 **Arguments** 217 218 i: an integer indicating which tree should be returned 219 220 **Returns** 221 222 a 3-tuple consisting of: 223 224 1) the tree 225 226 2) its count 227 228 3) its error 229 """ 230 return (self.treeList[i], self.countList[i], self.errList[i])
231
232 - def SetDataTuple(self, i, tup):
233 """ sets all relevant data for a particular tree in the forest 234 235 **Arguments** 236 237 - i: an integer indicating which tree should be returned 238 239 - tup: a 3-tuple consisting of: 240 241 1) the tree 242 243 2) its count 244 245 3) its error 246 """ 247 self.treeList[i], self.countList[i], self.errList[i] = tup
248
249 - def GetAllData(self):
250 """ Returns everything we know 251 252 **Returns** 253 254 a 3-tuple consisting of: 255 256 1) our list of trees 257 258 2) our list of tree counts 259 260 3) our list of tree errors 261 262 """ 263 return (self.treeList, self.countList, self.errList)
264
265 - def __len__(self):
266 """ allows len(forest) to work 267 268 """ 269 return len(self.treeList)
270
271 - def __getitem__(self, which):
272 """ allows forest[i] to work. return the data tuple 273 274 """ 275 return self.GetDataTuple(which)
276
277 - def __str__(self):
278 """ allows the forest to show itself as a string 279 280 """ 281 outStr = 'Forest\n' 282 for i in range(len(self.treeList)): 283 outStr = (outStr + ' Tree % 4d: % 5d occurances %%% 5.2f average error\n' % 284 (i, self.countList[i], 100. * self.errList[i])) 285 return outStr
286
287 - def __init__(self):
288 self.treeList = [] 289 self.errList = [] 290 self.countList = [] 291 self.treeVotes = []
292 293 294 if __name__ == '__main__': 295 from rdkit.ML.DecTree import DecTree 296 f = Forest() 297 n = DecTree.DecTreeNode(None, 'foo') 298 f.AddTree(n, 0.5) 299 f.AddTree(n, 0.5) 300 f.AverageErrors() 301 f.SortTrees() 302 print(f) 303