Package rdkit :: Package ML :: Module AnalyzeComposite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.AnalyzeComposite

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2002-2008  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11  """ command line utility to report on the contributions of descriptors to 
 12  tree-based composite models 
 13   
 14  Usage:  AnalyzeComposite [optional args] <models> 
 15   
 16        <models>: file name(s) of pickled composite model(s) 
 17          (this is the name of the db table if using a database) 
 18   
 19      Optional Arguments: 
 20   
 21        -n number: the number of levels of each model to consider 
 22   
 23        -d dbname: the database from which to read the models 
 24   
 25        -N Note: the note string to search for to pull models from the database 
 26   
 27        -v: be verbose whilst screening 
 28  """ 
 29  from __future__ import print_function 
 30   
 31  import sys 
 32   
 33  import numpy 
 34   
 35  from rdkit.Dbase.DbConnection import DbConnect 
 36  from rdkit.ML import ScreenComposite 
 37  from rdkit.ML.Data import Stats 
 38  from rdkit.ML.DecTree import TreeUtils, Tree 
 39  from rdkit.six.moves import cPickle 
 40   
 41   
 42  __VERSION_STRING = "2.2.0" 
 43   
 44   
45 -def ProcessIt(composites, nToConsider=3, verbose=0):
46 composite = composites[0] 47 nComposites = len(composites) 48 ns = composite.GetDescriptorNames() 49 # nDesc = len(ns)-2 50 if len(ns) > 2: 51 globalRes = {} 52 53 nDone = 1 54 descNames = {} 55 for composite in composites: 56 if verbose > 0: 57 print('#------------------------------------') 58 print('Doing: ', nDone) 59 nModels = len(composite) 60 nDone += 1 61 res = {} 62 for i in range(len(composite)): 63 model = composite.GetModel(i) 64 if isinstance(model, Tree.TreeNode): 65 levels = TreeUtils.CollectLabelLevels(model, {}, 0, nToConsider) 66 TreeUtils.CollectDescriptorNames(model, descNames, 0, nToConsider) 67 for descId in levels.keys(): 68 v = res.get(descId, numpy.zeros(nToConsider, numpy.float)) 69 v[levels[descId]] += 1. / nModels 70 res[descId] = v 71 for k in res: 72 v = globalRes.get(k, numpy.zeros(nToConsider, numpy.float)) 73 v += res[k] / nComposites 74 globalRes[k] = v 75 if verbose > 0: 76 for k in res.keys(): 77 name = descNames[k] 78 strRes = ', '.join(['%4.2f' % x for x in res[k]]) 79 print('%s,%s,%5.4f' % (name, strRes, sum(res[k]))) 80 81 print() 82 83 if verbose >= 0: 84 print('# Average Descriptor Positions') 85 retVal = [] 86 for k in globalRes: 87 name = descNames[k] 88 if verbose >= 0: 89 strRes = ', '.join(['%4.2f' % x for x in globalRes[k]]) 90 print('%s,%s,%5.4f' % (name, strRes, sum(globalRes[k]))) 91 tmp = [name] 92 tmp.extend(globalRes[k]) 93 tmp.append(sum(globalRes[k])) 94 retVal.append(tmp) 95 if verbose >= 0: 96 print() 97 else: 98 retVal = [] 99 return retVal
100 101
102 -def ErrorStats(conn, where, enrich=1):
103 fields = ('overall_error,holdout_error,overall_result_matrix,' + 104 'holdout_result_matrix,overall_correct_conf,overall_incorrect_conf,' + 105 'holdout_correct_conf,holdout_incorrect_conf') 106 try: 107 data = conn.GetData(fields=fields, where=where) 108 except Exception: 109 import traceback 110 traceback.print_exc() 111 return None 112 nPts = len(data) 113 if not nPts: 114 sys.stderr.write('no runs found\n') 115 return None 116 overall = numpy.zeros(nPts, numpy.float) 117 overallEnrich = numpy.zeros(nPts, numpy.float) 118 oCorConf = 0.0 119 oInCorConf = 0.0 120 holdout = numpy.zeros(nPts, numpy.float) 121 holdoutEnrich = numpy.zeros(nPts, numpy.float) 122 hCorConf = 0.0 123 hInCorConf = 0.0 124 overallMatrix = None 125 holdoutMatrix = None 126 for i in range(nPts): 127 if data[i][0] is not None: 128 overall[i] = data[i][0] 129 oCorConf += data[i][4] 130 oInCorConf += data[i][5] 131 if data[i][1] is not None: 132 holdout[i] = data[i][1] 133 haveHoldout = 1 134 else: 135 haveHoldout = 0 136 tmpOverall = 1. * eval(data[i][2]) 137 if enrich >= 0: 138 overallEnrich[i] = ScreenComposite.CalcEnrichment(tmpOverall, tgt=enrich) 139 if haveHoldout: 140 tmpHoldout = 1. * eval(data[i][3]) 141 if enrich >= 0: 142 holdoutEnrich[i] = ScreenComposite.CalcEnrichment(tmpHoldout, tgt=enrich) 143 if overallMatrix is None: 144 if data[i][2] is not None: 145 overallMatrix = tmpOverall 146 if haveHoldout and data[i][3] is not None: 147 holdoutMatrix = tmpHoldout 148 else: 149 overallMatrix += tmpOverall 150 if haveHoldout: 151 holdoutMatrix += tmpHoldout 152 if haveHoldout: 153 hCorConf += data[i][6] 154 hInCorConf += data[i][7] 155 156 avgOverall = sum(overall) / nPts 157 oCorConf /= nPts 158 oInCorConf /= nPts 159 overallMatrix /= nPts 160 oSort = numpy.argsort(overall) 161 oMin = overall[oSort[0]] 162 overall -= avgOverall 163 devOverall = numpy.sqrt(sum(overall**2) / (nPts - 1)) 164 res = {} 165 res['oAvg'] = 100 * avgOverall 166 res['oDev'] = 100 * devOverall 167 res['oCorrectConf'] = 100 * oCorConf 168 res['oIncorrectConf'] = 100 * oInCorConf 169 res['oResultMat'] = overallMatrix 170 res['oBestIdx'] = oSort[0] 171 res['oBestErr'] = 100 * oMin 172 173 if enrich >= 0: 174 mean, dev = Stats.MeanAndDev(overallEnrich) 175 res['oAvgEnrich'] = mean 176 res['oDevEnrich'] = dev 177 178 if haveHoldout: 179 avgHoldout = sum(holdout) / nPts 180 hCorConf /= nPts 181 hInCorConf /= nPts 182 holdoutMatrix /= nPts 183 hSort = numpy.argsort(holdout) 184 hMin = holdout[hSort[0]] 185 holdout -= avgHoldout 186 devHoldout = numpy.sqrt(sum(holdout**2) / (nPts - 1)) 187 res['hAvg'] = 100 * avgHoldout 188 res['hDev'] = 100 * devHoldout 189 res['hCorrectConf'] = 100 * hCorConf 190 res['hIncorrectConf'] = 100 * hInCorConf 191 res['hResultMat'] = holdoutMatrix 192 res['hBestIdx'] = hSort[0] 193 res['hBestErr'] = 100 * hMin 194 if enrich >= 0: 195 mean, dev = Stats.MeanAndDev(holdoutEnrich) 196 res['hAvgEnrich'] = mean 197 res['hDevEnrich'] = dev 198 return res
199 200
201 -def ShowStats(statD, enrich=1):
202 statD = statD.copy() 203 statD['oBestIdx'] = statD['oBestIdx'] + 1 204 txt = """ 205 # Error Statistics: 206 \tOverall: %(oAvg)6.3f%% (%(oDev)6.3f) %(oCorrectConf)4.1f/%(oIncorrectConf)4.1f 207 \t\tBest: %(oBestIdx)d %(oBestErr)6.3f%%""" % (statD) 208 if 'hAvg' in statD: 209 statD['hBestIdx'] = statD['hBestIdx'] + 1 210 txt += """ 211 \tHoldout: %(hAvg)6.3f%% (%(hDev)6.3f) %(hCorrectConf)4.1f/%(hIncorrectConf)4.1f 212 \t\tBest: %(hBestIdx)d %(hBestErr)6.3f%% 213 """ % (statD) 214 print(txt) 215 print() 216 print('# Results matrices:') 217 print('\tOverall:') 218 tmp = numpy.transpose(statD['oResultMat']) 219 colCounts = sum(tmp) 220 rowCounts = sum(tmp, 1) 221 for i in range(len(tmp)): 222 if rowCounts[i] == 0: 223 rowCounts[i] = 1 224 row = tmp[i] 225 print('\t\t', end='') 226 for j in range(len(row)): 227 print('% 6.2f' % row[j], end='') 228 print('\t| % 4.2f' % (100. * tmp[i, i] / rowCounts[i])) 229 print('\t\t', end='') 230 for i in range(len(tmp)): 231 print('------', end='') 232 print() 233 print('\t\t', end='') 234 for i in range(len(tmp)): 235 if colCounts[i] == 0: 236 colCounts[i] = 1 237 print('% 6.2f' % (100. * tmp[i, i] / colCounts[i]), end='') 238 print() 239 if enrich > -1 and 'oAvgEnrich' in statD: 240 print('\t\tEnrich(%d): %.3f (%.3f)' % (enrich, statD['oAvgEnrich'], statD['oDevEnrich'])) 241 242 if 'hResultMat' in statD: 243 print('\tHoldout:') 244 tmp = numpy.transpose(statD['hResultMat']) 245 colCounts = sum(tmp) 246 rowCounts = sum(tmp, 1) 247 for i in range(len(tmp)): 248 if rowCounts[i] == 0: 249 rowCounts[i] = 1 250 row = tmp[i] 251 print('\t\t', end='') 252 for j in range(len(row)): 253 print('% 6.2f' % row[j], end='') 254 print('\t| % 4.2f' % (100. * tmp[i, i] / rowCounts[i])) 255 print('\t\t', end='') 256 for i in range(len(tmp)): 257 print('------', end='') 258 print() 259 print('\t\t', end='') 260 for i in range(len(tmp)): 261 if colCounts[i] == 0: 262 colCounts[i] = 1 263 print('% 6.2f' % (100. * tmp[i, i] / colCounts[i]), end='') 264 print() 265 if enrich > -1 and 'hAvgEnrich' in statD: 266 print('\t\tEnrich(%d): %.3f (%.3f)' % (enrich, statD['hAvgEnrich'], statD['hDevEnrich'])) 267 268 return
269 270
271 -def Usage():
272 print(__doc__) 273 sys.exit(-1)
274 275 276 if __name__ == "__main__": 277 import getopt 278 try: 279 args, extras = getopt.getopt(sys.argv[1:], 'n:d:N:vX', ('skip', 280 'enrich=', )) 281 except Exception: 282 Usage() 283 284 count = 3 285 db = None 286 note = '' 287 verbose = 0 288 skip = 0 289 enrich = 1 290 for arg, val in args: 291 if arg == '-n': 292 count = int(val) + 1 293 elif arg == '-d': 294 db = val 295 elif arg == '-N': 296 note = val 297 elif arg == '-v': 298 verbose = 1 299 elif arg == '--skip': 300 skip = 1 301 elif arg == '--enrich': 302 enrich = int(val) 303 composites = [] 304 if db is None: 305 for arg in extras: 306 composite = cPickle.load(open(arg, 'rb')) 307 composites.append(composite) 308 else: 309 tbl = extras[0] 310 conn = DbConnect(db, tbl) 311 if note: 312 where = "where note='%s'" % (note) 313 else: 314 where = '' 315 if not skip: 316 pkls = conn.GetData(fields='model', where=where) 317 composites = [] 318 for pkl in pkls: 319 pkl = str(pkl[0]) 320 comp = cPickle.loads(pkl) 321 composites.append(comp) 322 323 if len(composites): 324 ProcessIt(composites, count, verbose=verbose) 325 elif not skip: 326 print('ERROR: no composite models found') 327 sys.exit(-1) 328 329 if db: 330 res = ErrorStats(conn, where, enrich=enrich) 331 if res: 332 ShowStats(res) 333