1
2
3
4 """ handles doing cross validation with neural nets
5
6 This is, perhaps, a little misleading. For the purposes of this module,
7 cross validation == evaluating the accuracy of a net.
8
9 """
10 from __future__ import print_function
11 from rdkit.ML.Neural import Network, Trainers
12 from rdkit.ML.Data import SplitData
13 import math
14
15
16 -def CrossValidate(net, testExamples, tolerance, appendExamples=0):
17 """ Determines the classification error for the testExamples
18 **Arguments**
19
20 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
21
22 - testExamples: a list of examples to be used for testing
23
24 - appendExamples: a toggle which is ignored, it's just here to maintain
25 the same API as the decision tree code.
26
27 **Returns**
28
29 a 2-tuple consisting of:
30
31 1) the percent error of the net
32
33 2) a list of misclassified examples
34
35 **Note**
36 At the moment, this is specific to nets with only one output
37 """
38 nTest = len(testExamples)
39 nBad = 0
40 badExamples = []
41 for i in range(nTest):
42 testEx = testExamples[i]
43 trueRes = testExamples[i][-1]
44 res = net.ClassifyExample(testEx)
45 if math.fabs(trueRes - res) > tolerance:
46 badExamples.append(testEx)
47 nBad = nBad + 1
48
49 return float(nBad) / nTest, badExamples
50
51
52 -def CrossValidationDriver(examples, attrs=[], nPossibleVals=[], holdOutFrac=.3, silent=0,
53 tolerance=0.3, calcTotalError=0, hiddenSizes=None, **kwargs):
54 """
55 **Arguments**
56
57 - examples: the full set of examples
58
59 - attrs: a list of attributes to consider in the tree building
60 *This argument is ignored*
61
62 - nPossibleVals: a list of the number of possible values each variable can adopt
63 *This argument is ignored*
64
65 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
66 (used to calculate the error)
67
68 - silent: a toggle used to control how much visual noise this makes as it goes.
69
70 - tolerance: the tolerance for convergence of the net
71
72 - calcTotalError: if this is true the entire data set is used to calculate
73 accuracy of the net
74
75 - hiddenSizes: a list containing the size(s) of the hidden layers in the network.
76 if _hiddenSizes_ is None, one hidden layer containing the same number of nodes
77 as the input layer will be used
78
79 **Returns**
80
81 a 2-tuple containing:
82
83 1) the net
84
85 2) the cross-validation error of the net
86
87 **Note**
88 At the moment, this is specific to nets with only one output
89
90 """
91 nTot = len(examples)
92 if not kwargs.get('replacementSelection', 0):
93 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1,
94 replacement=0)
95 else:
96 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0,
97 replacement=1)
98 trainExamples = [examples[x] for x in trainIndices]
99 testExamples = [examples[x] for x in testIndices]
100
101 nTrain = len(trainExamples)
102 if not silent:
103 print('Training with %d examples' % (nTrain))
104
105 nInput = len(examples[0]) - 1
106 nOutput = 1
107 if hiddenSizes is None:
108 nHidden = nInput
109 netSize = [nInput, nHidden, nOutput]
110 else:
111 netSize = [nInput] + hiddenSizes + [nOutput]
112 net = Network.Network(netSize)
113 t = Trainers.BackProp()
114 t.TrainOnLine(trainExamples, net, errTol=tolerance, useAvgErr=0, silent=silent)
115
116 nTest = len(testExamples)
117 if not silent:
118 print('Testing with %d examples' % nTest)
119 if not calcTotalError:
120 xValError, _ = CrossValidate(net, testExamples, tolerance)
121 else:
122 xValError, _ = CrossValidate(net, examples, tolerance)
123 if not silent:
124 print('Validation error was %%%4.2f' % (100 * xValError))
125 net._trainIndices = trainIndices
126 return net, xValError
127