cs231n K最近邻(KNN)总结

虽然很久之前就知道CS231N这门课程了,但是一直没学习,估计是自己太忙了吧 :^^:。最近深入之后,发现这门课真的很棒,授课视频+笔记+作业,完成这一整套,必然能极大地提升自己的能力。

在这个系列博客里,计划把每一小节的作业内容都进行一下总结,并以此作为自己近期的学习线索,然后根据里面涉及的知识点扩散开来,构成一个完整的学习系统,一点一滴地积累,作为自己今后快速复习巩固的宝贵资料。各个小节的内容不求非常全面,只希望将自己不太熟悉的部分记录下来,尤其是常用的API。

KNN原理

K最近邻(K Nearest Neighbor, KNN)是最简单、最基础的一个图像分类算法,基本过程是:

  • 训练阶段:简单地保存训练集的图片内容和对应的标签,不做其它任何处理。
  • 预测阶段:对于测试集中的每一张图片,分别计算其与训练集中的每张图片的距离(Distance),然后挑选K个距离最近的训练集图片,并将这K个图所对应的标签次数最多的图片(所对应的标签)作为预测结果。

计算图片的距离也非常简单,通常包含两种计算方法:L1距离L2距离

计算两张图片的L1距离可以可以拆分为:先计算两张图片中对应像素点的像素值之差;然后对这些像素差取绝对值;最后对所有这些像素点之差的绝对值进行累计求和。具体公式是:

L2距离与L1距离非常相似,具体计算过程可以拆分为:先计算两张图片中对应像素点的像素值之差;然后对这些像素差求平方值;然后对所有这些像素差的平方值进行累加求和;最后再对其进行开平方根。具体公式是:

KNN缺点

KNN最大的也是唯一的优点是简单易于理解,但是其缺点非常致命:

  • 非常浪费存储空间。因为分类器会记住所有训练集的数据和标签,而大型图片训练集都是GB为单位的。
  • 预测时计算量大/慢。因为每一张测试/预测图片都会与训练集中的所有图片计算距离。

抛开上面这两条不说(这两条主要是考虑实际的部署时的情况),其分类的效果也很差。因此,在实践中很少使用KNN进行图像分类。

pickle模块

最开始的几个作业所使用的数据集都是CIFAR-10,其格式是pickle。pickle是python提供的一个可以将python的数据结构与文件相互转换的模块。pickle最常用的两个API是dumpload

举例:使用pickle.dump()将python结构保存到文件中:

1
2
3
4
5
6
7
8
9
10
11
12
import pickle

# An arbitrary collection of objects supported by pickle.
data = {
'a': [1, 2.0, 3, 4+6j],
'b': ("character string", b"byte string"),
'c': {None, True, False}
}

with open('data.pickle', 'wb') as f:
# Pickle the 'data' dictionary using the highest protocol available.
pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

举例:使用pickle.load()将pickle文件中的内容恢复为python数据结构:

1
2
3
4
5
6
import pickle

with open('data.pickle', 'rb') as f:
# The protocol version used is detected automatically, so we do not
# have to specify it.
data = pickle.load(f)

相关链接:pickle — Python object serialization

训练图片可视化

需求:对于每一类图片,从训练集中随机取出7个,然后画出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
samples_per_class = 7
for y, cls in enumerate(classes):
# np.flatnonzero(a)的作用是返回a中非0元素的index,因此
# np.flatnonzero(y_train == y)的作用就是返回训练集中所有标签等于y的
# 图片的index,这一点比较巧妙
# 总结:我们可以用np.flatnonzero(a==x)返回数组中的特定元素的index
idxs = np.flatnonzero(y_train == y)
# np.random.choice(idxs, size)的作用是从idxs中随机选择并返回size个元素
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(X_train[idx].astype('uint8'))
plt.axis('off')
if i == 0:
plt.title(cls)
plt.show()

相关链接:

计算测试精度

1
2
3
num_correct = np.sum(y_test_pred == y_test)
accuracy = float(num_correct) / num_test
print('Got %d / %d correct => accuracy: %f' % (num_correct, num_test, accuracy))

数组拆分/组合

英文词汇