压缩近邻法的可视化Python实现

最近学习模式识别的时候看到压缩近邻法,然后对其进行了Python实现。这里分享一下代码。

压缩近邻法是为了降低近邻法的计算复杂度,其通过将数据集进行压缩,然后再进行近邻法的计算。这样可以大大降低计算复杂度。

本程序功能:

  1. 生成一个二维的数据集,以 $y=sinx$ 函数作非线性分割。
  2. 使用压缩近邻法构建比原数据集小的分类点集。并可视化了构建过程。
  3. 可以添加测试数据集并观察使用筛选后的分类点集在K-近邻法中对于不同K取值的分类效果。

构建过程如下:

首先设原数据集为$D$,算法尝试构建分类集$S$,使得$S$中的点在最近邻算法下可以对$D$进行分类。

$S$的构建过程如下:

  1. 从$D$中随机选取一个点加入$S$。
  2. 使用最近邻算法对$D$中的点进行分类,将所有分类错误的点归为错误集$E$。
  3. 从$E$中随机选取一个点加入$S$。(此步骤也可以使用其他策略,这里使用随机选择)
  4. 重复2-3步骤,直到$E$为空。这时$S$就是我们要找的分类集。

以下为代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, Slider
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import tkinter as tk
from tkinter import simpledialog
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 指定默认字体:解决plot不能显示中文问题
mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题

class ClassifierVisualizer:
def __init__(self, n_points=100):
self.n_points = n_points
self.generate_data()
self.S = set()
self.setup_plot()
self.test_mode = False
self.k = 1

def generate_data(self):
self.X = np.random.uniform(-5, 5, (self.n_points, 2))
self.y = (self.X[:, 1] > np.sin(self.X[:, 0])).astype(int)

def setup_plot(self):
self.fig = plt.figure(figsize=(15, 6))
self.ax1 = self.fig.add_axes([0.05, 0.15, 0.4, 0.75])
self.ax2 = self.fig.add_axes([0.55, 0.15, 0.4, 0.75])
self.fig.suptitle('分类可视化')

# 原始数据图
self.scatter2 = self.ax2.scatter(self.X[:, 0], self.X[:, 1], c=self.y, cmap='coolwarm')
self.ax2.set_title('原始数据')
x = np.linspace(-6, 6, 100)
self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2)
self.ax2.set_xlim(-6, 6)
self.ax2.set_ylim(-6, 6)

# 当前分类图
self.scatter1 = self.ax1.scatter(self.X[:, 0], self.X[:, 1], c='gray')
self.s_scatter1 = self.ax1.scatter([], [], c='orange', s=100)
self.ax1.set_title('当前分类')
self.ax1.set_xlim(-6, 6)
self.ax1.set_ylim(-6, 6)
self.sine1, = self.ax1.plot(x, np.sin(x), 'g-', lw=2)

# 添加按钮
self.ax_button = plt.axes([0.81, 0.02, 0.1, 0.075])
self.button = Button(self.ax_button, '下一步')
self.button.on_clicked(self.step)

# 添加K值滑动条
self.ax_slider = plt.axes([0.55, 0.05, 0.3, 0.03])
self.k_slider = Slider(self.ax_slider, 'K值', 1, 10, valinit=1, valstep=1)
self.k_slider.on_changed(self.update_k)

def update_k(self, val):
self.k = int(val)
if self.test_mode:
self.update_test_classification()

def nearest_neighbors(self, point, k):
if not self.S:
return 0
distances = np.sum((self.X[list(self.S)] - point)**2, axis=1)
k = min(k, len(self.S))
nearest_indices = np.argsort(distances)[:k]
nearest_labels = [self.y[list(self.S)[i]] for i in nearest_indices]
return np.mean(nearest_labels) > 0.5

def step(self, event):
if not self.test_mode:
if not self.S:
self.S.add(np.random.choice(self.n_points))
else:
W = [i for i in range(self.n_points) if i not in self.S and
self.nearest_neighbors(self.X[i], self.k) != self.y[i]]
if W:
self.S.add(np.random.choice(W))
else:
self.prompt_test_mode()
return

predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X])
colors = np.where(predictions, 'red', 'blue')
self.scatter1.set_facecolors(colors)

s_points = self.X[list(self.S)]
self.s_scatter1.set_offsets(s_points)

self.ax1.set_title(f'当前分类 (S集合大小: {len(self.S)})')
else:
self.get_test_points()

self.fig.canvas.draw_idle()

def prompt_test_mode(self):
self.test_mode = True
self.ax2.clear()
self.ax2.set_title('测试分类')
self.ax2.set_xlim(-6, 6)
self.ax2.set_ylim(-6, 6)
x = np.linspace(-6, 6, 100)
self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2)
self.scatter2 = self.ax2.scatter([], [])
self.s_scatter2 = self.ax2.scatter([], [], c='orange', s=100)

# 隐藏"下一步"按钮
self.button.ax.set_visible(False)

self.fig.canvas.draw_idle()
self.get_test_points()

def get_test_points(self):
root = tk.Tk()
root.withdraw() # 隐藏主窗口

n_test = simpledialog.askinteger("输入", "请输入测试点的数量:", parent=root, minvalue=1, maxvalue=1000)
x_range = simpledialog.askstring("输入", "请输入x坐标范围 (min max):", parent=root)
y_range = simpledialog.askstring("输入", "请输入y坐标范围 (min max):", parent=root)

if n_test is None or x_range is None or y_range is None:
return

x_min, x_max = map(float, x_range.split())
y_min, y_max = map(float, y_range.split())

x_range = max(abs(x_min), abs(x_max), 6)
y_range = max(abs(y_min), abs(y_max), 6)

self.ax2.set_xlim(-x_range, x_range)
self.ax2.set_ylim(-y_range, y_range)

# 更新sin(x)函数的显示范围
x = np.linspace(-x_range, x_range, 1000)
self.sine2.set_data(x, np.sin(x))

self.X_test = np.random.uniform(low=[x_min, y_min], high=[x_max, y_max], size=(n_test, 2))
self.update_test_classification()

def update_test_classification(self):
# 显示原始数据
self.scatter2.set_offsets(self.X)
self.scatter2.set_facecolors(np.where(self.y, 'red', 'blue'))

# 显示S集合的点
s_points = self.X[list(self.S)]
self.s_scatter2.set_offsets(s_points)

# 显示并分类测试点
predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X_test])
colors = np.where(predictions, 'red', 'blue')
test_scatter = self.ax2.scatter(self.X_test[:, 0], self.X_test[:, 1], c=colors, marker='s')

self.ax2.set_title(f'测试分类 (K={self.k})')
self.fig.canvas.draw_idle()

def run(self):
plt.show()

# 创建并运行可视化器
visualizer = ClassifierVisualizer()
visualizer.run()