Package rdkit :: Package ML :: Module GrowComposite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.GrowComposite

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2003-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11  """ command line utility for growing composite models 
 12   
 13  **Usage** 
 14   
 15    _GrowComposite [optional args] filename_ 
 16   
 17  **Command Line Arguments** 
 18   
 19    - -n *count*: number of new models to build 
 20   
 21    - -C *pickle file name*:  name of file containing composite upon which to build. 
 22   
 23    - --inNote *note*: note to be used in loading composite models from the database 
 24        for growing 
 25   
 26    - --balTable *table name*:  table from which to take the original data set 
 27       (for balancing) 
 28   
 29    - --balWeight *weight*: (between 0 and 1) weighting factor for the new data 
 30       (for balancing). OR, *weight* can be a list of weights 
 31   
 32    - --balCnt *count*: number of individual models in the balanced composite 
 33       (for balancing) 
 34   
 35    - --balH: use only the holdout set from the original data set in the balancing 
 36       (for balancing) 
 37   
 38    - --balT: use only the training set from the original data set in the balancing 
 39       (for balancing) 
 40   
 41    - -S: shuffle the original data set 
 42       (for balancing) 
 43   
 44    - -r: randomize the activities of the original data set 
 45       (for balancing) 
 46   
 47    - -N *note*: note to be attached to the grown composite when it's saved in the 
 48       database 
 49   
 50    - --outNote *note*: equivalent to -N 
 51   
 52    - -o *filename*: name of an output file to hold the pickled composite after 
 53       it has been grown. 
 54       If multiple balance weights are used, the weights will be added to 
 55       the filenames. 
 56   
 57    - -L *limit*: provide an (integer) limit on individual model complexity 
 58   
 59    - -d *database name*: instead of reading the data from a QDAT file, 
 60       pull it from a database.  In this case, the _filename_ argument 
 61       provides the name of the database table containing the data set. 
 62   
 63    - -p *tablename*: store persistence data in the database 
 64       in table *tablename* 
 65   
 66    - -l: locks the random number generator to give consistent sets 
 67       of training and hold-out data.  This is primarily intended 
 68       for testing purposes. 
 69   
 70    - -g: be less greedy when training the models. 
 71   
 72    - -G *number*: force trees to be rooted at descriptor *number*. 
 73   
 74    - -D: show a detailed breakdown of the composite model performance 
 75       across the training and, when appropriate, hold-out sets. 
 76   
 77    - -t *threshold value*: use high-confidence predictions for the final 
 78       analysis of the hold-out data. 
 79   
 80    - -q *list string*:  Add QuantTrees to the composite and use the list 
 81       specified in *list string* as the number of target quantization 
 82       bounds for each descriptor.  Don't forget to include 0's at the 
 83       beginning and end of *list string* for the name and value fields. 
 84       For example, if there are 4 descriptors and you want 2 quant bounds 
 85       apiece, you would use _-q "[0,2,2,2,2,0]"_. 
 86       Two special cases: 
 87         1) If you would like to ignore a descriptor in the model building, 
 88            use '-1' for its number of quant bounds. 
 89         2) If you have integer valued data that should not be quantized 
 90            further, enter 0 for that descriptor. 
 91   
 92    - -V: print the version number and exit 
 93   
 94  """ 
 95  from __future__ import print_function 
 96   
 97  import sys 
 98  import time 
 99   
100  import numpy 
101   
102  from rdkit.Dbase.DbConnection import DbConnect 
103  from rdkit.ML import CompositeRun 
104  from rdkit.ML import ScreenComposite, BuildComposite 
105  from rdkit.ML.Composite import AdjustComposite 
106  from rdkit.ML.Data import DataUtils, SplitData 
107  from rdkit.six.moves import cPickle 
108   
109  _runDetails = CompositeRun.CompositeRun() 
110   
111  __VERSION_STRING = "0.5.0" 
112   
113  _verbose = 1 
114   
115   
116 -def message(msg):
117 """ emits messages to _sys.stdout_ 118 override this in modules which import this one to redirect output 119 120 **Arguments** 121 122 - msg: the string to be displayed 123 124 """ 125 if _verbose: 126 sys.stdout.write('%s\n' % (msg))
127 128
129 -def GrowIt(details, composite, progressCallback=None, saveIt=1, setDescNames=0, data=None):
130 """ does the actual work of building a composite model 131 132 **Arguments** 133 134 - details: a _CompositeRun.CompositeRun_ object containing details 135 (options, parameters, etc.) about the run 136 137 - composite: the composite model to grow 138 139 - progressCallback: (optional) a function which is called with a single 140 argument (the number of models built so far) after each model is built. 141 142 - saveIt: (optional) if this is nonzero, the resulting model will be pickled 143 and dumped to the filename specified in _details.outName_ 144 145 - setDescNames: (optional) if nonzero, the composite's _SetInputOrder()_ method 146 will be called using the results of the data set's _GetVarNames()_ method; 147 it is assumed that the details object has a _descNames attribute which 148 is passed to the composites _SetDescriptorNames()_ method. Otherwise 149 (the default), _SetDescriptorNames()_ gets the results of _GetVarNames()_. 150 151 - data: (optional) the data set to be used. If this is not provided, the 152 data set described in details will be used. 153 154 **Returns** 155 156 the enlarged composite model 157 158 159 """ 160 details.rundate = time.asctime() 161 162 if data is None: 163 fName = details.tableName.strip() 164 if details.outName == '': 165 details.outName = fName + '.pkl' 166 if details.dbName == '': 167 data = DataUtils.BuildQuantDataSet(fName) 168 elif details.qBounds != []: 169 details.tableName = fName 170 data = details.GetDataSet() 171 else: 172 data = DataUtils.DBToQuantData( # Function no longer defined 173 details.dbName, fName, quantName=details.qTableName, user=details.dbUser, 174 password=details.dbPassword) 175 176 seed = composite._randomSeed 177 DataUtils.InitRandomNumbers(seed) 178 if details.shuffleActivities == 1: 179 DataUtils.RandomizeActivities(data, shuffle=1, runDetails=details) 180 elif details.randomActivities == 1: 181 DataUtils.RandomizeActivities(data, shuffle=0, runDetails=details) 182 183 namedExamples = data.GetNamedData() 184 trainExamples = namedExamples 185 nExamples = len(trainExamples) 186 message('Training with %d examples' % (nExamples)) 187 message('\t%d descriptors' % (len(trainExamples[0]) - 2)) 188 nVars = data.GetNVars() 189 nPossibleVals = composite.nPossibleVals 190 attrs = list(range(1, nVars + 1)) 191 192 if details.useTrees: 193 from rdkit.ML.DecTree import CrossValidate, PruneTree 194 if details.qBounds != []: 195 from rdkit.ML.DecTree import BuildQuantTree 196 builder = BuildQuantTree.QuantTreeBoot 197 else: 198 from rdkit.ML.DecTree import ID3 199 builder = ID3.ID3Boot 200 driver = CrossValidate.CrossValidationDriver 201 pruner = PruneTree.PruneTree 202 203 if setDescNames: 204 composite.SetInputOrder(data.GetVarNames()) 205 composite.Grow(trainExamples, attrs, [0] + nPossibleVals, buildDriver=driver, pruner=pruner, 206 nTries=details.nModels, pruneIt=details.pruneIt, lessGreedy=details.lessGreedy, 207 needsQuantization=0, treeBuilder=builder, nQuantBounds=details.qBounds, 208 startAt=details.startAt, maxDepth=details.limitDepth, 209 progressCallback=progressCallback, silent=not _verbose) 210 211 else: 212 from rdkit.ML.Neural import CrossValidate 213 driver = CrossValidate.CrossValidationDriver 214 composite.Grow(trainExamples, attrs, [0] + nPossibleVals, nTries=details.nModels, 215 buildDriver=driver, needsQuantization=0) 216 217 composite.AverageErrors() 218 composite.SortModels() 219 modelList, counts, avgErrs = composite.GetAllData() 220 counts = numpy.array(counts) 221 avgErrs = numpy.array(avgErrs) 222 composite._varNames = data.GetVarNames() 223 224 for i in range(len(modelList)): 225 modelList[i].NameModel(composite._varNames) 226 227 # do final statistics 228 weightedErrs = counts * avgErrs 229 averageErr = sum(weightedErrs) / sum(counts) 230 devs = (avgErrs - averageErr) 231 devs = devs * counts 232 devs = numpy.sqrt(devs * devs) 233 avgDev = sum(devs) / sum(counts) 234 if _verbose: 235 message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f' % 236 (100. * averageErr, 100. * avgDev)) 237 238 if details.bayesModel: 239 composite.Train(trainExamples, verbose=0) 240 241 badExamples = [] 242 if not details.detailedRes: 243 if _verbose: 244 message('Testing all examples') 245 wrong = BuildComposite.testall(composite, namedExamples, badExamples) 246 if _verbose: 247 message('%d examples (%% %5.2f) were misclassified' % 248 (len(wrong), 100. * float(len(wrong)) / float(len(namedExamples)))) 249 _runDetails.overall_error = float(len(wrong)) / len(namedExamples) 250 251 if details.detailedRes: 252 if _verbose: 253 message('\nEntire data set:') 254 resTup = ScreenComposite.ShowVoteResults( 255 range(data.GetNPts()), data, composite, nPossibleVals[-1], details.threshold) 256 nGood, nBad, _, avgGood, avgBad, _, voteTab = resTup 257 nPts = len(namedExamples) 258 nClass = nGood + nBad 259 _runDetails.overall_error = float(nBad) / nClass 260 _runDetails.overall_correct_conf = avgGood 261 _runDetails.overall_incorrect_conf = avgBad 262 _runDetails.overall_result_matrix = repr(voteTab) 263 nRej = nClass - nPts 264 if nRej > 0: 265 _runDetails.overall_fraction_dropped = float(nRej) / nPts 266 267 return composite
268 269
270 -def GetComposites(details):
271 res = [] 272 if details.persistTblName and details.inNote: 273 conn = DbConnect(details.dbName, details.persistTblName) 274 mdls = conn.GetData(fields='MODEL', where="where note='%s'" % (details.inNote)) 275 for row in mdls: 276 rawD = row[0] 277 res.append(cPickle.loads(str(rawD))) 278 elif details.composFileName: 279 res.append(cPickle.load(open(details.composFileName, 'rb'))) 280 return res
281 282
283 -def BalanceComposite(details, composite, data1=None, data2=None):
284 """ balances the composite using the parameters provided in details 285 286 **Arguments** 287 288 - details a _CompositeRun.RunDetails_ object 289 290 - composite: the composite model to be balanced 291 292 - data1: (optional) if provided, this should be the 293 data set used to construct the original models 294 295 - data2: (optional) if provided, this should be the 296 data set used to construct the new individual models 297 298 """ 299 if not details.balCnt or details.balCnt > len(composite): 300 return composite 301 message("Balancing Composite") 302 303 # 304 # start by getting data set 1: which is the data set used to build the 305 # original models 306 # 307 if data1 is None: 308 message("\tReading First Data Set") 309 fName = details.balTable.strip() 310 tmp = details.tableName 311 details.tableName = fName 312 dbName = details.dbName 313 details.dbName = details.balDb 314 data1 = details.GetDataSet() 315 details.tableName = tmp 316 details.dbName = dbName 317 if data1 is None: 318 return composite 319 details.splitFrac = composite._splitFrac 320 details.randomSeed = composite._randomSeed 321 DataUtils.InitRandomNumbers(details.randomSeed) 322 if details.shuffleActivities == 1: 323 DataUtils.RandomizeActivities(data1, shuffle=1, runDetails=details) 324 elif details.randomActivities == 1: 325 DataUtils.RandomizeActivities(data1, shuffle=0, runDetails=details) 326 namedExamples = data1.GetNamedData() 327 if details.balDoHoldout or details.balDoTrain: 328 trainIdx, testIdx = SplitData.SplitIndices(len(namedExamples), details.splitFrac, silent=1) 329 trainExamples = [namedExamples[x] for x in trainIdx] 330 testExamples = [namedExamples[x] for x in testIdx] 331 if details.filterFrac != 0.0: 332 trainIdx, temp = DataUtils.FilterData(trainExamples, details.filterVal, details.filterFrac, 333 -1, indicesOnly=1) 334 tmp = [trainExamples[x] for x in trainIdx] 335 testExamples += [trainExamples[x] for x in temp] 336 trainExamples = tmp 337 if details.balDoHoldout: 338 testExamples, trainExamples = trainExamples, testExamples 339 else: 340 trainExamples = namedExamples 341 dataSet1 = trainExamples 342 cols1 = [x.upper() for x in data1.GetVarNames()] 343 data1 = None 344 345 # 346 # now grab data set 2: the data used to build the new individual models 347 # 348 if data2 is None: 349 message("\tReading Second Data Set") 350 data2 = details.GetDataSet() 351 if data2 is None: 352 return composite 353 details.splitFrac = composite._splitFrac 354 details.randomSeed = composite._randomSeed 355 DataUtils.InitRandomNumbers(details.randomSeed) 356 if details.shuffleActivities == 1: 357 DataUtils.RandomizeActivities(data2, shuffle=1, runDetails=details) 358 elif details.randomActivities == 1: 359 DataUtils.RandomizeActivities(data2, shuffle=0, runDetails=details) 360 dataSet2 = data2.GetNamedData() 361 cols2 = [x.upper() for x in data2.GetVarNames()] 362 data2 = None 363 364 # and balance it: 365 res = [] 366 weights = details.balWeight 367 if not isinstance(weights, (tuple, list)): 368 weights = (weights, ) 369 for weight in weights: 370 message("\tBalancing with Weight: %.4f" % (weight)) 371 res.append( 372 AdjustComposite.BalanceComposite(composite, dataSet1, dataSet2, weight, details.balCnt, 373 names1=cols1, names2=cols2)) 374 return res
375 376
377 -def ShowVersion(includeArgs=0):
378 """ prints the version number 379 380 """ 381 print('This is GrowComposite.py version %s' % (__VERSION_STRING)) 382 if includeArgs: 383 print('command line was:') 384 print(' '.join(sys.argv))
385 386
387 -def Usage():
388 """ provides a list of arguments for when this is used from the command line 389 390 """ 391 print(__doc__) 392 sys.exit(-1)
393 394
395 -def SetDefaults(runDetails=None):
396 """ initializes a details object with default values 397 398 **Arguments** 399 400 - details: (optional) a _CompositeRun.CompositeRun_ object. 401 If this is not provided, the global _runDetails will be used. 402 403 **Returns** 404 405 the initialized _CompositeRun_ object. 406 407 408 """ 409 if runDetails is None: 410 runDetails = _runDetails 411 return CompositeRun.SetDefaults(runDetails)
412 413
414 -def ParseArgs(runDetails):
415 """ parses command line arguments and updates _runDetails_ 416 417 **Arguments** 418 419 - runDetails: a _CompositeRun.CompositeRun_ object. 420 421 """ 422 import getopt 423 args, extra = getopt.getopt(sys.argv[1:], 'P:o:n:p:b:sf:F:v:hlgd:rSTt:Q:q:DVG:L:C:N:', 424 ['inNote=', 425 'outNote=', 426 'balTable=', 427 'balWeight=', 428 'balCnt=', 429 'balH', 430 'balT', 431 'balDb=', ]) 432 runDetails.inNote = '' 433 runDetails.composFileName = '' 434 runDetails.balTable = '' 435 runDetails.balWeight = (0.5, ) 436 runDetails.balCnt = 0 437 runDetails.balDoHoldout = 0 438 runDetails.balDoTrain = 0 439 runDetails.balDb = '' 440 for arg, val in args: 441 if arg == '-n': 442 runDetails.nModels = int(val) 443 elif arg == '-C': 444 runDetails.composFileName = val 445 elif arg == '--balTable': 446 runDetails.balTable = val 447 elif arg == '--balWeight': 448 runDetails.balWeight = eval(val) 449 if not isinstance(runDetails.balWeight, (tuple, list)): 450 runDetails.balWeight = (runDetails.balWeight, ) 451 elif arg == '--balCnt': 452 runDetails.balCnt = int(val) 453 elif arg == '--balH': 454 runDetails.balDoHoldout = 1 455 elif arg == '--balT': 456 runDetails.balDoTrain = 1 457 elif arg == '--balDb': 458 runDetails.balDb = val 459 elif arg == '--inNote': 460 runDetails.inNote = val 461 elif arg == '-N' or arg == '--outNote': 462 runDetails.note = val 463 elif arg == '-o': 464 runDetails.outName = val 465 elif arg == '-p': 466 runDetails.persistTblName = val 467 elif arg == '-r': 468 runDetails.randomActivities = 1 469 elif arg == '-S': 470 runDetails.shuffleActivities = 1 471 elif arg == '-h': 472 Usage() 473 elif arg == '-l': 474 runDetails.lockRandom = 1 475 elif arg == '-g': 476 runDetails.lessGreedy = 1 477 elif arg == '-G': 478 runDetails.startAt = int(val) 479 elif arg == '-d': 480 runDetails.dbName = val 481 elif arg == '-T': 482 runDetails.useTrees = 0 483 elif arg == '-t': 484 runDetails.threshold = float(val) 485 elif arg == '-D': 486 runDetails.detailedRes = 1 487 elif arg == '-L': 488 runDetails.limitDepth = int(val) 489 elif arg == '-q': 490 qBounds = eval(val) 491 assert isinstance(qBounds, 492 (tuple, list)), 'bad argument type for -q, specify a list as a string' 493 runDetails.qBoundCount = val 494 runDetails.qBounds = qBounds 495 elif arg == '-Q': 496 qBounds = eval(val) 497 assert type(qBounds) in [type([]), type( 498 ())], 'bad argument type for -Q, specify a list as a string' 499 runDetails.activityBounds = qBounds 500 runDetails.activityBoundsVals = val 501 elif arg == '-V': 502 ShowVersion() 503 sys.exit(0) 504 else: 505 print('bad argument:', arg, file=sys.stderr) 506 Usage() 507 runDetails.tableName = extra[0] 508 if not runDetails.balDb: 509 runDetails.balDb = runDetails.dbName
510 511 512 if __name__ == '__main__': 513 if len(sys.argv) < 2: 514 Usage() 515 516 _runDetails.cmd = ' '.join(sys.argv) 517 SetDefaults(_runDetails) 518 ParseArgs(_runDetails) 519 520 ShowVersion(includeArgs=1) 521 522 initModels = GetComposites(_runDetails) 523 nModels = len(initModels) 524 if nModels > 1: 525 for i in range(nModels): 526 sys.stderr.write( 527 '---------------------------------\n\tDoing %d of %d\n---------------------------------\n' % 528 (i + 1, nModels)) 529 composite = GrowIt(_runDetails, initModels[i], setDescNames=1) 530 if _runDetails.balTable and _runDetails.balCnt: 531 composites = BalanceComposite(_runDetails, composite) 532 else: 533 composites = [composite] 534 for mdl in composites: 535 mdl.ClearModelExamples() 536 if _runDetails.outName: 537 nWeights = len(_runDetails.balWeight) 538 if nWeights == 1: 539 outName = _runDetails.outName 540 composites[0].Pickle(outName) 541 else: 542 for i in range(nWeights): 543 weight = int(100 * _runDetails.balWeight[i]) 544 model = composites[i] 545 outName = '%s.%d.pkl' % (_runDetails.outName.split('.pkl')[0], weight) 546 model.Pickle(outName) 547 if _runDetails.persistTblName and _runDetails.dbName: 548 message('Updating results table %s:%s' % (_runDetails.dbName, _runDetails.persistTblName)) 549 if (len(_runDetails.balWeight)) > 1: 550 message('WARNING: updating results table with models having different weights') 551 # save the composite 552 for i in range(len(composites)): 553 _runDetails.model = cPickle.dumps(composites[i]) 554 _runDetails.Store(db=_runDetails.dbName, table=_runDetails.persistTblName) 555 elif nModels == 1: 556 composite = GrowIt(_runDetails, initModels[0], setDescNames=1) 557 if _runDetails.balTable and _runDetails.balCnt: 558 composites = BalanceComposite(_runDetails, composite) 559 else: 560 composites = [composite] 561 for mdl in composites: 562 mdl.ClearModelExamples() 563 if _runDetails.outName: 564 nWeights = len(_runDetails.balWeight) 565 if nWeights == 1: 566 outName = _runDetails.outName 567 composites[0].Pickle(outName) 568 else: 569 for i in range(nWeights): 570 weight = int(100 * _runDetails.balWeight[i]) 571 model = composites[i] 572 outName = '%s.%d.pkl' % (_runDetails.outName.split('.pkl')[0], weight) 573 model.Pickle(outName) 574 if _runDetails.persistTblName and _runDetails.dbName: 575 message('Updating results table %s:%s' % (_runDetails.dbName, _runDetails.persistTblName)) 576 if (len(composites)) > 1: 577 message('WARNING: updating results table with models having different weights') 578 for i in range(len(composites)): 579 _runDetails.model = cPickle.dumps(composites[i]) 580 _runDetails.Store(db=_runDetails.dbName, table=_runDetails.persistTblName) 581 else: 582 message("No models found") 583