1
2
3
4
5
6 """ functionality for drawing trees on sping canvases
7
8 """
9 import math
10
11 from rdkit.sping import pid as piddle
12
13
15 circRad = 10
16 minCircRad = 4
17 maxCircRad = 16
18 circColor = piddle.Color(0.6, 0.6, 0.9)
19 terminalEmptyColor = piddle.Color(.8, .8, .2)
20 terminalOnColor = piddle.Color(0.8, 0.8, 0.8)
21 terminalOffColor = piddle.Color(0.2, 0.2, 0.2)
22 outlineColor = piddle.transparent
23 lineColor = piddle.Color(0, 0, 0)
24 lineWidth = 2
25 horizOffset = 10
26 vertOffset = 50
27 labelFont = piddle.Font(face='helvetica', size=10)
28 highlightColor = piddle.Color(1., 1., .4)
29 highlightWidth = 2
30
31
32 visOpts = VisOpts()
33
34
36 """Recursively calculate the total number of nodes under us.
37
38 results are set in node.totNChildren for this node and
39 everything underneath it.
40 """
41 children = node.GetChildren()
42 if len(children) > 0:
43 nHere = 0
44 nBelow = 0
45 for child in children:
46 CalcTreeNodeSizes(child)
47 nHere = nHere + child.totNChildren
48 if child.nLevelsBelow > nBelow:
49 nBelow = child.nLevelsBelow
50 else:
51 nBelow = 0
52 nHere = 1
53
54 node.nExamples = len(node.GetExamples())
55 node.totNChildren = nHere
56 node.nLevelsBelow = nBelow + 1
57
58
60 if node.GetTerminal():
61 cnt = node.nExamples
62 if cnt < min:
63 min = cnt
64 if cnt > max:
65 max = cnt
66 else:
67 for child in node.GetChildren():
68 provMin, provMax = _ExampleCounter(child, min, max)
69 if provMin < min:
70 min = provMin
71 if provMax > max:
72 max = provMax
73 return min, max
74
75
77 if node.GetTerminal():
78 if max != min:
79 loc = float(node.nExamples - min) / (max - min)
80 else:
81 loc = .5
82 node._scaleLoc = loc
83 else:
84 for child in node.GetChildren():
85 _ApplyNodeScales(child, min, max)
86
87
93
94
95 -def DrawTreeNode(node, loc, canvas, nRes=2, scaleLeaves=False, showPurity=False):
96 """Recursively displays the given tree node and all its children on the canvas
97 """
98 try:
99 nChildren = node.totNChildren
100 except AttributeError:
101 nChildren = None
102 if nChildren is None:
103 CalcTreeNodeSizes(node)
104
105 if not scaleLeaves or not node.GetTerminal():
106 rad = visOpts.circRad
107 else:
108 scaleLoc = getattr(node, "_scaleLoc", 0.5)
109
110 rad = visOpts.minCircRad + node._scaleLoc * (visOpts.maxCircRad - visOpts.minCircRad)
111
112 x1 = loc[0] - rad
113 y1 = loc[1] - rad
114 x2 = loc[0] + rad
115 y2 = loc[1] + rad
116
117 if showPurity and node.GetTerminal():
118 examples = node.GetExamples()
119 nEx = len(examples)
120 if nEx:
121 tgtVal = int(node.GetLabel())
122 purity = 0.0
123 for ex in examples:
124 if int(ex[-1]) == tgtVal:
125 purity += 1. / len(examples)
126 else:
127 purity = 1.0
128
129 deg = purity * math.pi
130 xFact = rad * math.sin(deg)
131 yFact = rad * math.cos(deg)
132 pureX = loc[0] + xFact
133 pureY = loc[1] + yFact
134
135 children = node.GetChildren()
136
137 childY = loc[1] + visOpts.vertOffset
138
139 childX = loc[0] - ((visOpts.horizOffset + visOpts.circRad) * node.totNChildren) / 2
140 for i in range(len(children)):
141
142 child = children[i]
143 halfWidth = ((visOpts.horizOffset + visOpts.circRad) * child.totNChildren) / 2
144
145 childX = childX + halfWidth
146 childLoc = [childX, childY]
147 canvas.drawLine(loc[0], loc[1], childLoc[0], childLoc[1], visOpts.lineColor, visOpts.lineWidth)
148 DrawTreeNode(child, childLoc, canvas, nRes=nRes, scaleLeaves=scaleLeaves, showPurity=showPurity)
149
150
151 childX = childX + halfWidth
152
153 if node.GetTerminal():
154 lab = node.GetLabel()
155 cFac = float(lab) / float(nRes - 1)
156 if hasattr(node, 'GetExamples') and node.GetExamples():
157 theColor = (1. - cFac) * visOpts.terminalOffColor + cFac * visOpts.terminalOnColor
158 outlColor = visOpts.outlineColor
159 else:
160 theColor = (1. - cFac) * visOpts.terminalOffColor + cFac * visOpts.terminalOnColor
161 outlColor = visOpts.terminalEmptyColor
162 canvas.drawEllipse(x1, y1, x2, y2, outlColor, visOpts.lineWidth, theColor)
163 if showPurity:
164 canvas.drawLine(loc[0], loc[1], pureX, pureY, piddle.Color(1, 1, 1), 2)
165 else:
166 theColor = visOpts.circColor
167 canvas.drawEllipse(x1, y1, x2, y2, visOpts.outlineColor, visOpts.lineWidth, theColor)
168
169
170 canvas.defaultFont = visOpts.labelFont
171
172 labelStr = str(node.GetLabel())
173 strLoc = (loc[0] - canvas.stringWidth(labelStr) / 2, loc[1] + canvas.fontHeight() / 4)
174
175 canvas.drawString(labelStr, strLoc[0], strLoc[1])
176 node._bBox = (x1, y1, x2, y2)
177
178
186
187
188 -def DrawTree(tree, canvas, nRes=2, scaleLeaves=False, allowShrink=True, showPurity=False):
205
206
208 tree._scales = None
209 tree.totNChildren = None
210 for child in tree.GetChildren():
211 ResetTree(child)
212
213
215 from .Tree import TreeNode as Node
216 root = Node(None, 'r', label='r')
217 c1 = root.AddChild('l1_1', label='l1_1')
218 c2 = root.AddChild('l1_2', isTerminal=1, label=1)
219 c3 = c1.AddChild('l2_1', isTerminal=1, label=0)
220 c4 = c1.AddChild('l2_2', isTerminal=1, label=1)
221
222 DrawTreeNode(root, (150, visOpts.vertOffset), canv)
223
224
225 if __name__ == '__main__':
226 from rdkit.sping.PIL.pidPIL import PILCanvas
227 canv = PILCanvas(size=(300, 300), name='test.png')
228 _simpleTest(canv)
229 canv.save()
230