1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
| import numpy as np import matplotlib.pyplot as plt from matplotlib.widgets import Button, Slider from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg import tkinter as tk from tkinter import simpledialog from pylab import mpl mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei'] mpl.rcParams['axes.unicode_minus'] = False
class ClassifierVisualizer: def __init__(self, n_points=100): self.n_points = n_points self.generate_data() self.S = set() self.setup_plot() self.test_mode = False self.k = 1
def generate_data(self): self.X = np.random.uniform(-5, 5, (self.n_points, 2)) self.y = (self.X[:, 1] > np.sin(self.X[:, 0])).astype(int)
def setup_plot(self): self.fig = plt.figure(figsize=(15, 6)) self.ax1 = self.fig.add_axes([0.05, 0.15, 0.4, 0.75]) self.ax2 = self.fig.add_axes([0.55, 0.15, 0.4, 0.75]) self.fig.suptitle('分类可视化') self.scatter2 = self.ax2.scatter(self.X[:, 0], self.X[:, 1], c=self.y, cmap='coolwarm') self.ax2.set_title('原始数据') x = np.linspace(-6, 6, 100) self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2) self.ax2.set_xlim(-6, 6) self.ax2.set_ylim(-6, 6)
self.scatter1 = self.ax1.scatter(self.X[:, 0], self.X[:, 1], c='gray') self.s_scatter1 = self.ax1.scatter([], [], c='orange', s=100) self.ax1.set_title('当前分类') self.ax1.set_xlim(-6, 6) self.ax1.set_ylim(-6, 6) self.sine1, = self.ax1.plot(x, np.sin(x), 'g-', lw=2)
self.ax_button = plt.axes([0.81, 0.02, 0.1, 0.075]) self.button = Button(self.ax_button, '下一步') self.button.on_clicked(self.step)
self.ax_slider = plt.axes([0.55, 0.05, 0.3, 0.03]) self.k_slider = Slider(self.ax_slider, 'K值', 1, 10, valinit=1, valstep=1) self.k_slider.on_changed(self.update_k)
def update_k(self, val): self.k = int(val) if self.test_mode: self.update_test_classification()
def nearest_neighbors(self, point, k): if not self.S: return 0 distances = np.sum((self.X[list(self.S)] - point)**2, axis=1) k = min(k, len(self.S)) nearest_indices = np.argsort(distances)[:k] nearest_labels = [self.y[list(self.S)[i]] for i in nearest_indices] return np.mean(nearest_labels) > 0.5
def step(self, event): if not self.test_mode: if not self.S: self.S.add(np.random.choice(self.n_points)) else: W = [i for i in range(self.n_points) if i not in self.S and self.nearest_neighbors(self.X[i], self.k) != self.y[i]] if W: self.S.add(np.random.choice(W)) else: self.prompt_test_mode() return
predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X]) colors = np.where(predictions, 'red', 'blue') self.scatter1.set_facecolors(colors) s_points = self.X[list(self.S)] self.s_scatter1.set_offsets(s_points) self.ax1.set_title(f'当前分类 (S集合大小: {len(self.S)})') else: self.get_test_points()
self.fig.canvas.draw_idle()
def prompt_test_mode(self): self.test_mode = True self.ax2.clear() self.ax2.set_title('测试分类') self.ax2.set_xlim(-6, 6) self.ax2.set_ylim(-6, 6) x = np.linspace(-6, 6, 100) self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2) self.scatter2 = self.ax2.scatter([], []) self.s_scatter2 = self.ax2.scatter([], [], c='orange', s=100) self.button.ax.set_visible(False) self.fig.canvas.draw_idle() self.get_test_points()
def get_test_points(self): root = tk.Tk() root.withdraw()
n_test = simpledialog.askinteger("输入", "请输入测试点的数量:", parent=root, minvalue=1, maxvalue=1000) x_range = simpledialog.askstring("输入", "请输入x坐标范围 (min max):", parent=root) y_range = simpledialog.askstring("输入", "请输入y坐标范围 (min max):", parent=root)
if n_test is None or x_range is None or y_range is None: return
x_min, x_max = map(float, x_range.split()) y_min, y_max = map(float, y_range.split())
x_range = max(abs(x_min), abs(x_max), 6) y_range = max(abs(y_min), abs(y_max), 6)
self.ax2.set_xlim(-x_range, x_range) self.ax2.set_ylim(-y_range, y_range)
x = np.linspace(-x_range, x_range, 1000) self.sine2.set_data(x, np.sin(x))
self.X_test = np.random.uniform(low=[x_min, y_min], high=[x_max, y_max], size=(n_test, 2)) self.update_test_classification()
def update_test_classification(self): self.scatter2.set_offsets(self.X) self.scatter2.set_facecolors(np.where(self.y, 'red', 'blue')) s_points = self.X[list(self.S)] self.s_scatter2.set_offsets(s_points) predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X_test]) colors = np.where(predictions, 'red', 'blue') test_scatter = self.ax2.scatter(self.X_test[:, 0], self.X_test[:, 1], c=colors, marker='s') self.ax2.set_title(f'测试分类 (K={self.k})') self.fig.canvas.draw_idle()
def run(self): plt.show()
visualizer = ClassifierVisualizer() visualizer.run()
|