从零开始构建 DINO:自监督视觉 Transformer

DINO模型输出的狗冲刺无标签自蒸馏(DINO)《从几个“补丁”中重建完整图像 | 构建可扩展学习器的掩模自编码器》这边文章讲了如何构建可扩展学习器,这是我对视觉变换器系列的继续,其中我解释了最重要的架构及其从零开始的实现。 自监督学习自监督学习(SSL)是一种机器学习类型,模型通过无需手动标记的示例来学习理解数据。 相反,它从数据本身生成其监督信号。

从零开始构建 DINO:自监督视觉 Transformer

DINO模型输出的狗冲刺

无标签自蒸馏(DINO)

《从几个“补丁”中重建完整图像 | 构建可扩展学习器的掩模自编码器》这边文章讲了如何构建可扩展学习器,这是我对视觉变换器系列的继续,其中我解释了最重要的架构及其从零开始的实现。

自监督学习

自监督学习(SSL)是一种机器学习类型,模型通过无需手动标记的示例来学习理解数据。相反,它从数据本身生成其监督信号。当标记数据有限且获取成本高昂时,这种方法非常有益。在SSL中,学习过程涉及创建任务,其中输入数据可以用来预测数据本身的某些部分。常见的技术包括:

  • 对比学习:模型通过区分相似和不相似的数据对来学习。
  • 预测任务:模型从其他部分预测输入数据的一部分,例如预测句子中的下一个词或从其周围环境中预测词的上下文。

DINO模型

DINO(无标签蒸馏)模型是一种应用于视觉变换器(ViTs)的尖端自监督学习方法。它代表了计算机视觉领域的一个重大进步,使模型能够在不需要任何标记数据的情况下学习有效的图像表示。由Facebook AI Research(FAIR)的研究人员开发,DINO利用学生-教师框架和创新的训练技术,在各种视觉任务上取得了卓越的性能。

学生-教师网络

从零开始构建 DINO:自监督视觉 Transformer

在DINO模型中,学生-教师网络是实现无需标记数据的自监督学习的核心机制。这个框架涉及两个网络:学生网络和教师网络。两个网络都是视觉变换器,它们被设计用来通过将图像处理为序列块来处理图像,类似于变换器处理文本序列的方式。

学生网络的任务是从输入图像中学习生成有意义的表示。另一方面,教师网络提供目标表示,学生网络旨在匹配这些表示。教师网络不是一个静态实体;它通过逐渐整合学生网络的参数随时间演变。这是通过一种称为指数移动平均的技术完成的,其中教师的参数被更新为其当前参数和学生参数的加权平均值。

目标是最小化学生表示和教师表示之间的差异,这些表示是针对相同增强图像视图的。这通常是通过使用一个损失函数来实现的,该函数鼓励学生和教师输出之间的对齐,同时确保不同图像的表示保持不同。

通过根据学生网络的学习进度不断更新教师网络,并训练学生网络以匹配教师的输出,DINO有效地利用了两个网络的优势。教师网络为学生提供了稳定和一致的目标,而学生网络推动了学习过程。这种协作设置允许模型在无需手动标签的情况下从数据中学习强大和不变的特征,从而实现有效的自监督学习。

学生和教师的增强输入

在DINO模型中,X1和X2(见上图)指的是同一原始图像X的不同增强视图。这些视图分别用作学生和教师网络的输入。目标是让学生网络学习在这些增强下产生一致的表示。学生和教师模型根据以下策略接收不同的增强:

  • 全局裁剪:从原始图像创建两个全局裁剪。这些是覆盖图像大部分的较大裁剪,通常与原始图像有很高的重叠。除了其他增强(如颜色抖动、高斯模糊、翻转等)之外。
  • 局部裁剪:除了全局裁剪外,教师网络还接收几个局部裁剪。这些是关注图像不同部分的较小裁剪,捕捉更多局部细节。

我们将如何为参数图像定义这些增强,这些图像包含我们在训练期间想要转换的一批图像。

# These augmentations are defined exactly as proposed in the paper
def global_augment(images):
    global_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),  # Larger crops
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),  # Color jittering
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return torch.stack([global_transform(img) for img in images])

def multiple_local_augments(images, num_crops=6):
    size = 96  # Smaller crops for local
    local_transform = transforms.Compose([
        transforms.RandomResizedCrop(size, scale=(0.05, 0.4)),  # Smaller, more concentrated crops
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),  # Same level of jittering
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    # Apply the transformation multiple times to the same image
    return torch.stack([local_transform(img) for img in images])

蒸馏损失

在这里,我们希望使用某种距离度量来计算学生输出和教师输出之间的损失。我们这样做:

  • 获取教师预测输出的中心化Softmax,然后应用锐化。
  • 获取学生的Softmax预测,然后应用锐化。
def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
        """
        Calculates distillation loss with centering and sharpening (function H in pseudocode).
        """
        # Detach teacher output to stop gradients.
        teacher_output = teacher_output.detach()

        # Center and sharpen teacher's outputs
        teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)

        # Sharpen student's outputs
        student_probs = F.log_softmax(student_output / tau_s, dim=1)

        # Calculate cross-entropy loss between students' and teacher's probabilities.
        loss = - (teacher_probs * student_probs).sum(dim=1).mean()
        return loss
  • 中心化:中心化教师的输出确保学生模型更多地关注教师输出分布中最显著的特征或区别。通过中心化分布,鼓励学生更多地关注对准确预测至关重要的显著特征,而不是受数据中的变化或偏差的影响。这有助于更有效的知识传递,并可能导致学生模型的性能提高。
  • 锐化:锐化涉及放大数据分布中的特定特征,旨在强调教师模型突出的区分。这个过程使学生模型能够专注于学习教师预测中存在的复杂细节,这对于在数据集上准确复制其输出至关重要。

训练DINO模型

从零开始构建 DINO:自监督视觉 Transformer

阐明DINO伪代码的图像,取自官方论文

有3个重要的步骤需要强调:

(1) 获取学生和教师架构的不同输入(x1,x2)的增强。

(2) 我们之前讨论的蒸馏损失函数,注意它是如何计算不同增强输入的架构的蒸馏损失的,即gs({x1, x2})和gt({x1, x2})。

(3) 更新(a)学生参数(b)教师参数和(c)中心。这里的关键是我们对更新教师参数执行指数移动平均更新。

  • 教师参数:EMA应用于教师模型的参数。而不是在每次训练迭代中直接更新教师参数,EMA随时间维护这些参数的移动平均值。这个移动平均值作为教师模型的更平滑、更稳定的表示,可以帮助指导学生模型的训练。
  • 中心:此外,在DINO的一些实现中,EMA也用于更新中心。中心代表教师输出分布的平均值,用于归一化目的。通过应用EMA更新中心,它在整个训练过程中逐渐演变,为归一化提供更稳定的参考点。

DINO模型

class DINO(nn.Module):
    def __init__(self, student_arch: Callable, teacher_arch: Callable, device: torch.device):
        """
        Args:
            student_arch (nn.Module): ViT Network for student_arch
            teacher_arch (nn.Module): ViT Network for teacher_arch
            device: torch.device ('cuda' or 'cpu')
        """
        super(DINO, self).__init__()
    
        self.student = student_arch().to(device)
        self.teacher = teacher_arch().to(device)
        self.teacher.load_state_dict(self.student.state_dict())

        # Initialize center as buffer to avoid backpropagation
        self.register_buffer('center', torch.zeros(1, student_arch().output_dim))

        # Ensure the teacher parameters do not get updated during backprop
        for param in self.teacher.parameters():
            param.requires_grad = False

    @staticmethod
    def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
        """
        Calculates distillation loss with centering and sharpening (function H in pseudocode).
        """
        # Detach teacher output to stop gradients.
        teacher_output = teacher_output.detach()

        # Center and sharpen teacher's outputs
        teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)

        # Sharpen student's outputs
        student_probs = F.log_softmax(student_output / tau_s, dim=1)

        # Calculate cross-entropy loss between student's and teacher's probabilities.
        loss = - (teacher_probs * student_probs).sum(dim=1).mean()
        return loss

    def teacher_update(self, beta: float):
        for teacher_params, student_params in zip(self.teacher.parameters(), self.student.parameters()):
            teacher_params.data.mul_(beta).add_(student_params.data, alpha=(1 - beta))

为了更新教师的参数,我们使用论文中提出公式,即gt.param = gt.param*beta + gs.param*(1 — beta),其中beta是移动平均衰减,gt、gs分别是相应的教师和学生架构。

进一步,我们在__init__下看到,教师的参数已设置为“required_grads = False”,因为我们不希望在反向传播期间更新它们,而是应用移动平均更新。

此外,在PyTorch中将变量初始化为bugger是一种常见方法,用于将其保持在梯度图之外,并不参与反向传播。

Dino模型进一步需要如下调用

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dino = DINO(ViT(), ViT(), device)

在这里,我们传递学生和教师架构,这不过是标准的视觉变换器,即ViT-B/16或ViT-L/16,正如第一篇论文中提出的。

最终训练

现在可以将整个实现放入训练循环中,正如论文中提出的。

def train_dino(dino: DINO,
               data_loader: DataLoader,
               optimizer: Optimizer,
               device: torch.device,
               num_epochs,
               tps=0.9,
               tpt= 0.04,
               beta= 0.9,
               m= 0.9,
               ):
        """
        Args:
        dino: DINO Module
        data_loader (nn.Module): Dataloader for training
        optimizer (nn.optimizer): Optimizer for optimization (SGD etc.)
        defice (torch.device): 'cuda', 'cpu'
        num_epochs: Number of Epochs
        tps (float): tau for sharpening student logits
        tpt: for sharpening teacher logits
        beta (float): moving average decay 
        m (float): center moveing average decay
        """
    
        for epoch in range(num_epochs):
            print(f"Epoch: {epoch+1}/{len(num_epochs)}")
            for x in data_loader:

                x1, x2 = global_augment(x), multiple_local_augments(x)  

                student_output1, student_output2 = dino.student(x1.to(device)), dino.student(x2.to(device))
                with torch.no_grad():
                    teacher_output1, teacher_output2 = dino.teacher(x1.to(device)), dino.teacher(x2.to(device))

                # Compute distillation loss
                loss = (dino.distillation_loss(teacher_output1, student_output2, dino.center, tps, tpt) +
                        dino.distillation_loss(teacher_output2, student_output1, dino.center, tps, tpt)) / 2

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update the teacher network parameters
                dino.teacher_update(beta)
                
                # Update the center
                with torch.no_grad():
                    dino.center = m * dino.center + (1 - m) * torch.cat([teacher_output1, teacher_output2], dim=0).mean(dim=0)

(1) 我们用不同的全局和局部增强计算x1和x2。

(2) 之后,我们根据论文中提出的,为学生和教师模型获取输出,回想上面的算法循环图。

(3) 在这里,我们将torch设置为no_grad()函数,以确保教师的参数不会通过反向传播更新。

(4) 最后,我们再次根据论文中提出的方法计算蒸馏损失。

(5) 在蒸馏损失中,我们首先中心化教师模型的输出,这样学生模型就不容易崩溃,也不会只学习不重要的特征,或者比另一个特征更多地学习一个特征,而是专注于从教师模型中学习最独特和潜在的特征。

(6) 然后我们锐化特征,以便在计算损失时,我们现在能够比较两个特征(学生和教师的)具有非常不同的数据分布,这意味着锐化后,更重要的特征会被锐化,而不太重要的特征则不会,这将创建一个更独特的特征图,使学生更容易学习。

(7) 然后我们执行反向传播并执行optimizer.step(),更新学生模型并通过之前实现的指数移动平均更新教师网络。

(8) 作为最后一步,我们将再次将torch设置为no_grad()并通过移动平均更新中心。我们根据教师的输出更新中心,因此它与训练过程中输出数据分布的变化保持一致。

就这样,这就是如何从零开始训练DINO模型。到目前为止,在视觉变换器系列中,我们已经实现了标准的ViT、Swin、CvT、Mae和DINO(自监督)。希望你喜欢阅读这篇文章。

# Create your own CustomDataset and dataloader
dataloader = DataLoader(CustomDataset, batch_size=32, shuffle=True)
optimizer = torch.optim.AdamW(dino.parameters(), lr=1e-4)
train_dino(dino,
           DataLoader=dataloader,
           Optimizer=optimizer,
           device=device,
           num_epochs=300,
           tps=0.9,
           tpt= 0.04,
           beta= 0.9,
           m= 0.9)

相关资讯

提升 YOLO 模型:使用 Albumentations 进行高级数据增强

在计算机视觉领域迅速发展的今天,YOLO(You Only Look Once)模型已成为实时目标检测任务的热门选择。 从自动驾驶到视频监控,YOLO模型因其速度和准确性而表现出色。 然而,与任何机器学习模型一样,训练数据的质量极大地影响着它们的性能。

首个通用双向Adapter多模态目标追踪方法BAT,入选AAAI 2024

能够有效实现多模态交叉提示跟踪。目标跟踪是计算机视觉的一项基础视觉任务,由于计算机视觉的快速发展,单模态 (RGB) 目标跟踪近年来取得了重大进展。考虑到单一成像传感器的局限性,我们需要引入多模态图像 (RGB、红外等) 来弥补这一缺陷,以实现复杂环境下全天候目标跟踪。然而,现有的多模态跟踪任务也面临两个主要问题:由于多模态目标跟踪的数据标注成本高,大多数现有数据集规模有限,不足以支持构建有效的多模态跟踪器;因为不同的成像方式在变化的环境中对物体的敏感度不同,开放世界中主导模态是动态变化的,多模态数据之间的主导相关

ICLR 2024 Spotlight|厦门大学、Intel、大疆联合出品,从网络视频中学习零样本图像匹配大模型

图像匹配是计算机视觉的一项基础任务,其目标在于估计两张图像之间的像素对应关系。图像匹配是众多视觉应用如三维重建、视觉定位和神经渲染 (neural rendering) 等的基础和前置步骤,其精确度和效率对于后续处理十分重要。传统算法(SIFT)在面临长基线或极端天气等复杂场景时,其匹配的准确度和密度往往有限。为了解决这些问题,近年来,基于深度学习的匹配模型逐渐流行。然而,由于缺乏大规模且多样化的具有真值标签的训练数据,目前的匹配模型通常是在 ScanNet 和 MegaDepth 上分别训练室内和室外两个模型。这