终于把神经网络中的知识蒸馏搞懂了!!!

大家好,我是小寒今天给大家分享神经网络中的一个关键知识点,知识蒸馏知识蒸馏是一种模型压缩方法,用于将大型神经网络(教师模型)中的知识转移到较小的神经网络(学生模型)中。 这一技术能够在保持或接近原始模型性能的情况下,显著减小模型的体积,从而提升推理效率。 知识蒸馏在很多场景中非常有用,尤其是在计算资源有限或需要部署到边缘设备的应用中。

大家好,我是小寒

今天给大家分享神经网络中的一个关键知识点,知识蒸馏

知识蒸馏是一种模型压缩方法,用于将大型神经网络(教师模型)中的知识转移到较小的神经网络(学生模型)中。

这一技术能够在保持或接近原始模型性能的情况下,显著减小模型的体积,从而提升推理效率。

知识蒸馏在很多场景中非常有用,尤其是在计算资源有限或需要部署到边缘设备的应用中。

知识蒸馏的背景和动机

在深度学习中,尤其是在计算机视觉和自然语言处理等任务中,深度神经网络(DNN)常常有非常庞大的参数量。尽管这些大型模型(如BERT、ResNet等)能够取得非常好的性能,但它们也面临着存储、计算和延迟等挑战。为了克服这一问题,知识蒸馏被提出作为一种方法,通过训练较小的学生模型来模拟大型教师模型的行为。

知识蒸馏的基本概念

  1. 教师模型(Teacher Model)通常是一个预训练的、复杂的深度神经网络,具有较高的精度,但计算和存储开销较大。
  2. 学生模型(Student Model)学生模型相对简单,参数较少,推理速度更快,目标是通过知识蒸馏从教师模型中获取知识,提升其性能。
  3. 软标签(Soft Labels)软标签是教师模型输出的概率分布,而非简单的类别标签。教师模型通常使用 softmax 层生成的概率分布作为软标签,这些分布包含了类别间的相对关系。
  4. 温度(Temperature)在蒸馏过程中,通常使用一个温度参数来调节教师模型输出的概率分布的“平滑程度”。较高的温度会使得输出分布更加平滑,从而让学生模型学习到更多的类间关系。

知识蒸馏的流程

  1. 训练教师模型首先训练一个大型的、高性能的教师模型。该模型在给定的训练数据集上表现非常好,具有高精度,但计算开销较大。
  2. 生成软标签用教师模型对训练数据进行预测,得到每个样本的类别概率分布(即软标签)。可以使用 softmax 函数将教师模型的原始输出转换为概率分布,并通过调节温度参数来控制这些概率分布的平滑度。
  3. 训练学生模型使用教师模型生成的软标签来训练一个较小的学生模型。学生模型的目标是模仿教师模型的输出,从而尽可能地学习到教师模型的知识。训练过程中,学生模型同时会使用真实标签(硬标签)和软标签进行监督学习。
  4. 损失函数设计知识蒸馏的损失函数通常由两个部分组成。传统的监督损失:计算学生模型输出与真实标签之间的交叉熵。蒸馏损失:计算学生模型输出与教师模型输出之间的差异,通常使用 KL 散度度量两个概率分布之间的差异。因此,知识蒸馏的损失函数通常是这两个损失的加权和:

终于把神经网络中的知识蒸馏搞懂了!!!

温度的作用

在知识蒸馏中,温度 T 控制了教师模型输出的“软标签”分布的平滑程度。

较高的温度会使得输出的概率分布更加平滑,减少类间的差异,使学生模型能够学习到更多的类之间的相似性。

  • 在高温度下,教师模型的输出概率分布更加平滑,类之间的概率差异较小。
  • 在低温度下,输出概率分布变得更加尖锐,教师模型的预测结果接近于硬标签。

通过调节温度,可以让学生模型更好地学习到教师模型的知识。

知识蒸馏的优点

  1. 模型压缩通过蒸馏,学生模型通常比教师模型更小,参数数量更少,可以大幅度降低计算和存储开销。
  2. 提高推理速度由于学生模型体积较小,推理速度较快,适合部署到移动设备或资源有限的边缘设备上。

案例分享

以下是一个基于 PyTorch 实现的简单示例代码,展示了如何进行神经网络中的知识蒸馏。

首先,定义教师模型和学生模型。

复制
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F

# 教师模型(较大网络)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(7*7*64, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 7*7*64)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 学生模型(较小网络)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(7*7*32, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 7*7*32)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

接下来,定义蒸馏损失函数。

复制
def distillation_loss(y_student, y_teacher, T=2.0, alpha=0.7):
    # 计算软标签的交叉熵损失
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(y_student / T, dim=1),
        F.softmax(y_teacher / T, dim=1)
    )

    # 计算真实标签的交叉熵损失
    hard_loss = F.cross_entropy(y_student, torch.argmax(y_teacher, dim=1))

    # 综合蒸馏损失
    return alpha * soft_loss + (1 - alpha) * hard_loss

接下来定义一个训练函数,其中教师模型先训练好,然后使用蒸馏损失训练学生模型。

复制
def train(model, device, train_loader, optimizer, epoch, teacher_model=None, T=2.0, alpha=0.7):
    model.train()
    running_loss = 0.0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # 教师模型和学生模型的输出
        output = model(data)
        with torch.no_grad():  # 教师模型在蒸馏时不更新参数
            teacher_output = teacher_model(data)

        # 计算蒸馏损失
        loss = distillation_loss(output, teacher_output, T, alpha)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Train Epoch: {epoch} \tLoss: {running_loss / len(train_loader):.6f}")
    
    
batch_size = 64
epochs = 10
lr = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('.', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 初始化教师模型和学生模型
teacher_model = TeacherModel().to(device)
student_model = StudentModel().to(device)

# 教师模型训练(简单训练)
optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=lr)
teacher_model.train()
for epoch in range(1, epochs + 1):
    train_teacher(teacher_model, device, train_loader, optimizer_teacher, epoch)

# 学生模型训练(蒸馏)
optimizer_student = optim.Adam(student_model.parameters(), lr=lr)
student_model.train()
for epoch in range(1, epochs + 1):
     train(student_model, device, train_loader, optimizer_student, epoch, teacher_mod

相关资讯

只需一行代码,即可轻松驱散基因组分析中DNN产生的数字噪音

编辑 | 白菜叶人工智能已经进入我们的日常生活。它可以是 ChatGPT,也可以是人工智能生成的比萨饼和啤酒广告。虽然我们不能相信人工智能是完美的,但事实证明,有些时候我们根本无法相信人工智能。冷泉港实验室(CSHL)西蒙斯定量生物学中心的助理教授 Peter Koo 发现,在分析 DNA 时,使用流行的计算工具来解释 AI 预测的科学家会收集到太多的「噪音」或额外信息。他找到了解决这个问题的方法。他的团队确定了一个以前被忽视的归因噪声源,该噪声源源于深度神经网络(DNN)如何处理单热编码 DNA。研究人员证明这种

NeurIPS 2023 | 「解释一切」图像概念解释器来了,港科大团队出品

Segment Anything Model(SAM)首次被应用到了基于增强概念的可解释 AI 上。你是否好奇当一个黑盒深度神经网络 (DNN) 预测下图的时候,图中哪个部分对于输出预测为「击球手」的帮助最大?香港科技大学团队最新的 NeurIPS2023 研究成果给出了他们的答案。论文:: Meta 的分割一切 (SAM) 后,港科大团队首次借助 SAM 实现了人类可解读的任意 DNN 模型图像概念解释器:Explain Any Concept (EAC)。你往往会看到传统的 DNN 图像概念解释器会给出这样的解

模型鲁棒性好不好,复旦大学一键式评测平台告诉你

复旦大学自然语言处理实验室发布模型鲁棒性评测平台 TextFlint。该平台涵盖 12 项 NLP 任务,囊括 80 余种数据变形方法,花费超 2 万 GPU 小时,进行了 6.7 万余次实验,验证约 100 种模型,选取约 10 万条变形后数据进行了语言合理性和语法正确性人工评测,为模型鲁棒性评测及提升提供了一站式解决方案。