1
2
3
4
5
6
7
8
9
10
11 """ contains the Cluster class for representing hierarchical cluster trees
12
13 """
14 from __future__ import print_function
15
16 from rdkit.six import cmp
17
18 CMPTOL = 1e-6
19
20
22 """a class for storing clusters/data
23
24 **General Remarks**
25
26 - It is assumed that the bottom of any cluster hierarchy tree is composed of
27 the individual data points which were clustered.
28
29 - Clusters objects store the following pieces of data, most are
30 accessible via standard Setters/Getters:
31
32 - Children: *Not Settable*, the list of children. You can add children
33 with the _AddChild()_ and _AddChildren()_ methods.
34
35 **Note** this can be of arbitrary length,
36 but the current algorithms I have only produce trees with two children
37 per cluster
38
39 - Metric: the metric for this cluster (i.e. how far apart its children are)
40
41 - Index: the order in which this cluster was generated
42
43 - Points: *Not Settable*, the list of original points in this cluster
44 (calculated recursively from the children)
45
46 - PointsPositions: *Not Settable*, the list of positions of the original
47 points in this cluster (calculated recursively from the children)
48
49 - Position: the location of the cluster **Note** for a cluster this
50 probably means the location of the average of all the Points which are
51 its children.
52
53 - Data: a data field. This is used with the original points to store their
54 data value (i.e. the value we're using to classify)
55
56 - Name: the name of this cluster
57
58 """
59
60 - def __init__(self, metric=0.0, children=None, position=None, index=-1, name=None, data=None):
61 """Constructor
62
63 **Arguments**
64
65 see the class documentation for the meanings of these arguments
66
67 *my wrists are tired*
68
69 """
70 if children is None:
71 children = []
72 if position is None:
73 position = []
74 self.metric = metric
75 self.children = children
76 self._UpdateLength()
77 self.pos = position
78 self.index = index
79 self.name = name
80 self._points = None
81 self._pointsPositions = None
82 self.data = data
83
86
89
92
95
98
101
103 if self._pointsPositions is not None:
104 return self._pointsPositions
105 else:
106 self._GenPoints()
107 return self._pointsPositions
108
110 if self._points is not None:
111 return self._points
112 else:
113 self._GenPoints()
114 return self._points
115
117 """ finds and returns the subtree with a particular index
118 """
119 res = None
120 if index == self.index:
121 res = self
122 else:
123 for child in self.children:
124 res = child.FindSubtree(index)
125 if res:
126 break
127 return res
128
130 """ Generates the _Points_ and _PointsPositions_ lists
131
132 *intended for internal use*
133
134 """
135 if len(self) == 1:
136 self._points = [self]
137 self._pointsPositions = [self.GetPosition()]
138 return self._points
139 else:
140 res = []
141 children = self.GetChildren()
142 children.sort(key=lambda x: len(x), reverse=True)
143 for child in children:
144 res += child.GetPoints()
145 self._points = res
146 self._pointsPositions = [x.GetPosition() for x in res]
147
149 """Adds a child to our list
150
151 **Arguments**
152
153 - child: a Cluster
154
155 """
156 self.children.append(child)
157 self._GenPoints()
158 self._UpdateLength()
159
171
173 """Removes a child from our list
174
175 **Arguments**
176
177 - child: a Cluster
178
179 """
180 self.children.remove(child)
181 self._UpdateLength()
182
186
189
192
195
197 if self.name is None:
198 return 'Cluster(%d)' % (self.GetIndex())
199 else:
200 return self.name
201
202 - def Print(self, level=0, showData=0, offset='\t'):
203 if not showData or self.GetData() is None:
204 print('%s%s%s Metric: %f' % (' ' * level, self.GetName(), offset, self.GetMetric()))
205 else:
206 print('%s%s%s Data: %f\t Metric: %f' %
207 (' ' * level, self.GetName(), offset, self.GetData(), self.GetMetric()))
208
209 for child in self.GetChildren():
210 child.Print(level=level + 1, showData=showData, offset=offset)
211
212 - def Compare(self, other, ignoreExtras=1):
213 """ not as choosy as self==other
214
215 """
216 tv1, tv2 = str(type(self)), str(type(other))
217 tv = cmp(tv1, tv2)
218 if tv:
219 return tv
220 tv1, tv2 = len(self), len(other)
221 tv = cmp(tv1, tv2)
222 if tv:
223 return tv
224
225 if not ignoreExtras:
226 m1, m2 = self.GetMetric(), other.GetMetric()
227 if abs(m1 - m2) > CMPTOL:
228 return cmp(m1, m2)
229
230 if cmp(self.GetName(), other.GetName()):
231 return cmp(self.GetName(), other.GetName())
232
233 sP = self.GetPosition()
234 oP = other.GetPosition()
235 try:
236 r = cmp(len(sP), len(oP))
237 except Exception:
238 pass
239 else:
240 if r:
241 return r
242
243 try:
244 r = cmp(sP, oP)
245 except Exception:
246 r = sum(sP - oP)
247 if r:
248 return r
249
250 c1, c2 = self.GetChildren(), other.GetChildren()
251 if cmp(len(c1), len(c2)):
252 return cmp(len(c1), len(c2))
253 for i in range(len(c1)):
254 t = c1[i].Compare(c2[i], ignoreExtras=ignoreExtras)
255 if t:
256 return t
257
258 return 0
259
261 """ updates our length
262
263 *intended for internal use*
264
265 """
266 self._len = sum(len(c) for c in self.children) + 1
267
269 return self._len <= 1
270
272 """ allows _len(cluster)_ to work
273
274 """
275 return self._len
276
278 """ allows _cluster1 == cluster2_ to work
279
280 """
281 if cmp(type(self), type(other)):
282 return cmp(type(self), type(other))
283
284 m1, m2 = self.GetMetric(), other.GetMetric()
285 if abs(m1 - m2) > CMPTOL:
286 return cmp(m1, m2)
287
288 if cmp(self.GetName(), other.GetName()):
289 return cmp(self.GetName(), other.GetName())
290
291 c1, c2 = self.GetChildren(), other.GetChildren()
292 return cmp(c1, c2)
293
294
295 if __name__ == '__main__':
296 from rdkit.ML.Cluster import ClusterUtils
297 root = Cluster(index=1, metric=1000)
298 c1 = Cluster(index=10, metric=100)
299 c1.AddChild(Cluster(index=30, metric=10))
300 c1.AddChild(Cluster(index=31, metric=10))
301 c1.AddChild(Cluster(index=32, metric=10))
302
303 c2 = Cluster(index=11, metric=100)
304 c2.AddChild(Cluster(index=40, metric=10))
305 c2.AddChild(Cluster(index=41, metric=10))
306
307 root.AddChild(c1)
308 root.AddChild(c2)
309
310 nodes = ClusterUtils.GetNodeList(root)
311
312 indices = [x.GetIndex() for x in nodes]
313 print('XXX:', indices)
314