from tkinter import * from tkinter.messagebox import showinfo from R2Graph import * import math import random import numpy as np from sklearn.tree import DecisionTreeClassifier, export_text SCALEX = 40. SCALEY = SCALEX STEPX = 5./SCALEX # 5 pixels STEPY = 5./SCALEX blueColor = "#0000FF" redColor = "#BB0000" greenColor = "#00AA44" # Class colors classColors = ( "#EE0000", "#00AA00", "#0044FF", "#CCAA00", "#880088", "#008888", "#BB4400", "#224488" ) alphaColors = ( "#FF4444", "#00FF00", "#0088FF", "#FFCC00", "#AA00AA", "#00AAAA", "#EE8800", "#4488AA" ) numClassColors = len(classColors) undefinedClass = (-1) undefinedClassColor = "#444444" alpha = 0.25 pointRadius = 0.15 # Default values numClasses_default = 3 maxClasses = 8 maxDepth_default = 3 maxDepthUpper = 10 randomPoints_default = 60 maxRandomPoints = 200 def pointColor(classNumber): if classNumber == undefinedClass: return undefinedClassColor else: return classColors[classNumber % numClassColors] def alphaColor(classNumber): return alphaColors[classNumber % numClassColors] def main(): points = [] classNumbers = [] targetClassNumbers = [] objectIDs = [] rectIDs = [] scaleX = SCALEX; scaleY = SCALEY root = Tk() root.title("Decision Trees") root.geometry("900x600") panel = Frame(root) panel2 = Frame(root) classifyButton = Button(panel, text="Classify") resetButton = Button(panel, text="Reset") clearButton = Button(panel, text="Clear") generateRandomButton = Button( panel, text="Generate random", state=DISABLED ) paintButton = Button(panel2, text="Paint", state=DISABLED) drawArea = Canvas(root, bg="white") panel.pack(side=TOP, fill=X) panel2.pack(side=TOP, fill=X) classifyButton.pack(side=LEFT, padx=4, pady=4) resetButton.pack(side=LEFT, padx=4, pady=4) clearButton.pack(side=LEFT, padx=4, pady=4) generateRandomButton.pack(side=LEFT, padx=4, pady=4) paintButton.pack(side=LEFT, padx=4, pady=4) drawArea.pack(side=TOP, fill=BOTH, expand=True, padx=4, pady=4) inputMethodIdx = IntVar() # Control variable for the group of radio buttons inputMethodIdx.set(value = 0) maxDepth = maxDepth_default randomPoints = randomPoints_default clf = None # Decision tree manualRadio = Radiobutton( panel, text = "Manual definition", variable=inputMethodIdx, value=0, fg = blueColor ) randomRadio = Radiobutton( panel, text = "Random", variable=inputMethodIdx, value=1, fg = redColor ) manualRadio.pack(side=LEFT, padx=4, pady=4) randomRadio.pack(side=LEFT, padx=4, pady=4) randomLabel = Label(panel, text="Classes:", fg = redColor) numClassesText = StringVar(value = str(numClasses_default)) numClassesEntry = Entry( panel, bg="white", textvariable=numClassesText, fg=redColor, width = 5, state = DISABLED ) randomPointsLabel = Label(panel, text="Points:", fg = redColor) randomPointsText = StringVar(value = str(randomPoints_default)) randomPointsEntry = Entry( panel, bg="white", textvariable=randomPointsText, fg=redColor, width = 5, state = DISABLED ) randomLabel.pack(side=LEFT, padx=4, pady=4) numClassesEntry.pack(side=LEFT, padx=4, pady=4) randomPointsLabel.pack(side=LEFT, padx=4, pady=4) randomPointsEntry.pack(side=LEFT, padx=4, pady=4) maxDepthLabel = Label(panel2, text="Max. depth:") maxDepthText = StringVar(value = str(maxDepth_default)) maxDepthEntry = Entry( panel2, bg="white", textvariable=maxDepthText, width = 5 ) maxDepthLabel.pack(side=LEFT, padx=4, pady=4) maxDepthEntry.pack(side=LEFT, padx=4, pady=4) def onRadioChange(): selected = inputMethodIdx.get() if selected == 0: # Manual input numClasses = 3 numClassesText.set(value = str(numClasses)) numClassesEntry.configure(state = DISABLED) randomPointsEntry.configure(state = DISABLED) generateRandomButton.configure(state = DISABLED) else: # Random numClassesEntry.configure(state = NORMAL) randomPointsEntry.configure(state = NORMAL) generateRandomButton.configure(state = NORMAL) manualRadio.configure(command = onRadioChange) randomRadio.configure(command = onRadioChange) def map(t): w = drawArea.winfo_width() h = drawArea.winfo_height() centerX = w/2. centerY = h/2. x = centerX + t.x*scaleX y = centerY - t.y*scaleY return (x, y) def invmap(p): w = drawArea.winfo_width() h = drawArea.winfo_height() centerX = w/2. centerY = h/2. x = (p[0] - centerX)/scaleX y = (centerY - p[1])/scaleY return R2Point(x, y) def xMin(): w = drawArea.winfo_width() return (-(w/scaleX)/2.) def xMax(): return (-xMin()) def yMin(): w = drawArea.winfo_height() return (-(w/scaleY)/2.) def yMax(): return (-yMin()) def drawGrid(): ix0 = int(xMin()) ix1 = int(xMax()) x = ix0 while x <= ix1: if x != 0: p0 = map(R2Point(x, yMin())) p1 = map(R2Point(x, yMax())) drawArea.create_line(p0, p1, fill="lightGray", width=1) x += 1 iy0 = int(yMin()) iy1 = int(yMax()) y = iy0 while y <= iy1: if y != 0: p0 = map(R2Point(xMin(), y)) p1 = map(R2Point(xMax(), y)) drawArea.create_line(p0, p1, fill="lightGray", width=1) y += 1 # Draw x-axis drawArea.create_line( map(R2Point(xMin(), 0.)), map(R2Point(xMax(), 0.)), fill="black", width=2 ) # Draw y-axis drawArea.create_line( map(R2Point(0., yMin())), map(R2Point(0., yMax())), fill="black", width=2 ) def onMouseRelease(e): # print("Mouse release event:", e) p = (e.x, e.y) t = invmap(p) points.append(t) y = e.num - 1 targetClassNumbers.append(y) classNumbers.append(undefinedClass) drawPoint(t, y) def drawPoint(t, classNumber=undefinedClass, correctPoint = False): vx = R2Vector(pointRadius, 0.) vy = R2Vector(0., pointRadius) color = pointColor(classNumber) if correctPoint: rectangleID = drawArea.create_rectangle( map(t - vx - vy), map(t + vx + vy), fill=color ) objectIDs.append(rectangleID) else: circleID = drawArea.create_oval( map(t - vx - vy), map(t + vx + vy), fill=color ) objectIDs.append(circleID) def drawPoints(): for i in range(len(points)): drawPoint( points[i], targetClassNumbers[i], (targetClassNumbers[i] == classNumbers[i]) ) def onClassify(): nonlocal clf try: maxDepth = int(maxDepthText.get()) if maxDepth > maxDepthUpper: maxDepth = maxDepthUpper maxDepthText.set(value = str(maxDepth)) elif maxDepth <= 0: maxDepth = 1 maxDepthText.set(value = str(maxDepth)) except ValueError: maxDepth = maxDepth_default maxDepthText.set(value = str(maxDepth)) clf = DecisionTreeClassifier(max_depth = maxDepth) X = np.array([[p[0], p[1]] for p in points]) y = np.array([c for c in targetClassNumbers]) clf.fit(X, y) print(export_text(clf)) y_predicted = clf.predict(X) for i in range(len(X)): classNumbers[i] = y_predicted[i] clearPicture() drawPoints() paintButton.configure(state=NORMAL) def onReset(): for i in range(len(classNumbers)): classNumbers[i] = undefinedClass clearPicture() clf = None paintButton.configure(state=DISABLED) drawPoints() def clearPicture(): for i in objectIDs: drawArea.delete(i) objectIDs.clear() for i in rectIDs: drawArea.delete(i) rectIDs.clear() def onClear(): nonlocal clf clearPicture() points.clear() targetClassNumbers.clear() classNumbers.clear() clf = None paintButton.configure(state=DISABLED) def onGenerateRandom(): try: numClasses = int(numClassesText.get()) if numClasses > maxClasses: numClasses = maxClasses numClassesText.set(value = str(numClasses)) elif numClasses <= 0: numClasses = 1 numClassesText.set(value = str(numClasses)) except ValueError: numClasses = numClasses_default numClassesText.set(value = str(numClasses)) try: randomPoints = int(randomPointsText.get()) if randomPoints > maxRandomPoints: randomPoints = maxRandomPoints randomPointsText.set(value = str(randomPoints)) elif randomPoints <= 0: randomPoints = 1 randomPointsText.set(value = str(randomPoints)) except ValueError: randomPoints = randomPoints_default randomPointsText.set(value = str(randomPoints)) onClear() n = randomPoints centers = [] ex = [] ey = [] x0 = xMin()*1/2; x1 = xMax()*1/2 y0 = yMin()*1/3; y1 = yMax()*1/3 for c in range(numClasses): cx = random.uniform(x0, x1) cy = random.uniform(y0, y1) centers.append(R2Point(cx, cy)) angle = random.uniform(0., math.pi) e_x = R2Vector(math.cos(angle), math.sin(angle)) e_y = e_x.normal() e_x *= random.uniform(1., (x1 - x0)/10.) e_y *= random.uniform(1., (x1 - x0)/10.) ex.append(e_x) ey.append(e_y) for i in range(randomPoints): c = random.randrange(numClasses) xxx = random.normalvariate(0., 1.5) yyy = random.normalvariate(0., 1.5) points.append( centers[c] + ex[c]*xxx + ey[c]*yyy ) classNumbers.append(undefinedClass) targetClassNumbers.append(c) drawPoints() def onPaint(): if clf == None: return dx = 0.2 dy = 0.2 x0 = xMin(); x1 = xMax() y = yMin() y1 = yMax() while y < y1: x = x0 while x < x1: c = clf.predict([[x + dx/2, y + dy/2]])[0] color = alphaColor(c) leftTop = R2Point(x, y + dy) rightBottom = R2Point(x + dx, y) lt = map(leftTop) rb = map(rightBottom) rectID = drawArea.create_rectangle( lt[0], lt[1], rb[0], rb[1], fill = color, outline = color ) rectIDs.append(rectID) x += dx y += dy drawPoints() def onConfigure(e): drawArea.delete("all") drawGrid() drawPoints() classifyButton.configure(command = onClassify) resetButton.configure(command = onReset) clearButton.configure(command = onClear) generateRandomButton.configure(command = onGenerateRandom) paintButton.configure(command = onPaint) drawArea.bind("", onMouseRelease) drawArea.bind("", onMouseRelease) drawArea.bind("", onMouseRelease) drawArea.bind("", onConfigure) drawGrid() root.mainloop() if __name__ == "__main__": main()