Package rdkit :: Package ML :: Package DecTree :: Module Tree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.Tree

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ Implements a class used to represent N-ary trees 
  5   
  6  """ 
  7  from __future__ import print_function 
  8   
  9  from rdkit.six.moves import cPickle 
 10   
 11   
 12  # FIX: the TreeNode class has not been updated to new-style classes 
 13  # (RD Issue380) because that would break all of our legacy pickled 
 14  # data. Until a solution is found for this breakage, an update is 
 15  # impossible. 
16 -class TreeNode:
17 """ This is your bog standard Tree class. 18 19 the root of the tree is just a TreeNode like all other members. 20 """ 21
22 - def __init__(self, parent, name, label=None, data=None, level=0, isTerminal=0):
23 """ constructor 24 25 **Arguments** 26 27 - parent: the parent of this node in the tree 28 29 - name: the name of the node 30 31 - label: the node's label (should be an integer) 32 33 - data: an optional data field 34 35 - level: an integer indicating the level of this node in the hierarchy 36 (used for printing) 37 38 - isTerminal: flags a node as being terminal. This is useful for those 39 times when it's useful to know such things. 40 41 """ 42 self.children = [] 43 self.parent = parent 44 self.name = name 45 self.data = data 46 self.terminalNode = isTerminal 47 self.label = label 48 self.level = level 49 self.examples = []
50
51 - def NameTree(self, varNames):
52 """ Set the names of each node in the tree from a list of variable names. 53 54 **Arguments** 55 56 - varNames: a list of names to be assigned 57 58 **Notes** 59 60 1) this works its magic by recursively traversing all children 61 62 2) The assumption is made here that the varNames list can be indexed 63 by the labels of tree nodes 64 65 """ 66 if self.GetTerminal(): 67 return 68 else: 69 for child in self.GetChildren(): 70 child.NameTree(varNames) 71 self.SetName(varNames[self.GetLabel()])
72 73 NameModel = NameTree 74
75 - def AddChildNode(self, node):
76 """ Adds a TreeNode to the local list of children 77 78 **Arguments** 79 80 - node: the node to be added 81 82 **Note** 83 84 the level of the node (used in printing) is set as well 85 86 """ 87 node.SetLevel(self.level + 1) 88 self.children.append(node)
89
90 - def AddChild(self, name, label=None, data=None, isTerminal=0):
91 """ Creates a new TreeNode and adds a child to the tree 92 93 **Arguments** 94 95 - name: the name of the new node 96 97 - label: the label of the new node (should be an integer) 98 99 - data: the data to be stored in the new node 100 101 - isTerminal: a toggle to indicate whether or not the new node is 102 a terminal (leaf) node. 103 104 **Returns* 105 106 the _TreeNode_ which is constructed 107 108 """ 109 child = TreeNode(self, name, label, data, level=self.level + 1, isTerminal=isTerminal) 110 self.children.append(child) 111 return child
112
113 - def PruneChild(self, child):
114 """ Removes the child node 115 116 **Arguments** 117 118 - child: a TreeNode 119 120 """ 121 self.children.remove(child)
122
123 - def ReplaceChildIndex(self, index, newChild):
124 """ Replaces a given child with a new one 125 126 **Arguments** 127 128 - index: an integer 129 130 - child: a TreeNode 131 132 """ 133 self.children[index] = newChild
134
135 - def GetChildren(self):
136 """ Returns a python list of the children of this node 137 138 """ 139 return self.children
140
141 - def Destroy(self):
142 """ Destroys this node and all of its children 143 144 """ 145 for child in self.children: 146 child.Destroy() 147 self.children = [] 148 # clean up circular references 149 self.parent = None
150
151 - def GetName(self):
152 """ Returns the name of this node 153 154 """ 155 return self.name
156
157 - def SetName(self, name):
158 """ Sets the name of this node 159 160 """ 161 self.name = name
162
163 - def GetData(self):
164 """ Returns the data stored at this node 165 166 """ 167 return self.data
168
169 - def SetData(self, data):
170 """ Sets the data stored at this node 171 172 """ 173 self.data = data
174
175 - def GetTerminal(self):
176 """ Returns whether or not this node is terminal 177 178 """ 179 return self.terminalNode
180
181 - def SetTerminal(self, isTerminal):
182 """ Sets whether or not this node is terminal 183 184 """ 185 self.terminalNode = isTerminal
186
187 - def GetLabel(self):
188 """ Returns the label of this node 189 190 """ 191 return self.label
192
193 - def SetLabel(self, label):
194 """ Sets the label of this node (should be an integer) 195 196 """ 197 self.label = label
198
199 - def GetLevel(self):
200 """ Returns the level of this node 201 202 """ 203 return self.level
204
205 - def SetLevel(self, level):
206 """ Sets the level of this node 207 208 """ 209 self.level = level
210
211 - def GetParent(self):
212 """ Returns the parent of this node 213 214 """ 215 return self.parent
216
217 - def SetParent(self, parent):
218 """ Sets the parent of this node 219 220 """ 221 self.parent = parent
222
223 - def Print(self, level=0, showData=0):
224 """ Pretty prints the tree 225 226 **Arguments** 227 228 - level: sets the number of spaces to be added at the beginning of the output 229 230 - showData: if this is nonzero, the node's _data_ value will be printed as well 231 232 **Note** 233 234 this works recursively 235 236 """ 237 if showData: 238 print('%s%s: %s' % (' ' * level, self.name, str(self.data))) 239 else: 240 print('%s%s' % (' ' * level, self.name)) 241 242 for child in self.children: 243 child.Print(level + 1, showData=showData)
244
245 - def Pickle(self, fileName='foo.pkl'):
246 """ Pickles the tree and writes it to disk 247 248 """ 249 with open(fileName, 'wb+') as pFile: 250 cPickle.dump(self, pFile)
251
252 - def __str__(self):
253 """ returns a string representation of the tree 254 255 **Note** 256 257 this works recursively 258 259 """ 260 here = '%s%s\n' % (' ' * self.level, self.name) 261 for child in self.children: 262 here = here + str(child) 263 return here
264
265 - def __cmp__(self, other):
266 """ allows tree1 == tree2 267 268 **Note** 269 270 This works recursively 271 """ 272 return (self < other) * -1 or (other < self) * 1
273
274 - def __lt__(self, other):
275 """ allows tree1 < tree2 276 277 **Note** 278 279 This works recursively 280 """ 281 try: 282 nChildren = len(self.children) 283 oChildren = len(other.children) 284 if str(type(self)) < str(type(other)): 285 return True 286 if self.name < other.name: 287 return True 288 if self.label is not None: 289 if other.label is not None: 290 if self.label < other.label: 291 return True 292 else: 293 return False 294 elif other.label is not None: 295 return True 296 if nChildren < oChildren: 297 return True 298 if nChildren > oChildren: 299 return False 300 for i in range(nChildren): 301 if self.children[i] < other.children[i]: 302 return True 303 except AttributeError: 304 return True 305 return False
306
307 - def __eq__(self, other):
308 return not self < other and not other < self
309 310
311 -def _exampleCode():
312 tree = TreeNode(None, 'root') 313 for i in range(3): 314 tree.AddChild('child %d' % i) 315 print(tree) 316 tree.GetChildren()[1].AddChild('grandchild') 317 tree.GetChildren()[1].AddChild('grandchild2') 318 tree.GetChildren()[1].AddChild('grandchild3') 319 print(tree) 320 tree.Pickle('save.pkl') 321 print('prune') 322 tree.PruneChild(tree.GetChildren()[1]) 323 print('done') 324 print(tree) 325 326 import copy 327 tree2 = copy.deepcopy(tree) 328 print('tree==tree2', tree == tree2) 329 330 foo = [tree] 331 print('tree in [tree]:', tree in foo, foo.index(tree)) 332 print('tree2 in [tree]:', tree2 in foo, foo.index(tree2)) 333 334 tree2.GetChildren()[1].AddChild('grandchild4') 335 print('tree==tree2', tree == tree2) 336 tree.Destroy()
337 338 339 if __name__ == '__main__': # pragma: nocover 340 _exampleCode() 341