Package rdkit :: Package ML :: Package Neural :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Neural.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with neural nets 
  5   
  6  This is, perhaps, a little misleading.  For the purposes of this module, 
  7  cross validation == evaluating the accuracy of a net. 
  8   
  9  """ 
 10  from __future__ import print_function 
 11  from rdkit.ML.Neural import Network, Trainers 
 12  from rdkit.ML.Data import SplitData 
 13  import math 
 14   
 15   
16 -def CrossValidate(net, testExamples, tolerance, appendExamples=0):
17 """ Determines the classification error for the testExamples 18 **Arguments** 19 20 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 21 22 - testExamples: a list of examples to be used for testing 23 24 - appendExamples: a toggle which is ignored, it's just here to maintain 25 the same API as the decision tree code. 26 27 **Returns** 28 29 a 2-tuple consisting of: 30 31 1) the percent error of the net 32 33 2) a list of misclassified examples 34 35 **Note** 36 At the moment, this is specific to nets with only one output 37 """ 38 nTest = len(testExamples) 39 nBad = 0 40 badExamples = [] 41 for i in range(nTest): 42 testEx = testExamples[i] 43 trueRes = testExamples[i][-1] 44 res = net.ClassifyExample(testEx) 45 if math.fabs(trueRes - res) > tolerance: 46 badExamples.append(testEx) 47 nBad = nBad + 1 48 49 return float(nBad) / nTest, badExamples
50 51
52 -def CrossValidationDriver(examples, attrs=[], nPossibleVals=[], holdOutFrac=.3, silent=0, 53 tolerance=0.3, calcTotalError=0, hiddenSizes=None, **kwargs):
54 """ 55 **Arguments** 56 57 - examples: the full set of examples 58 59 - attrs: a list of attributes to consider in the tree building 60 *This argument is ignored* 61 62 - nPossibleVals: a list of the number of possible values each variable can adopt 63 *This argument is ignored* 64 65 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 66 (used to calculate the error) 67 68 - silent: a toggle used to control how much visual noise this makes as it goes. 69 70 - tolerance: the tolerance for convergence of the net 71 72 - calcTotalError: if this is true the entire data set is used to calculate 73 accuracy of the net 74 75 - hiddenSizes: a list containing the size(s) of the hidden layers in the network. 76 if _hiddenSizes_ is None, one hidden layer containing the same number of nodes 77 as the input layer will be used 78 79 **Returns** 80 81 a 2-tuple containing: 82 83 1) the net 84 85 2) the cross-validation error of the net 86 87 **Note** 88 At the moment, this is specific to nets with only one output 89 90 """ 91 nTot = len(examples) 92 if not kwargs.get('replacementSelection', 0): 93 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1, 94 replacement=0) 95 else: 96 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0, 97 replacement=1) 98 trainExamples = [examples[x] for x in trainIndices] 99 testExamples = [examples[x] for x in testIndices] 100 101 nTrain = len(trainExamples) 102 if not silent: 103 print('Training with %d examples' % (nTrain)) 104 105 nInput = len(examples[0]) - 1 106 nOutput = 1 107 if hiddenSizes is None: 108 nHidden = nInput 109 netSize = [nInput, nHidden, nOutput] 110 else: 111 netSize = [nInput] + hiddenSizes + [nOutput] 112 net = Network.Network(netSize) 113 t = Trainers.BackProp() 114 t.TrainOnLine(trainExamples, net, errTol=tolerance, useAvgErr=0, silent=silent) 115 116 nTest = len(testExamples) 117 if not silent: 118 print('Testing with %d examples' % nTest) 119 if not calcTotalError: 120 xValError, _ = CrossValidate(net, testExamples, tolerance) 121 else: 122 xValError, _ = CrossValidate(net, examples, tolerance) 123 if not silent: 124 print('Validation error was %%%4.2f' % (100 * xValError)) 125 net._trainIndices = trainIndices 126 return net, xValError
127