1
2
3
4 """ code for dealing with forests (collections) of decision trees
5
6 **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running.
7
8 """
9 from __future__ import print_function
10
11 import numpy
12
13 from rdkit.ML.DecTree import CrossValidate, PruneTree
14 from rdkit.six.moves import cPickle
15
16
18 """a forest of unique decision trees.
19
20 adding an existing tree just results in its count field being incremented
21 and the errors being averaged.
22
23 typical usage:
24
25 1) grow the forest with AddTree until happy with it
26
27 2) call AverageErrors to calculate the average error values
28
29 3) call SortTrees to put things in order by either error or count
30
31 """
32
34 """ creates a histogram of error/count pairs
35
36 """
37 nExamples = len(self.treeList)
38 histo = []
39 i = 1
40 lastErr = self.errList[0]
41 countHere = self.countList[0]
42 eps = 0.001
43 while i < nExamples:
44 if self.errList[i] - lastErr > eps:
45 histo.append((lastErr, countHere))
46 lastErr = self.errList[i]
47 countHere = self.countList[i]
48 else:
49 countHere = countHere + self.countList[i]
50 i = i + 1
51
52 return histo
53
55 """ collects votes across every member of the forest for the given example
56
57 **Returns**
58
59 a list of the results
60
61 """
62 nTrees = len(self.treeList)
63 votes = [0] * nTrees
64 for i in range(nTrees):
65 votes[i] = self.treeList[i].ClassifyExample(example)
66 return votes
67
69 """ classifies the given example using the entire forest
70
71 **returns** a result and a measure of confidence in it.
72
73 **FIX:** statistics sucks... I'm not seeing an obvious way to get
74 the confidence intervals. For that matter, I'm not seeing
75 an unobvious way.
76
77 For now, this is just treated as a voting problem with the confidence
78 measure being the percent of trees which voted for the winning result.
79 """
80 self.treeVotes = self.CollectVotes(example)
81 votes = [0] * len(self._nPossible)
82 for i in range(len(self.treeList)):
83 res = self.treeVotes[i]
84 votes[res] = votes[res] + self.countList[i]
85
86 totVotes = sum(votes)
87 res = numpy.argmax(votes)
88
89 return res, float(votes[res]) / float(totVotes)
90
92 """ Returns the details of the last vote the forest conducted
93
94 this will be an empty list if no voting has yet been done
95
96 """
97 return self.treeVotes
98
99 - def Grow(self, examples, attrs, nPossibleVals, nTries=10, pruneIt=0, lessGreedy=0):
100 """ Grows the forest by adding trees
101
102 **Arguments**
103
104 - examples: the examples to be used for training
105
106 - attrs: a list of the attributes to be used in training
107
108 - nPossibleVals: a list with the number of possible values each variable
109 (as well as the result) can take on
110
111 - nTries: the number of new trees to add
112
113 - pruneIt: a toggle for whether or not the tree should be pruned
114
115 - lessGreedy: toggles the use of a less greedy construction algorithm where
116 each possible tree root is used. The best tree from each step is actually
117 added to the forest.
118
119 """
120 self._nPossible = nPossibleVals
121 for i in range(nTries):
122 tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, silent=1,
123 calcTotalError=1, lessGreedy=lessGreedy)
124 if pruneIt:
125 tree, frac2 = PruneTree.PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples(),
126 minimizeTestErrorOnly=0)
127 print('prune: ', frac, frac2)
128 frac = frac2
129 self.AddTree(tree, frac)
130 if i % (nTries / 10) == 0:
131 print('Cycle: % 4d' % (i))
132
133 - def Pickle(self, fileName='foo.pkl'):
134 """ Writes this forest off to a file so that it can be easily loaded later
135
136 **Arguments**
137
138 fileName is the name of the file to be written
139
140 """
141 pFile = open(fileName, 'wb+')
142 cPickle.dump(self, pFile, 1)
143 pFile.close()
144
146 """ Adds a tree to the forest
147
148 If an identical tree is already present, its count is incremented
149
150 **Arguments**
151
152 - tree: the new tree
153
154 - error: its error value
155
156 **NOTE:** the errList is run as an accumulator,
157 you probably want to call AverageErrors after finishing the forest
158
159 """
160 if tree in self.treeList:
161 idx = self.treeList.index(tree)
162 self.errList[idx] = self.errList[idx] + error
163 self.countList[idx] = self.countList[idx] + 1
164 else:
165 self.treeList.append(tree)
166 self.errList.append(error)
167 self.countList.append(1)
168
170 """ convert summed error to average error
171
172 This does the conversion in place
173 """
174 self.errList = [x / y for x, y in zip(self.errList, self.countList)]
175
177 """ sorts the list of trees
178
179 **Arguments**
180
181 sortOnError: toggles sorting on the trees' errors rather than their counts
182
183 """
184 if sortOnError:
185 order = numpy.argsort(self.errList)
186 else:
187 order = numpy.argsort(self.countList)
188
189
190
191 self.treeList = [self.treeList[x] for x in order]
192 self.countList = [self.countList[x] for x in order]
193 self.errList = [self.errList[x] for x in order]
194
196 return self.treeList[i]
197
199 self.treeList[i] = val
200
202 return self.countList[i]
203
205 self.countList[i] = val
206
208 return self.errList[i]
209
211 self.errList[i] = val
212
214 """ returns all relevant data about a particular tree in the forest
215
216 **Arguments**
217
218 i: an integer indicating which tree should be returned
219
220 **Returns**
221
222 a 3-tuple consisting of:
223
224 1) the tree
225
226 2) its count
227
228 3) its error
229 """
230 return (self.treeList[i], self.countList[i], self.errList[i])
231
233 """ sets all relevant data for a particular tree in the forest
234
235 **Arguments**
236
237 - i: an integer indicating which tree should be returned
238
239 - tup: a 3-tuple consisting of:
240
241 1) the tree
242
243 2) its count
244
245 3) its error
246 """
247 self.treeList[i], self.countList[i], self.errList[i] = tup
248
250 """ Returns everything we know
251
252 **Returns**
253
254 a 3-tuple consisting of:
255
256 1) our list of trees
257
258 2) our list of tree counts
259
260 3) our list of tree errors
261
262 """
263 return (self.treeList, self.countList, self.errList)
264
266 """ allows len(forest) to work
267
268 """
269 return len(self.treeList)
270
272 """ allows forest[i] to work. return the data tuple
273
274 """
275 return self.GetDataTuple(which)
276
278 """ allows the forest to show itself as a string
279
280 """
281 outStr = 'Forest\n'
282 for i in range(len(self.treeList)):
283 outStr = (outStr + ' Tree % 4d: % 5d occurances %%% 5.2f average error\n' %
284 (i, self.countList[i], 100. * self.errList[i]))
285 return outStr
286
288 self.treeList = []
289 self.errList = []
290 self.countList = []
291 self.treeVotes = []
292
293
294 if __name__ == '__main__':
295 from rdkit.ML.DecTree import DecTree
296 f = Forest()
297 n = DecTree.DecTreeNode(None, 'foo')
298 f.AddTree(n, 0.5)
299 f.AddTree(n, 0.5)
300 f.AverageErrors()
301 f.SortTrees()
302 print(f)
303