论文来源
【ICCV 2021 最佳论文】 Swin Transformer: Hierarchical Vision Transformer Using Shifted Windows [Paper] [Code]


研究动机和思路

“We seek to expand the applicability of Transformer such that it can serve as a general-purpose backbone for computer vision, as it does for NLP and as CNNs do in vision.” 我们试图扩展 Transformer 的适用性,使其可以作为计算机视觉任务的通用主干,就像它在 NLP 领域和 CNN 在视觉邻域中所起到的效果。

图像信息建模:如下图所示,ViT 在对图像进行自注意力时,始终在原图 1/16 大小的 patch 上进行,实现图像信息的全局建模。受限于此,**ViT 无法从局部层面提取图像特征,以及无法实现图像多尺度特征的表示**(在密集预测型任务中尤为重要,如图像分割和目标检测)。

时间复杂度:由于标准 Transformer 架构的自注意力计算过程是在 tokentoken 之间进行,因此复杂度极大程度上取决于 token 的数量。
“The global computation leads to quadratic complexity with respect to the number of tokens.” 全局计算复杂度是关于 token 数量的二次复杂度。

通过合并图像 patch 得到的多尺度特征图

如何兼顾局部和全局

Swin Transformer 的实现方式
(1)预处理:将输入图像取成 4×4 (pixel) 的小 patch
(2)Layer L:使用 7×7 (patch)windowpatch 块圈起来,在该 window 内对 7×7=49patch 进行自注意力,实现图像局部特征的建模
(3)Layer L+1:通过滑动 window 使得原本不在一个 window 内的 patch 处于一个 window 内,通过对其进行自注意力实现 cross-window connections
(4)通过步长为 spatch merging 将临近的 小 patch 合并成 patch, 使得整图分辨率下降 1/s,实现多尺度图像特征的提取
(5)当图像尺寸减少至一定程度时,一个 window 能够对整图进行处理,实现图像全局特征的建模

shifted window approach
patch merging(序号仅用于理解)

Swin Transformer

网络架构

以下结合网络架构图和代码推导一下(阅读文字时可将代码块折叠) 👇👇👇:

:区别于 ViT 的一点在于,Swin Transformer 在进行分类任务时没用引入 class token,而是在最后使用 global average pooling (GAP) 得到类别预测的结果,目的在于使得 Swin Transformer 能够很好地兼容到视觉的其他任务中,如图像分割和目标检测。

Patch Embedding

Patch Partition:不妨设输入图像尺寸为 ${H}\times{W}\times{C_{in}}$,Swin Transformerpatch size 设置为 ${4}\times{4}$,则一个 token 的大小为 ${4}\times{4}\times{C_{in}}$,token 序列的长度为 $\frac{H}{4}\times\frac{W}{4}$;因此,整幅图像被转化成了维度为 $(\frac{H}{4}\times\frac{W}{4})\times({4}\times{4}\times{C_{in}})$ 的 token 序列,以 224×224×3 的输入图像为例,其产生的 token 序列的长度为 (56×56)×(16×16×3)(以上过程通过 4×4 卷积层实现)

Linear Embedding:通过 Patch Partition 得到的 token 序列的长度对于 Transformer 模型而言是巨大的,因此需要减少其长度至设定的超参数 $C$;(以上过程通过 Linear 层实现)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Patch Partition + Linear Embedding

class PatchEmbed(nn.Module):

def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]

self.in_chans = in_chans
self.embed_dim = embed_dim

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None

def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x

Hierarchical Stage

stage2 为例(stage3、4 同理),推导一下网络:
(1)Layer Input:输入特征图维度为 $\frac{H}{4}\times\frac{W}{4}\times{C}$;
(2)Patch Merging:经上图右侧所示过程,合并 patch 之后的特征图尺寸减少 1/2 倍,通道数增加 4 倍,即经 patch merging 之后的输出特征图维度为 $\frac{H}{8}\times\frac{W}{8}\times{4C}$;
(3)Channel Reduction:为了保持与卷积神经网络拥有相同的层级表示,进一步通过 Linear 层或 1×1 卷积层(二者作用一致,原文代码用的 Linear 层)将通道数降为 2C,使得最终输出特征图维度为 $\frac{H}{8}\times\frac{W}{8}\times{2C}$;(注:本过程为 Patch Merging 中的步骤)
(3)Swin Transformer Block:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Patch Merging

class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)

def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

x = x.view(B, H, W, C)

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C

x = self.norm(x)
x = self.reduction(x)

return x

Swin Transformer Block

Shifted Window based Multi-head Self-Attention

前面提到,Swin Transformer 通过设置 window,对处于 window 内的 patch 做自注意力。以 stage1 为例,56×56 的 patch 数量,设置 7×7 的 window size,对整图运算则需要的 window 数量为 (56/7)×(56/7)=8×8=64。

计算复杂度分析
(1)标准 Multi-head Self-attention
$$3HWC^{2}+(HW)^{2}C+(HW)^{2}C+HWC^{2}=4HWC^{2}+2(HW)^{2}C, \tag{1}$$
(2)Swin Transformer 中的 Self-attention
$$(\frac{H}{M}\times\frac{W}{M})\times(4MMC^{2}+2(MM)^{2}C)=4HWC^{2}+2M^{2}HWC, \tag{2}$$
将 $(2)$ 式减 $(1)$ 式得 $(HW-M^{2})\times(2HWC)$,确实有 一定程度 的下降。