Skip to content

论文链接

约 4544 个字 预计阅读时间 23 分钟

Distilling the Knowledge in a Neural Network

论文的核心思想

  • 问题背景:在机器学习中,集成多个模型(如多个神经网络)通常可以提高性能,但部署这些模型时计算成本过高。因此,论文提出了一种方法,将复杂模型的知识“蒸馏”到一个更小的模型中,以便更容易部署。
  • 知识蒸馏的核心:通过将复杂模型的输出(称为“软目标”或“soft targets”)作为小模型的训练目标,小模型可以学习到复杂模型的泛化能力

因为模型最终的落实主要成本在于推理和部署,小模型的推理和部署成本较低,通过小模型学习大模型提取的复杂特征,大模型的知识能够传递给小模型。

本文澄清了一个误区

大模型学习的知识并不是它本身的参数,而是输入向量和输出向量中的转换,是这种映射,我们希望小模型学习的是这种转换关系。因此,比如在分类模型中,尽管正确分类的概率占据最大部分,其他错误分类的占比仍然有所区分,这是有意义,这反映了模型的思考,尽管都是错的,但就应该有一些更接近正确。
即更重要的是学习大模型的泛化能力。小模型可以通过对大模型的模仿来获得一部分的泛化能力。这通常比直接在训练数据上训练的效果要好。

软目标

软目标(Soft Targets)的概念

  • 软目标是指复杂模型(如大型模型或集成模型)输出的概率分布。
  • 与硬目标(hard targets,即真实的类别标签)不同,软目标包含了关于数据相似性的丰富信息。例如,一个图像被分类为“3”的概率可能很低,但它被分类为“7”的概率可能比被分类为“胡萝卜”的概率高得多。这种概率分布反映了复杂模型对数据的泛化规律

使用软目标进行知识蒸馏

  • 为了将复杂模型的泛化能力迁移到小模型中,可以使用复杂模型的输出概率分布(软目标)作为小模型的训练目标。
  • 这种迁移可以通过在相同的训练集或一个单独的“迁移集”(transfer set)上进行。

集成模型的软目标

  • 如果复杂模型是一个由多个简单模型组成的集成模型,可以将这些模型的预测分布取算术平均或几何平均作为软目标。
  • 这种方法可以捕捉集成模型的集体智慧,从而提高小模型的性能。

4. 软目标的优势

  • 高熵:软目标的概率分布通常具有较高的熵,这意味着它们提供了比硬目标更多的信息。例如,硬目标只能告诉模型“这个样本属于类别 A”,而软目标可以告诉模型“这个样本有 70% 的概率属于类别 A,20% 的概率属于类别 B,10% 的概率属于类别 C”。
  • 梯度稳定性:软目标可以减少梯度的方差,使得训练过程更加稳定。
  • 数据效率:由于软目标提供了更多的信息,小模型通常可以在比复杂模型更少的数据上进行训练,并且可以使用更高的学习率。

知识蒸馏的实现

我们知道 softmax 常作为分类模型的输出层,用于把输入向量 z 转换成概率分布 q

\[ p(i) = \frac{\exp(z_i)}{\sum^{n}_{j=1} \exp(z_j)} \]

温度参数(Temperature)

在知识蒸馏中,通常会引入一个温度参数 T,用于调整 softmax 输出的概率分布的“软硬”程度。温度参数 T 的引入使得 softmax 函数变为

\[ p(i) = \frac{\exp(z_i / T)}{\sum^{n}_{j=1} \exp(z_j / T)} \]
  • 高温(T 较大):输出的概率分布更加平滑,类别之间的差异被缩小。
  • 低温(T 较小):输出的概率分布更加尖锐,类别之间的差异被放大。

最简单的形式,用高温训练大模型,用同样的温度训练小模型,然后再降温小模型为 1 让输出结果更尖锐
一种改进方式,使用软目标的交叉熵和正确数据标签的交叉熵,加权平均,由于软目标的梯度变为了 \(1/T^2\), 因此要额外乘一个 \(T^2\)
论文提到正确标签的交叉熵较小时效果更好

logits 匹配

  • logits 是神经网络输出层的原始值
    知识蒸馏的过程中,迁移集中的每个样本对蒸馏模型的每个 logit \(z_i\) 都提供 v 一个交叉熵梯度 \(\frac{\partial C}{\partial z_i}\),其中大模型的 logits 为 \(v_i\)。经过温度 T 生成的软目标概率为 \(P_i\),交叉上梯度可以表示为:
\[ \frac{\partial C}{\partial z_i} = \frac{1}{T} \left( q_i - p_i \right)= \frac{1}{T} \left( \frac{e^{{z_i}/T}}{\sum_{j} e^{z_j / T}} - \frac{e^{{v_i}/T}}{\sum_{j} e^{v_j / T}} \right) \]

一文详解Softmax函数 - 知乎这里是对梯度的详细计算

如果温度的数量级远大于 logits 的数量级,即 \(T \gg |z_i|\) 的时候,利用泰勒展开 \(e^{z_i/T}\approx 1+z_i/T\), 得到

\[ \frac{\partial C}{\partial z_i}\approx \frac{1}{T} \left( \frac{1 + z_i / T}{N + \sum_{j} z_j / T}+ \frac{1 + v_i / T}{N + \sum_{j} v_j / T}\right) \]

如果是零均值的情况下 \(\sum_{j} z_j = 0\) \(\sum_{j} v_j = 0\) 得到进一步的近似公式

\[ \frac{\partial C}{\partial z_i} \approx \frac{1}{N T^2} (z_i - v_i) \]

这就是高温极限蒸馏下的损失函数了
* 当温度 T 较低时,蒸馏过程对 logits 的匹配变得不那么敏感,尤其是对那些比平均值更负的 logits。这是因为低温使得 softmax 函数的输出更加尖锐,负 logits 的影响被缩小。

为什么忽略负 logits 可能有益

  • 噪声问题:负 logits 可能包含大量噪声,因为它们在复杂模型的训练过程中几乎不受损失函数的约束。忽略这些 logits 可以减少噪声对蒸馏模型的影响。
  • 信息价值:尽管负 logits 可能包含一些关于复杂模型知识的信息,但它们的实际价值可能较低,尤其是在蒸馏模型较小时。

中间温度的优势

当蒸馏模型过小,无法完全捕捉复杂模型的所有知识时,使用中间温度通常能取得最佳效果。这表明忽略较大的负 logits 可以提高蒸馏效果,因为这样可以减少噪声的影响,同时保留有用的信息。

论文其余内容

MNIST 上的实验概括

  • 验证知识蒸馏的有效性:通过在 MNIST 数据集上进行实验,验证知识蒸馏是否能够将复杂模型的泛化能力迁移到小模型中。
  • 探索温度参数的影响:研究不同温度参数对蒸馏效果的影响。

实验设置

  • 数据集:MNIST 数据集,包含 60,000 个训练样本和 10,000 个测试样本。
  • 复杂模型:一个大型神经网络,包含两层隐藏层,每层有 1200 个神经元,使用 ReLU 激活函数,并通过 dropout 和权重约束进行正则化。
  • 小模型:一个较小的神经网络,包含两层隐藏层,每层有 800 个神经元,没有正则化。
  • 蒸馏模型:通过知识蒸馏将复杂模型的知识迁移到小模型中。

实验过程

  1. 训练复杂模型
    • 复杂模型在 MNIST 数据集上进行训练,使用 dropout 和权重约束进行正则化。
    • 复杂模型在测试集上的错误数为 67。
  2. 训练小模型
    • 小模型在相同的训练集上进行训练,没有使用正则化。
    • 小模型在测试集上的错误数为 146。
  3. 蒸馏过程
    • 使用复杂模型的输出作为小模型的软目标。
    • 蒸馏模型在训练时使用高温 softmax 生成软目标,并在训练完成后将温度设为 1。
    • 蒸馏模型在测试集上的错误数为 74,显著优于未经过蒸馏的小模型。

实验结果

  • 复杂模型:测试错误数为 67。
  • 小模型:测试错误数为 146。
  • 蒸馏模型:测试错误数为 74,表明知识蒸馏能够显著提高小模型的性能。

温度参数的影响

  • 高温蒸馏:当蒸馏模型的隐藏层单元数足够多(300 或更多)时,温度参数 T 在 8 以上时结果相似。
  • 低温蒸馏:当蒸馏模型的隐藏层单元数非常少(30)时,温度参数在 2.5 到 4 之间效果最佳。

特殊情况实验

  • 缺失类别:在蒸馏过程中,如果蒸馏模型从未见过某些类别(例如数字 3),它仍然能够通过软目标学习到这些类别的特征。
    • 蒸馏模型在测试集上的错误数为 206,其中 133 个错误是关于数字 3 的。
    • 通过调整偏置项,蒸馏模型在数字 3 上的错误数减少到 14,总错误数减少到 109。
  • 极端情况:如果蒸馏模型只见过训练集中的 7 和 8,其测试错误率为 47.3%,但通过调整偏置项,错误率降低到 13.2%。

总结

  • 知识蒸馏的有效性:通过知识蒸馏,小模型能够学习到复杂模型的泛化能力,显著提高性能。
  • 温度参数的影响:高温适合隐藏层单元数较多的模型,而低温适合隐藏层单元数较少的模型。
  • 鲁棒性:即使蒸馏模型从未见过某些类别,通过调整偏置项,仍然能够取得较好的性能

语音识别中的应用

  • 验证知识蒸馏在语音识别中的有效性:通过在语音识别任务中应用知识蒸馏,验证蒸馏模型是否能够捕捉复杂模型的泛化能力。
  • 比较蒸馏模型与基线模型和集成模型的性能:评估蒸馏模型在帧分类准确率和词错误率(WER)上的表现。

实验设置

  • 数据集:使用了大约 2000 小时的英语语音数据,包含约 7 亿个训练样本。
  • 基线模型:一个深度神经网络(DNN),包含 8 层隐藏层,每层有 2560 个 ReLU 神经元,最终 softmax 层有 14,000 个标签。总参数量约为 85M。
  • 集成模型:训练了 10 个独立的 DNN 模型,每个模型的结构和训练过程与基线模型相同,但初始化参数不同。
  • 蒸馏模型:通过知识蒸馏将集成模型的知识迁移到单个 DNN 模型中。

实验过程

  1. 训练基线模型
    • 基线模型在训练集上进行训练,使用分布式随机梯度下降。
    • 基线模型在开发集上的帧分类准确率为 58.9%,词错误率(WER)为 10.9%。
  2. 训练集成模型
    • 训练了 10 个独立的 DNN 模型,每个模型的结构和训练过程与基线模型相同,但初始化参数不同。
    • 集成模型通过平均预测分布来提高性能。
    • 集成模型在开发集上的帧分类准确率为 61.1%,词错误率为 10.7%。
  3. 蒸馏过程
    • 使用集成模型的输出作为蒸馏模型的软目标。
    • 蒸馏模型在训练时使用高温 softmax 生成软目标,并在训练完成后将温度设为 1。
    • 蒸馏模型在开发集上的帧分类准确率为 60.8%,词错误率为 10.7%。

实验结果

  • 基线模型
    • 帧分类准确率:58.9%
    • 词错误率(WER):10.9%
  • 集成模型
    • 帧分类准确率:61.1%
    • 词错误率(WER):10.7%
  • 蒸馏模型
    • 帧分类准确率:60.8%
    • 词错误率(WER):10.7%

结论

  • 知识蒸馏的有效性:蒸馏模型在帧分类准确率和词错误率上接近集成模型的性能,表明知识蒸馏能够有效地将集成模型的泛化能力迁移到单个模型中。
  • 计算效率:蒸馏模型的性能接近集成模型,但计算成本显著降低,更适合部署。

大规模数据集上训练专家模型集合

  • 大规模数据集的挑战:当数据集非常大且类别数量非常多时,训练大型神经网络的计算成本极高。
  • 专家模型的优势:专家模型可以专注于数据集中的一小部分类别,从而减少计算成本并提高模型性能。

专家模型的设计

  • 通用模型:首先训练一个通用模型(generalist model),该模型在所有数据上进行训练。
  • 专家模型:每个专家模型专注于一个特定的、容易混淆的类别子集。专家模型的 softmax 层可以缩小到只包含其专注的类别和一个“垃圾箱”类别(dustbin class),用于合并所有不相关的类别。

训练专家模型

  1. 初始化:专家模型的权重从通用模型的权重初始化。
  2. 数据选择:每个专家模型在包含其专注类别和随机采样的其他类别的数据上进行训练。
  3. 调整偏置:为了纠正训练集的偏差,专家模型的垃圾箱类别的 logit 增加了一个偏置项,该偏置项反映了其专注类别在训练集中的过采样比例。

类别分配

  • 混淆矩阵:通过分析通用模型的预测混淆矩阵,找到经常被混淆的类别群组。
  • 聚类算法:使用聚类算法(如在线 K-means)对混淆矩阵的列进行聚类,以确定每个专家模型的专注类别。

推理过程

  • 两步分类
    1. 通用模型预测:使用通用模型预测每个测试样本的最可能类别。
    2. 专家模型预测:根据通用模型的预测结果,选择相关的专家模型进行进一步预测,并通过最小化 KL 散度融合通用模型和专家模型的预测结果

实验结果

  • 数据集:JFT 数据集,包含 1 亿张图像和 15,000 个类别。
  • 基线模型:一个深度卷积神经网络,训练了约 6 个月。
  • 专家模型:训练了 61 个专家模型,每个专家模型专注于 300 个类别(加上垃圾箱类别)。
  • 性能提升
    • 基线模型的测试准确率为 43.1%。
    • 结合专家模型后,测试准确率提升到 45.9%。
    • 条件测试准确率(仅考虑专家类别)从 25.0% 提升到 26.1%。

结论

  • 专家模型的有效性:专家模型能够显著提高大规模数据集上的分类性能,同时减少训练时间。
  • 并行化优势:专家模型可以独立训练,易于并行化,适合大规模数据集的处理。

总结部分

知识蒸馏的有效性

  • MNIST 实验:知识蒸馏在 MNIST 数据集上表现出色,即使蒸馏模型从未见过某些类别,仍然能够通过软目标学习到这些类别的特征。
  • 语音识别实验:知识蒸馏能够将集成模型的性能迁移到单个模型中,蒸馏模型在帧分类准确率和词错误率上接近集成模型的性能,同时减少了计算成本。
  • 大规模数据集:在 JFT 数据集上,通过训练专家模型集合,知识蒸馏显著提高了模型的性能,同时减少了训练时间。

温度参数的影响

  • 高温蒸馏:高温适合隐藏层单元数较多的模型,能够有效地传递复杂模型的泛化能力。
  • 低温蒸馏:低温适合隐藏层单元数较少的模型,能够减少噪声的影响。
  • 中间温度:当中间温度被使用时,蒸馏效果最佳,特别是在蒸馏模型较小时,这表明忽略较大的负 logits 是有益的。

专家模型的优势

  • 并行化训练:专家模型可以独立训练,易于并行化,适合大规模数据集的处理。
  • 性能提升:专家模型通过专注于特定类别并结合通用模型的预测,能够在推理时提供更准确的结果。

未来研究方向

  • 进一步蒸馏专家模型:探索如何将专家模型的知识进一步蒸馏到单个模型中。
  • 更复杂的数据集:在更复杂的数据集上验证知识蒸馏的有效性。
  • 不同的模型架构:研究知识蒸馏在不同模型架构中的应用,如卷积神经网络(CNN)和循环神经网络(RNN)。

阅读思考

虽然这篇论文已经是 2015 年的距今已经有十年了,但是其中的很多思想也是很有用的,deepseek 采用的 MOE 框架并不是新鲜的内容,主流的模型架构也从 RNN,CNN 转而现在的 Transformer,但是模型蒸馏仍然是有意义的,即:
* 用更少的资源来实现相同的推理能力
* 相同的计算消耗来获取更强的推理性能
* 缩小模型从而更有利于部署和降低推理成本

参考资料

模型蒸馏简单上手 - Ecank的小屋
Distilling the Knowledge in a Neural Network
浅谈蒸馏 Distilling the Knowledge in a Neural Network 及公式推导 - 知乎
一文详解 Softmax 函数 - 知乎