Package rdkit :: Package ML :: Module ScreenComposite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.ScreenComposite

   1  # $Id$ 
   2  # 
   3  #  Copyright (C) 2000-2008 greg Landrum and Rational Discovery LLC 
   4  # 
   5  #   @@ All Rights Reserved @@ 
   6  #  This file is part of the RDKit. 
   7  #  The contents are covered by the terms of the BSD license 
   8  #  which is included in the file license.txt, found at the root 
   9  #  of the RDKit source tree. 
  10  # 
  11  """ command line utility for screening composite models 
  12   
  13  **Usage** 
  14   
  15    _ScreenComposite [optional args] modelfile(s) datafile_ 
  16   
  17  Unless indicated otherwise (via command line arguments), _modelfile_ is 
  18  a file containing a pickled composite model and _filename_ is a QDAT file. 
  19   
  20  **Command Line Arguments** 
  21   
  22    - -t *threshold value(s)*: use high-confidence predictions for the final 
  23       analysis of the hold-out data.  The threshold value can be either a single 
  24       float or a list/tuple of floats.  All thresholds should be between 
  25       0.0 and 1.0 
  26   
  27    - -D: do a detailed screen. 
  28   
  29    - -d *database name*: instead of reading the data from a QDAT file, 
  30       pull it from a database.  In this case, the _datafile_ argument 
  31       provides the name of the database table containing the data set. 
  32   
  33    - -N *note*: use all models from the database which have this note. 
  34                 The modelfile argument should contain the name of the table 
  35                 with the models. 
  36   
  37    - -H: screen only the hold out set (works only if a version of 
  38          BuildComposite more recent than 1.2.2 was used). 
  39   
  40    - -T: screen only the training set (works only if a version of 
  41          BuildComposite more recent than 1.2.2 was used). 
  42   
  43    - -E: do a detailed Error analysis.  This shows each misclassified 
  44       point and the number of times it was missed across all screened 
  45       composites.  If the --enrich argument is also provided, only compounds 
  46       that have true activity value equal to the enrichment value will be 
  47       used. 
  48   
  49    - --enrich *enrichVal*: target "active" value to be used in calculating 
  50       enrichments. 
  51   
  52    - -A: show All predictions. 
  53   
  54    - -S: shuffle activity values before screening 
  55   
  56    - -R: randomize activity values before screening 
  57   
  58    - -F *filter frac*: filters the data before training to change the 
  59       distribution of activity values in the training set.  *filter frac* 
  60       is the fraction of the training set that should have the target value. 
  61       **See note in BuildComposite help about data filtering** 
  62   
  63    - -v *filter value*: filters the data before training to change the 
  64       distribution of activity values in the training set. *filter value* 
  65       is the target value to use in filtering. 
  66       **See note in BuildComposite help about data filtering** 
  67   
  68    - -V: be verbose when screening multiple models 
  69   
  70    - -h: show this message and exit 
  71   
  72    - --OOB: Do out an "out-of-bag" generalization error estimate.  This only 
  73        makes sense when applied to the original data set. 
  74   
  75    - --pickleCol *colId*: index of the column containing a pickled value 
  76        (used primarily for cases where fingerprints are used as descriptors) 
  77   
  78    *** Options for making Prediction (Hanneke) Plots *** 
  79   
  80    - --predPlot=<fileName>: triggers the generation of a Hanneke plot and 
  81        sets the name of the .txt file which will hold the output data. 
  82        A Gnuplot control file, <fileName>.gnu, will also be generated. 
  83   
  84    - --predActTable=<name> (optional):  name of the database table 
  85        containing activity values.  If this is not provided, activities 
  86        will be read from the same table containing the screening data 
  87   
  88    - --predActCol=<name> (optional):  name of the activity column. If not 
  89        provided, the name of the last column in the activity table will 
  90        be used. 
  91   
  92    - --predLogScale (optional):  If provided, the x axis of the 
  93        prediction plot (the activity axis) will be plotted using a log 
  94        scale 
  95   
  96    - --predShow: launch a gnuplot instance and display the prediction 
  97        plot (the plot will still be written to disk). 
  98   
  99    *** The following options are likely obsolete *** 
 100   
 101    - -P: read pickled data.  The datafile argument should contain 
 102       a pickled data set. *relevant only to qdat files* 
 103   
 104    - -q: data are not quantized (the composite should take care of 
 105       quantization itself if it requires quantized data). *relevant only to 
 106       qdat files* 
 107   
 108   
 109   
 110  """ 
 111  from __future__ import print_function 
 112   
 113  import os 
 114  import sys 
 115   
 116  import numpy 
 117   
 118  from rdkit import DataStructs 
 119  from rdkit.Dbase import DbModule 
 120  from rdkit.Dbase.DbConnection import DbConnect 
 121  from rdkit.ML import CompositeRun 
 122  from rdkit.ML.Data import DataUtils, SplitData 
 123  from rdkit.six.moves import cPickle 
 124  from rdkit.six.moves import input 
 125   
 126   
 127  try: 
 128    from PIL import Image, ImageDraw 
 129  except ImportError: 
 130    hasPil = 0 
 131  else: 
 132    hasPil = 1 
 133   
 134  _details = CompositeRun.CompositeRun() 
 135   
 136  __VERSION_STRING = "3.3.0" 
 137   
 138   
139 -def message(msg, noRet=0):
140 """ emits messages to _sys.stdout_ 141 override this in modules which import this one to redirect output 142 143 **Arguments** 144 145 - msg: the string to be displayed 146 147 """ 148 if noRet: 149 sys.stdout.write('%s ' % (msg)) 150 else: 151 sys.stdout.write('%s\n' % (msg))
152 153
154 -def error(msg):
155 """ emits messages to _sys.stderr_ 156 override this in modules which import this one to redirect output 157 158 **Arguments** 159 160 - msg: the string to be displayed 161 162 """ 163 sys.stderr.write('ERROR: %s\n' % (msg))
164 165
166 -def CalcEnrichment(mat, tgt=1):
167 if tgt < 0 or tgt >= mat.shape[0]: 168 return 0 169 nPts = float(sum(sum(mat))) 170 nTgtPred = float(sum(mat[:, tgt])) 171 if nTgtPred: 172 pctCorrect = mat[tgt, tgt] / nTgtPred 173 nTgtReal = float(sum(mat[tgt, :])) 174 pctOverall = nTgtReal / nPts 175 else: 176 return 0.0 177 return pctCorrect / pctOverall
178 179
180 -def CollectResults(indices, dataSet, composite, callback=None, appendExamples=0, errorEstimate=0):
181 """ screens a set of examples through a composite and returns the 182 results 183 #DOC 184 185 **Arguments** 186 187 - examples: the examples to be screened (a sequence of sequences) 188 it's assumed that the last element in each example is it's "value" 189 190 - composite: the composite model to be used 191 192 - callback: (optional) if provided, this should be a function 193 taking a single argument that is called after each example is 194 screened with the number of examples screened so far as the 195 argument. 196 197 - appendExamples: (optional) this value is passed on to the 198 composite's _ClassifyExample()_ method. 199 200 - errorEstimate: (optional) calculate the "out of bag" error 201 estimate for the composite using Breiman's definition. This 202 only makes sense when screening the original data set! 203 [L. Breiman "Out-of-bag Estimation", UC Berkeley Dept of 204 Statistics Technical Report (1996)] 205 206 **Returns** 207 208 a list of 3-tuples _nExamples_ long: 209 210 1) answer: the value from the example 211 212 2) pred: the composite model's prediction 213 214 3) conf: the confidence of the composite 215 216 """ 217 # for i in range(len(composite)): 218 # print(' ',i,'TRAIN:',composite[i][0]._trainIndices) 219 220 for j in range(len(composite)): 221 tmp = composite.GetModel(j) 222 if hasattr(tmp, '_trainIndices') and type(tmp._trainIndices) != dict: 223 tis = {} 224 if hasattr(tmp, '_trainIndices'): 225 for v in tmp._trainIndices: 226 tis[v] = 1 227 tmp._trainIndices = tis 228 229 nPts = len(indices) 230 res = [None] * nPts 231 for i in range(nPts): 232 idx = indices[i] 233 example = dataSet[idx] 234 if errorEstimate: 235 use = [] 236 for j in range(len(composite)): 237 mdl = composite.GetModel(j) 238 if not mdl._trainIndices.get(idx, 0): 239 use.append(j) 240 else: 241 use = None 242 # print('IDX:',idx,'use:',use ) 243 pred, conf = composite.ClassifyExample(example, appendExample=appendExamples, onlyModels=use) 244 if composite.GetActivityQuantBounds(): 245 answer = composite.QuantizeActivity(example)[-1] 246 else: 247 answer = example[-1] 248 res[i] = answer, pred, conf 249 if callback: 250 callback(i) 251 return res
252 253
254 -def DetailedScreen(indices, data, composite, threshold=0, screenResults=None, goodVotes=None, 255 badVotes=None, noVotes=None, callback=None, appendExamples=0, errorEstimate=0):
256 """ screens a set of examples cross a composite and breaks the 257 predictions into *correct*,*incorrect* and *unclassified* sets. 258 #DOC 259 **Arguments** 260 261 - examples: the examples to be screened (a sequence of sequences) 262 it's assumed that the last element in each example is its "value" 263 264 - composite: the composite model to be used 265 266 - threshold: (optional) the threshold to be used to decide whether 267 or not a given prediction should be kept 268 269 - screenResults: (optional) the results of screening the results 270 (a sequence of 3-tuples in the format returned by 271 _CollectResults()_). If this is provided, the examples will not 272 be screened again. 273 274 - goodVotes,badVotes,noVotes: (optional) if provided these should 275 be lists (or anything supporting an _append()_ method) which 276 will be used to pass the screening results back. 277 278 - callback: (optional) if provided, this should be a function 279 taking a single argument that is called after each example is 280 screened with the number of examples screened so far as the 281 argument. 282 283 - appendExamples: (optional) this value is passed on to the 284 composite's _ClassifyExample()_ method. 285 286 - errorEstimate: (optional) calculate the "out of bag" error 287 estimate for the composite using Breiman's definition. This 288 only makes sense when screening the original data set! 289 [L. Breiman "Out-of-bag Estimation", UC Berkeley Dept of 290 Statistics Technical Report (1996)] 291 292 **Notes** 293 294 - since this function doesn't return anything, if one or more of 295 the arguments _goodVotes_, _badVotes_, and _noVotes_ is not 296 provided, there's not much reason to call it 297 298 """ 299 if screenResults is None: 300 screenResults = CollectResults(indices, data, composite, callback=callback, 301 appendExamples=appendExamples, errorEstimate=errorEstimate) 302 if goodVotes is None: 303 goodVotes = [] 304 if badVotes is None: 305 badVotes = [] 306 if noVotes is None: 307 noVotes = [] 308 for i in range(len(screenResults)): 309 answer, pred, conf = screenResults[i] 310 if conf > threshold: 311 if pred != answer: 312 badVotes.append((answer, pred, conf, i)) 313 else: 314 goodVotes.append((answer, pred, conf, i)) 315 else: 316 noVotes.append((answer, pred, conf, i))
317 318
319 -def ShowVoteResults(indices, data, composite, nResultCodes, threshold, verbose=1, 320 screenResults=None, callback=None, appendExamples=0, goodVotes=None, 321 badVotes=None, noVotes=None, errorEstimate=0):
322 """ screens the results and shows a detailed workup 323 324 The work of doing the screening and processing the results is 325 handled by _DetailedScreen()_ 326 #DOC 327 328 **Arguments** 329 330 - examples: the examples to be screened (a sequence of sequences) 331 it's assumed that the last element in each example is its "value" 332 333 - composite: the composite model to be used 334 335 - nResultCodes: the number of possible results the composite can 336 return 337 338 - threshold: the threshold to be used to decide whether or not a 339 given prediction should be kept 340 341 - screenResults: (optional) the results of screening the results 342 (a sequence of 3-tuples in the format returned by 343 _CollectResults()_). If this is provided, the examples will not 344 be screened again. 345 346 - callback: (optional) if provided, this should be a function 347 taking a single argument that is called after each example is 348 screened with the number of examples screened so far as the 349 argument. 350 351 - appendExamples: (optional) this value is passed on to the 352 composite's _ClassifyExample()_ method. 353 354 - goodVotes,badVotes,noVotes: (optional) if provided these should 355 be lists (or anything supporting an _append()_ method) which 356 will be used to pass the screening results back. 357 358 - errorEstimate: (optional) calculate the "out of bag" error 359 estimate for the composite using Breiman's definition. This 360 only makes sense when screening the original data set! 361 [L. Breiman "Out-of-bag Estimation", UC Berkeley Dept of 362 Statistics Technical Report (1996)] 363 364 **Returns** 365 366 a 7-tuple: 367 368 1) the number of good (correct) predictions 369 370 2) the number of bad (incorrect) predictions 371 372 3) the number of predictions skipped due to the _threshold_ 373 374 4) the average confidence in the good predictions 375 376 5) the average confidence in the bad predictions 377 378 6) the average confidence in the skipped predictions 379 380 7) the results table 381 382 """ 383 nExamples = len(indices) 384 if goodVotes is None: 385 goodVotes = [] 386 if badVotes is None: 387 badVotes = [] 388 if noVotes is None: 389 noVotes = [] 390 DetailedScreen(indices, data, composite, threshold, screenResults=screenResults, 391 goodVotes=goodVotes, badVotes=badVotes, noVotes=noVotes, callback=callback, 392 appendExamples=appendExamples, errorEstimate=errorEstimate) 393 nBad = len(badVotes) 394 nGood = len(goodVotes) 395 nClassified = nGood + nBad 396 if verbose: 397 print('\n\t*** Vote Results ***') 398 print('misclassified: %d/%d (%%%4.2f)\t%d/%d (%%%4.2f)' % 399 (nBad, nExamples, 100. * float(nBad) / nExamples, nBad, nClassified, 400 100. * float(nBad) / nClassified)) 401 nSkip = len(noVotes) 402 if nSkip > 0: 403 if verbose: 404 print('skipped: %d/%d (%%% 4.2f)' % (nSkip, nExamples, 100. * float(nSkip) / nExamples)) 405 noConf = numpy.array([x[2] for x in noVotes]) 406 avgSkip = sum(noConf) / float(nSkip) 407 else: 408 avgSkip = 0. 409 410 if nBad > 0: 411 badConf = numpy.array([x[2] for x in badVotes]) 412 avgBad = sum(badConf) / float(nBad) 413 else: 414 avgBad = 0. 415 416 if nGood > 0: 417 goodRes = [x[1] for x in goodVotes] 418 goodConf = numpy.array([x[2] for x in goodVotes]) 419 avgGood = sum(goodConf) / float(nGood) 420 else: 421 goodRes = [] 422 goodConf = [] 423 avgGood = 0. 424 425 if verbose: 426 print() 427 print('average correct confidence: % 6.4f' % avgGood) 428 print('average incorrect confidence: % 6.4f' % avgBad) 429 430 voteTab = numpy.zeros((nResultCodes, nResultCodes), numpy.int) 431 for res in goodRes: 432 voteTab[res, res] += 1 433 for ans, res, conf, idx in badVotes: 434 voteTab[ans, res] += 1 435 436 if verbose: 437 print() 438 print('\tResults Table:') 439 vTab = voteTab.transpose() 440 colCounts = numpy.sum(vTab, 0) 441 rowCounts = numpy.sum(vTab, 1) 442 message('') 443 for i in range(nResultCodes): 444 if rowCounts[i] == 0: 445 rowCounts[i] = 1 446 row = vTab[i] 447 message(' ', noRet=1) 448 for j in range(nResultCodes): 449 entry = row[j] 450 message(' % 6d' % entry, noRet=1) 451 message(' | % 4.2f' % (100. * vTab[i, i] / rowCounts[i])) 452 message(' ', noRet=1) 453 for i in range(nResultCodes): 454 message('-------', noRet=1) 455 message('') 456 message(' ', noRet=1) 457 for i in range(nResultCodes): 458 if colCounts[i] == 0: 459 colCounts[i] = 1 460 message(' % 6.2f' % (100. * vTab[i, i] / colCounts[i]), noRet=1) 461 message('') 462 463 return nGood, nBad, nSkip, avgGood, avgBad, avgSkip, voteTab
464 465
466 -def ScreenIt(composite, indices, data, partialVote=0, voteTol=0.0, verbose=1, screenResults=None, 467 goodVotes=None, badVotes=None, noVotes=None):
468 """ screens a set of data using a composite model and prints out 469 statistics about the screen. 470 #DOC 471 The work of doing the screening and processing the results is 472 handled by _DetailedScreen()_ 473 474 **Arguments** 475 476 - composite: the composite model to be used 477 478 - data: the examples to be screened (a sequence of sequences) 479 it's assumed that the last element in each example is its "value" 480 481 - partialVote: (optional) toggles use of the threshold value in 482 the screnning. 483 484 - voteTol: (optional) the threshold to be used to decide whether or not a 485 given prediction should be kept 486 487 - verbose: (optional) sets degree of verbosity of the screening 488 489 - screenResults: (optional) the results of screening the results 490 (a sequence of 3-tuples in the format returned by 491 _CollectResults()_). If this is provided, the examples will not 492 be screened again. 493 494 - goodVotes,badVotes,noVotes: (optional) if provided these should 495 be lists (or anything supporting an _append()_ method) which 496 will be used to pass the screening results back. 497 498 499 **Returns** 500 501 a 7-tuple: 502 503 1) the number of good (correct) predictions 504 505 2) the number of bad (incorrect) predictions 506 507 3) the number of predictions skipped due to the _threshold_ 508 509 4) the average confidence in the good predictions 510 511 5) the average confidence in the bad predictions 512 513 6) the average confidence in the skipped predictions 514 515 7) None 516 517 """ 518 if goodVotes is None: 519 goodVotes = [] 520 if badVotes is None: 521 badVotes = [] 522 if noVotes is None: 523 noVotes = [] 524 525 if not partialVote: 526 voteTol = 0.0 527 528 DetailedScreen(indices, data, composite, voteTol, screenResults=screenResults, 529 goodVotes=goodVotes, badVotes=badVotes, noVotes=noVotes) 530 531 nGood = len(goodVotes) 532 goodAccum = 0. 533 for res, pred, conf, idx in goodVotes: 534 goodAccum += conf 535 536 misCount = len(badVotes) 537 badAccum = 0. 538 for res, pred, conf, idx in badVotes: 539 badAccum += conf 540 541 nSkipped = len(noVotes) 542 goodSkipped = 0 543 badSkipped = 0 544 skipAccum = 0. 545 for ans, pred, conf, idx in noVotes: 546 skipAccum += conf 547 if ans != pred: 548 badSkipped += 1 549 else: 550 goodSkipped += 1 551 552 nData = nGood + misCount + nSkipped 553 if verbose: 554 print('Total N Points:', nData) 555 if partialVote: 556 nCounted = nData - nSkipped 557 if verbose: 558 print('Misclassifications: %d (%%%4.2f)' % (misCount, 100. * float(misCount) / nCounted)) 559 print('N Skipped: %d (%%%4.2f)' % (nSkipped, 100. * float(nSkipped) / nData)) 560 print('\tGood Votes Skipped: %d (%%%4.2f)' % 561 (goodSkipped, 100. * float(goodSkipped) / nSkipped)) 562 print('\tBad Votes Skipped: %d (%%%4.2f)' % (badSkipped, 100. * float(badSkipped) / nSkipped)) 563 else: 564 if verbose: 565 print('Misclassifications: %d (%%%4.2f)' % (misCount, 100. * float(misCount) / nData)) 566 print('Average Correct Vote Confidence: % 6.4f' % (goodAccum / (nData - misCount))) 567 print('Average InCorrect Vote Confidence: % 6.4f' % (badAccum / misCount)) 568 569 avgGood = 0 570 avgBad = 0 571 avgSkip = 0 572 if nGood: 573 avgGood = goodAccum / nGood 574 if misCount: 575 avgBad = badAccum / misCount 576 if nSkipped: 577 avgSkip = skipAccum / nSkipped 578 return nGood, misCount, nSkipped, avgGood, avgBad, avgSkip, None
579 580
581 -def _processVoteList(votes, data):
582 """ *Internal Use Only* 583 584 converts a list of 4 tuples: (answer,prediction,confidence,idx) into 585 an alternate list: (answer,prediction,confidence,data point) 586 587 **Arguments** 588 589 - votes: a list of 4 tuples: (answer, prediction, confidence, 590 index) 591 592 - data: a _DataUtils.MLData.MLDataSet_ 593 594 595 **Note**: alterations are done in place in the _votes_ list 596 597 """ 598 for i in range(len(votes)): 599 ans, pred, conf, idx = votes[i] 600 votes[i] = (ans, pred, conf, data[idx])
601 602
603 -def PrepareDataFromDetails(model, details, data, verbose=0):
604 if (hasattr(details, 'doHoldout') and details.doHoldout) or \ 605 (hasattr(details, 'doTraining') and details.doTraining): 606 try: 607 splitF = model._splitFrac 608 except AttributeError: 609 pass 610 else: 611 if verbose: 612 message('s', noRet=1) 613 614 if hasattr(details, 'errorEstimate') and details.errorEstimate and \ 615 hasattr(details, 'doHoldout') and details.doHoldout: 616 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 617 message('****** WARNING: OOB screening should not be combined with doHoldout option.') 618 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 619 trainIdx, testIdx = SplitData.SplitIndices(data.GetNPts(), splitF, silent=1) 620 621 if hasattr(details, 'filterFrac') and details.filterFrac != 0.0: 622 if verbose: 623 message('f', noRet=1) 624 trainFilt, temp = DataUtils.FilterData(data, details.filterVal, details.filterFrac, -1, 625 indicesToUse=trainIdx, indicesOnly=1) 626 testIdx += temp 627 trainIdx = trainFilt 628 elif hasattr(details, 'errorEstimate') and details.errorEstimate: 629 # the OOB screening works by checking to see if a given index 630 # is in the 631 if hasattr(details, 'filterFrac') and details.filterFrac != 0.0: 632 if verbose: 633 message('f', noRet=1) 634 testIdx, trainIdx = DataUtils.FilterData(data, details.filterVal, details.filterFrac, -1, 635 indicesToUse=range(data.GetNPts()), indicesOnly=1) 636 testIdx.extend(trainIdx) 637 else: 638 testIdx = list(range(data.GetNPts())) 639 trainIdx = [] 640 else: 641 testIdx = list(range(data.GetNPts())) 642 trainIdx = [] 643 if hasattr(details, 'doTraining') and details.doTraining: 644 testIdx, trainIdx = trainIdx, testIdx 645 return trainIdx, testIdx
646 647
648 -def ScreenFromDetails(models, details, callback=None, setup=None, appendExamples=0, goodVotes=None, 649 badVotes=None, noVotes=None, data=None, enrichments=None):
650 """ Screens a set of data using a a _CompositeRun.CompositeRun_ 651 instance to provide parameters 652 653 # DOC 654 655 The actual data to be used are extracted from the database and 656 table specified in _details_ 657 658 Aside from dataset construction, _ShowVoteResults()_ does most of 659 the heavy lifting here. 660 661 **Arguments** 662 663 - model: a composite model 664 665 - details: a _CompositeRun.CompositeRun_ object containing details 666 (options, parameters, etc.) about the run 667 668 - callback: (optional) if provided, this should be a function 669 taking a single argument that is called after each example is 670 screened with the number of examples screened so far as the 671 argument. 672 673 - setup: (optional) a function taking a single argument which is 674 called at the start of screening with the number of points to 675 be screened as the argument. 676 677 - appendExamples: (optional) this value is passed on to the 678 composite's _ClassifyExample()_ method. 679 680 - goodVotes,badVotes,noVotes: (optional) if provided these should 681 be lists (or anything supporting an _append()_ method) which 682 will be used to pass the screening results back. 683 684 685 **Returns** 686 687 a 7-tuple: 688 689 1) the number of good (correct) predictions 690 691 2) the number of bad (incorrect) predictions 692 693 3) the number of predictions skipped due to the _threshold_ 694 695 4) the average confidence in the good predictions 696 697 5) the average confidence in the bad predictions 698 699 6) the average confidence in the skipped predictions 700 701 7) the results table 702 703 """ 704 if data is None: 705 if hasattr(details, 'pickleCol'): 706 data = details.GetDataSet(pickleCol=details.pickleCol, 707 pickleClass=DataStructs.ExplicitBitVect) 708 else: 709 data = details.GetDataSet() 710 if details.threshold > 0.0: 711 details.partialVote = 1 712 else: 713 details.partialVote = 0 714 715 if type(models) not in [list, tuple]: 716 models = (models, ) 717 718 nModels = len(models) 719 720 if setup is not None: 721 setup(nModels * data.GetNPts()) 722 723 nGood = numpy.zeros(nModels, numpy.float) 724 nBad = numpy.zeros(nModels, numpy.float) 725 nSkip = numpy.zeros(nModels, numpy.float) 726 confGood = numpy.zeros(nModels, numpy.float) 727 confBad = numpy.zeros(nModels, numpy.float) 728 confSkip = numpy.zeros(nModels, numpy.float) 729 voteTab = None 730 if goodVotes is None: 731 goodVotes = [] 732 if badVotes is None: 733 badVotes = [] 734 if noVotes is None: 735 noVotes = [] 736 if enrichments is None: 737 enrichments = [0.0] * nModels 738 badVoteDict = {} 739 noVoteDict = {} 740 741 for i in range(nModels): 742 if nModels > 1: 743 goodVotes = [] 744 badVotes = [] 745 noVotes = [] 746 model = models[i] 747 748 try: 749 seed = model._randomSeed 750 except AttributeError: 751 pass 752 else: 753 DataUtils.InitRandomNumbers(seed) 754 755 if (hasattr(details, 'shuffleActivities') and details.shuffleActivities) or \ 756 (hasattr(details, 'randomActivities') and details.randomActivities): 757 if hasattr(details, 'shuffleActivities') and details.shuffleActivities: 758 shuffle = True 759 else: 760 shuffle = False 761 randomize = True 762 DataUtils.RandomizeActivities(data, shuffle=shuffle, runDetails=details) 763 else: 764 randomize = False 765 shuffle = False 766 767 if hasattr(model, '_shuffleActivities') and \ 768 model._shuffleActivities and \ 769 not shuffle: 770 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 771 message('****** WARNING: Shuffled model being screened with unshuffled data.') 772 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 773 if hasattr(model, '_randomizeActivities') and \ 774 model._randomizeActivities and \ 775 not randomize: 776 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 777 message('****** WARNING: Random model being screened with non-random data.') 778 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 779 780 trainIdx, testIdx = PrepareDataFromDetails(model, details, data) 781 782 nPossible = model.GetQuantBounds()[1] 783 if callback: 784 cb = lambda x, y=callback, z=i * data.GetNPts(): y(x + z) 785 else: 786 cb = None 787 if not hasattr(details, 'errorEstimate') or not details.errorEstimate: 788 errorEstimate = 0 789 else: 790 errorEstimate = 1 791 g, b, s, aG, aB, aS, vT = ShowVoteResults( 792 testIdx, data, model, nPossible[-1], details.threshold, verbose=0, callback=cb, 793 appendExamples=appendExamples, goodVotes=goodVotes, badVotes=badVotes, noVotes=noVotes, 794 errorEstimate=errorEstimate) 795 if voteTab is None: 796 voteTab = numpy.zeros(vT.shape, numpy.float) 797 if hasattr(details, 'errorAnalysis') and details.errorAnalysis: 798 for a, p, c, idx in badVotes: 799 label = testIdx[idx] 800 if hasattr(details, 'enrichTgt') and details.enrichTgt >= 0: 801 if a == details.enrichTgt: 802 badVoteDict[label] = badVoteDict.get(label, 0) + 1 803 else: 804 badVoteDict[label] = badVoteDict.get(label, 0) + 1 805 for a, p, c, idx in noVotes: 806 label = testIdx[idx] 807 if hasattr(details, 'enrichTgt') and details.enrichTgt >= 0: 808 if a == details.enrichTgt: 809 noVoteDict[label] = noVoteDict.get(label, 0) + 1 810 else: 811 noVoteDict[label] = noVoteDict.get(label, 0) + 1 812 813 voteTab += vT 814 nGood[i] = g 815 nBad[i] = b 816 nSkip[i] = s 817 confGood[i] = aG 818 confBad[i] = aB 819 confSkip[i] = aS 820 821 if hasattr(details, 'enrichTgt') and details.enrichTgt >= 0: 822 enrichments[i] = CalcEnrichment(vT, tgt=details.enrichTgt) 823 824 if nModels == 1: 825 return g, b, s, aG, aB, aS, vT 826 else: 827 voteTab /= nModels 828 829 avgNBad = sum(nBad) / nModels 830 devNBad = numpy.sqrt(sum((nBad - avgNBad)**2) / (nModels - 1)) 831 832 # bestIdx = numpy.argsort(nBad)[0] 833 834 avgNGood = sum(nGood) / nModels 835 devNGood = numpy.sqrt(sum((nGood - avgNGood)**2) / (nModels - 1)) 836 837 avgNSkip = sum(nSkip) / nModels 838 devNSkip = numpy.sqrt(sum((nSkip - avgNSkip)**2) / (nModels - 1)) 839 840 avgConfBad = sum(confBad) / nModels 841 devConfBad = numpy.sqrt(sum((confBad - avgConfBad)**2) / (nModels - 1)) 842 843 avgConfGood = sum(confGood) / nModels 844 devConfGood = numpy.sqrt(sum((confGood - avgConfGood)**2) / (nModels - 1)) 845 846 avgConfSkip = sum(confSkip) / nModels 847 devConfSkip = numpy.sqrt(sum((confSkip - avgConfSkip)**2) / (nModels - 1)) 848 return ((avgNGood, devNGood), (avgNBad, devNBad), (avgNSkip, devNSkip), 849 (avgConfGood, devConfGood), (avgConfBad, devConfBad), (avgConfSkip, devConfSkip), 850 voteTab)
851 852
853 -def GetScreenImage(nGood, nBad, nRej, size=None):
854 if not hasPil: 855 return None 856 try: 857 nTot = float(nGood) + float(nBad) + float(nRej) 858 except TypeError: 859 nGood = nGood[0] 860 nBad = nBad[0] 861 nRej = nRej[0] 862 nTot = float(nGood) + float(nBad) + float(nRej) 863 864 if not nTot: 865 return None 866 goodColor = (100, 100, 255) 867 badColor = (255, 100, 100) 868 rejColor = (255, 255, 100) 869 870 pctGood = float(nGood) / nTot 871 pctBad = float(nBad) / nTot 872 pctRej = float(nRej) / nTot 873 874 if size is None: 875 size = (100, 100) 876 img = Image.new('RGB', size, (255, 255, 255)) 877 draw = ImageDraw.Draw(img) 878 box = (0, 0, size[0] - 1, size[1] - 1) 879 880 startP = -90 881 endP = int(startP + pctGood * 360) 882 draw.pieslice(box, startP, endP, fill=goodColor) 883 startP = endP 884 endP = int(startP + pctBad * 360) 885 draw.pieslice(box, startP, endP, fill=badColor) 886 startP = endP 887 endP = int(startP + pctRej * 360) 888 draw.pieslice(box, startP, endP, fill=rejColor) 889 890 return img
891 892
893 -def ScreenToHtml(nGood, nBad, nRej, avgGood, avgBad, avgSkip, voteTable, imgDir='.', fullPage=1, 894 skipImg=0, includeDefs=1):
895 """ returns the text of a web page showing the screening details 896 #DOC 897 **Arguments** 898 899 - nGood: number of correct predictions 900 901 - nBad: number of incorrect predictions 902 903 - nRej: number of rejected predictions 904 905 - avgGood: average correct confidence 906 907 - avgBad: average incorrect confidence 908 909 - avgSkip: average rejected confidence 910 911 - voteTable: vote table 912 913 - imgDir: (optional) the directory to be used to hold the vote 914 image (if constructed) 915 916 **Returns** 917 918 a string containing HTML 919 920 """ 921 if type(nGood) == tuple: 922 multModels = 1 923 else: 924 multModels = 0 925 926 if fullPage: 927 outTxt = ["""<html><body>"""] 928 outTxt.append('<center><h2>VOTE DETAILS</h2></center>') 929 else: 930 outTxt = [] 931 932 outTxt.append('<font>') 933 934 # Get the image 935 if not skipImg: 936 img = GetScreenImage(nGood, nBad, nRej) 937 if img: 938 if imgDir: 939 imgFileName = '/'.join((imgDir, 'votes.png')) 940 else: 941 imgFileName = 'votes.png' 942 img.save(imgFileName) 943 outTxt.append('<center><img src="%s"></center>' % (imgFileName)) 944 945 nPoss = len(voteTable) 946 pureCounts = numpy.sum(voteTable, 1) 947 accCounts = numpy.sum(voteTable, 0) 948 pureVect = numpy.zeros(nPoss, numpy.float) 949 accVect = numpy.zeros(nPoss, numpy.float) 950 for i in range(nPoss): 951 if pureCounts[i]: 952 pureVect[i] = float(voteTable[i, i]) / pureCounts[i] 953 if accCounts[i]: 954 accVect[i] = float(voteTable[i, i]) / accCounts[i] 955 956 outTxt.append('<center><table border=1>') 957 outTxt.append('<tr><td></td>') 958 for i in range(nPoss): 959 outTxt.append('<th>%d</th>' % i) 960 outTxt.append('<th>% Accurate</th>') 961 outTxt.append('</tr>') 962 # outTxt.append('<th rowspan=%d>Predicted</th></tr>'%(nPoss+1)) 963 for i in range(nPoss): 964 outTxt.append('<tr><th>%d</th>' % (i)) 965 for j in range(nPoss): 966 if i == j: 967 if not multModels: 968 outTxt.append('<td bgcolor="#A0A0FF">%d</td>' % (voteTable[j, i])) 969 else: 970 outTxt.append('<td bgcolor="#A0A0FF">%.2f</td>' % (voteTable[j, i])) 971 else: 972 if not multModels: 973 outTxt.append('<td>%d</td>' % (voteTable[j, i])) 974 else: 975 outTxt.append('<td>%.2f</td>' % (voteTable[j, i])) 976 outTxt.append('<td>%4.2f</td</tr>' % (100.0 * accVect[i])) 977 if i == 0: 978 outTxt.append('<th rowspan=%d>Predicted</th></tr>' % (nPoss)) 979 else: 980 outTxt.append('</tr>') 981 outTxt.append('<tr><th>% Pure</th>') 982 for i in range(nPoss): 983 outTxt.append('<td>%4.2f</td>' % (100.0 * pureVect[i])) 984 outTxt.append('</tr>') 985 outTxt.append('<tr><td></td><th colspan=%d>Original</th>' % (nPoss)) 986 outTxt.append('</table></center>') 987 988 if not multModels: 989 nTotal = nBad + nGood + nRej 990 nClass = nBad + nGood 991 if nClass: 992 pctErr = 100. * float(nBad) / nClass 993 else: 994 pctErr = 0.0 995 996 outTxt.append('<p>%d of %d examples were misclassified (%%%4.2f)' % 997 (nBad, nGood + nBad, pctErr)) 998 if nRej > 0: 999 pctErr = 100. * float(nBad) / (nGood + nBad + nRej) 1000 outTxt.append('<p> %d of %d overall: (%%%4.2f)' % (nBad, nTotal, pctErr)) 1001 pctRej = 100. * float(nRej) / nTotal 1002 outTxt.append('<p>%d of %d examples were rejected (%%%4.2f)' % (nRej, nTotal, pctRej)) 1003 if nGood != 0: 1004 outTxt.append('<p>The correctly classified examples had an average confidence of %6.4f' % 1005 avgGood) 1006 1007 if nBad != 0: 1008 outTxt.append('<p>The incorrectly classified examples had an average confidence of %6.4f' % 1009 avgBad) 1010 if nRej != 0: 1011 outTxt.append('<p>The rejected examples had an average confidence of %6.4f' % avgSkip) 1012 else: 1013 nTotal = nBad[0] + nGood[0] + nRej[0] 1014 nClass = nBad[0] + nGood[0] 1015 devClass = nBad[1] + nGood[1] 1016 if nClass: 1017 pctErr = 100. * float(nBad[0]) / nClass 1018 devPctErr = 100. * float(nBad[1]) / nClass 1019 else: 1020 pctErr = 0.0 1021 devPctErr = 0.0 1022 1023 outTxt.append('<p>%.2f(%.2f) of %.2f(%.2f) examples were misclassified (%%%4.2f(%4.2f))' % 1024 (nBad[0], nBad[1], nClass, devClass, pctErr, devPctErr)) 1025 if nRej > 0: 1026 pctErr = 100. * float(nBad[0]) / nTotal 1027 devPctErr = 100. * float(nBad[1]) / nTotal 1028 outTxt.append('<p> %.2f(%.2f) of %d overall: (%%%4.2f(%4.2f))' % 1029 (nBad[0], nBad[1], nTotal, pctErr, devPctErr)) 1030 pctRej = 100. * float(nRej[0]) / nTotal 1031 devPctRej = 100. * float(nRej[1]) / nTotal 1032 outTxt.append('<p>%.2f(%.2f) of %d examples were rejected (%%%4.2f(%4.2f))' % 1033 (nRej[0], nRej[1], nTotal, pctRej, devPctRej)) 1034 if nGood != 0: 1035 outTxt.append( 1036 '<p>The correctly classified examples had an average confidence of %6.4f(%.4f)' % avgGood) 1037 1038 if nBad != 0: 1039 outTxt.append( 1040 '<p>The incorrectly classified examples had an average confidence of %6.4f(%.4f)' % avgBad) 1041 if nRej != 0: 1042 outTxt.append('<p>The rejected examples had an average confidence of %6.4f(%.4f)' % avgSkip) 1043 1044 outTxt.append('</font>') 1045 if includeDefs: 1046 txt = """ 1047 <p><b>Definitions:</b> 1048 <ul> 1049 <li> <i>% Pure:</i> The percentage of, for example, known positives predicted to be positive. 1050 <li> <i>% Accurate:</i> The percentage of, for example, predicted positives that actually 1051 are positive. 1052 </ul> 1053 """ 1054 outTxt.append(txt) 1055 1056 if fullPage: 1057 outTxt.append("""</body></html>""") 1058 return '\n'.join(outTxt)
1059 1060
1061 -def MakePredPlot(details, indices, data, goodVotes, badVotes, nRes, idCol=0, verbose=0):
1062 """ 1063 1064 **Arguments** 1065 1066 - details: a CompositeRun.RunDetails object 1067 1068 - indices: a sequence of integer indices into _data_ 1069 1070 - data: the data set in question. We assume that the ids for 1071 the data points are in the _idCol_ column 1072 1073 - goodVotes/badVotes: predictions where the model was correct/incorrect. 1074 These are sequences of 4-tuples: 1075 (answer,prediction,confidence,index into _indices_) 1076 1077 """ 1078 if not hasattr(details, 'predPlot') or not details.predPlot: 1079 return 1080 1081 if verbose: 1082 message('\n-> Constructing Prediction (Hanneke) Plot') 1083 outF = open(details.predPlot, 'w+') 1084 gnuF = open('%s.gnu' % details.predPlot, 'w+') 1085 # first get the ids of the data points we screened: 1086 ptIds = [data[x][idCol] for x in indices] 1087 1088 # get a connection to the database we'll use to grab the continuous 1089 # activity values: 1090 origConn = DbConnect(details.dbName, details.tableName, user=details.dbUser, 1091 password=details.dbPassword) 1092 colNames = origConn.GetColumnNames() 1093 idName = colNames[idCol] 1094 if not hasattr(details, 'predActTable') or \ 1095 not details.predActTable or \ 1096 details.predActTable == details.tableName: 1097 actConn = origConn 1098 else: 1099 actConn = DbConnect(details.dbName, details.predActTable, user=details.dbUser, 1100 password=details.dbPassword) 1101 if verbose: 1102 message('\t-> Pulling Activity Data') 1103 1104 if type(ptIds[0]) not in [type(''), type(u'')]: 1105 ptIds = [str(x) for x in ptIds] 1106 whereL = [DbModule.placeHolder] * len(ptIds) 1107 if hasattr(details, 'predActCol') and details.predActCol: 1108 actColName = details.predActCol 1109 else: 1110 actColName = actConn.GetColumnNames()[-1] 1111 1112 whereTxt = "%s in (%s)" % (idName, ','.join(whereL)) 1113 rawD = actConn.GetData(fields='%s,%s' % (idName, actColName), where=whereTxt, extras=ptIds) 1114 # order the data returned: 1115 if verbose: 1116 message('\t-> Creating Plot') 1117 acts = [None] * len(ptIds) 1118 for entry in rawD: 1119 ID, act = entry 1120 idx = ptIds.index(ID) 1121 acts[idx] = act 1122 outF.write('#ID Pred Conf %s\n' % (actColName)) 1123 for ans, pred, conf, idx in goodVotes: 1124 act = acts[idx] 1125 if act != 'None': 1126 act = float(act) 1127 else: 1128 act = 0 1129 outF.write('%s %d %.4f %f\n' % (ptIds[idx], pred, conf, act)) 1130 for ans, pred, conf, idx in badVotes: 1131 act = acts[idx] 1132 if act != 'None': 1133 act = float(act) 1134 else: 1135 act = 0 1136 outF.write('%s %d %.4f %f\n' % (ptIds[idx], pred, conf, act)) 1137 outF.close() 1138 if not hasattr(details, 'predLogScale') or not details.predLogScale: 1139 actLabel = actColName 1140 else: 1141 actLabel = 'log(%s)' % (actColName) 1142 actLabel = actLabel.replace('_', ' ') 1143 gnuHdr = """# Generated by ScreenComposite.py version: %s 1144 set size square 0.7 1145 set yrange [:1] 1146 set data styl points 1147 set ylab 'confidence' 1148 set xlab '%s' 1149 set grid 1150 set nokey 1151 set term postscript enh color solid "Helvetica" 16 1152 set term X 1153 """ % (__VERSION_STRING, actLabel) 1154 gnuF.write(gnuHdr) 1155 plots = [] 1156 for i in range(nRes): 1157 if not hasattr(details, 'predLogScale') or not details.predLogScale: 1158 plots.append("'%s' us 4:($2==%d?$3:0/0)" % (details.predPlot, i)) 1159 else: 1160 plots.append("'%s' us (log10($4)):($2==%d?$3:0/0)" % (details.predPlot, i)) 1161 gnuF.write("plot %s\n" % (','.join(plots))) 1162 gnuTail = """ 1163 # EOF 1164 """ 1165 gnuF.write(gnuTail) 1166 gnuF.close() 1167 if hasattr(details, 'predShow') and details.predShow: 1168 try: 1169 try: 1170 from Gnuplot import Gnuplot 1171 except ImportError: 1172 raise ImportError('Functionality requires the Gnuplot module') 1173 p = Gnuplot() 1174 p('cd "%s"' % (os.getcwd())) 1175 p('load "%s.gnu"' % (details.predPlot)) 1176 input('press return to continue...\n') 1177 except Exception: 1178 import traceback 1179 traceback.print_exc()
1180 1181
1182 -def Go(details):
1183 pass
1184 1185
1186 -def SetDefaults(details=None):
1187 global _details 1188 if details is None: 1189 details = _details 1190 CompositeRun.SetDefaults(details) 1191 details.screenVoteTol = [0.] 1192 details.detailedScreen = 0 1193 details.doHoldout = 0 1194 details.doTraining = 0 1195 details.errorAnalysis = 0 1196 details.verbose = 0 1197 details.partialVote = 0 1198 return details
1199 1200
1201 -def Usage():
1202 """ prints a list of arguments for when this is used from the 1203 command line and then exits 1204 1205 """ 1206 print(__doc__) 1207 sys.exit(-1)
1208 1209
1210 -def ShowVersion(includeArgs=0):
1211 """ prints the version number of the program 1212 1213 """ 1214 print('This is ScreenComposite.py version %s' % (__VERSION_STRING)) 1215 if includeArgs: 1216 print('command line was:') 1217 print(' '.join(sys.argv))
1218 1219
1220 -def ParseArgs(details):
1221 import getopt 1222 try: 1223 args, extras = getopt.getopt(sys.argv[1:], 'EDd:t:VN:HThSRF:v:AX', ['predPlot=', 1224 'predActCol=', 1225 'predActTable=', 1226 'predLogScale', 1227 'predShow', 1228 'OOB', 1229 'pickleCol=', 1230 'enrich=', ]) 1231 except Exception: 1232 import traceback 1233 traceback.print_exc() 1234 Usage() 1235 1236 details.predPlot = '' 1237 details.predActCol = '' 1238 details.predActTable = '' 1239 details.predLogScale = '' 1240 details.predShow = 0 1241 details.errorEstimate = 0 1242 details.pickleCol = -1 1243 details.enrichTgt = -1 1244 for arg, val in args: 1245 if arg == '-d': 1246 details.dbName = val 1247 elif arg == '-D': 1248 details.detailedScreen = 1 1249 elif arg == '-t': 1250 details.partialVote = 1 1251 voteTol = eval(val) 1252 if type(voteTol) not in [type([]), type((1, 1))]: 1253 voteTol = [voteTol] 1254 for tol in voteTol: 1255 if tol > 1 or tol < 0: 1256 error('Voting threshold must be between 0 and 1') 1257 sys.exit(-2) 1258 details.screenVoteTol = voteTol 1259 elif arg == '-N': 1260 details.note = val 1261 elif arg == '-H': 1262 details.doTraining = 0 1263 details.doHoldout = 1 1264 elif arg == '-T': 1265 details.doHoldout = 0 1266 details.doTraining = 1 1267 elif arg == '-E': 1268 details.errorAnalysis = 1 1269 details.detailedScreen = 1 1270 elif arg == '-A': 1271 details.showAll = 1 1272 details.detailedScreen = 1 1273 elif arg == '-S': 1274 details.shuffleActivities = 1 1275 elif arg == '-R': 1276 details.randomActivities = 1 1277 elif arg == '-h': 1278 Usage() 1279 elif arg == '-F': 1280 details.filterFrac = float(val) 1281 elif arg == '-v': 1282 details.filterVal = float(val) 1283 elif arg == '-V': 1284 verbose = 1 1285 elif arg == '--predPlot': 1286 details.detailedScreen = 1 1287 details.predPlot = val 1288 elif arg == '--predActCol': 1289 details.predActCol = val 1290 elif arg == '--predActTable': 1291 details.predActTable = val 1292 elif arg == '--predLogScale': 1293 details.predLogScale = 1 1294 elif arg == '--predShow': 1295 details.predShow = 1 1296 elif arg == '--predShow': 1297 details.predShow = 1 1298 elif arg == '--OOB': 1299 details.errorEstimate = 1 1300 elif arg == '--pickleCol': 1301 details.pickleCol = int(val) - 1 1302 elif arg == '--enrich': 1303 details.enrichTgt = int(val) 1304 else: 1305 Usage() 1306 1307 if len(extras) < 1: 1308 Usage() 1309 return extras
1310 1311 1312 if __name__ == '__main__': 1313 details = SetDefaults() 1314 extras = ParseArgs(details) 1315 ShowVersion(includeArgs=1) 1316 1317 models = [] 1318 if details.note and details.dbName: 1319 tblName = extras[0] 1320 message('-> Retrieving models from database') 1321 conn = DbConnect(details.dbName, tblName) 1322 blobs = conn.GetData(fields='model', where="where note='%s'" % (details.note)) 1323 for blob in blobs: 1324 blob = blob[0] 1325 try: 1326 models.append(cPickle.loads(str(blob))) 1327 except Exception: 1328 import traceback 1329 traceback.print_exc() 1330 message('Model load failed') 1331 1332 else: 1333 message('-> Loading model') 1334 modelFile = open(extras[0], 'rb') 1335 models.append(cPickle.load(modelFile)) 1336 if not len(models): 1337 error('No composite models found') 1338 sys.exit(-1) 1339 else: 1340 message('-> Working with %d models.' % len(models)) 1341 1342 extras = extras[1:] 1343 1344 for fName in extras: 1345 if details.dbName != '': 1346 details.tableName = fName 1347 data = details.GetDataSet(pickleCol=details.pickleCol, 1348 pickleClass=DataStructs.ExplicitBitVect) 1349 else: 1350 data = DataUtils.BuildDataSet(fName) 1351 descNames = data.GetVarNames() 1352 nModels = len(models) 1353 screenResults = [None] * nModels 1354 dataSets = [None] * nModels 1355 message('-> Constructing and screening data sets') 1356 testIdx = list(range(data.GetNPts())) 1357 trainIdx = testIdx 1358 1359 for modelIdx in range(nModels): 1360 # tmpD = copy.deepcopy(data) 1361 tmpD = data 1362 model = models[modelIdx] 1363 message('.', noRet=1) 1364 1365 try: 1366 seed = model._randomSeed 1367 except AttributeError: 1368 pass 1369 else: 1370 DataUtils.InitRandomNumbers(seed) 1371 1372 if details.shuffleActivities or details.randomActivities: 1373 shuffle = details.shuffleActivities 1374 randomize = 1 1375 DataUtils.RandomizeActivities(tmpD, shuffle=details.shuffleActivities, runDetails=details) 1376 else: 1377 randomize = False 1378 shuffle = False 1379 1380 if hasattr(model, '_shuffleActivities') and \ 1381 model._shuffleActivities and \ 1382 not shuffle: 1383 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 1384 message('****** WARNING: Shuffled model being screened with unshuffled data.') 1385 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 1386 if hasattr(model, '_randomizeActivities') and \ 1387 model._randomizeActivities and \ 1388 not randomize: 1389 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 1390 message('****** WARNING: Random model being screened with non-random data.') 1391 message('*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*!*') 1392 1393 trainIdx, testIdx = PrepareDataFromDetails(model, details, tmpD, verbose=1) 1394 screenResults[modelIdx] = CollectResults(testIdx, tmpD, model, 1395 errorEstimate=details.errorEstimate) 1396 dataSets[modelIdx] = testIdx 1397 for tol in details.screenVoteTol: 1398 if len(details.screenVoteTol) > 1: 1399 message('\n-----*****-----*****-----*****-----*****-----*****-----*****-----\n') 1400 message('Tolerance: %f' % tol) 1401 nGood = numpy.zeros(nModels, numpy.float) 1402 nBad = numpy.zeros(nModels, numpy.float) 1403 nSkip = numpy.zeros(nModels, numpy.float) 1404 confGood = numpy.zeros(nModels, numpy.float) 1405 confBad = numpy.zeros(nModels, numpy.float) 1406 confSkip = numpy.zeros(nModels, numpy.float) 1407 if details.enrichTgt >= 0: 1408 enrichments = numpy.zeros(nModels, numpy.float) 1409 goodVoteDict = {} 1410 badVoteDict = {} 1411 noVoteDict = {} 1412 voteTab = None 1413 for modelIdx in range(nModels): 1414 model = models[modelIdx] 1415 model.SetInputOrder(descNames) 1416 testIdx = dataSets[modelIdx] 1417 screenRes = screenResults[modelIdx] 1418 if not details.detailedScreen: 1419 g, b, s, aG, aB, aS, vT = ScreenIt(model, testIdx, tmpD, details.partialVote, tol, 1420 verbose=details.verbose, screenResults=screenRes) 1421 else: 1422 if model.GetActivityQuantBounds(): 1423 nRes = len(model.GetActivityQuantBounds()) + 1 1424 else: 1425 nRes = model.GetQuantBounds()[1][-1] 1426 badVotes = [] 1427 noVotes = [] 1428 if (hasattr(details, 'showAll') and details.showAll) or \ 1429 (hasattr(details, 'predPlot') and details.predPlot): 1430 goodVotes = [] 1431 else: 1432 goodVotes = None 1433 g, b, s, aG, aB, aS, vT = ShowVoteResults( 1434 testIdx, tmpD, model, nRes, tol, verbose=details.verbose, screenResults=screenRes, 1435 badVotes=badVotes, noVotes=noVotes, goodVotes=goodVotes, 1436 errorEstimate=details.errorEstimate) 1437 if voteTab is None: 1438 voteTab = numpy.zeros(vT.shape, numpy.float) 1439 if details.errorAnalysis: 1440 for a, p, c, idx in badVotes: 1441 label = testIdx[idx] 1442 if hasattr(details, 'enrichTgt') and details.enrichTgt >= 0: 1443 if a == details.enrichTgt: 1444 badVoteDict[label] = badVoteDict.get(label, 0) + 1 1445 else: 1446 badVoteDict[label] = badVoteDict.get(label, 0) + 1 1447 for a, p, c, idx in noVotes: 1448 label = testIdx[idx] 1449 if hasattr(details, 'enrichTgt') and details.enrichTgt >= 0: 1450 if a == details.enrichTgt: 1451 noVoteDict[label] = noVoteDict.get(label, 0) + 1 1452 else: 1453 noVoteDict[label] = noVoteDict.get(label, 0) + 1 1454 1455 if hasattr(details, 'showAll') and details.showAll: 1456 for a, p, c, idx in goodVotes: 1457 label = testIdx[idx] 1458 if details.enrichTgt >= 0: 1459 if a == details.enrichTgt: 1460 goodVoteDict[label] = goodVoteDict.get(label, 0) + 1 1461 else: 1462 goodVoteDict[label] = goodVoteDict.get(label, 0) + 1 1463 1464 if details.enrichTgt > -1: 1465 enrichments[modelIdx] = CalcEnrichment(vT, tgt=details.enrichTgt) 1466 1467 voteTab += vT 1468 if details.detailedScreen and hasattr(details, 'predPlot') and details.predPlot: 1469 MakePredPlot(details, testIdx, tmpD, goodVotes, badVotes, nRes, verbose=1) 1470 1471 if hasattr(details, 'showAll') and details.showAll: 1472 print('-v-v-v-v-v-v-v- All Votes -v-v-v-v-v-v-v-') 1473 print('id, prediction, confidence, flag(-1=skipped,0=wrong,1=correct)') 1474 for ans, pred, conf, idx in goodVotes: 1475 pt = tmpD[testIdx[idx]] 1476 assert model.GetActivityQuantBounds() or pt[-1] == ans, 'bad point?: %s != %s' % ( 1477 str(pt[-1]), str(ans)) 1478 print('%s, %d, %.4f, 1' % (str(pt[0]), pred, conf)) 1479 for ans, pred, conf, idx in badVotes: 1480 pt = tmpD[testIdx[idx]] 1481 assert model.GetActivityQuantBounds() or pt[-1] == ans, 'bad point?: %s != %s' % ( 1482 str(pt[-1]), str(ans)) 1483 print('%s, %d, %.4f, 0' % (str(pt[0]), pred, conf)) 1484 for ans, pred, conf, idx in noVotes: 1485 pt = tmpD[testIdx[idx]] 1486 assert model.GetActivityQuantBounds() or pt[-1] == ans, 'bad point?: %s != %s' % ( 1487 str(pt[-1]), str(ans)) 1488 print('%s, %d, %.4f, -1' % (str(pt[0]), pred, conf)) 1489 print('-^-^-^-^-^-^-^- -^-^-^-^-^-^-^-') 1490 1491 nGood[modelIdx] = g 1492 nBad[modelIdx] = b 1493 nSkip[modelIdx] = s 1494 confGood[modelIdx] = aG 1495 confBad[modelIdx] = aB 1496 confSkip[modelIdx] = aS 1497 print() 1498 1499 if nModels > 1: 1500 print('-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*') 1501 print('AVERAGES:') 1502 1503 avgNBad = sum(nBad) / nModels 1504 devNBad = numpy.sqrt(sum((nBad - avgNBad)**2) / (nModels - 1)) 1505 1506 bestIdx = numpy.argsort(nBad)[0] 1507 1508 avgNGood = sum(nGood) / nModels 1509 devNGood = numpy.sqrt(sum((nGood - avgNGood)**2) / (nModels - 1)) 1510 1511 avgNSkip = sum(nSkip) / nModels 1512 devNSkip = numpy.sqrt(sum((nSkip - avgNSkip)**2) / (nModels - 1)) 1513 1514 avgConfBad = sum(confBad) / nModels 1515 devConfBad = numpy.sqrt(sum((confBad - avgConfBad)**2) / (nModels - 1)) 1516 1517 avgConfGood = sum(confGood) / nModels 1518 devConfGood = numpy.sqrt(sum((confGood - avgConfGood)**2) / (nModels - 1)) 1519 1520 avgConfSkip = sum(confSkip) / nModels 1521 devConfSkip = numpy.sqrt(sum((confSkip - avgConfSkip)**2) / (nModels - 1)) 1522 1523 nClassified = avgNGood + avgNBad 1524 nExamples = nClassified + avgNSkip 1525 print('Misclassifications: \t%%%5.2f(%%%5.2f) %4.1f(%4.1f) / %d' % 1526 (100 * avgNBad / nExamples, 100 * devNBad / nExamples, avgNBad, devNBad, nExamples)) 1527 if avgNSkip > 0: 1528 print('\tthreshold: \t%%%5.2f(%%%5.2f) %4.1f(%4.1f) / %d' % 1529 (100 * avgNBad / nClassified, 100 * devNBad / nClassified, avgNBad, devNBad, 1530 nClassified)) 1531 print() 1532 print('Number Skipped: %%%4.2f(%%%4.2f) %4.2f(%4.2f)' % 1533 (100 * avgNSkip / nExamples, 100 * devNSkip / nExamples, avgNSkip, devNSkip)) 1534 1535 print() 1536 print('Confidences:') 1537 print('\tCorrect: \t%4.2f(%4.2f)' % (100 * avgConfGood, 100 * devConfGood)) 1538 print('\tIncorrect: \t%4.2f(%4.2f)' % (100 * avgConfBad, 100 * devConfBad)) 1539 if avgNSkip > 0: 1540 print('\tSkipped: \t%4.2f(%4.2f)' % (100 * avgConfSkip, 100 * devConfSkip)) 1541 1542 if details.detailedScreen: 1543 message('Results Table:') 1544 voteTab = numpy.transpose(voteTab) / nModels 1545 nResultCodes = len(voteTab) 1546 colCounts = numpy.sum(voteTab, 0) 1547 rowCounts = numpy.sum(voteTab, 1) 1548 print() 1549 for i in range(nResultCodes): 1550 if rowCounts[i] == 0: 1551 rowCounts[i] = 1 1552 row = voteTab[i] 1553 message(' ', noRet=1) 1554 for j in range(nResultCodes): 1555 entry = row[j] 1556 message(' % 6.2f' % entry, noRet=1) 1557 message(' | % 4.2f' % (100. * voteTab[i, i] / rowCounts[i])) 1558 message(' ', noRet=1) 1559 for i in range(nResultCodes): 1560 message('-------', noRet=1) 1561 message('') 1562 message(' ', noRet=1) 1563 for i in range(nResultCodes): 1564 if colCounts[i] == 0: 1565 colCounts[i] = 1 1566 message(' % 6.2f' % (100. * voteTab[i, i] / colCounts[i]), noRet=1) 1567 message('') 1568 if details.enrichTgt > -1: 1569 mean = sum(enrichments) / nModels 1570 enrichments -= mean 1571 dev = numpy.sqrt(sum(enrichments * enrichments)) / (nModels - 1) 1572 message(' Enrichment of value %d: %.4f (%.4f)' % (details.enrichTgt, mean, dev)) 1573 else: 1574 bestIdx = 0 1575 print('------------------------------------------------') 1576 print('Best Model: ', bestIdx + 1) 1577 bestBad = nBad[bestIdx] 1578 bestGood = nGood[bestIdx] 1579 bestSkip = nSkip[bestIdx] 1580 nClassified = bestGood + bestBad 1581 nExamples = nClassified + bestSkip 1582 print('Misclassifications: \t%%%5.2f %d / %d' % (100 * bestBad / nExamples, bestBad, 1583 nExamples)) 1584 if bestSkip > 0: 1585 print('\tthreshold: \t%%%5.2f %d / %d' % (100 * bestBad / nClassified, bestBad, 1586 nClassified)) 1587 print() 1588 print('Number Skipped: %%%4.2f %d' % (100 * bestSkip / nExamples, bestSkip)) 1589 1590 print() 1591 print('Confidences:') 1592 print('\tCorrect: \t%4.2f' % (100 * confGood[bestIdx])) 1593 print('\tIncorrect: \t%4.2f' % (100 * confBad[bestIdx])) 1594 if bestSkip > 0: 1595 print('\tSkipped: \t%4.2f' % (100 * confSkip[bestIdx])) 1596 1597 if nModels == 1 and details.detailedScreen: 1598 message('') 1599 message('Results Table:') 1600 voteTab = numpy.transpose(vT) 1601 nResultCodes = len(vT) 1602 colCounts = numpy.sum(voteTab, 0) 1603 rowCounts = numpy.sum(voteTab, 1) 1604 message('') 1605 for i in range(nResultCodes): 1606 if rowCounts[i] == 0: 1607 rowCounts[i] = 1 1608 row = voteTab[i] 1609 message(' ', noRet=1) 1610 for j in range(nResultCodes): 1611 entry = row[j] 1612 message(' % 6.2f' % entry, noRet=1) 1613 message(' | % 4.2f' % (100. * voteTab[i, i] / rowCounts[i])) 1614 message(' ', noRet=1) 1615 for i in range(nResultCodes): 1616 message('-------', noRet=1) 1617 message('') 1618 message(' ', noRet=1) 1619 for i in range(nResultCodes): 1620 if colCounts[i] == 0: 1621 colCounts[i] = 1 1622 message(' % 6.2f' % (100. * voteTab[i, i] / colCounts[i]), noRet=1) 1623 message('') 1624 if details.errorAnalysis: 1625 message('\n*-*-*-*-*-*-*-*- ERROR ANALYSIS -*-*-*-*-*-*-*-*\n') 1626 ks = badVoteDict.keys() 1627 if len(ks): 1628 message(' ---> Bad Vote Counts') 1629 ks = noVoteDict.keys() 1630 if len(ks): 1631 message(' ---> Skipped Compound Counts') 1632 for k in ks: 1633 pt = data[k] 1634 message('%s,%d' % (str(pt[0]), noVoteDict[k])) 1635 1636 if hasattr(details, 'showAll') and details.showAll: 1637 ks = goodVoteDict.keys() 1638 if len(ks): 1639 message(' ---> Good Vote Counts') 1640 for k in ks: 1641 pt = data[k] 1642 message('%s,%d' % (str(pt[0]), goodVoteDict[k])) 1643