1
2
3
4
5
6 """ handles doing cross validation with naive bayes models
7 and evaluation of individual models
8
9 """
10 from __future__ import print_function
11
12 from rdkit.ML.Data import SplitData
13 from rdkit.ML.NaiveBayes.ClassificationModel import NaiveBayesClassifier
14
15 try:
16 from rdkit.ML.FeatureSelect import CMIM
17 except ImportError:
18 CMIM = None
19
20
21 -def makeNBClassificationModel(trainExamples, attrs, nPossibleValues, nQuantBounds,
22 mEstimateVal=-1.0, useSigs=False, ensemble=None, useCMIM=0, **kwargs):
23 if CMIM is not None and useCMIM > 0 and useSigs and not ensemble:
24 ensemble = CMIM.SelectFeatures(trainExamples, useCMIM, bvCol=1)
25 if ensemble:
26 attrs = ensemble
27 model = NaiveBayesClassifier(attrs, nPossibleValues, nQuantBounds, mEstimateVal=mEstimateVal,
28 useSigs=useSigs)
29
30 model.SetTrainingExamples(trainExamples)
31 model.trainModel()
32 return model
33
34
36
37 nTest = len(testExamples)
38 assert nTest, 'no test examples: %s' % str(testExamples)
39 badExamples = []
40 nBad = 0
41 preds = NBmodel.ClassifyExamples(testExamples, appendExamples)
42 assert len(preds) == nTest
43
44 for i in range(nTest):
45 testEg = testExamples[i]
46 trueRes = testEg[-1]
47 res = preds[i]
48
49 if (trueRes != res):
50 badExamples.append(testEg)
51 nBad += 1
52 return float(nBad) / nTest, badExamples
53
54
58 nTot = len(examples)
59 if not kwargs.get('replacementSelection', 0):
60 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1,
61 replacement=0)
62 else:
63 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0,
64 replacement=1)
65
66 trainExamples = [examples[x] for x in trainIndices]
67 testExamples = [examples[x] for x in testIndices]
68
69 NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds, mEstimateVal,
70 **kwargs)
71
72 if not calcTotalError:
73 xValError, _ = CrossValidate(NBmodel, testExamples, appendExamples=1)
74 else:
75 xValError, _ = CrossValidate(NBmodel, examples, appendExamples=0)
76
77 if not silent:
78 print('Validation error was %%%4.2f' % (100 * xValError))
79 NBmodel._trainIndices = trainIndices
80 return NBmodel, xValError
81