Package rdkit :: Package ML :: Module BuildComposite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.BuildComposite

   1  # $Id$ 
   2  # 
   3  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
   4  # 
   5  #   @@ All Rights Reserved @@ 
   6  #  This file is part of the RDKit. 
   7  #  The contents are covered by the terms of the BSD license 
   8  #  which is included in the file license.txt, found at the root 
   9  #  of the RDKit source tree. 
  10  # 
  11  """ command line utility for building composite models 
  12   
  13  #DOC 
  14   
  15  **Usage** 
  16   
  17    BuildComposite [optional args] filename 
  18   
  19  Unless indicated otherwise (via command line arguments), _filename_ is 
  20  a QDAT file. 
  21   
  22  **Command Line Arguments** 
  23   
  24    - -o *filename*: name of the output file for the pickled composite 
  25   
  26    - -n *num*: number of separate models to add to the composite 
  27   
  28    - -p *tablename*: store persistence data in the database 
  29       in table *tablename* 
  30   
  31    - -N *note*: attach some arbitrary text to the persistence data 
  32   
  33    - -b *filename*: name of the text file to hold examples from the 
  34       holdout set which are misclassified 
  35   
  36    - -s: split the data into training and hold-out sets before building 
  37       the composite 
  38   
  39    - -f *frac*: the fraction of data to use in the training set when the 
  40       data is split 
  41   
  42    - -r: randomize the activities (for testing purposes).  This ignores 
  43       the initial distribution of activity values and produces each 
  44       possible activity value with equal likliehood. 
  45   
  46    - -S: shuffle the activities (for testing purposes) This produces 
  47       a permutation of the input activity values. 
  48   
  49    - -l: locks the random number generator to give consistent sets 
  50       of training and hold-out data.  This is primarily intended 
  51       for testing purposes. 
  52   
  53    - -B: use a so-called Bayesian composite model. 
  54   
  55    - -d *database name*: instead of reading the data from a QDAT file, 
  56       pull it from a database.  In this case, the _filename_ argument 
  57       provides the name of the database table containing the data set. 
  58   
  59    - -D: show a detailed breakdown of the composite model performance 
  60       across the training and, when appropriate, hold-out sets. 
  61   
  62    - -P *pickle file name*: write out the pickled data set to the file 
  63   
  64    - -F *filter frac*: filters the data before training to change the 
  65       distribution of activity values in the training set.  *filter 
  66       frac* is the fraction of the training set that should have the 
  67       target value.  **See note below on data filtering.** 
  68   
  69    - -v *filter value*: filters the data before training to change the 
  70       distribution of activity values in the training set. *filter 
  71       value* is the target value to use in filtering.  **See note below 
  72       on data filtering.** 
  73   
  74    - --modelFiltFrac *model filter frac*: Similar to filter frac above, 
  75       in this case the data is filtered for each model in the composite 
  76       rather than a single overall filter for a composite. *model 
  77       filter frac* is the fraction of the training set for each model 
  78       that should have the target value (*model filter value*). 
  79   
  80    - --modelFiltVal *model filter value*: target value to use for 
  81       filtering data before training each model in the composite. 
  82   
  83    - -t *threshold value*: use high-confidence predictions for the 
  84       final analysis of the hold-out data. 
  85   
  86    - -Q *list string*: the values of quantization bounds for the 
  87       activity value.  See the _-q_ argument for the format of *list 
  88       string*. 
  89   
  90    - --nRuns *count*: build *count* composite models 
  91   
  92    - --prune: prune any models built 
  93   
  94    - -h: print a usage message and exit. 
  95   
  96    - -V: print the version number and exit 
  97   
  98    *-*-*-*-*-*-*-*- Tree-Related Options -*-*-*-*-*-*-*-* 
  99   
 100    - -g: be less greedy when training the models. 
 101   
 102    - -G *number*: force trees to be rooted at descriptor *number*. 
 103   
 104    - -L *limit*: provide an (integer) limit on individual model 
 105       complexity 
 106   
 107    - -q *list string*: Add QuantTrees to the composite and use the list 
 108       specified in *list string* as the number of target quantization 
 109       bounds for each descriptor.  Don't forget to include 0's at the 
 110       beginning and end of *list string* for the name and value fields. 
 111       For example, if there are 4 descriptors and you want 2 quant 
 112       bounds apiece, you would use _-q "[0,2,2,2,2,0]"_. 
 113       Two special cases: 
 114         1) If you would like to ignore a descriptor in the model 
 115            building, use '-1' for its number of quant bounds. 
 116         2) If you have integer valued data that should not be quantized 
 117            further, enter 0 for that descriptor. 
 118   
 119    - --recycle: allow descriptors to be used more than once in a tree 
 120   
 121    - --randomDescriptors=val: toggles growing random forests with val 
 122        randomly-selected descriptors available at each node. 
 123   
 124   
 125    *-*-*-*-*-*-*-*- KNN-Related Options -*-*-*-*-*-*-*-* 
 126   
 127    - --doKnn: use K-Nearest Neighbors models 
 128   
 129    - --knnK=*value*: the value of K to use in the KNN models 
 130   
 131    - --knnTanimoto: use the Tanimoto metric in KNN models 
 132   
 133    - --knnEuclid: use a Euclidean metric in KNN models 
 134   
 135    *-*-*-*-*-*-*- Naive Bayes Classifier Options -*-*-*-*-*-*-*-* 
 136    - --doNaiveBayes : use Naive Bayes classifiers 
 137   
 138    - --mEstimateVal : the value to be used in the m-estimate formula 
 139        If this is greater than 0.0, we use it to compute the conditional 
 140        probabilities by the m-estimate 
 141   
 142    *-*-*-*-*-*-*-*- SVM-Related Options -*-*-*-*-*-*-*-* 
 143   
 144    **** NOTE: THESE ARE DISABLED **** 
 145   
 146  # #   - --doSVM: use Support-vector machines 
 147   
 148  # #   - --svmKernel=*kernel*: choose the type of kernel to be used for 
 149  # #     the SVMs.  Options are: 
 150  # #     The default is: 
 151   
 152  # #   - --svmType=*type*: choose the type of support-vector machine 
 153  # #     to be used.  Options are: 
 154  # #     The default is: 
 155   
 156  # #   - --svmGamma=*gamma*: provide the gamma value for the SVMs.  If this 
 157  # #     is not provided, a grid search will be carried out to determine an 
 158  # #     optimal *gamma* value for each SVM. 
 159   
 160  # #   - --svmCost=*cost*: provide the cost value for the SVMs.  If this is 
 161  # #     not provided, a grid search will be carried out to determine an 
 162  # #     optimal *cost* value for each SVM. 
 163   
 164  # #   - --svmWeights=*weights*: provide the weight values for the 
 165  # #     activities.  If provided this should be a sequence of (label, 
 166  # #     weight) 2-tuples *nActs* long.  If not provided, a weight of 1 
 167  # #     will be used for each activity. 
 168   
 169  # #   - --svmEps=*epsilon*: provide the epsilon value used to determine 
 170  # #     when the SVM has converged.  Defaults to 0.001 
 171   
 172  # #   - --svmDegree=*degree*: provide the degree of the kernel (when 
 173  # #     sensible) Defaults to 3 
 174   
 175  # #   - --svmCoeff=*coeff*: provide the coefficient for the kernel (when 
 176  # #     sensible) Defaults to 0 
 177   
 178  # #   - --svmNu=*nu*: provide the nu value for the kernel (when sensible) 
 179  # #     Defaults to 0.5 
 180   
 181  # #   - --svmDataType=*float*: if the data is contains only 1 and 0 s, specify by 
 182  # #     using binary. Defaults to float 
 183   
 184  # #   - --svmCache=*cache*: provide the size of the memory cache (in MB) 
 185  # #     to be used while building the SVM.  Defaults to 40 
 186   
 187  **Notes** 
 188   
 189    - *Data filtering*: When there is a large disparity between the 
 190      numbers of points with various activity levels present in the 
 191      training set it is sometimes desirable to train on a more 
 192      homogeneous data set.  This can be accomplished using filtering. 
 193      The filtering process works by selecting a particular target 
 194      fraction and target value.  For example, in a case where 95% of 
 195      the original training set has activity 0 and ony 5% activity 1, we 
 196      could filter (by randomly removing points with activity 0) so that 
 197      30% of the data set used to build the composite has activity 1. 
 198   
 199   
 200  """ 
 201  from __future__ import print_function 
 202   
 203  import sys 
 204  import time 
 205   
 206  import numpy 
 207   
 208  from rdkit import DataStructs 
 209  from rdkit.Dbase import DbModule 
 210  from rdkit.ML import CompositeRun 
 211  from rdkit.ML import ScreenComposite 
 212  from rdkit.ML.Composite import Composite, BayesComposite 
 213  from rdkit.ML.Data import DataUtils, SplitData 
 214  from rdkit.utils import listutils 
 215  from rdkit.six.moves import cPickle 
 216   
 217  # # from ML.SVM import SVMClassificationModel as SVM 
 218  _runDetails = CompositeRun.CompositeRun() 
 219   
 220  __VERSION_STRING = "3.2.3" 
 221   
 222  _verbose = 1 
 223   
 224   
225 -def message(msg):
226 """ emits messages to _sys.stdout_ 227 override this in modules which import this one to redirect output 228 229 **Arguments** 230 231 - msg: the string to be displayed 232 233 """ 234 if _verbose: 235 sys.stdout.write('%s\n' % (msg))
236 237
238 -def testall(composite, examples, badExamples=[]):
239 """ screens a number of examples past a composite 240 241 **Arguments** 242 243 - composite: a composite model 244 245 - examples: a list of examples (with results) to be screened 246 247 - badExamples: a list to which misclassified examples are appended 248 249 **Returns** 250 251 a list of 2-tuples containing: 252 253 1) a vote 254 255 2) a confidence 256 257 these are the votes and confidence levels for **misclassified** examples 258 259 """ 260 wrong = [] 261 for example in examples: 262 if composite.GetActivityQuantBounds(): 263 answer = composite.QuantizeActivity(example)[-1] 264 else: 265 answer = example[-1] 266 res, conf = composite.ClassifyExample(example) 267 if res != answer: 268 wrong.append((res, conf)) 269 badExamples.append(example) 270 271 return wrong
272 273
274 -def GetCommandLine(details):
275 """ #DOC 276 277 """ 278 args = ['BuildComposite'] 279 args.append('-n %d' % (details.nModels)) 280 if details.filterFrac != 0.0: 281 args.append('-F %.3f -v %d' % (details.filterFrac, details.filterVal)) 282 if details.modelFilterFrac != 0.0: 283 args.append('--modelFiltFrac=%.3f --modelFiltVal=%d' % (details.modelFilterFrac, 284 details.modelFilterVal)) 285 if details.splitRun: 286 args.append('-s -f %.3f' % (details.splitFrac)) 287 if details.shuffleActivities: 288 args.append('-S') 289 if details.randomActivities: 290 args.append('-r') 291 if details.threshold > 0.0: 292 args.append('-t %.3f' % (details.threshold)) 293 if details.activityBounds: 294 args.append('-Q "%s"' % (details.activityBoundsVals)) 295 if details.dbName: 296 args.append('-d %s' % (details.dbName)) 297 if details.detailedRes: 298 args.append('-D') 299 if hasattr(details, 'noScreen') and details.noScreen: 300 args.append('--noScreen') 301 if details.persistTblName and details.dbName: 302 args.append('-p %s' % (details.persistTblName)) 303 if details.note: 304 args.append('-N %s' % (details.note)) 305 if details.useTrees: 306 if details.limitDepth > 0: 307 args.append('-L %d' % (details.limitDepth)) 308 if details.lessGreedy: 309 args.append('-g') 310 if details.qBounds: 311 shortBounds = listutils.CompactListRepr(details.qBounds) 312 if details.qBounds: 313 args.append('-q "%s"' % (shortBounds)) 314 else: 315 if details.qBounds: 316 args.append('-q "%s"' % (details.qBoundCount)) 317 318 if details.pruneIt: 319 args.append('--prune') 320 if details.startAt: 321 args.append('-G %d' % details.startAt) 322 if details.recycleVars: 323 args.append('--recycle') 324 if details.randomDescriptors: 325 args.append('--randomDescriptors=%d' % details.randomDescriptors) 326 if details.useSigTrees: 327 args.append('--doSigTree') 328 if details.limitDepth > 0: 329 args.append('-L %d' % (details.limitDepth)) 330 if details.randomDescriptors: 331 args.append('--randomDescriptors=%d' % details.randomDescriptors) 332 333 if details.useKNN: 334 args.append('--doKnn --knnK %d' % (details.knnNeighs)) 335 if details.knnDistFunc == 'Tanimoto': 336 args.append('--knnTanimoto') 337 else: 338 args.append('--knnEuclid') 339 340 if details.useNaiveBayes: 341 args.append('--doNaiveBayes') 342 if details.mEstimateVal >= 0.0: 343 args.append('--mEstimateVal=%.3f' % details.mEstimateVal) 344 345 # # if details.useSVM: 346 # # args.append('--doSVM') 347 # # if details.svmKernel: 348 # # for k in SVM.kernels.keys(): 349 # # if SVM.kernels[k]==details.svmKernel: 350 # # args.append('--svmKernel=%s'%k) 351 # # break 352 # # if details.svmType: 353 # # for k in SVM.machineTypes.keys(): 354 # # if SVM.machineTypes[k]==details.svmType: 355 # # args.append('--svmType=%s'%k) 356 # # break 357 # # if details.svmGamma: 358 # # args.append('--svmGamma=%f'%details.svmGamma) 359 # # if details.svmCost: 360 # # args.append('--svmCost=%f'%details.svmCost) 361 # # if details.svmWeights: 362 # # args.append("--svmWeights='%s'"%str(details.svmWeights)) 363 # # if details.svmDegree: 364 # # args.append('--svmDegree=%d'%details.svmDegree) 365 # # if details.svmCoeff: 366 # # args.append('--svmCoeff=%d'%details.svmCoeff) 367 # # if details.svmEps: 368 # # args.append('--svmEps=%f'%details.svmEps) 369 # # if details.svmNu: 370 # # args.append('--svmNu=%f'%details.svmNu) 371 # # if details.svmCache: 372 # # args.append('--svmCache=%d'%details.svmCache) 373 # # if detail.svmDataType: 374 # # args.append('--svmDataType=%s'%details.svmDataType) 375 # # if not details.svmShrink: 376 # # args.append('--svmShrink') 377 378 if details.replacementSelection: 379 args.append('--replacementSelection') 380 381 # this should always be last: 382 if details.tableName: 383 args.append(details.tableName) 384 385 return ' '.join(args)
386 387
388 -def RunOnData(details, data, progressCallback=None, saveIt=1, setDescNames=0):
389 if details.lockRandom: 390 seed = details.randomSeed 391 else: 392 import random 393 seed = (random.randint(0, 1e6), random.randint(0, 1e6)) 394 DataUtils.InitRandomNumbers(seed) 395 testExamples = [] 396 if details.shuffleActivities == 1: 397 DataUtils.RandomizeActivities(data, shuffle=1, runDetails=details) 398 elif details.randomActivities == 1: 399 DataUtils.RandomizeActivities(data, shuffle=0, runDetails=details) 400 401 namedExamples = data.GetNamedData() 402 if details.splitRun == 1: 403 trainIdx, testIdx = SplitData.SplitIndices( 404 len(namedExamples), details.splitFrac, silent=not _verbose) 405 406 trainExamples = [namedExamples[x] for x in trainIdx] 407 testExamples = [namedExamples[x] for x in testIdx] 408 else: 409 testExamples = [] 410 testIdx = [] 411 trainIdx = list(range(len(namedExamples))) 412 trainExamples = namedExamples 413 414 if details.filterFrac != 0.0: 415 # if we're doing quantization on the fly, we need to handle that here: 416 if hasattr(details, 'activityBounds') and details.activityBounds: 417 tExamples = [] 418 bounds = details.activityBounds 419 for pt in trainExamples: 420 pt = pt[:] 421 act = pt[-1] 422 placed = 0 423 bound = 0 424 while not placed and bound < len(bounds): 425 if act < bounds[bound]: 426 pt[-1] = bound 427 placed = 1 428 else: 429 bound += 1 430 if not placed: 431 pt[-1] = bound 432 tExamples.append(pt) 433 else: 434 bounds = None 435 tExamples = trainExamples 436 trainIdx, temp = DataUtils.FilterData(tExamples, details.filterVal, details.filterFrac, -1, 437 indicesOnly=1) 438 tmp = [trainExamples[x] for x in trainIdx] 439 testExamples += [trainExamples[x] for x in temp] 440 trainExamples = tmp 441 442 counts = DataUtils.CountResults(trainExamples, bounds=bounds) 443 ks = counts.keys() 444 ks.sort() 445 message('Result Counts in training set:') 446 for k in ks: 447 message(str((k, counts[k]))) 448 counts = DataUtils.CountResults(testExamples, bounds=bounds) 449 ks = counts.keys() 450 ks.sort() 451 message('Result Counts in test set:') 452 for k in ks: 453 message(str((k, counts[k]))) 454 nExamples = len(trainExamples) 455 message('Training with %d examples' % (nExamples)) 456 457 nVars = data.GetNVars() 458 attrs = list(range(1, nVars + 1)) 459 nPossibleVals = data.GetNPossibleVals() 460 for i in range(1, len(nPossibleVals)): 461 if nPossibleVals[i - 1] == -1: 462 attrs.remove(i) 463 464 if details.pickleDataFileName != '': 465 pickleDataFile = open(details.pickleDataFileName, 'wb+') 466 cPickle.dump(trainExamples, pickleDataFile) 467 cPickle.dump(testExamples, pickleDataFile) 468 pickleDataFile.close() 469 470 if details.bayesModel: 471 composite = BayesComposite.BayesComposite() 472 else: 473 composite = Composite.Composite() 474 475 composite._randomSeed = seed 476 composite._splitFrac = details.splitFrac 477 composite._shuffleActivities = details.shuffleActivities 478 composite._randomizeActivities = details.randomActivities 479 480 if hasattr(details, 'filterFrac'): 481 composite._filterFrac = details.filterFrac 482 if hasattr(details, 'filterVal'): 483 composite._filterVal = details.filterVal 484 485 composite.SetModelFilterData(details.modelFilterFrac, details.modelFilterVal) 486 487 composite.SetActivityQuantBounds(details.activityBounds) 488 nPossibleVals = data.GetNPossibleVals() 489 if details.activityBounds: 490 nPossibleVals[-1] = len(details.activityBounds) + 1 491 492 if setDescNames: 493 composite.SetInputOrder(data.GetVarNames()) 494 composite.SetDescriptorNames(details._descNames) 495 else: 496 composite.SetDescriptorNames(data.GetVarNames()) 497 composite.SetActivityQuantBounds(details.activityBounds) 498 if details.nModels == 1: 499 details.internalHoldoutFrac = 0.0 500 if details.useTrees: 501 from rdkit.ML.DecTree import CrossValidate, PruneTree 502 if details.qBounds != []: 503 from rdkit.ML.DecTree import BuildQuantTree 504 builder = BuildQuantTree.QuantTreeBoot 505 else: 506 from rdkit.ML.DecTree import ID3 507 builder = ID3.ID3Boot 508 driver = CrossValidate.CrossValidationDriver 509 pruner = PruneTree.PruneTree 510 511 composite.SetQuantBounds(details.qBounds) 512 nPossibleVals = data.GetNPossibleVals() 513 if details.activityBounds: 514 nPossibleVals[-1] = len(details.activityBounds) + 1 515 composite.Grow( 516 trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver, pruner=pruner, 517 nTries=details.nModels, pruneIt=details.pruneIt, lessGreedy=details.lessGreedy, 518 needsQuantization=0, treeBuilder=builder, nQuantBounds=details.qBounds, 519 startAt=details.startAt, maxDepth=details.limitDepth, progressCallback=progressCallback, 520 holdOutFrac=details.internalHoldoutFrac, replacementSelection=details.replacementSelection, 521 recycleVars=details.recycleVars, randomDescriptors=details.randomDescriptors, 522 silent=not _verbose) 523 524 elif details.useSigTrees: 525 from rdkit.ML.DecTree import CrossValidate 526 from rdkit.ML.DecTree import BuildSigTree 527 builder = BuildSigTree.SigTreeBuilder 528 driver = CrossValidate.CrossValidationDriver 529 nPossibleVals = data.GetNPossibleVals() 530 if details.activityBounds: 531 nPossibleVals[-1] = len(details.activityBounds) + 1 532 if hasattr(details, 'sigTreeBiasList'): 533 biasList = details.sigTreeBiasList 534 else: 535 biasList = None 536 if hasattr(details, 'useCMIM'): 537 useCMIM = details.useCMIM 538 else: 539 useCMIM = 0 540 if hasattr(details, 'allowCollections'): 541 allowCollections = details.allowCollections 542 else: 543 allowCollections = False 544 composite.Grow( 545 trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver, 546 nTries=details.nModels, needsQuantization=0, treeBuilder=builder, maxDepth=details.limitDepth, 547 progressCallback=progressCallback, holdOutFrac=details.internalHoldoutFrac, 548 replacementSelection=details.replacementSelection, recycleVars=details.recycleVars, 549 randomDescriptors=details.randomDescriptors, biasList=biasList, useCMIM=useCMIM, 550 allowCollection=allowCollections, silent=not _verbose) 551 552 elif details.useKNN: 553 from rdkit.ML.KNN import CrossValidate 554 from rdkit.ML.KNN import DistFunctions 555 556 driver = CrossValidate.CrossValidationDriver 557 dfunc = '' 558 if (details.knnDistFunc == "Euclidean"): 559 dfunc = DistFunctions.EuclideanDist 560 elif (details.knnDistFunc == "Tanimoto"): 561 dfunc = DistFunctions.TanimotoDist 562 else: 563 assert 0, "Bad KNN distance metric value" 564 565 composite.Grow(trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver, 566 nTries=details.nModels, needsQuantization=0, numNeigh=details.knnNeighs, 567 holdOutFrac=details.internalHoldoutFrac, distFunc=dfunc) 568 569 elif details.useNaiveBayes or details.useSigBayes: 570 from rdkit.ML.NaiveBayes import CrossValidate 571 driver = CrossValidate.CrossValidationDriver 572 if not (hasattr(details, 'useSigBayes') and details.useSigBayes): 573 composite.Grow(trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver, 574 nTries=details.nModels, needsQuantization=0, nQuantBounds=details.qBounds, 575 holdOutFrac=details.internalHoldoutFrac, 576 replacementSelection=details.replacementSelection, 577 mEstimateVal=details.mEstimateVal, silent=not _verbose) 578 else: 579 if hasattr(details, 'useCMIM'): 580 useCMIM = details.useCMIM 581 else: 582 useCMIM = 0 583 584 composite.Grow(trainExamples, attrs, nPossibleVals=[0] + nPossibleVals, buildDriver=driver, 585 nTries=details.nModels, needsQuantization=0, nQuantBounds=details.qBounds, 586 mEstimateVal=details.mEstimateVal, useSigs=True, useCMIM=useCMIM, 587 holdOutFrac=details.internalHoldoutFrac, 588 replacementSelection=details.replacementSelection, silent=not _verbose) 589 590 # # elif details.useSVM: 591 # # from rdkit.ML.SVM import CrossValidate 592 # # driver = CrossValidate.CrossValidationDriver 593 # # composite.Grow(trainExamples, attrs, nPossibleVals=[0]+nPossibleVals, 594 # # buildDriver=driver, nTries=details.nModels, 595 # # needsQuantization=0, 596 # # cost=details.svmCost,gamma=details.svmGamma, 597 # # weights=details.svmWeights,degree=details.svmDegree, 598 # # type=details.svmType,kernelType=details.svmKernel, 599 # # coef0=details.svmCoeff,eps=details.svmEps,nu=details.svmNu, 600 # # cache_size=details.svmCache,shrinking=details.svmShrink, 601 # # dataType=details.svmDataType, 602 # # holdOutFrac=details.internalHoldoutFrac, 603 # # replacementSelection=details.replacementSelection, 604 # # silent=not _verbose) 605 606 else: 607 from rdkit.ML.Neural import CrossValidate 608 driver = CrossValidate.CrossValidationDriver 609 composite.Grow(trainExamples, attrs, [0] + nPossibleVals, nTries=details.nModels, 610 buildDriver=driver, needsQuantization=0) 611 612 composite.AverageErrors() 613 composite.SortModels() 614 modelList, counts, avgErrs = composite.GetAllData() 615 counts = numpy.array(counts) 616 avgErrs = numpy.array(avgErrs) 617 composite._varNames = data.GetVarNames() 618 619 for i in range(len(modelList)): 620 modelList[i].NameModel(composite._varNames) 621 622 # do final statistics 623 weightedErrs = counts * avgErrs 624 averageErr = sum(weightedErrs) / sum(counts) 625 devs = (avgErrs - averageErr) 626 devs = devs * counts 627 devs = numpy.sqrt(devs * devs) 628 avgDev = sum(devs) / sum(counts) 629 message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f' % 630 (100. * averageErr, 100. * avgDev)) 631 632 if details.bayesModel: 633 composite.Train(trainExamples, verbose=0) 634 635 # blow out the saved examples and then save the composite: 636 composite.ClearModelExamples() 637 if saveIt: 638 composite.Pickle(details.outName) 639 details.model = DbModule.binaryHolder(cPickle.dumps(composite)) 640 641 badExamples = [] 642 if not details.detailedRes and (not hasattr(details, 'noScreen') or not details.noScreen): 643 if details.splitRun: 644 message('Testing all hold-out examples') 645 wrong = testall(composite, testExamples, badExamples) 646 message('%d examples (%% %5.2f) were misclassified' % (len(wrong), 100. * float(len(wrong)) / 647 float(len(testExamples)))) 648 _runDetails.holdout_error = float(len(wrong)) / len(testExamples) 649 else: 650 message('Testing all examples') 651 wrong = testall(composite, namedExamples, badExamples) 652 message('%d examples (%% %5.2f) were misclassified' % (len(wrong), 100. * float(len(wrong)) / 653 float(len(namedExamples)))) 654 _runDetails.overall_error = float(len(wrong)) / len(namedExamples) 655 656 if details.detailedRes: 657 message('\nEntire data set:') 658 resTup = ScreenComposite.ShowVoteResults( 659 range(data.GetNPts()), data, composite, nPossibleVals[-1], details.threshold) 660 nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup 661 nPts = len(namedExamples) 662 nClass = nGood + nBad 663 _runDetails.overall_error = float(nBad) / nClass 664 _runDetails.overall_correct_conf = avgGood 665 _runDetails.overall_incorrect_conf = avgBad 666 _runDetails.overall_result_matrix = repr(voteTab) 667 nRej = nClass - nPts 668 if nRej > 0: 669 _runDetails.overall_fraction_dropped = float(nRej) / nPts 670 671 if details.splitRun: 672 message('\nHold-out data:') 673 resTup = ScreenComposite.ShowVoteResults( 674 range(len(testExamples)), testExamples, composite, nPossibleVals[-1], details.threshold) 675 nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab = resTup 676 nPts = len(testExamples) 677 nClass = nGood + nBad 678 _runDetails.holdout_error = float(nBad) / nClass 679 _runDetails.holdout_correct_conf = avgGood 680 _runDetails.holdout_incorrect_conf = avgBad 681 _runDetails.holdout_result_matrix = repr(voteTab) 682 nRej = nClass - nPts 683 if nRej > 0: 684 _runDetails.holdout_fraction_dropped = float(nRej) / nPts 685 686 if details.persistTblName and details.dbName: 687 message('Updating results table %s:%s' % (details.dbName, details.persistTblName)) 688 details.Store(db=details.dbName, table=details.persistTblName) 689 690 if details.badName != '': 691 badFile = open(details.badName, 'w+') 692 for i in range(len(badExamples)): 693 ex = badExamples[i] 694 vote = wrong[i] 695 outStr = '%s\t%s\n' % (ex, vote) 696 badFile.write(outStr) 697 badFile.close() 698 699 composite.ClearModelExamples() 700 return composite
701 702
703 -def RunIt(details, progressCallback=None, saveIt=1, setDescNames=0):
704 """ does the actual work of building a composite model 705 706 **Arguments** 707 708 - details: a _CompositeRun.CompositeRun_ object containing details 709 (options, parameters, etc.) about the run 710 711 - progressCallback: (optional) a function which is called with a single 712 argument (the number of models built so far) after each model is built. 713 714 - saveIt: (optional) if this is nonzero, the resulting model will be pickled 715 and dumped to the filename specified in _details.outName_ 716 717 - setDescNames: (optional) if nonzero, the composite's _SetInputOrder()_ method 718 will be called using the results of the data set's _GetVarNames()_ method; 719 it is assumed that the details object has a _descNames attribute which 720 is passed to the composites _SetDescriptorNames()_ method. Otherwise 721 (the default), _SetDescriptorNames()_ gets the results of _GetVarNames()_. 722 723 **Returns** 724 725 the composite model constructed 726 727 728 """ 729 details.rundate = time.asctime() 730 731 fName = details.tableName.strip() 732 if details.outName == '': 733 details.outName = fName + '.pkl' 734 if not details.dbName: 735 if details.qBounds != []: 736 data = DataUtils.TextFileToData(fName) 737 else: 738 data = DataUtils.BuildQuantDataSet(fName) 739 elif details.useSigTrees or details.useSigBayes: 740 details.tableName = fName 741 data = details.GetDataSet(pickleCol=0, pickleClass=DataStructs.ExplicitBitVect) 742 elif details.qBounds != [] or not details.useTrees: 743 details.tableName = fName 744 data = details.GetDataSet() 745 else: 746 data = DataUtils.DBToQuantData(details.dbName, # Function no longer defined 747 fName, 748 quantName=details.qTableName, 749 user=details.dbUser, 750 password=details.dbPassword) 751 752 composite = RunOnData(details, data, progressCallback=progressCallback, saveIt=saveIt, 753 setDescNames=setDescNames) 754 return composite
755 756
757 -def ShowVersion(includeArgs=0):
758 """ prints the version number 759 760 """ 761 print('This is BuildComposite.py version %s' % (__VERSION_STRING)) 762 if includeArgs: 763 print('command line was:') 764 print(' '.join(sys.argv))
765 766
767 -def Usage():
768 """ provides a list of arguments for when this is used from the command line 769 770 """ 771 print(__doc__) 772 sys.exit(-1)
773 774
775 -def SetDefaults(runDetails=None):
776 """ initializes a details object with default values 777 778 **Arguments** 779 780 - details: (optional) a _CompositeRun.CompositeRun_ object. 781 If this is not provided, the global _runDetails will be used. 782 783 **Returns** 784 785 the initialized _CompositeRun_ object. 786 787 788 """ 789 if runDetails is None: 790 runDetails = _runDetails 791 return CompositeRun.SetDefaults(runDetails)
792 793
794 -def ParseArgs(runDetails):
795 """ parses command line arguments and updates _runDetails_ 796 797 **Arguments** 798 799 - runDetails: a _CompositeRun.CompositeRun_ object. 800 801 """ 802 import getopt 803 args, extra = getopt.getopt( 804 sys.argv[1:], 805 'P:o:n:p:b:sf:F:v:hlgd:rSTt:BQ:q:DVG:N:L:', 806 ['nRuns=', 807 'prune', 808 'profile', 809 'seed=', 810 'noScreen', 811 'modelFiltFrac=', 812 'modelFiltVal=', 813 'recycle', 814 'randomDescriptors=', 815 'doKnn', 816 'knnK=', 817 'knnTanimoto', 818 'knnEuclid', 819 'doSigTree', 820 'allowCollections', 821 'doNaiveBayes', 822 'mEstimateVal=', 823 'doSigBayes', 824 825 # # 'doSVM','svmKernel=','svmType=','svmGamma=', 826 # # 'svmCost=','svmWeights=','svmDegree=', 827 # # 'svmCoeff=','svmEps=','svmNu=','svmCache=', 828 # # 'svmShrink','svmDataType=', 829 'replacementSelection', ]) 830 runDetails.profileIt = 0 831 for arg, val in args: 832 if arg == '-n': 833 runDetails.nModels = int(val) 834 elif arg == '-N': 835 runDetails.note = val 836 elif arg == '-o': 837 runDetails.outName = val 838 elif arg == '-Q': 839 qBounds = eval(val) 840 assert type(qBounds) in [type([]), type( 841 ())], 'bad argument type for -Q, specify a list as a string' 842 runDetails.activityBounds = qBounds 843 runDetails.activityBoundsVals = val 844 elif arg == '-p': 845 runDetails.persistTblName = val 846 elif arg == '-P': 847 runDetails.pickleDataFileName = val 848 elif arg == '-r': 849 runDetails.randomActivities = 1 850 elif arg == '-S': 851 runDetails.shuffleActivities = 1 852 elif arg == '-b': 853 runDetails.badName = val 854 elif arg == '-B': 855 runDetails.bayesModels = 1 856 elif arg == '-s': 857 runDetails.splitRun = 1 858 elif arg == '-f': 859 runDetails.splitFrac = float(val) 860 elif arg == '-F': 861 runDetails.filterFrac = float(val) 862 elif arg == '-v': 863 runDetails.filterVal = float(val) 864 elif arg == '-l': 865 runDetails.lockRandom = 1 866 elif arg == '-g': 867 runDetails.lessGreedy = 1 868 elif arg == '-G': 869 runDetails.startAt = int(val) 870 elif arg == '-d': 871 runDetails.dbName = val 872 elif arg == '-T': 873 runDetails.useTrees = 0 874 elif arg == '-t': 875 runDetails.threshold = float(val) 876 elif arg == '-D': 877 runDetails.detailedRes = 1 878 elif arg == '-L': 879 runDetails.limitDepth = int(val) 880 elif arg == '-q': 881 qBounds = eval(val) 882 assert type(qBounds) in [type([]), type( 883 ())], 'bad argument type for -q, specify a list as a string' 884 runDetails.qBoundCount = val 885 runDetails.qBounds = qBounds 886 elif arg == '-V': 887 ShowVersion() 888 sys.exit(0) 889 elif arg == '--nRuns': 890 runDetails.nRuns = int(val) 891 elif arg == '--modelFiltFrac': 892 runDetails.modelFilterFrac = float(val) 893 elif arg == '--modelFiltVal': 894 runDetails.modelFilterVal = float(val) 895 elif arg == '--prune': 896 runDetails.pruneIt = 1 897 elif arg == '--profile': 898 runDetails.profileIt = 1 899 900 elif arg == '--recycle': 901 runDetails.recycleVars = 1 902 elif arg == '--randomDescriptors': 903 runDetails.randomDescriptors = int(val) 904 905 elif arg == '--doKnn': 906 runDetails.useKNN = 1 907 runDetails.useTrees = 0 908 # # runDetails.useSVM=0 909 runDetails.useNaiveBayes = 0 910 elif arg == '--knnK': 911 runDetails.knnNeighs = int(val) 912 elif arg == '--knnTanimoto': 913 runDetails.knnDistFunc = "Tanimoto" 914 elif arg == '--knnEuclid': 915 runDetails.knnDistFunc = "Euclidean" 916 917 elif arg == '--doSigTree': 918 # # runDetails.useSVM=0 919 runDetails.useKNN = 0 920 runDetails.useTrees = 0 921 runDetails.useNaiveBayes = 0 922 runDetails.useSigTrees = 1 923 elif arg == '--allowCollections': 924 runDetails.allowCollections = True 925 926 elif arg == '--doNaiveBayes': 927 runDetails.useNaiveBayes = 1 928 # # runDetails.useSVM=0 929 runDetails.useKNN = 0 930 runDetails.useTrees = 0 931 runDetails.useSigBayes = 0 932 elif arg == '--doSigBayes': 933 runDetails.useSigBayes = 1 934 runDetails.useNaiveBayes = 0 935 # # runDetails.useSVM=0 936 runDetails.useKNN = 0 937 runDetails.useTrees = 0 938 elif arg == '--mEstimateVal': 939 runDetails.mEstimateVal = float(val) 940 941 # # elif arg == '--doSVM': 942 # # runDetails.useSVM=1 943 # # runDetails.useKNN=0 944 # # runDetails.useTrees=0 945 # # runDetails.useNaiveBayes=0 946 # # elif arg == '--svmKernel': 947 # # if val not in SVM.kernels.keys(): 948 # # message('kernel %s not in list of available kernels:\n%s\n'%(val,SVM.kernels.keys())) 949 # # sys.exit(-1) 950 # # else: 951 # # runDetails.svmKernel=SVM.kernels[val] 952 # # elif arg == '--svmType': 953 # # if val not in SVM.machineTypes.keys(): 954 # # message('type %s not in list of available machines:\n%s\n'%(val, 955 # # SVM.machineTypes.keys())) 956 # # sys.exit(-1) 957 # # else: 958 # # runDetails.svmType=SVM.machineTypes[val] 959 # # elif arg == '--svmGamma': 960 # # runDetails.svmGamma = float(val) 961 # # elif arg == '--svmCost': 962 # # runDetails.svmCost = float(val) 963 # # elif arg == '--svmWeights': 964 # # # FIX: this is dangerous 965 # # runDetails.svmWeights = eval(val) 966 # # elif arg == '--svmDegree': 967 # # runDetails.svmDegree = int(val) 968 # # elif arg == '--svmCoeff': 969 # # runDetails.svmCoeff = float(val) 970 # # elif arg == '--svmEps': 971 # # runDetails.svmEps = float(val) 972 # # elif arg == '--svmNu': 973 # # runDetails.svmNu = float(val) 974 # # elif arg == '--svmCache': 975 # # runDetails.svmCache = int(val) 976 # # elif arg == '--svmShrink': 977 # # runDetails.svmShrink = 0 978 # # elif arg == '--svmDataType': 979 # # runDetails.svmDataType=val 980 981 elif arg == '--seed': 982 # FIX: dangerous 983 runDetails.randomSeed = eval(val) 984 985 elif arg == '--noScreen': 986 runDetails.noScreen = 1 987 988 elif arg == '--replacementSelection': 989 runDetails.replacementSelection = 1 990 991 elif arg == '-h': 992 Usage() 993 994 else: 995 Usage() 996 runDetails.tableName = extra[0]
997 998 if __name__ == '__main__': 999 if len(sys.argv) < 2: 1000 Usage() 1001 1002 _runDetails.cmd = ' '.join(sys.argv) 1003 SetDefaults(_runDetails) 1004 ParseArgs(_runDetails) 1005 1006 ShowVersion(includeArgs=1) 1007 1008 if _runDetails.nRuns > 1: 1009 for i in range(_runDetails.nRuns): 1010 sys.stderr.write( 1011 '---------------------------------\n\tDoing %d of %d\n---------------------------------\n' % 1012 (i + 1, _runDetails.nRuns)) 1013 RunIt(_runDetails) 1014 else: 1015 if _runDetails.profileIt: 1016 try: 1017 import hotshot 1018 import hotshot.stats 1019 prof = hotshot.Profile('prof.dat') 1020 prof.runcall(RunIt, _runDetails) 1021 stats = hotshot.stats.load('prof.dat') 1022 stats.strip_dirs() 1023 stats.sort_stats('time', 'calls') 1024 stats.print_stats(30) 1025 except ImportError: 1026 print('Profiling requires the hotshot module') 1027 else: 1028 RunIt(_runDetails) 1029