Package rdkit :: Package ML :: Package Data :: Module SplitData
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Data.SplitData

  1  # 
  2  #  Copyright (C) 2003-2008 Greg Landrum and Rational Discovery LLC 
  3  #    All Rights Reserved 
  4  # 
  5  from __future__ import print_function 
  6   
  7  import random 
  8   
  9  from rdkit import RDRandom 
 10   
 11  SeqTypes = (list, tuple) 
 12   
 13   
14 -def SplitIndices(nPts, frac, silent=1, legacy=0, replacement=0):
15 """ splits a set of indices into a data set into 2 pieces 16 17 **Arguments** 18 19 - nPts: the total number of points 20 21 - frac: the fraction of the data to be put in the first data set 22 23 - silent: (optional) toggles display of stats 24 25 - legacy: (optional) use the legacy splitting approach 26 27 - replacement: (optional) use selection with replacement 28 29 **Returns** 30 31 a 2-tuple containing the two sets of indices. 32 33 **Notes** 34 35 - the _legacy_ splitting approach uses randomly-generated floats 36 and compares them to _frac_. This is provided for 37 backwards-compatibility reasons. 38 39 - the default splitting approach uses a random permutation of 40 indices which is split into two parts. 41 42 - selection with replacement can generate duplicates. 43 44 45 **Usage**: 46 47 We'll start with a set of indices and pick from them using 48 the three different approaches: 49 >>> from rdkit.ML.Data import DataUtils 50 51 The base approach always returns the same number of compounds in 52 each set and has no duplicates: 53 >>> DataUtils.InitRandomNumbers((23,42)) 54 >>> test,train = SplitIndices(10,.5) 55 >>> test 56 [1, 5, 6, 4, 2] 57 >>> train 58 [3, 0, 7, 8, 9] 59 60 >>> test,train = SplitIndices(10,.5) 61 >>> test 62 [5, 2, 9, 8, 7] 63 >>> train 64 [6, 0, 3, 1, 4] 65 66 67 The legacy approach can return varying numbers, but still has no 68 duplicates. Note the indices come back ordered: 69 >>> DataUtils.InitRandomNumbers((23,42)) 70 >>> test,train = SplitIndices(10,.5,legacy=1) 71 >>> test 72 [3, 5, 7, 8, 9] 73 >>> train 74 [0, 1, 2, 4, 6] 75 76 >>> test,train = SplitIndices(10,.5,legacy=1) 77 >>> test 78 [0, 1, 2, 3, 5, 8, 9] 79 >>> train 80 [4, 6, 7] 81 82 The replacement approach returns a fixed number in the training set, 83 a variable number in the test set and can contain duplicates in the 84 training set. 85 >>> DataUtils.InitRandomNumbers((23,42)) 86 >>> test,train = SplitIndices(10,.5,replacement=1) 87 >>> test 88 [9, 9, 8, 0, 5] 89 >>> train 90 [1, 2, 3, 4, 6, 7] 91 >>> test,train = SplitIndices(10,.5,replacement=1) 92 >>> test 93 [4, 5, 1, 1, 4] 94 >>> train 95 [0, 2, 3, 6, 7, 8, 9] 96 97 """ 98 if frac < 0. or frac > 1.: 99 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)' % (frac)) 100 101 if replacement: 102 nTrain = int(nPts * frac) 103 resData = [None] * nTrain 104 resTest = [] 105 for i in range(nTrain): 106 val = int(RDRandom.random() * nPts) 107 if val == nPts: 108 val = nPts - 1 109 resData[i] = val 110 for i in range(nPts): 111 if i not in resData: 112 resTest.append(i) 113 elif legacy: 114 resData = [] 115 resTest = [] 116 for i in range(nPts): 117 val = RDRandom.random() 118 if val < frac: 119 resData.append(i) 120 else: 121 resTest.append(i) 122 else: 123 perm = list(range(nPts)) 124 random.shuffle(perm, random=random.random) 125 nTrain = int(nPts * frac) 126 127 resData = list(perm[:nTrain]) 128 resTest = list(perm[nTrain:]) 129 130 if not silent: 131 print('Training with %d (of %d) points.' % (len(resData), nPts)) 132 print('\t%d points are in the hold-out set.' % (len(resTest))) 133 return resData, resTest
134 135
136 -def SplitDataSet(data, frac, silent=0):
137 """ splits a data set into two pieces 138 139 **Arguments** 140 141 - data: a list of examples to be split 142 143 - frac: the fraction of the data to be put in the first data set 144 145 - silent: controls the amount of visual noise produced. 146 147 **Returns** 148 149 a 2-tuple containing the two new data sets. 150 151 """ 152 if frac < 0. or frac > 1.: 153 raise ValueError('frac must be between 0.0 and 1.0') 154 155 nOrig = len(data) 156 train, test = SplitIndices(nOrig, frac, silent=1) 157 resData = [data[x] for x in train] 158 resTest = [data[x] for x in test] 159 160 if not silent: 161 print('Training with %d (of %d) points.' % (len(resData), nOrig)) 162 print('\t%d points are in the hold-out set.' % (len(resTest))) 163 return resData, resTest
164 165
166 -def SplitDbData(conn, fracs, table='', fields='*', where='', join='', labelCol='', useActs=0, 167 nActs=2, actCol='', actBounds=[], silent=0):
168 """ "splits" a data set held in a DB by returning lists of ids 169 170 **Arguments**: 171 172 - conn: a DbConnect object 173 174 - frac: the split fraction. This can optionally be specified as a 175 sequence with a different fraction for each activity value. 176 177 - table,fields,where,join: (optional) SQL query parameters 178 179 - useActs: (optional) toggles splitting based on activities 180 (ensuring that a given fraction of each activity class ends 181 up in the hold-out set) 182 Defaults to 0 183 184 - nActs: (optional) number of possible activity values, only 185 used if _useActs_ is nonzero 186 Defaults to 2 187 188 - actCol: (optional) name of the activity column 189 Defaults to use the last column returned by the query 190 191 - actBounds: (optional) sequence of activity bounds 192 (for cases where the activity isn't quantized in the db) 193 Defaults to an empty sequence 194 195 - silent: controls the amount of visual noise produced. 196 197 **Usage**: 198 199 Set up the db connection, the simple tables we're using have actives with even 200 ids and inactives with odd ids: 201 >>> from rdkit.ML.Data import DataUtils 202 >>> from rdkit.Dbase.DbConnection import DbConnect 203 >>> from rdkit import RDConfig 204 >>> conn = DbConnect(RDConfig.RDTestDatabase) 205 206 Pull a set of points from a simple table... take 33% of all points: 207 >>> DataUtils.InitRandomNumbers((23,42)) 208 >>> train,test = SplitDbData(conn,1./3.,'basic_2class') 209 >>> [str(x) for x in train] 210 ['id-7', 'id-6', 'id-2', 'id-8'] 211 212 ...take 50% of actives and 50% of inactives: 213 >>> DataUtils.InitRandomNumbers((23,42)) 214 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1) 215 >>> [str(x) for x in train] 216 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8'] 217 218 219 Notice how the results came out sorted by activity 220 221 We can be asymmetrical: take 33% of actives and 50% of inactives: 222 >>> DataUtils.InitRandomNumbers((23,42)) 223 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1) 224 >>> [str(x) for x in train] 225 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10'] 226 227 And we can pull from tables with non-quantized activities by providing 228 activity quantization bounds: 229 >>> DataUtils.InitRandomNumbers((23,42)) 230 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0]) 231 >>> [str(x) for x in train] 232 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8'] 233 234 """ 235 if not table: 236 table = conn.tableName 237 if actBounds and len(actBounds) != nActs - 1: 238 raise ValueError('activity bounds list length incorrect') 239 if useActs: 240 if type(fracs) not in SeqTypes: 241 fracs = tuple([fracs] * nActs) 242 for frac in fracs: 243 if frac < 0.0 or frac > 1.0: 244 raise ValueError('fractions must be between 0.0 and 1.0') 245 else: 246 if type(fracs) in SeqTypes: 247 frac = fracs[0] 248 if frac < 0.0 or frac > 1.0: 249 raise ValueError('fractions must be between 0.0 and 1.0') 250 else: 251 frac = fracs 252 # start by getting the name of the ID column: 253 colNames = conn.GetColumnNames(table=table, what=fields, join=join) 254 idCol = colNames[0] 255 256 if not useActs: 257 # get the IDS: 258 d = conn.GetData(table=table, fields=idCol, join=join) 259 ids = [x[0] for x in d] 260 nRes = len(ids) 261 train, test = SplitIndices(nRes, frac, silent=1) 262 trainPts = [ids[x] for x in train] 263 testPts = [ids[x] for x in test] 264 else: 265 trainPts = [] 266 testPts = [] 267 if not actCol: 268 actCol = colNames[-1] 269 whereBase = where.strip() 270 if whereBase.find('where') != 0: 271 whereBase = 'where ' + whereBase 272 if where: 273 whereBase += ' and ' 274 for act in range(nActs): 275 frac = fracs[act] 276 if not actBounds: 277 whereTxt = whereBase + '%s=%d' % (actCol, act) 278 else: 279 whereTxt = whereBase 280 if act != 0: 281 whereTxt += '%s>=%f ' % (actCol, actBounds[act - 1]) 282 if act < nActs - 1: 283 if act != 0: 284 whereTxt += 'and ' 285 whereTxt += '%s<%f' % (actCol, actBounds[act]) 286 d = conn.GetData(table=table, fields=idCol, join=join, where=whereTxt) 287 ids = [x[0] for x in d] 288 nRes = len(ids) 289 train, test = SplitIndices(nRes, frac, silent=1) 290 trainPts.extend([ids[x] for x in train]) 291 testPts.extend([ids[x] for x in test]) 292 293 return trainPts, testPts
294 295 296 # ------------------------------------ 297 # 298 # doctest boilerplate 299 #
300 -def _runDoctests(verbose=None): # pragma: nocover
301 import sys 302 import doctest 303 failed, _ = doctest.testmod(optionflags=doctest.ELLIPSIS, verbose=verbose) 304 sys.exit(failed) 305 306 307 if __name__ == '__main__': # pragma: nocover 308 _runDoctests() 309