从ViT到Swin:手把手教你理解Transformer在CV中的进化之路(附PyTorch代码解读)

张开发
2026/4/18 21:52:46 15 分钟阅读

分享文章

从ViT到Swin:手把手教你理解Transformer在CV中的进化之路(附PyTorch代码解读)
从ViT到SwinTransformer在计算机视觉中的架构革新与实战解析当Vision TransformerViT首次将自然语言处理领域的Transformer成功迁移到计算机视觉任务时整个CV社区为之振奋。但很快研究者们发现这种暴力移植存在明显的效率瓶颈——尤其是处理高分辨率图像和密集预测任务时平方级增长的计算复杂度让ViT难以成为通用视觉骨干网络。2021年ICCV最佳论文Swin Transformer的诞生通过巧妙的层级窗口注意力设计不仅解决了ViT的核心痛点更开创了视觉Transformer的新范式。1. ViT的突破与局限为何需要架构革新ViT的核心思想简单而优雅将输入图像分割为16×16的图像块patch每个patch经过线性投影后视为一个视觉词元然后直接套用原始Transformer的编码器结构。这种设计在ImageNet分类任务上证明了自注意力机制在视觉领域的潜力但面临三个关键挑战计算复杂度问题标准自注意力机制需要计算所有patch之间的两两关系。对于h×w的patch序列其内存和计算成本为O((h×w)²)。当处理512×512分辨率图像时约1024个patch这会导致难以承受的计算负担。多尺度特征缺失ViT采用单一尺度特征表示缺乏CNN固有的层级结构。这使得它在目标检测、语义分割等需要多尺度特征的任务中表现欠佳。位置编码僵化ViT使用固定的绝对位置编码当测试分辨率与训练不一致时需要通过插值调整位置编码这会导致性能下降。# ViT的patch嵌入层典型实现 class PatchEmbed(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim768): super().__init__() self.proj nn.Conv2d(in_chans, embed_dim, kernel_sizepatch_size, stridepatch_size) def forward(self, x): x self.proj(x) # [B, C, H, W] - [B, E, H/P, W/P] x x.flatten(2).transpose(1, 2) # [B, E, N] where N (H*W)/P² return x提示ViT的全局注意力机制虽然理论上有更大的感受野但在实际应用中远距离像素间的注意力权重往往趋近于零造成计算资源浪费。2. Swin Transformer的核心创新窗口与移位窗口注意力Swin Transformer的突破在于将局部性先验重新引入视觉Transformer通过两个关键设计实现计算效率与模型性能的平衡2.1 非重叠窗口注意力W-MSA将图像划分为M×M默认7×7的非重叠窗口每个窗口内独立计算自注意力。这种设计将计算复杂度从全局的O(N²)降低到窗口级的O(N×M²)其中N为总patch数。模型类型计算复杂度内存消耗适合任务ViTO(N²)高图像分类Swin-W-MSAO(N×M²)中分类/检测Swin-SW-MSAO(2N×M²)中通用视觉2.2 移位窗口注意力SW-MSA为解决窗口间信息隔离问题Swin在相邻Transformer块间交替使用常规窗口和移位窗口向右下角偏移⌊M/2⌋个像素。这种设计实现了跨窗口连接同时保持计算效率。# Swin的窗口注意力实现关键代码 def window_partition(x, window_size): B, H, W, C x.shape x x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows x.permute(0, 1, 3, 2, 4, 5).contiguous() return windows.view(-1, window_size, window_size, C) def shifted_window_attention(x, shift_size0): if shift_size 0: x torch.roll(x, shifts(-shift_size, -shift_size), dims(1, 2)) # 后续进行常规窗口注意力计算 ...注意移位窗口会生成不规则的子窗口Swin通过巧妙的掩码机制和循环位移cyclic shift保持窗口形状一致避免计算复杂度增加。3. 层级特征金字塔视觉任务的通用骨干网络Swin Transformer的另一大创新是构建了类似CNN的层级特征金字塔通过四个阶段的处理逐步下采样并扩展通道维度Patch Merging在每阶段开始时将相邻2×2的小patch合并为一个大patch同时将特征维度翻倍实现4倍下采样。深度可扩展每个阶段包含多个Swin Transformer块不同规模的模型Tiny/Small/Base/Large通过调整块数量和头数实现能力扩展。# Patch Merging操作示例实现 class PatchMerging(nn.Module): def __init__(self, dim): super().__init__() self.reduction nn.Linear(4 * dim, 2 * dim, biasFalse) def forward(self, x): B, H, W, C x.shape x0 x[:, 0::2, 0::2, :] # 左上 x1 x[:, 1::2, 0::2, :] # 左下 x2 x[:, 0::2, 1::2, :] # 右上 x3 x[:, 1::2, 1::2, :] # 右下 x torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C] return self.reduction(x) # [B, H/2, W/2, 2*C]这种设计带来的优势非常明显在COCO目标检测任务中Swin-T比ResNet-50获得4.1 box AP提升在ADE20K语义分割中Swin-S比DeiT-S提升5.3 mIoU在ImageNet-1K分类中Swin-B达到85.2% top-1准确率4. 工程实践Swin Transformer的PyTorch实现要点在实际项目中部署Swin Transformer时有几个关键技术细节需要特别注意4.1 相对位置偏置的高效实现Swin采用可学习的相对位置偏置代替绝对位置编码每个头维护一个(2M-1)×(2M-1)的偏置矩阵通过查表方式应用到注意力得分class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.relative_position_bias_table nn.Parameter( torch.zeros((2*window_size-1)**2, num_heads)) # 初始化相对位置索引 coords torch.stack(torch.meshgrid( [torch.arange(window_size), torch.arange(window_size)])) coords_flatten torch.flatten(coords, 1) relative_coords coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] window_size - 1 relative_coords[:, :, 1] window_size - 1 relative_coords[:, :, 0] * 2 * window_size - 1 relative_position_index relative_coords.sum(-1) self.register_buffer(relative_position_index, relative_position_index) def forward(self): bias self.relative_position_bias_table[ self.relative_position_index.view(-1)].view( self.window_size**2, self.window_size**2, -1) return bias.permute(2, 0, 1).unsqueeze(0)4.2 内存优化技巧共享Key矩阵同一窗口内所有查询共享相同的Key集合大幅减少内存访问梯度检查点在训练深层模型时对部分块启用梯度检查点节省显存混合精度训练使用AMP自动混合精度加速训练过程4.3 自定义数据加载策略对于高分辨率视觉任务如检测/分割建议实现智能批处理动态填充Dynamic Padding将同一批次图像填充到相同尺寸批内重缩放Intra-batch Rescaling保持长宽比的同时缩放图像# 示例Swin兼容的数据增强管道 train_transform transforms.Compose([ transforms.RandomResizedCrop(224, scale(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]), ])在COCO数据集上的实际测试表明Swin Transformer相比传统CNN骨干网络在保持相似计算量的情况下能够获得显著的性能提升模型参数量FLOPsAP^boxAP^maskResNet-5044M260G41.037.1Swin-T48M264G45.141.5ResNet-10163M336G43.539.7Swin-S69M354G48.544.2这种架构革新使得Transformer真正成为计算机视觉领域的通用骨干网络为后续的许多工作如MViT、CSWin等奠定了基础。在实际项目中选择Swin Transformer时需要根据任务复杂度从Tiny到Large不同规模中进行权衡——对于实时性要求高的应用Swin-T往往能在速度和精度间取得良好平衡而对于追求极致性能的场景Swin-L或更大的变体可能更为合适。

更多文章