假设vit的输入图片尺寸是224,将图片分为固定大小的patch,patch大小为patch_size=16x16

则每张图像会生成224x224/16x16=14*14=196个patch, 每个patch的长度是16*16*3=768

这里还需要加上一个特殊字符cls,因此最终的维度是197x768

相当于NLP中一个句子有197个单词,每个单词的embedding dim是768

现在,保持patch size不变,将输入图片尺寸改成288*288

当输入图片尺寸发生变化时,由于每个 patch 的尺寸固定,图片切分出的 patch 数就会发生变化。表现在上述特征图中,就是特征图的尺寸发生了变化。这样一来,我们原本位置编码图的尺寸就和图像特征图的尺寸对不上了,无法进行后续的计算。

找到了问题所在,解决的方法也就顺理成章了。位置编码代表的是 patch 所在位置的附加信息,那么如果和图像特征图的尺寸不匹配,只需要使用双三次插值法(Bicubic)对位置编码图进行插值缩放,缩放到与图像特征图一致的尺寸,就同样可以表现每个 patch 在图片中的位置信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch 
import torch.nn.functional as F

# 原始位置编码(in vit):从1到196,即总共有196个patch,每个patch对应一条维度为768的embedding向量
pos_embed = torch.rand(1, 197, 768)
# 原始图像尺寸下(vit用的224),长和宽方向的 patch 数
src_shape = (14, 14) # vit中的patch数目=14*14=196

# 当输入图片尺寸发生变化时,由于每个 patch 的尺寸固定,图片切分出的 patch 数就会发生变化。
# 输入图像尺寸下(假设输入尺寸是288),长和宽方向的 patch 数,这里假设是18*18=324个
dst_shape = (18, 18) # 输入图片的patch数是18*18=324

# 额外编码数,在 ViT 中,为 1,指 class embedding;在 DeiT 中为 2
num_extra_tokens = 1

# 原始图像尺寸下(vit用的224)
_, L, C = pos_embed.shape #位置编码:[L,C] == [target_word_num,embedding_dim]
src_h, src_w = src_shape #patch数:长和宽方向的 patch 数
# 位置编码第二个维度大小应当等于 patch 数 + 额外编码数
assert L == src_h * src_w + num_extra_tokens

# 拆分额外编码和纯位置编码
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
print(pos_embed.shape)# torch.Size([1, 197, 768])
print(extra_tokens.shape)# torch.Size([1, 1, 768])
print(src_weight.shape)# torch.Size([1, 196, 768])

# 将位置编码组织成 (1, C, H, W) 形式,其中 C 为通道数 ,以便使用torch中的插值interpolate API
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) # [bs,embedding_dim,scr_h,src_w]
print(src_weight.shape)# torch.Size([1, 768, 14, 14])

# 进行双三次插值 ,将vit的位置编码尺寸(14*14)插值到自定义输入图片的位置编码尺寸(18,18)
dst_weight = F.interpolate(src_weight, size=dst_shape, mode='bicubic')
print(dst_weight.shape)# torch.Size([1, 768, 18, 18])

# 重组位置编码为(1,H*W, C)形式,再拼接上额外编码,即获得新的位置编码
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
print(dst_weight.shape)# torch.Size([1, 324, 768])

pos_embed = torch.cat((extra_tokens, dst_weight), dim=1) # 自定义输入图片尺寸对应的位置编码
print(pos_embed.shape)# torch.Size([1, 325, 768])# 共324+1个patch,每个patch对应一条维度为768的embedding向量

参考: