解锁KNN算法:机器学习的“近邻之道”
一、KNN 算法核心原理
KNN(K - 近邻算法)是机器学习中最简单直观的分类与回归方法,核心思想可概括为 “物以类聚”:对于未知样本,通过计算它与已知样本的距离,找出最近的 K 个样本(近邻),再根据这 K 个样本的类别(或数值)来判断未知样本的类别(或预测数值)。
生活中也有类似逻辑:若想判断一种陌生水果是否为苹果,可观察它与周边已知水果(苹果、梨、橘子)的相似度(大小、颜色、形状),找最像的几个,若多数是苹果,就可判断它是苹果。
二、关键要素解析
(一)K 值的选择
l 定义:K 是近邻的数量(正整数),是 KNN 算法唯一的超参数。
ll 影响:
¢ K 过小:模型易受噪声影响,可能过拟合(把局部特征当普遍规律)。
¢ K 过大:近邻中混入无关样本,可能欠拟合(模型过于笼统)。
选优方法:交叉验证(如将数据分成 5 份,用 4 份训练、1 份验证,测试不同 K 值的准确率,选表现最好的 K)。
(二)距离度量方法
距离是判断 “相似度” 的核心,常用以下 3 种:
1. 欧氏距离:最常用,适用于连续特征(如身高、体重),
2. 曼哈顿距离:适用于城市网格类场景(如计算两点间的街道距离),
3. 闵可夫斯基距离:前两种的推广,当参数 p=2 时为欧氏距离,p=1 时为曼哈顿距离。
(三)决策规则
l 分类问题:多数表决法(K 个近邻中占比最高的类别即为预测结果)。
ll 回归问题:平均值法(K 个近邻的数值平均值作为预测结果)。
三、实现步骤与代码
(一)基本流程
1. 准备数据:清洗(处理缺失值、异常值)、标准化(使各特征量级一致,避免距离被某一特征主导)。
2. 划分数据集:分为训练集(已知样本)和测试集(待预测样本)。
3. 对测试集样本,计算与所有训练集样本的距离。
4. 按距离排序,选前 K 个近邻。
5. 用决策规则预测结果,评估准确率(分类)或误差(回归)。
(二)代码实现(Python + Scikit-learn)
以鸢尾花分类为例:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets
# import the iris data
iris = datasets.load_iris()
# Only use the first two features: sepal length, sepal width
X = iris.data[:, :2]
# Vector of labels
y = iris.target
# generate mesh
h = .02 # step size in the mesh
x1_min, x1_max = X[:, 0].min() - 0.2, X[:, 0].max() + 0.2
x2_min, x2_max = X[:, 1].min() - 0.2, X[:, 1].max() + 0.2
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, h),np.arange(x2_min, x2_max, h))
# Create color maps
rgb = [[255, 238, 255], # red
[219, 238, 244], # blue
[228, 228, 228]] # black
rgb = np.array(rgb)/255.
cmap_light = ListedColormap(rgb)
cmap_bold = [[255, 51, 0], [0, 153, 255], [138,138,138]]
cmap_bold = [np.array(c)/255. for c in cmap_bold]
# 近邻数量
k_neighbors = 4
# kNN分类器
clf = neighbors.KNeighborsClassifier(k_neighbors)
# 拟合数据
clf.fit(X, y)
# 查询点
q = np.c_[xx1.ravel(), xx2.ravel()];
# 预测
y_predict = clf.predict(q)
# 规整形状
y_predict = y_predict.reshape(xx1.shape)
# visualization
fig, ax = plt.subplots()
# plot decision regions
plt.contourf(xx1, xx2, y_predict, cmap=cmap_light)
# plot decision boundaries
plt.contour(xx1, xx2, y_predict, levels=[0,1,2], colors=[np.array([0, 68, 138])/255.],linewidths=1)
# Plot data points
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=iris.target_names[y],
palette=cmap_bold, alpha=1.0,
linewidth=1, edgecolor=[1,1,1])
# Figure decorations
plt.xlim(xx1.min(), xx1.max())
plt.ylim(xx2.min(), xx2.max())
plt.title("k-NN classifier (k = %i, weights = 'uniform')"
% (k_neighbors))
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
ax.grid(linestyle='--', linewidth=0.25, color=[0.5,0.5,0.5])
plt.tight_layout()
plt.axis('equal')
plt.show()四、KNN 在回归问题中的应用
(一)核心逻辑
回归任务中,KNN 通过计算未知样本的 K 个近邻的数值平均值(或加权平均值)进行预测。例如:预测某房屋价格时,找出与该房屋特征(面积、地段、房龄)最接近的 K 套已售房屋,用它们的均价作为预测结果。
(二)应用场景
l 房价预测(基于面积、地段等特征)
ll 气温预测(基于前几日温度、湿度等)
lll 产品销量预测(基于相似区域 / 时段的销量)
(三)代码实现(以加州房价预测为例)
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler
# 设置matplotlib字体
plt.rcParams["font.family"] = ["Arial Unicode MS", "Heiti TC", "sans-serif"]
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
# 1. 数据准备
california = fetch_california_housing()
X, y = california.data, california.target
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 划分数据集(取部分样本用于可视化)
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.2, random_state=42
)
X_test_sample = X_test[:50]
y_test_sample = y_test[:50]
# 2. 模型训练与预测
knn_reg = KNeighborsRegressor(n_neighbors=5, weights='distance')
knn_reg.fit(X_train, y_train)
y_pred_sample = knn_reg.predict(X_test_sample)
# 3. 计算误差
mse = mean_squared_error(y_test, knn_reg.predict(X_test))
print(f"回归均方误差(MSE):{mse:.2f}")
# 4. 绘制预测对比图(标题设置为英文)
plt.figure(figsize=(12, 5))
# 图1:实际值 vs 预测值散点图
plt.subplot(1, 2, 1)
plt.scatter(y_test_sample, y_pred_sample, alpha=0.6, color='b', label='预测值')
plt.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', label='理想预测线')
plt.xlabel('实际房价(10万美元)')
plt.ylabel('预测房价(10万美元)')
plt.title('Actual vs Predicted Values') # 英文标题
plt.legend()
# 图2:前20个样本的预测细节
plt.subplot(1, 2, 2)
indices = np.arange(20)
plt.plot(indices, y_test_sample[:20], 'bo-', label='实际值')
plt.plot(indices, y_pred_sample[:20], 'ro--', label='预测值')
plt.xlabel('样本索引')
plt.ylabel('房价(10万美元)')
plt.title('Prediction Details (First 20 Samples)') # 英文标题
plt.legend()
plt.tight_layout()
plt.show()五、优缺点与适用场景
l 优点:简单易懂,无需训练过程,对非线性数据效果好。
ll 缺点:计算量大(预测时需与所有训练样本算距离),对高维数据(如文本、图像)效果差(维度灾难)。
lll 适用场景:小规模数据集、分类问题(如客户分群、疾病诊断)、简单回归(如房价初步预测)。
通过调整 K 值和距离度量,KNN 可在许多基础任务中发挥作用,是入门机器学习的理想起点。