
| import numpy as np import matplotlib.pyplot as plt from matplotlib.widgets import Button, Slider from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg import tkinter as tk from tkinter import simpledialog from pylab import mpl mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei'] mpl.rcParams['axes.unicode_minus'] = False
class ClassifierVisualizer: def __init__(self, n_points=100): self.n_points = n_points self.generate_data() self.S = set() self.setup_plot() self.test_mode = False self.k = 1
def generate_data(self): self.X = np.random.uniform(-5, 5, (self.n_points, 2)) self.y = (self.X[:, 1] > np.sin(self.X[:, 0])).astype(int)
def setup_plot(self): self.fig = plt.figure(figsize=(15, 6)) self.ax1 = self.fig.add_axes([0.05, 0.15, 0.4, 0.75]) self.ax2 = self.fig.add_axes([0.55, 0.15, 0.4, 0.75]) self.fig.suptitle('分类可视化') self.scatter2 = self.ax2.scatter(self.X[:, 0], self.X[:, 1], c=self.y, cmap='coolwarm') self.ax2.set_title('原始数据') x = np.linspace(-6, 6, 100) self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2) self.ax2.set_xlim(-6, 6) self.ax2.set_ylim(-6, 6)
self.scatter1 = self.ax1.scatter(self.X[:, 0], self.X[:, 1], c='gray') self.s_scatter1 = self.ax1.scatter([], [], c='orange', s=100) self.ax1.set_title('当前分类') self.ax1.set_xlim(-6, 6) self.ax1.set_ylim(-6, 6) self.sine1, = self.ax1.plot(x, np.sin(x), 'g-', lw=2)
self.ax_button = plt.axes([0.81, 0.02, 0.1, 0.075]) self.button = Button(self.ax_button, '下一步') self.button.on_clicked(self.step)
self.ax_slider = plt.axes([0.55, 0.05, 0.3, 0.03]) self.k_slider = Slider(self.ax_slider, 'K值', 1, 10, valinit=1, valstep=1) self.k_slider.on_changed(self.update_k)
def update_k(self, val): self.k = int(val) if self.test_mode: self.update_test_classification()
def nearest_neighbors(self, point, k): if not self.S: return 0 distances = np.sum((self.X[list(self.S)] - point)**2, axis=1) k = min(k, len(self.S)) nearest_indices = np.argsort(distances)[:k] nearest_labels = [self.y[list(self.S)[i]] for i in nearest_indices] return np.mean(nearest_labels) > 0.5
def step(self, event): if not self.test_mode: if not self.S: self.S.add(np.random.choice(self.n_points)) else: W = [i for i in range(self.n_points) if i not in self.S and self.nearest_neighbors(self.X[i], self.k) != self.y[i]] if W: self.S.add(np.random.choice(W)) else: self.prompt_test_mode() return
predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X]) colors = np.where(predictions, 'red', 'blue') self.scatter1.set_facecolors(colors) s_points = self.X[list(self.S)] self.s_scatter1.set_offsets(s_points) self.ax1.set_title(f'当前分类 (S集合大小: {len(self.S)})') else: self.get_test_points()
self.fig.canvas.draw_idle()
def prompt_test_mode(self): self.test_mode = True self.ax2.clear() self.ax2.set_title('测试分类') self.ax2.set_xlim(-6, 6) self.ax2.set_ylim(-6, 6) x = np.linspace(-6, 6, 100) self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2) self.scatter2 = self.ax2.scatter([], []) self.s_scatter2 = self.ax2.scatter([], [], c='orange', s=100) self.button.ax.set_visible(False) self.fig.canvas.draw_idle() self.get_test_points()
def get_test_points(self): root = tk.Tk() root.withdraw()
n_test = simpledialog.askinteger("输入", "请输入测试点的数量:", parent=root, minvalue=1, maxvalue=1000) x_range = simpledialog.askstring("输入", "请输入x坐标范围 (min max):", parent=root) y_range = simpledialog.askstring("输入", "请输入y坐标范围 (min max):", parent=root)
if n_test is None or x_range is None or y_range is None: return
x_min, x_max = map(float, x_range.split()) y_min, y_max = map(float, y_range.split())
x_range = max(abs(x_min), abs(x_max), 6) y_range = max(abs(y_min), abs(y_max), 6)
self.ax2.set_xlim(-x_range, x_range) self.ax2.set_ylim(-y_range, y_range)
x = np.linspace(-x_range, x_range, 1000) self.sine2.set_data(x, np.sin(x))
self.X_test = np.random.uniform(low=[x_min, y_min], high=[x_max, y_max], size=(n_test, 2)) self.update_test_classification()
def update_test_classification(self): self.scatter2.set_offsets(self.X) self.scatter2.set_facecolors(np.where(self.y, 'red', 'blue')) s_points = self.X[list(self.S)] self.s_scatter2.set_offsets(s_points) predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X_test]) colors = np.where(predictions, 'red', 'blue') test_scatter = self.ax2.scatter(self.X_test[:, 0], self.X_test[:, 1], c=colors, marker='s') self.ax2.set_title(f'测试分类 (K={self.k})') self.fig.canvas.draw_idle()
def run(self): plt.show()
visualizer = ClassifierVisualizer() visualizer.run()
|