Distillation论文总结(1)Do Deep Nets Really Need to be Deep?

论文地址: NIPS2014; arxiv.

Do Deep Nets Really Need to be Deep?

目前,深度神经网络在许多领域都得到了出色的,但本文我们要说明浅层网络通过学习深层网络也能达到相似的结果。我们在TIMIT与CIFAR-10两个数据集上进行了测试。

引言

传说中如果我用一个深的网络去训练数据,会得到一个比浅网络更好的结果。这个“更好”是从哪里来的呢?

以前的研究(我觉得是很早的研究)表明一个深网络加上逐层预训练效果不错,还有很多研究表示很难将浅网络训出和深网络一样好的效果。本文就想要说明其实浅网络也有一样的表达能力,方法是先训一个好的深网络,再用一个浅网络去模仿它。虽然不能直接用数据训出相媲美的结果,但是可以用model compression的方法。如果一个有着相同参数个数的模型能有一样好的效果,不就可以说明深度网络并不需要深了么。

模仿学习训练浅层网络

模型压缩

Model Compression,用未标注的数据去训练小模型,使它产生与大模型相似的输出,从而有相近的准确率。现有的算法不足以利用源数据将小模型训到那种精度,所以必须要用大模型来指导教学。这也说明了模型的复杂度,与模型的结构、大小是两码事。(大模型不一定是复杂的模型)

用L2Loss回归来模仿学习

训练方法:deep model还是老方法:用softmax输出交叉熵训练;然而小模型要用softmax之前的logits。因为一来softmax之后会使原本比较小的数便得更小学不到细节;二来不同的logits可能会导致一样的概率,SoftMax内有信息损失。

除了用logits的L2Loss以外 ,文章还考虑了KL散度、pred的L2等多种训练方法,但是结果都没前面好;文章还试过对logits作规一化,效果会更好,但不是必须的。

引入线性层加速模仿学习

为了有一样的参数量,浅模型在一层当中就需要有更多的神经元,这会便训练变得很慢——一方面,大的Matrix当中会有很多相关的参数;另一方面,大矩阵的乘法很消耗资源。我们就引入了一个线性瓶颈层(其实就是一个矩阵的低秩分解),将一个矩阵分解成两个。从而既加速了收敛又减少了存储占用。以前这个只在最后一层使用,现在我们把它用到了输入层。

这个矩阵分解可以只在训练时使用,然后训练完之后可以将中间的非线性层去掉,又变成了同一层。

TIMIT Phoneme Recognition

目标检测: CIFAR-10

CIFAR-10数据集中有10类物品,是tiny images datasets中的一个标注子集,每张图片是32x32x3共3072个通道。我们对数据的处理包括减均值并除以每张图片的方差以做到global contrast normalization,然后再对图片做ZCA白化。

因为CIFAR-10数据太少了,所以我们用tiny images和cifar-10一起构成一个1.05M的数据集进行训练。因为非卷积网络之前被证实效果不好,所以我们就用尽可能浅的conv网络。事实证明小网络的acc是可以被蒸馏学习抬上去0.9%(怎么感觉有、少啊)

讨论

为什么蒸馏比直接训练原数据好

  1. 原数据可能有问题(???),T网络可能可以消除它;
  2. 直接学是有难度的,T网络可以理解成对target做了一个滤波;
  3. 不确定的logits相较于0-1标签有更多的信息;
  4. GT可能依赖于某些输入并没有给的特征,而T网络的输出是一定可以从输入得到的。

以上也可以看作是对overfit的预防措施。相较而言,浅网络更容易受到overfit的困扰,如果我们有更好的正则化手段的话,浅深网络之间的鸿沟就可能被消弥。Model Compression似乎是一种可靠的手段。

浅网络的表征能力

我们用两个不同大小的浅网络去学习深网络。在使用同样的Teacher的情况下自然大一些的网络学的更好一些,但是两个浅网络的学习曲线都是近似的45度斜线——即浅网络可以通过更好的模型或更多的数据与深网络的准确率同步增长。文章认为,没有任何证据表明浅网络的表征能力存在上限,而限制其acc的主要是我们的学习算法和正则项。

并行分布处理VS深度顺序处理

未来工作

结论

文章认为,当前深层网络的优势主要是在于其与目下的学习算法较好的匹配性,如果以后能设计出一种适用于浅层网络的学习算法就好了。一定数量的参数,网络深一点可以更容易学一点,但不是必须的。

猜你喜欢

转载自blog.csdn.net/volga_chen/article/details/84970178