1
2
3
4 """ handles doing cross validation with decision trees
5
6 This is, perhaps, a little misleading. For the purposes of this module,
7 cross validation == evaluating the accuracy of a tree.
8
9
10 """
11 from __future__ import print_function
12
13 import numpy
14
15 from rdkit.ML.Data import SplitData
16 from rdkit.ML.DecTree import ID3
17 from rdkit.ML.DecTree import randomtest
18
19
20 -def ChooseOptimalRoot(examples, trainExamples, testExamples, attrs, nPossibleVals, treeBuilder,
21 nQuantBounds=[], **kwargs):
22 """ loops through all possible tree roots and chooses the one which produces the best tree
23
24 **Arguments**
25
26 - examples: the full set of examples
27
28 - trainExamples: the training examples
29
30 - testExamples: the testing examples
31
32 - attrs: a list of attributes to consider in the tree building
33
34 - nPossibleVals: a list of the number of possible values each variable can adopt
35
36 - treeBuilder: the function to be used to actually build the tree
37
38 - nQuantBounds: an optional list. If present, it's assumed that the builder
39 algorithm takes this argument as well (for building QuantTrees)
40
41 **Returns**
42
43 The best tree found
44
45 **Notes**
46
47 1) Trees are built using _trainExamples_
48
49 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and
50 the entire set of data (i.e. all of _examples_)
51
52 3) _trainExamples_ is not used at all, which immediately raises the question of
53 why it's even being passed in
54
55 """
56 attrs = attrs[:]
57 if nQuantBounds:
58 for i in range(len(nQuantBounds)):
59 if nQuantBounds[i] == -1 and i in attrs:
60 attrs.remove(i)
61 nAttrs = len(attrs)
62 trees = [None] * nAttrs
63 errs = [0] * nAttrs
64 errs[0] = 1e6
65
66 for i in range(1, nAttrs):
67 argD = {'initialVar': attrs[i]}
68 argD.update(kwargs)
69 if nQuantBounds is None or nQuantBounds == []:
70 trees[i] = treeBuilder(trainExamples, attrs, nPossibleVals, **argD)
71 else:
72 trees[i] = treeBuilder(trainExamples, attrs, nPossibleVals, nQuantBounds, **argD)
73 if trees[i]:
74 errs[i], _ = CrossValidate(trees[i], examples, appendExamples=0)
75 else:
76 errs[i] = 1e6
77 best = numpy.argmin(errs)
78
79 return trees[best]
80
81
83 """ Determines the classification error for the testExamples
84
85 **Arguments**
86
87 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
88
89 - testExamples: a list of examples to be used for testing
90
91 - appendExamples: a toggle which is passed along to the tree as it does
92 the classification. The trees can use this to store the examples they
93 classify locally.
94
95 **Returns**
96
97 a 2-tuple consisting of:
98
99 1) the percent error of the tree
100
101 2) a list of misclassified examples
102
103 """
104 nTest = len(testExamples)
105 nBad = 0
106 badExamples = []
107 for i in range(nTest):
108 testEx = testExamples[i]
109 trueRes = testEx[-1]
110 res = tree.ClassifyExample(testEx, appendExamples)
111 if (trueRes != res).any():
112 badExamples.append(testEx)
113 nBad += 1
114
115 return float(nBad) / nTest, badExamples
116
117
118 -def CrossValidationDriver(examples, attrs, nPossibleVals, holdOutFrac=.3, silent=0,
119 calcTotalError=0, treeBuilder=ID3.ID3Boot, lessGreedy=0, startAt=None,
120 nQuantBounds=[], maxDepth=-1, **kwargs):
121 """ Driver function for building trees and doing cross validation
122
123 **Arguments**
124
125 - examples: the full set of examples
126
127 - attrs: a list of attributes to consider in the tree building
128
129 - nPossibleVals: a list of the number of possible values each variable can adopt
130
131 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
132 (used to calculate the error)
133
134 - silent: a toggle used to control how much visual noise this makes as it goes.
135
136 - calcTotalError: a toggle used to indicate whether the classification error
137 of the tree should be calculated using the entire data set (when true) or just
138 the training hold out set (when false)
139
140 - treeBuilder: the function to call to build the tree
141
142 - lessGreedy: toggles use of the less greedy tree growth algorithm (see
143 _ChooseOptimalRoot_).
144
145 - startAt: forces the tree to be rooted at this descriptor
146
147 - nQuantBounds: an optional list. If present, it's assumed that the builder
148 algorithm takes this argument as well (for building QuantTrees)
149
150 - maxDepth: an optional integer. If present, it's assumed that the builder
151 algorithm takes this argument as well
152
153 **Returns**
154
155 a 2-tuple containing:
156
157 1) the tree
158
159 2) the cross-validation error of the tree
160
161 """
162 nTot = len(examples)
163 if not kwargs.get('replacementSelection', 0):
164 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1,
165 replacement=0)
166 else:
167 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0,
168 replacement=1)
169 trainExamples = [examples[x] for x in trainIndices]
170 testExamples = [examples[x] for x in testIndices]
171
172 nTrain = len(trainExamples)
173 if not silent:
174 print('Training with %d examples' % (nTrain))
175
176 if not lessGreedy:
177 if nQuantBounds is None or nQuantBounds == []:
178 tree = treeBuilder(trainExamples, attrs, nPossibleVals, initialVar=startAt, maxDepth=maxDepth,
179 **kwargs)
180 else:
181 tree = treeBuilder(trainExamples, attrs, nPossibleVals, nQuantBounds, initialVar=startAt,
182 maxDepth=maxDepth, **kwargs)
183 else:
184 tree = ChooseOptimalRoot(examples, trainExamples, testExamples, attrs, nPossibleVals,
185 treeBuilder, nQuantBounds, maxDepth=maxDepth, **kwargs)
186
187 nTest = len(testExamples)
188 if not silent:
189 print('Testing with %d examples' % nTest)
190 if not calcTotalError:
191 xValError, badExamples = CrossValidate(tree, testExamples, appendExamples=1)
192 else:
193 xValError, badExamples = CrossValidate(tree, examples, appendExamples=0)
194 if not silent:
195 print('Validation error was %%%4.2f' % (100 * xValError))
196 tree.SetBadExamples(badExamples)
197 tree.SetTrainingExamples(trainExamples)
198 tree.SetTestExamples(testExamples)
199 tree._trainIndices = trainIndices
200 return tree, xValError
201
202
204 """ testing code """
205 examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nExamples=200)
206 tree, _ = CrossValidationDriver(examples, attrs, nPossibleVals)
207
208 tree.Pickle('save.pkl')
209
210 import copy
211 t2 = copy.deepcopy(tree)
212 print('t1 == t2', tree == t2)
213 l = [tree]
214 print('t2 in [tree]', t2 in l, l.index(t2))
215
216
217 if __name__ == '__main__':
218 TestRun()
219