Package rdkit :: Package ML :: Package NaiveBayes :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.NaiveBayes.CrossValidate

 1  # $Id$ 
 2  # 
 3  #  Copyright (C) 2004-2005 Rational Discovery LLC. 
 4  #   All Rights Reserved 
 5  # 
 6  """ handles doing cross validation with naive bayes models 
 7  and evaluation of individual models 
 8   
 9  """ 
10  from __future__ import print_function 
11   
12  from rdkit.ML.Data import SplitData 
13  from rdkit.ML.NaiveBayes.ClassificationModel import NaiveBayesClassifier 
14   
15  try: 
16    from rdkit.ML.FeatureSelect import CMIM 
17  except ImportError: 
18    CMIM = None 
19   
20   
21 -def makeNBClassificationModel(trainExamples, attrs, nPossibleValues, nQuantBounds, 22 mEstimateVal=-1.0, useSigs=False, ensemble=None, useCMIM=0, **kwargs):
23 if CMIM is not None and useCMIM > 0 and useSigs and not ensemble: 24 ensemble = CMIM.SelectFeatures(trainExamples, useCMIM, bvCol=1) 25 if ensemble: 26 attrs = ensemble 27 model = NaiveBayesClassifier(attrs, nPossibleValues, nQuantBounds, mEstimateVal=mEstimateVal, 28 useSigs=useSigs) 29 30 model.SetTrainingExamples(trainExamples) 31 model.trainModel() 32 return model
33 34
35 -def CrossValidate(NBmodel, testExamples, appendExamples=0):
36 37 nTest = len(testExamples) 38 assert nTest, 'no test examples: %s' % str(testExamples) 39 badExamples = [] 40 nBad = 0 41 preds = NBmodel.ClassifyExamples(testExamples, appendExamples) 42 assert len(preds) == nTest 43 44 for i in range(nTest): 45 testEg = testExamples[i] 46 trueRes = testEg[-1] 47 res = preds[i] 48 49 if (trueRes != res): 50 badExamples.append(testEg) 51 nBad += 1 52 return float(nBad) / nTest, badExamples
53 54
55 -def CrossValidationDriver(examples, attrs, nPossibleValues, nQuantBounds, mEstimateVal=0.0, 56 holdOutFrac=0.3, modelBuilder=makeNBClassificationModel, silent=0, 57 calcTotalError=0, **kwargs):
58 nTot = len(examples) 59 if not kwargs.get('replacementSelection', 0): 60 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1, 61 replacement=0) 62 else: 63 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0, 64 replacement=1) 65 66 trainExamples = [examples[x] for x in trainIndices] 67 testExamples = [examples[x] for x in testIndices] 68 69 NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds, mEstimateVal, 70 **kwargs) 71 72 if not calcTotalError: 73 xValError, _ = CrossValidate(NBmodel, testExamples, appendExamples=1) 74 else: 75 xValError, _ = CrossValidate(NBmodel, examples, appendExamples=0) 76 77 if not silent: 78 print('Validation error was %%%4.2f' % (100 * xValError)) 79 NBmodel._trainIndices = trainIndices 80 return NBmodel, xValError
81