knowledge_distillation
Table of Contents
知识蒸馏(Knowledge Distillation)是一种机器学习技术,它通过将大型、复杂的模型(称为教师 模型,Teacher Model)的知识“蒸馏”到小型、简洁的模型(称为学生模型,Student Model)中,从 而实现模型压缩和加速,同时尽可能保持原始模型的性能。这一技术使得模型可以在资源有限的设备 上高效运行,如手机或嵌入式设备。
The method works by incorporating an additional loss into the traditional cross entropy loss, which is based on the softmax output of the teacher network. The assumption is that the output activations of a properly trained teacher network carry additional information that can be leveraged by a student network during training.
#
基本原理
知识转移:核心思想是通过让学生模型模仿教师模型的输出行为,不仅包括硬分类标签,还有软概率 分布(softmax概率),这样可以传递更多关于数据分布的信息。软标签相比硬标签含有更多关于数 据不确定性及类间关系的信息,有助于学生模型学习更细腻的决策边界。
特征蒸馏:除了输出层的知识外,泛化知识蒸馏还可以涉及中间层特征的学习,即学生模型试图学习 教师模型的高层特征表示。这有助于提升学生模型的泛化能力,因为它学会了如何从输入数据中提取 更有用的特征。
关系蒸馏:强调保持教师和学生模型对于输入样本间关系的理解一致性。这意味着学生模型不仅要学 会单个样本的处理,还要理解样本之间的相对关系,这对于一些需要理解复杂上下文的任务尤为重要。
知识蒸馏过程:
训练教师模型:首先,使用大量数据和计算资源训练一个高性能的深度神经网络(教师模 型)。这个模型可能包含数百万甚至数十亿个参数,但它在分类任务上的表现非常出色。
生成软标签:教师模型在对输入数据进行预测时,不仅仅给出最终的分类结果,还会给出各 类别的概率分布(通常通过softmax层获得)。这些概率分布被称为“软标签”,它们包含了额外的 信息,比如类别的不确定性。
训练学生模型:接下来,使用教师模型的软标签和实际的硬标签(即数据的真实类别)来训 练学生模型。学生模型的架构设计得更简单,参数量远小于教师模型。训练过程中,学生模型不 仅要学习模仿硬标签,还要通过损失函数(如KL散度或交叉熵)尽量接近教师模型的软标签输出。
温度参数调整:在生成软标签时,有时会引入一个“温度”参数来调整概率分布的平滑程度。 高温可以使软标签更加平滑,促进学生模型学习到教师模型的决策边界;低温则使得软标签接近 硬标签,但可能会丢失教师模型的一些细微决策信息。
#
知识蒸馏的三种代码实现:
- 知识转移:基于 softmax output 软标签
# ...
optimizer.zero_grad()
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
teacher_logits = teacher(inputs)
# Forward pass with the student model
student_logits = student(inputs)
#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
# ...
- 特征蒸馏:基于 hidden state 余弦相似度
# ...
optimizer.zero_grad()
# Forward pass with the teacher model and keep only the hidden representation
with torch.no_grad():
_, teacher_hidden_representation = teacher(inputs)
# Forward pass with the student model
student_logits, student_hidden_representation = student(inputs)
# Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
# ...
- 特征蒸馏:基于 Intermediate regressor MSE损失
# ...
optimizer.zero_grad()
# Again ignore teacher logits
with torch.no_grad():
_, teacher_feature_map = teacher(inputs)
# Forward pass with the student model
student_logits, regressor_feature_map = student(inputs)
# Calculate the loss
hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
# ...
具体参考:https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html?highlight=distill