假设vit的输入图片尺寸是224,将图片分为固定大小的patch,patch大小为patch_size=16x16
则每张图像会生成224x224/16x16=14*14=196个patch, 每个patch的长度是16*16*3=768
这里还需要加上一个特殊字符cls,因此最终的维度是197x768
相当于NLP中一个句子有197个单词,每个单词的embedding dim是768
现在,保持patch size不变,将输入图片尺寸改成288*288
当输入图片尺寸发生变化时,由于每个 patch 的尺寸固定,图片切分出的 patch 数就会发生变化。表现在上述特征图中,就是特征图的尺寸发生了变化。这样一来,我们原本位置编码图的尺寸就和图像特征图的尺寸对不上了,无法进行后续的计算。
找到了问题所在,解决的方法也就顺理成章了。位置编码代表的是 patch 所在位置的附加信息,那么如果和图像特征图的尺寸不匹配,只需要使用双三次插值法(Bicubic)对位置编码图进行插值缩放,缩放到与图像特征图一致的尺寸,就同样可以表现每个 patch 在图片中的位置信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import torch import torch.nn.functional as F
pos_embed = torch.rand(1, 197, 768)
src_shape = (14, 14)
dst_shape = (18, 18)
num_extra_tokens = 1
_, L, C = pos_embed.shape src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens
extra_tokens = pos_embed[:, :num_extra_tokens] src_weight = pos_embed[:, num_extra_tokens:] print(pos_embed.shape) print(extra_tokens.shape) print(src_weight.shape)
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) print(src_weight.shape)
dst_weight = F.interpolate(src_weight, size=dst_shape, mode='bicubic') print(dst_weight.shape)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) print(dst_weight.shape)
pos_embed = torch.cat((extra_tokens, dst_weight), dim=1) print(pos_embed.shape)
|
参考: