Package rdkit :: Package ML :: Package Composite :: Module Composite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Composite.Composite

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
  4  #   All Rights Reserved 
  5  # 
  6  """ code for dealing with composite models 
  7   
  8  For a model to be useable here, it should support the following API: 
  9   
 10    - _ClassifyExample(example)_, returns a classification 
 11   
 12  Other compatibility notes: 
 13   
 14   1) To use _Composite.Grow_ there must be some kind of builder 
 15      functionality which returns a 2-tuple containing (model,percent accuracy). 
 16   
 17   2) The models should be pickleable 
 18   
 19   3) It would be very happy if the models support the __cmp__ method so that 
 20      membership tests used to make sure models are unique work. 
 21   
 22   
 23   
 24  """ 
 25  from __future__ import print_function 
 26  import numpy 
 27  from rdkit.six.moves import cPickle 
 28  from rdkit.ML.Data import DataUtils 
 29   
 30   
31 -class Composite(object):
32 """a composite model 33 34 35 **Notes** 36 37 - adding a model which is already present just results in its count 38 field being incremented and the errors being averaged. 39 40 - typical usage: 41 42 1) grow the composite with AddModel until happy with it 43 44 2) call AverageErrors to calculate the average error values 45 46 3) call SortModels to put things in order by either error or count 47 48 - Composites can support individual models requiring either quantized or 49 nonquantized data. This is done by keeping a set of quantization bounds 50 (_QuantBounds_) in the composite and quantizing data passed in when required. 51 Quantization bounds can be set and interrogated using the 52 _Get/SetQuantBounds()_ methods. When models are added to the composite, 53 it can be indicated whether or not they require quantization. 54 55 - Composites are also capable of extracting relevant variables from longer lists. 56 This is accessible using _SetDescriptorNames()_ to register the descriptors about 57 which the composite cares and _SetInputOrder()_ to tell the composite what the 58 ordering of input vectors will be. **Note** there is a limitation on this: each 59 model needs to take the same set of descriptors as inputs. This could be changed. 60 61 """ 62
63 - def __init__(self):
64 self.modelList = [] 65 self.errList = [] 66 self.countList = [] 67 self.modelVotes = [] 68 self.quantBounds = None 69 self.nPossibleVals = None 70 self.quantizationRequirements = [] 71 self._descNames = [] 72 self._mapOrder = None 73 self.activityQuant = []
74
75 - def SetModelFilterData(self, modelFilterFrac=0.0, modelFilterVal=0.0):
76 self._modelFilterFrac = modelFilterFrac 77 self._modelFilterVal = modelFilterVal
78
79 - def SetDescriptorNames(self, names):
80 """ registers the names of the descriptors this composite uses 81 82 **Arguments** 83 84 - names: a list of descriptor names (strings). 85 86 **NOTE** 87 88 the _names_ list is not 89 copied, so if you modify it later, the composite itself will also be modified. 90 91 """ 92 self._descNames = names
93
94 - def GetDescriptorNames(self):
95 """ returns the names of the descriptors this composite uses 96 97 """ 98 return self._descNames
99
100 - def SetQuantBounds(self, qBounds, nPossible=None):
101 """ sets the quantization bounds that the composite will use 102 103 **Arguments** 104 105 - qBounds: a list of quantization bounds, each quantbound is a 106 list of boundaries 107 108 - nPossible: a list of integers indicating how many possible values 109 each descriptor can take on. 110 111 **NOTE** 112 113 - if the two lists are of different lengths, this will assert out 114 115 - neither list is copied, so if you modify it later, the composite 116 itself will also be modified. 117 118 """ 119 if nPossible is not None: 120 assert len(qBounds) == len(nPossible), 'qBounds/nPossible mismatch' 121 self.quantBounds = qBounds 122 self.nPossibleVals = nPossible
123
124 - def GetQuantBounds(self):
125 """ returns the quantization bounds 126 127 **Returns** 128 129 a 2-tuple consisting of: 130 131 1) the list of quantization bounds 132 133 2) the nPossibleVals list 134 135 """ 136 return self.quantBounds, self.nPossibleVals
137
138 - def GetActivityQuantBounds(self):
139 if not hasattr(self, 'activityQuant'): 140 self.activityQuant = [] 141 return self.activityQuant
142
143 - def SetActivityQuantBounds(self, bounds):
144 self.activityQuant = bounds
145
146 - def QuantizeActivity(self, example, activityQuant=None, actCol=-1):
147 if activityQuant is None: 148 activityQuant = self.activityQuant 149 if activityQuant: 150 example = example[:] 151 act = example[actCol] 152 for box in range(len(activityQuant)): 153 if act < activityQuant[box]: 154 act = box 155 break 156 else: 157 act = box + 1 158 example[actCol] = act 159 return example
160
161 - def QuantizeExample(self, example, quantBounds=None):
162 """ quantizes an example 163 164 **Arguments** 165 166 - example: a data point (list, tuple or numpy array) 167 168 - quantBounds: a list of quantization bounds, each quantbound is a 169 list of boundaries. If this argument is not provided, the composite 170 will use its own quantBounds 171 172 **Returns** 173 174 the quantized example as a list 175 176 **Notes** 177 178 - If _example_ is different in length from _quantBounds_, this will 179 assert out. 180 181 - This is primarily intended for internal use 182 183 """ 184 if quantBounds is None: 185 quantBounds = self.quantBounds 186 assert len(example) == len(quantBounds), 'example/quantBounds mismatch' 187 quantExample = [None] * len(example) 188 for i in range(len(quantBounds)): 189 bounds = quantBounds[i] 190 p = example[i] 191 if len(bounds): 192 for box in range(len(bounds)): 193 if p < bounds[box]: 194 p = box 195 break 196 else: 197 p = box + 1 198 else: 199 if i != 0: 200 p = int(p) 201 quantExample[i] = p 202 return quantExample
203
204 - def MakeHistogram(self):
205 """ creates a histogram of error/count pairs 206 207 **Returns** 208 209 the histogram as a series of (error, count) 2-tuples 210 211 """ 212 nExamples = len(self.modelList) 213 histo = [] 214 i = 1 215 lastErr = self.errList[0] 216 countHere = self.countList[0] 217 eps = 0.001 218 while i < nExamples: 219 if self.errList[i] - lastErr > eps: 220 histo.append((lastErr, countHere)) 221 lastErr = self.errList[i] 222 countHere = self.countList[i] 223 else: 224 countHere = countHere + self.countList[i] 225 i = i + 1 226 227 return histo
228
229 - def CollectVotes(self, example, quantExample, appendExample=0, onlyModels=None):
230 """ collects votes across every member of the composite for the given example 231 232 **Arguments** 233 234 - example: the example to be voted upon 235 236 - quantExample: the quantized form of the example 237 238 - appendExample: toggles saving the example on the models 239 240 - onlyModels: if provided, this should be a sequence of model 241 indices. Only the specified models will be used in the 242 prediction. 243 244 **Returns** 245 246 a list with a vote from each member 247 248 """ 249 if not onlyModels: 250 onlyModels = list(range(len(self))) 251 252 votes = [-1] * len(self) 253 for i in onlyModels: 254 if self.quantizationRequirements[i]: 255 votes[i] = int( 256 round(self.modelList[i].ClassifyExample(quantExample, appendExamples=appendExample))) 257 else: 258 votes[i] = int( 259 round(self.modelList[i].ClassifyExample(example, appendExamples=appendExample))) 260 261 return votes
262
263 - def ClassifyExample(self, example, threshold=0, appendExample=0, onlyModels=None):
264 """ classifies the given example using the entire composite 265 266 **Arguments** 267 268 - example: the data to be classified 269 270 - threshold: if this is a number greater than zero, then a 271 classification will only be returned if the confidence is 272 above _threshold_. Anything lower is returned as -1. 273 274 - appendExample: toggles saving the example on the models 275 276 - onlyModels: if provided, this should be a sequence of model 277 indices. Only the specified models will be used in the 278 prediction. 279 280 **Returns** 281 282 a (result,confidence) tuple 283 284 285 **FIX:** 286 statistics sucks... I'm not seeing an obvious way to get 287 the confidence intervals. For that matter, I'm not seeing 288 an unobvious way. 289 290 For now, this is just treated as a voting problem with the confidence 291 measure being the percent of models which voted for the winning result. 292 293 """ 294 if self._mapOrder is not None: 295 example = self._RemapInput(example) 296 if self.GetActivityQuantBounds(): 297 example = self.QuantizeActivity(example) 298 if self.quantBounds is not None and 1 in self.quantizationRequirements: 299 quantExample = self.QuantizeExample(example, self.quantBounds) 300 else: 301 quantExample = [] 302 303 if not onlyModels: 304 onlyModels = list(range(len(self))) 305 self.modelVotes = self.CollectVotes(example, quantExample, appendExample=appendExample, 306 onlyModels=onlyModels) 307 308 votes = [0] * self.nPossibleVals[-1] 309 for i in onlyModels: 310 res = self.modelVotes[i] 311 votes[res] = votes[res] + self.countList[i] 312 313 totVotes = sum(votes) 314 res = numpy.argmax(votes) 315 conf = float(votes[res]) / float(totVotes) 316 if conf > threshold: 317 return res, conf 318 else: 319 return -1, conf
320
321 - def GetVoteDetails(self):
322 """ returns the votes from the last classification 323 324 This will be _None_ if nothing has yet be classified 325 """ 326 return self.modelVotes
327
328 - def _RemapInput(self, inputVect):
329 """ remaps the input so that it matches the expected internal ordering 330 331 **Arguments** 332 333 - inputVect: the input to be reordered 334 335 **Returns** 336 337 - a list with the reordered (and possible shorter) data 338 339 **Note** 340 341 - you must call _SetDescriptorNames()_ and _SetInputOrder()_ for this to work 342 343 - this is primarily intended for internal use 344 345 """ 346 order = self._mapOrder 347 348 if order is None: 349 return inputVect 350 remappedInput = [None] * len(order) 351 352 for i in range(len(order) - 1): 353 remappedInput[i] = inputVect[order[i]] 354 if order[-1] == -1: 355 remappedInput[-1] = 0 356 else: 357 remappedInput[-1] = inputVect[order[-1]] 358 return remappedInput
359
360 - def GetInputOrder(self):
361 """ returns the input order (used in remapping inputs) 362 363 """ 364 return self._mapOrder
365
366 - def SetInputOrder(self, colNames):
367 """ sets the input order 368 369 **Arguments** 370 371 - colNames: a list of the names of the data columns that will be passed in 372 373 **Note** 374 375 - you must call _SetDescriptorNames()_ first for this to work 376 377 - if the local descriptor names do not appear in _colNames_, this will 378 raise an _IndexError_ exception. 379 """ 380 if type(colNames) != list: 381 colNames = list(colNames) 382 descs = [x.upper() for x in self.GetDescriptorNames()] 383 self._mapOrder = [None] * len(descs) 384 colNames = [x.upper() for x in colNames] 385 386 # FIX: I believe that we're safe assuming that field 0 387 # is always the label, and therefore safe to ignore errors, 388 # but this may not be the case 389 try: 390 self._mapOrder[0] = colNames.index(descs[0]) 391 except ValueError: 392 self._mapOrder[0] = 0 393 394 for i in range(1, len(descs) - 1): 395 try: 396 self._mapOrder[i] = colNames.index(descs[i]) 397 except ValueError: 398 raise ValueError('cannot find descriptor name: %s in set %s' % 399 (repr(descs[i]), repr(colNames))) 400 try: 401 self._mapOrder[-1] = colNames.index(descs[-1]) 402 except ValueError: 403 # ok, there's no obvious match for the final column (activity) 404 # We'll take the last one: 405 # self._mapOrder[-1] = len(descs)-1 406 self._mapOrder[-1] = -1
407
408 - def Grow(self, examples, attrs, nPossibleVals, buildDriver, pruner=None, nTries=10, pruneIt=0, 409 needsQuantization=1, progressCallback=None, **buildArgs):
410 """ Grows the composite 411 412 **Arguments** 413 414 - examples: a list of examples to be used in training 415 416 - attrs: a list of the variables to be used in training 417 418 - nPossibleVals: this is used to provide a list of the number 419 of possible values for each variable. It is used if the 420 local quantBounds have not been set (for example for when you 421 are working with data which is already quantized). 422 423 - buildDriver: the function to call to build the new models 424 425 - pruner: a function used to "prune" (reduce the complexity of) 426 the resulting model. 427 428 - nTries: the number of new models to add 429 430 - pruneIt: toggles whether or not pruning is done 431 432 - needsQuantization: used to indicate whether or not this type of model 433 requires quantized data 434 435 - **buildArgs: all other keyword args are passed to _buildDriver_ 436 437 **Note** 438 439 - new models are *added* to the existing ones 440 441 """ 442 silent = buildArgs.get('silent', 0) 443 buildArgs['silent'] = 1 444 buildArgs['calcTotalError'] = 1 445 446 if self._mapOrder is not None: 447 examples = map(self._RemapInput, examples) 448 if self.GetActivityQuantBounds(): 449 for i in range(len(examples)): 450 examples[i] = self.QuantizeActivity(examples[i]) 451 nPossibleVals[-1] = len(self.GetActivityQuantBounds()) + 1 452 if self.nPossibleVals is None: 453 self.nPossibleVals = nPossibleVals[:] 454 if needsQuantization: 455 trainExamples = [None] * len(examples) 456 nPossibleVals = self.nPossibleVals 457 for i in range(len(examples)): 458 trainExamples[i] = self.QuantizeExample(examples[i], self.quantBounds) 459 else: 460 trainExamples = examples 461 462 for i in range(nTries): 463 trainSet = None 464 465 if (hasattr(self, '_modelFilterFrac')) and (self._modelFilterFrac != 0): 466 trainIdx, _ = DataUtils.FilterData(trainExamples, self._modelFilterVal, 467 self._modelFilterFrac, -1, indicesOnly=1) 468 trainSet = [trainExamples[x] for x in trainIdx] 469 470 else: 471 trainSet = trainExamples 472 473 # print("Training model %i with %i out of %i examples"%(i, len(trainSet), len(trainExamples))) 474 model, frac = buildDriver(*(trainSet, attrs, nPossibleVals), **buildArgs) 475 if pruneIt: 476 model, frac2 = pruner(model, model.GetTrainingExamples(), model.GetTestExamples(), 477 minimizeTestErrorOnly=0) 478 frac = frac2 479 if (hasattr(self, '_modelFilterFrac') and self._modelFilterFrac != 0 and 480 hasattr(model, '_trainIndices')): 481 # correct the model's training indices: 482 trainIndices = [trainIdx[x] for x in model._trainIndices] 483 model._trainIndices = trainIndices 484 485 self.AddModel(model, frac, needsQuantization) 486 if not silent and (nTries < 10 or i % (nTries / 10) == 0): 487 print('Cycle: % 4d' % (i)) 488 if progressCallback is not None: 489 progressCallback(i)
490
491 - def ClearModelExamples(self):
492 for i in range(len(self)): 493 m = self.GetModel(i) 494 try: 495 m.ClearExamples() 496 except AttributeError: 497 pass
498
499 - def Pickle(self, fileName='foo.pkl', saveExamples=0):
500 """ Writes this composite off to a file so that it can be easily loaded later 501 502 **Arguments** 503 504 - fileName: the name of the file to be written 505 506 - saveExamples: if this is zero, the individual models will have 507 their stored examples cleared. 508 509 """ 510 if not saveExamples: 511 self.ClearModelExamples() 512 513 pFile = open(fileName, 'wb+') 514 cPickle.dump(self, pFile, 1) 515 pFile.close()
516
517 - def AddModel(self, model, error, needsQuantization=1):
518 """ Adds a model to the composite 519 520 **Arguments** 521 522 - model: the model to be added 523 524 - error: the model's error 525 526 - needsQuantization: a toggle to indicate whether or not this model 527 requires quantized inputs 528 529 **NOTE** 530 531 - this can be used as an alternative to _Grow()_ if you already have 532 some models constructed 533 534 - the errList is run as an accumulator, 535 you probably want to call _AverageErrors_ after finishing the forest 536 537 """ 538 if model in self.modelList: 539 try: 540 idx = self.modelList.index(model) 541 except ValueError: 542 # FIX: we should never get here, but sometimes we do anyway 543 self.modelList.append(model) 544 self.errList.append(error) 545 self.countList.append(1) 546 self.quantizationRequirements.append(needsQuantization) 547 else: 548 self.errList[idx] = self.errList[idx] + error 549 self.countList[idx] = self.countList[idx] + 1 550 else: 551 self.modelList.append(model) 552 self.errList.append(error) 553 self.countList.append(1) 554 self.quantizationRequirements.append(needsQuantization)
555
556 - def AverageErrors(self):
557 """ convert local summed error to average error 558 559 """ 560 self.errList = list(map(lambda x, y: x / y, self.errList, self.countList))
561
562 - def SortModels(self, sortOnError=True):
563 """ sorts the list of models 564 565 **Arguments** 566 567 sortOnError: toggles sorting on the models' errors rather than their counts 568 569 570 """ 571 if sortOnError: 572 order = numpy.argsort(self.errList) 573 else: 574 order = numpy.argsort(self.countList) 575 576 # these elaborate contortions are required because, at the time this 577 # code was written, Numeric arrays didn't unpickle so well... 578 # print(order,sortOnError,self.errList,self.countList) 579 self.modelList = [self.modelList[x] for x in order] 580 self.countList = [self.countList[x] for x in order] 581 self.errList = [self.errList[x] for x in order]
582
583 - def GetModel(self, i):
584 """ returns a particular model 585 586 """ 587 return self.modelList[i]
588
589 - def SetModel(self, i, val):
590 """ replaces a particular model 591 592 **Note** 593 594 This is included for the sake of completeness, but you need to be 595 *very* careful when you use it. 596 597 """ 598 self.modelList[i] = val
599
600 - def GetCount(self, i):
601 """ returns the count of the _i_th model 602 603 """ 604 return self.countList[i]
605
606 - def SetCount(self, i, val):
607 """ sets the count of the _i_th model 608 609 """ 610 self.countList[i] = val
611
612 - def GetError(self, i):
613 """ returns the error of the _i_th model 614 615 """ 616 return self.errList[i]
617
618 - def SetError(self, i, val):
619 """ sets the error of the _i_th model 620 621 """ 622 self.errList[i] = val
623
624 - def GetDataTuple(self, i):
625 """ returns all relevant data about a particular model 626 627 **Arguments** 628 629 i: an integer indicating which model should be returned 630 631 **Returns** 632 633 a 3-tuple consisting of: 634 635 1) the model 636 637 2) its count 638 639 3) its error 640 """ 641 return (self.modelList[i], self.countList[i], self.errList[i])
642
643 - def SetDataTuple(self, i, tup):
644 """ sets all relevant data for a particular tree in the forest 645 646 **Arguments** 647 648 - i: an integer indicating which model should be returned 649 650 - tup: a 3-tuple consisting of: 651 652 1) the model 653 654 2) its count 655 656 3) its error 657 658 **Note** 659 660 This is included for the sake of completeness, but you need to be 661 *very* careful when you use it. 662 663 """ 664 self.modelList[i], self.countList[i], self.errList[i] = tup
665
666 - def GetAllData(self):
667 """ Returns everything we know 668 669 **Returns** 670 671 a 3-tuple consisting of: 672 673 1) our list of models 674 675 2) our list of model counts 676 677 3) our list of model errors 678 679 """ 680 return (self.modelList, self.countList, self.errList)
681
682 - def __len__(self):
683 """ allows len(composite) to work 684 685 """ 686 return len(self.modelList)
687
688 - def __getitem__(self, which):
689 """ allows composite[i] to work, returns the data tuple 690 691 """ 692 return self.GetDataTuple(which)
693
694 - def __str__(self):
695 """ returns a string representation of the composite 696 697 """ 698 outStr = 'Composite\n' 699 for i in range(len(self.modelList)): 700 outStr = (outStr + ' Model %4d: %5d occurances %%%5.2f average error\n' % 701 (i, self.countList[i], 100. * self.errList[i])) 702 return outStr
703 704 705 if __name__ == '__main__': # pragma: nocover 706 if 0: 707 from rdkit.ML.DecTree import DecTree 708 c = Composite() 709 n = DecTree.DecTreeNode(None, 'foo') 710 c.AddModel(n, 0.5) 711 c.AddModel(n, 0.5) 712 c.AverageErrors() 713 c.SortModels() 714 print(c) 715 716 qB = [[], [.5, 1, 1.5]] 717 exs = [['foo', 0], ['foo', .4], ['foo', .6], ['foo', 1.1], ['foo', 2.0]] 718 print('quantBounds:', qB) 719 for ex in exs: 720 q = c.QuantizeExample(ex, qB) 721 print(ex, q) 722 else: 723 pass 724