搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(二)
极市导读
本文为详细解读Vision Transformer的第二篇,主要包括三个方向的分类:可变形的Transformer ,用于分类任务的Transformer ,用于底层视觉任务的Transformer,分别对应了三篇相关论文。附有超详细的代码解读。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(一)
目录
(每篇文章对应一个Section,目录持续更新。)
Section 1
1 一切从Self-attention开始
1.1 处理Sequence数据的模型
1.2 Self-attention
1.3 Multi-head Self-attention
1.4 Positional Encoding2 Transformer的实现和代码解读 (NIPS2017)
(来自Google Research, Brain Team)
2.1 Transformer原理分析
2.2 Transformer代码解读3 Transformer+Detection:引入视觉领域的首创DETR (ECCV2020)
(来自Facebook AI)
3.1 DETR原理分析
3.2 DETR代码解读
Section 2
4 Transformer+Detection:Deformable DETR:可变形的Transformer (ICLR2021) (来自商汤代季峰老师组)
4.1 Deformable Convolution原理分析
4.2 Deformable Convolution代码解读
4.3 Deformable DETR原理分析
4.4 Deformable DETR代码解读5 Transformer+Classification:用于分类任务的Transformer (ICLR2021)
(来自Google Research, Brain Team)
5.1 ViT原理分析
5.2 ViT代码解读6 Transformer+Image Processing:IPT:用于底层视觉任务的Transformer
(来自北京华为诺亚方舟实验室)
6.1 IPT原理分析
Section 3
7 Transformer+Distillation:DeiT:高效图像Transformer
(来自Facebook AI)
9.1 DeiT原理分析8 Transformer+GAN:VQGAN:实现高分辨率的图像生成
(来自德国海德堡大学)
8.1 VQGAN原理分析
8.2 VQGAN代码解读9 Transformer+多模态:CLIP
(来自OpenAI)
7.1 CLIP原理分析
4 Transformer+Detection:Deformable DETR:可变形的Transformer (ICLR2021)
论文名称:Deformable DETR: Deformable Transformer For End-To-End Object Detection
4.1 Deformable Convolution原理分析:
式4.3是怎么推导出的呢?先拿最简单的例子来做说明,如下图2所示,假设feature map只有4个点,则其中插入一个点 的值可以用式4.3来得到,这就是双线形插值的标准公式,对于相邻的点来说 。
对物体的形变和尺度建模的能力比较强。
感受野比一般卷积大很多,因为有偏移的原因,实际上相关实验已经表明了DNN网络很多时候受感受野不足的条件制约;但是一般的空洞卷积空洞是固定的,对不同的数据集不同情况可能最适合的空洞大小是不同的,但是可形变卷积的偏移是可以根据具体数据的情况进行学习的。
4.2 Deformable DETR代码解读:
Function的定义很直接:
定义DeformConvFunction这个函数。
import DCN
class DeformConvFunction(Function):
@staticmethod
def forward(ctx, input, offset, weight, bias,
stride, padding, dilation, group, deformable_groups, im2col_step):
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.kernel_size = _pair(weight.shape[2:4])
ctx.group = group
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
output = DCN.deform_conv_forward(input, weight, bias,
offset,
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.stride[0], ctx.stride[1],
ctx.padding[0], ctx.padding[1],
ctx.dilation[0], ctx.dilation[1],
ctx.group,
ctx.deformable_groups,
ctx.im2col_step)
ctx.save_for_backward(input, offset, weight, bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight, bias = ctx.saved_tensors
grad_input, grad_offset, grad_weight, grad_bias = \
DCN.deform_conv_backward(input, weight,
bias,
offset,
grad_output,
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.stride[0], ctx.stride[1],
ctx.padding[0], ctx.padding[1],
ctx.dilation[0], ctx.dilation[1],
ctx.group,
ctx.deformable_groups,
ctx.im2col_step)
return grad_input, grad_offset, grad_weight, grad_bias,\
None, None, None, None, None, None
注意这里最重要的是:import DCN,那DCN这个包的内部是什么?
setup(
name="DCN",
version="1.0",
author="xvjiarui",
url="https://github.com/charlesshang/DCNv2",
description="deformable convolutional networks",
packages=find_packages(exclude=("configs", "tests",)),
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
CUDAExtension(
"DCN",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
source: 所有的 /src下的.cpp文件,/src/cpu下的.cpp文件,/src/cuda下的.cpp文件的并集。
include_dirs:/src。
具体的操作定义在了src\cuda\deform_conv_cuda.cu的deform_conv_cuda_forward和deform_conv_cuda_backward里面。
定义好了Function:DeformConvFunction以后,接下来是封装nn.module:
变成一个卷积层的形式。
from functions.deform_conv_func import DeformConvFunction
class DeformConv(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, stride, padding, dilation=1, groups=1, deformable_groups=1, im2col_step=64, bias=True):
super(DeformConv, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels {} must be divisible by groups {}'.format(in_channels, groups))
if out_channels % groups != 0:
raise ValueError('out_channels {} must be divisible by groups {}'.format(out_channels, groups))
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.im2col_step = im2col_step
self.use_bias = bias
self.weight = nn.Parameter(torch.Tensor(
out_channels, in_channels//groups, *self.kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
if not self.use_bias:
self.bias.requires_grad = False
def reset_parameters(self):
n = self.in_channels
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input, offset):
assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
offset.shape[1]
return DeformConvFunction.apply(input, offset,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.deformable_groups,
self.im2col_step)
_DeformConv = DeformConvFunction.apply
注意,在DeformConvFunction.apply过程中,所需要的input和offset是forward()方法提供的,而其他的诸如weight,bias,stride等的参数均来自类内定义的变量。
把偏移量offset打包进去:
class DeformConvPack(DeformConv):
def __init__(self, in_channels, out_channels,
kernel_size, stride, padding,
dilation=1, groups=1, deformable_groups=1, im2col_step=64, bias=True, lr_mult=0.1):
super(DeformConvPack, self).__init__(in_channels, out_channels,
kernel_size, stride, padding, dilation, groups, deformable_groups, im2col_step, bias)
out_channels = self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset = nn.Conv2d(self.in_channels,
out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.conv_offset.lr_mult = lr_mult
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, input):
offset = self.conv_offset(input)
return DeformConvFunction.apply(input, offset,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.deformable_groups,
self.im2col_step)
4.3 Deformable DETR原理分析:
训练时间极长: 相比于已有的检测器,DETR需要更久的训练才能达到收敛(500 epochs),比Faster R-CNN慢了10-20倍。 计算复杂度高:发现DETR对小目标的性能很差,现代许多种检测器通常利用多尺度特征,从高分辨率(High Resolution)的特征图中检测小物体。但是高分辨率的特征图会大大提高DETR复杂度。
在初始化阶段, 对于特征图中的所有pixel的权重是Uniform的,导致要学习的注意力权重集中在稀疏的有意义的位置这一过程需要很长时间,意思是 从Uniform到Sparse and meaningful需要很久。 是 的,在图像领域我们一般认为 所以里面的weights的计算是像素点数目的平方。因此,处理高分辨率特征图需要非常高的计算量,存储也很复杂。
图9的这2张图为 的2种形式,我们可以与之前的 做个比较看看二者有何不同,我们发现:
的 Attention 矩阵不是由 和 作内积得到的,而是由输入特征 直接通过 Linear Transformation 得到的。
以上面的图为例:
我们假设输入特征 的维度是 ,它与几个转移矩阵相乘,得到一个大的张量,我们在图中把这个东西分开画成了 ,它们的维度都是 。其中, 代表Attention,一会要与 做weighted sum得到输出。
代表相对参考点的偏移量, 可以看做是 个向量,对于Encoder来讲, ,即特征图的每个点都是一个向量, 表示的就是特征图上的某一点 对应的 的位置,因为 的维度是 维的,所以有 个位置可以对应,我们需要的就是其中的 个位置,那具体是哪 个位置,就是由 计算得到的。
同时,输入特征 再与转移矩阵 相乘得到 。结合刚才计算的 ,我们需要为 维的 中的每一个采样 个分量。
所以采样之后的 , 个head的 就是 。我们把它分成 组,每组的 。
接下来我们使用 中的每一行分别与每一组 做weighted sum,并把结果拼在一起,得到 的输出,最后作Linear Transformation把这些head的输出合并为一个输出 。
再以下面的图为例:先对计算 的输入特征 采样成 组,每组的采样后的输入特征为 。然后,对这 组的输入特征乘以transformation matrix得到 组,每组的 。所以 还是 。
计算 的过程与上方的图一致,最后得到的输出 。
我们再计算下 的计算复杂度,如上图3所示:
计算 的计算复杂度:
取二者的极小值。
计算 的计l算复杂度:
计算 的计算复杂度:
计算 的计算复杂度:
计算多个 变成一个 的 的计算复杂度:
计算 的计算复杂度: ,其中的 来自双线性插值。
所以一共是:
由于在实验中 ,所以 $5K+3MK<c$< p=""></c$<>
所以总的计算复杂度是:
当用在 里面时: ,最终复杂度是 。
当用在 里面时: ,最终复杂度是 ,与特征图大小 无关。
大多数目标检测框架受益于多尺度特征图,而 可以自然地扩展到多尺度特征图中。
令 为这 个feature map的特征, 。
为每个的参考点的归一化坐标, 和 分别代表左上和右下的点
可以写为:
式中, 代表head, 代表feature map的level。 代表每个的采样点。 和 表示 特征level和 head中 采样点的采样偏移和注意权重。且满足:
把归一化坐标重新缩放到 的输入特征图。
与 非常相似,只是它从多尺度特征图中采样 个点,而不是从单尺度特征图中采样 点,而且你会发现 也变小了。对于一个 ,每一层采集 个点作为 ,转换成,对一个 ,所有层均采 个点,融合了不同层的特征,故不再需要FPN。 这里正则化是针对一个 ,所有 个位置的贡献进行 。
特殊情况:
当 时, 的每个特征点只采样1次,并伴随一段位移,之后与 相乘求和,整个过程就退化为了 。而前者是查看来自多尺度输入的多个采样点。
所以有下图10的对应关系:
Deformable DETR的整体架构
如上图11所示为Deformable DETR的整体架构。
Deformable Transformer Encoder:
将transformer中处理特征的部分都做替换,即所有的attention模块都使用了 。Encoder的输入输出均为多尺度feature map,保持相同的分辨率。
首先使用上图12右侧所示的多尺度特征,一共是有4种scale的特征,表示为 ,并把它输入到Transformer的Encoder里面去,所有的特征都是256 channels。注意,Deformable DETR中没有使用FPN这样的自上而下的结构,因为多尺度可变形注意本身可以在多尺度特征图之间交换信息。对于每一个Query pixel,2-D坐标归一化的参考点就是它自己。
作者还给feature representation加了 ,随机初始化,并随网络一起训练。这里不使用FPN, 因为每一层的query聚合了所有层key的特征。
Deformable Transformer Decoder:
Decoder中有 和 这2种模块,它们的 都来自Object queries, 的 来自Encoder的输出,Object queries从encoder输出的feature map中提取特征。 的 来自Object queries,Object queries彼此交互, 与 相同。
因为 被设计用于处理卷积特征图作为关键元素,所以只把它用来替换掉 ,而不替换 。
对于 的每个Object queries,2-D坐标归一化的参考点 来自其Object queries经过一个 。
为了降低训练难度,作者令检测头 的输出是相对于参考点的偏移量。
比如参考点的坐标是 ,检测头 的输出是 ,则最终的检测结果可以表示为:
我们发现最终输出了一个检测框的归一化值 。
下图13为 和 的变量shape:
Deformable DETR后续改进1:Iterative Bounding Box Refinement
假设Transformer的Decoder Layer有6个,作者希望每个Decoder Layer都可以基于来自前一层的预测来refine 。
假设 Decoder Layer的输出为 ,要求 预测的结果是 ,这样,结合以上二者有:
初始值设置为:
对比式(4.5)可以发现差异:现在预测的不是相对于参考点的差值,而是相对于上一个Decoder Layer的输出的差值。
Deformable DETR后续改进2:Two-Stage Deformable DETR
在原版的DETR中,Object queries 与当前图片无关。作者设计了2阶段的Deformable DETR,第1个阶段输出一些 。
这些 作为Object Queries被送入 中,这是第2阶段,输出为Refinement之后的结果。
在第一阶段,为了使得 实现高召回率,多尺度特征图中的每个pixels都将作为Object Queries。但是这样做会导致计算和存储成本过大。
为了避免这个问题,作者删除了解码器,只使用 ,并将其输出作为 。其中,多尺度特征图中的每个Pixels都将作为Object Queries,它直接预测一个边界框。得分最高的边界框被选为 。在将 提交第2阶段之前,不应用NMS。
在得到第一阶段输出的特征以后,为了得到 ,就需要一个检测头(Detection head, 3层 ,输出回归坐标和1个二分类(前景/背景) )。
假设 代表Pixel的索引,即 ,参考点 ,则这个pixel对应的Bounding Box为:
式中 是由 的回归分支输出。
在第二阶段, 给定第一阶段中的 ,最高得分的被挑选作为 。在第二阶段,这些 作为被馈送到 ,作为Iterative Bounding Box Refinement的初始值。
Experiments:
Setting:默认 ,基本设置和DETR一致,Loss上面增加了Focal Loss,Object Queries的数量从100增加到了300。数据集COCO 2017。
作者对DETR-DC5也进行了上述变化,以进行公平比较,并命名为DETR-DC5+。
实验1:Deformable DETR结果
需要的epoches减少了10倍。 使用上iterative bounding box refinement 和 two-stage paradigm 之后可以进一步提升性能。 DETR-DC5的速度问题主要是由于Transformer attention的大量内存访问。可变形注意可以减轻这个问题。
实验2:在不同setting下Deformable DETR的一些对比实验
Multi-scale inputs:提升1.7% ,2.9% 。 Multi-scale attention:提升1.5% 。 FPN:无性能提升。 :退化为可变形卷积,性能大幅下降。
实验3:与SOTA对比
4.4 Deformable DETR代码解读:
代码来自:
fundamentalvision/Deformable-DETRgithub.com
对Deformable DETR代码与DETR的代码的实现做了一下对比:
Backbone、 Matcher 和 positional encoding 的实现和 DETR是一样的。
主要的修改在 deformable_detr.py 和 deformable_transformer.py 中,而这两个.py文件中调用的最重要的函数来自 models\ops\functions\ms_deform_attn_func.py 和 models\ops\modules\ms_deform_attn.py 这2个.py文件。
真正对性能提升有效的应该主要在于 ,毕竟 multi-scale / pyramid representation, deformable 和 attention 本身就是 CV 里最 work 的几类 idea。
下面详细介绍下 的实现。
的实现方法与 的比较相似,讲解的顺序还是:
Function → Module → deformable_transformer.py → deformable_detr.py
Function的定义很直接:
定义MSDeformAttnFunction这个函数。
class MSDeformAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = MSDA.ms_deform_attn_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
MSDA.ms_deform_attn_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
grad_value, grad_sampling_loc, grad_attn_weight分别为输入的偏导,对偏移量的偏导,对注意力权重的偏导。
定义Module:
类内变量定义:
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
这里值得注意的是sampling_offsets的维度是: ,因为偏移量有 和 这2个方向。
attention_weights的维度是: 。
value_proj就是 。
output_proj就是 。
参数初始化:
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
,attention的weights都使用xavier_uniform_初始化,bias都初始化为0。
sampling_offsets的weights初始化为0,bias使用sincos初始化。
前向传播:
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output
要给MSDeformAttnFunction这个Function传入:
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step
这几个变量,各个变量的定义如代码中注释所示。
value的维度是:(N, Len_in, self.n_heads, self.d_model // self.n_heads)。
偏移量的维度:(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)。
attention weight的维度:(N, Len_q, self.n_heads, self.n_levels, self.n_points)。
reference_points和sampling_offsets相加得到sampling_locations。
Backbone的定义没变:
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool):
super().__init__()
for name, parameter in backbone.named_parameters():
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)
if return_interm_layers:
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
self.strides = [8, 16, 32]
self.num_channels = [512, 1024, 2048]
else:
return_layers = {'layer4': "0"}
self.strides = [32]
self.num_channels = [2048]
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out
还是把Backbone和位置编码Sequential起来:
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
self.strides = backbone.strides
self.num_channels = backbone.num_channels
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in sorted(xs.items()):
out.append(x)
# position encoding
for x in out:
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
Transformer的一个Encoder Layer:
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
# self attention
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
依旧遵循Attention,Add and Norm,FFN的流程。
注意需要传入的输入变量是:
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
有了一个Encoder Layer的定义,再看Transformer的整个Encoder:
class DeformableTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
output = src
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
for _, layer in enumerate(self.layers):
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
return output
get_reference_points为获取所有的参考点:
Transformer的一个Decoder Layer:
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
# self attention
#query,key的输入是object queries(query_pos) + Decoder的输入(tgt),shape都是(300,b,256)
#value的输入是Decoder的输入(tgt),shape = (300,b,256) all 0
q = k = self.with_pos_embed(tgt, query_pos)
#q, k: (300,b,256)
#q,k除了tgt以外都有query_pos,只有v只有tgt。
#self.self_attn需要的维度是:(N, Length_{query}, C) = (b,300,256)
tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) # (300,b,256)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# cross attention
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
reference_points,
src, src_spatial_shapes, level_start_index, src_padding_mask)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
整体Decoder定义:
class DeformableTransformerDecoder(nn.Module):
def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
query_pos=None, src_padding_mask=None):
output = tgt
# reference_points (bs, \sum_{H_×W_}, 2)
# src_valid_ratios (bs, 4, 2)
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
# (bs, \sum_{H_×W_}, 1, 4) * (bs, 1, 4, 4)
reference_points_input = reference_points[:, :, None] \
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] #(bs, \sum_{H_×W_}, 2)
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
# hack implementation for iterative bounding box refinement
if self.bbox_embed is not None:
tmp = self.bbox_embed[lid](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
return output, reference_points
5 Transformer+Classification:用于分类任务的Transformer(ICLR2021)
论文名称:An Image is Worth 16x16 Words:Transformers for Image Recognition at Scale
论文地址:
https://arxiv.org/abs/2010.11929
5.1 ViT原理分析:
这个工作本着尽可能少修改的原则,将原版的Transformer开箱即用地迁移到分类任务上面。并且作者认为没有必要总是依赖于CNN,只用Transformer也能够在分类任务中表现很好,尤其是在使用大规模训练集的时候。同时,在大规模数据集上预训练好的模型,在迁移到中等数据集或小数据集的分类任务上以后,也能取得比CNN更优的性能。下面看具体的方法:
图片预处理:分块和降维
这个工作首先把的图像,变成一个 的sequence of flattened 2D patches。它可以看做是一系列的展平的2D块的序列,这个序列中一共有 个展平的2D块,每个块的维度是 。其中 是块大小, 是channel数。
注意作者做这步变化的意图: 根据我们 之前的讲解,Transformer希望输入一个二维的矩阵 ,其中 是sequence的长度, 是sequence的每个向量的维度,常用256。
所以这里也要设法把 的三维图片转化成 的二维输入。
所以有: 。
其中, 是Transformer输入的sequence的长度。
代码是:
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
具体是采用了einops库实现,具体可以参考这篇博客。
科技猛兽:PyTorch 70.einops:优雅地操作张量维度
https://zhuanlan.zhihu.com/p/342675997
现在得到的向量维度是: ,要转化成 的二维输入,我们还需要做一步叫做Patch Embedding的步骤。
Patch Embedding
方法是对每个向量都做一个线性变换(即全连接层),压缩后的维度为 ,这里我们称其为 Patch Embedding。
这个全连接层就是上式(5.1)中的 ,它的输入维度大小是 ,输出维度大小是 。
# 将3072变成dim,假设是1024
注意这里的绿色字体 ,假设切成9个块,但是最终到Transfomer输入是10个向量,这是人为增加的一个向量。
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)
为什么要追加这个向量?
如果没有这个向量,假设 个向量输入Transformer Encoder,输出9个编码向量,然后呢?对于分类任务而言,我应该取哪个输出向量进行后续分类呢?
不知道。干脆就再来一个向量 ,这个向量是可学习的嵌入向量,它和那9个向量一并输入Transfomer Encoder,输出1+9个编码向量。然后就用第0个编码向量,即 的输出进行分类预测即可。
这么做的原因可以理解为:ViT其实只用到了Transformer的Encoder,而并没有用到Decoder,而 的作用有点类似于解码器中的 的作用,相对应的 就是其他9个编码向量的输出。
是一个可学习的嵌入向量,它的意义说通俗一点为:寻找其他9个输入向量对应的 的类别。
代码为:
# dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# forward前向代码
# 变成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 跟前面的分块进行concat
# 额外追加token,变成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)Positional Encoding
按照Transformer的位置编码的习惯,这个工作也使用了位置编码。引入了一个 Positional encoding 来加入序列的位置信息,同样在这里也引入了pos_embedding,是用一个可训练的变量。
没有采用原版Transformer的 编码,而是直接设置为可学习的Positional Encoding,效果差不多。对训练好的pos_embedding进行可视化,如下图所示。我们发现,位置越接近,往往具有更相似的位置编码。此外,出现了行列结构;同一行/列中的patch具有相似的位置编码。
# num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
Transformer Encoder的前向过程
其中,第1个式子为上面讲到的Patch Embedding和Positional Encoding的过程。
第2个式子为Transformer Encoder的 的过程,重复 次。
第3个式子为Transformer Encoder的 的过程,重复 次。
作者采用的是没有任何改动的transformer。
最后是一个 的 ,整个的结构只有这些,如下图所示,为了方便读者的理解,我把变量的维度变化过程标注在了图中。
训练方法:
先在大数据集上预训练,再迁移到小数据集上面。做法是把ViT的 去掉,换成一个 的 。其中 为对应数据集的类别数。
当输入的图片是更大的shape时,patch size 保持不变,则 会增大。
ViT可以处理任意 的输入,但是Positional Encoding是按照预训练的输入图片的尺寸设计的,所以输入图片变大之后,Positional Encoding需要根据它们在原始图像中的位置做2D插值。
最后,展示下ViT的动态过程:
ViT的动态过程
https://www.zhihu.com/zvideo/1336723494219350016
Experiments:
预训练模型使用到的数据集有:
ILSVRC-2012 ImageNet dataset:1000 classes ImageNet-21k:21k classes JFT:18k High Resolution Images
将预训练迁移到的数据集有:
CIFAR-10/100 Oxford-IIIT Pets Oxford Flowers-102 VTAB
作者设计了3种不同答小的ViT模型,它们分别是:
DModel | Layers | Hidden size | MLP size | Heads | Params |
---|---|---|---|---|---|
ViT-Base | 12 | 768 | 3072 | 12 | 86M |
ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
ViT-L/16代表ViT-Large + 16 patch size
评价指标 :
结果都是下游数据集上经过finetune之后的Accuracy,记录的是在各自数据集上finetune后的性能。
实验1:性能对比
实验结果如下图所示,整体模型还是挺大的,而经过大数据集的预训练后,性能也超过了当前CNN的一些SOTA结果。对比的CNN模型主要是:
2020年ECCV的Big Transfer (BiT)模型,它使用大的ResNet进行有监督转移学习。
2020年CVPR的Noisy Student模型,这是一个在ImageNet和JFT300M上使用半监督学习进行训练的大型高效网络,去掉了标签。
All models were trained on TPUv3 hardware。
在JFT-300M上预先训练的较小的ViT-L/16模型在所有任务上都优于BiT-L(在同一数据集上预先训练的),同时训练所需的计算资源要少得多。 更大的模型ViT-H/14进一步提高了性能,特别是在更具挑战性的数据集上——ImageNet, CIFAR-100和VTAB数据集。 与现有技术相比,该模型预训练所需的计算量仍然要少得多。
下图为VTAB数据集在Natural, Specialized, 和Structured子任务与CNN模型相比的性能,ViT模型仍然可以取得最优。
实验2:ViT对预训练数据的要求
ViT对于预训练数据的规模要求到底有多苛刻?
作者分别在下面这几个数据集上进行预训练:ImageNet, ImageNet-21k, 和JFT-300M。
结果如下图所示:
我们发现: 当在最小数据集ImageNet上进行预训练时,尽管进行了大量的正则化等操作,但ViT-大模型的性能不如ViT-Base模型。
但是有了稍微大一点的ImageNet-21k预训练,它们的表现也差不多。
只有到了JFT 300M,我们才能看到更大的ViT模型全部优势。 图3还显示了不同大小的BiT模型跨越的性能区域。BiT CNNs在ImageNet上的表现优于ViT(尽管进行了正则化优化),但在更大的数据集上,ViT超过了所有的模型,取得了SOTA。
作者还进行了一个实验: 在9M、30M和90M的随机子集以及完整的JFT300M数据集上训练模型,结果如下图所示。 ViT在较小数据集上的计算成本比ResNet高, ViT-B/32比ResNet50稍快;它在9M子集上表现更差, 但在90M+子集上表现更好。ResNet152x2和ViT-L/16也是如此。这个结果强化了一种直觉,即:
残差对于较小的数据集是有用的,但是对于较大的数据集,像attention一样学习相关性就足够了,甚至是更好的选择。
实验3:ViT的注意力机制Attention
作者还给了注意力观察得到的图片块, Self-attention使得ViT能够整合整个图像中的信息,甚至是最底层的信息。作者欲探究网络在多大程度上利用了这种能力。
具体来说,我们根据注意力权重计算图像空间中整合信息的平均距离,如下图所示。
注意这里我们只使用了attention,而没有使用CNN,所以这里的attention distance相当于CNN的receptive field的大小。作者发现:在最底层, 有些head也已经注意到了图像的大部分,说明模型已经可以globally地整合信息了,说明它们负责global信息的整合。其他的head 只注意到图像的一小部分,说明它们负责local信息的整合。Attention Distance随深度的增加而增加。
整合局部信息的attention head在混合模型(有CNN存在)时,效果并不好,说明它可能与CNN的底层卷积有着类似的功能。
作者给出了attention的可视化,注意到了适合分类的位置:
5.2 ViT代码解读:
代码来自:
https://github.com/google-research/vision_transformergithub.com
首先是介绍使用方法:
安装:
pip install vit-pytorch
使用:
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
preds = v(img, mask = mask) # (1, 1000)
传入参数的意义:
image_size:输入图片大小。
patch_size:论文中 patch size:的大小。
num_classes:数据集类别数。
dim:Transformer的隐变量的维度。
depth:Transformer的Encoder,Decoder的Layer数。
heads:Multi-head Attention layer的head数。
mlp_dim:MLP层的hidden dim。
dropout:Dropout rate。
emb_dropout:Embedding dropout rate。
定义残差, 等:
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
Attention和Transformer,注释已标注在代码中:
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
# b, 65, 1024, heads = 8
b, n, _, h = *x.shape, self.heads
# self.to_qkv(x): b, 65, 64*8*3
# qkv: b, 65, 64*8
qkv = self.to_qkv(x).chunk(3, dim = -1)
# b, 65, 64, 8
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# dots:b, 65, 64, 64
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
# attn:b, 65, 64, 64
attn = dots.softmax(dim=-1)
# 使用einsum表示矩阵乘法:
# out:b, 65, 64, 8
out = torch.einsum('bhij,bhjd->bhid', attn, v)
# out:b, 64, 65*8
out = rearrange(out, 'b h n d -> b n (h d)')
# out:b, 64, 1024
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
ViT整体:
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, mask = None):
p = self.patch_size
# 图片分块
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
# 降维(b,N,d)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
# 多一个可学习的x_class,与输入concat在一起,一起输入Transformer的Encoder。(b,1,d)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
# Positional Encoding:(b,N+1,d)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# Transformer的输入维度x的shape是:(b,N+1,d)
x = self.transformer(x, mask)
# (b,1,d)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
# (b,1,num_class)
6 Transformer+Image Processing:IPT:用于底层视觉任务的Transformer
论文名称:Pre-Trained Image Processing Transformer
论文地址:
Pre-Trained Image Processing Transformer
https://arxiv.org/abs/2012.00364
6.1 IPT原理分析:
这个工作第1个提出将Transformer应用于底层的视觉任务(Low-Level Computer Vision Task),具体任务是:去噪,去雨和超分。作者想构建一个基于Transformer的预训练模型,利用上它强大的表征能力,以Transformer为核心,配上不同的 和 ,以完成相对应的底层视觉任务。
但要想得到这样适合底层任务的预训练模型,需要克服2个困难:
特定任务的数据集很有限,尤其是涉及到付费数据或者隐私数据时。相机参数,光照,天气等多种因素会进一步影响数据的分布。 在测试图像出来之前,不知道将请求哪种类型的图像处理任务。因此,必须准备一系列的不同任务的图像处理模块,其中一些基本的操作可以共享。
这个工作所提出的IPT就是一个基于Transformer的,适用于3种Low-Level Computer Vision Task的预训练模型。要使模型同时与这3种任务兼容,就需要3个 和 ,加上一个共享的 。
数据集构建:
训练数据集:
基于ImageNet, 使用几个基于特定任务的不同的操作来生成多个损坏的副本,以服务于不同的任务。原始的ImageNet具有 1M 数据,把它们crop成 的patches。为了得到损坏的副本,作者进行了6种不同的退化(Degradation)操作。
对于超分任务,就对原始ImageNet图片分别进行 下采样。
对于去噪任务,就对原始ImageNet图片分别添加30, 50 noise level Gaussian noise。
对于去雨任务,就对原始ImageNet图片添加雨线。
这样一来,用来训练IPT的整个数据集包含大约1000多万幅图像。
测试数据集:
同样把特定任务对应的测试集crop成 的patches,overlap 10个pixels。为了公平地比较,CNN模型也进行了这样的操作。
具体方法:
IPT的整体结构可以分为4个部分:
:从corrupted images中提取特征。由3层卷积组成,得到 ,这一步可以表达为: ,其中 表达第几个任务。
:主要结构,用于恢复丢失的数据。
:最后映射特征。
然后把的特征,变成一个 的sequence of flattened 2D patches。它可以看做是一系列的展平的2D块的序列,这个序列中一共有 个展平的2D块,每个块的维度是 。其中 是块大小, 是channel数。
注意作者做这步变化的意图: 根据我们 之前的讲解,Transformer希望输入一个二维的矩阵 ,其中 是sequence的长度, 是sequence的每个向量的维度,常用256。
所以有: 。
其中, 是Transformer输入的sequence的长度。
现在得到的向量维度是: 。
Positional Encoding
按照Transformer的位置编码的习惯,这个工作也使用了位置编码。引入了一个 Positional encoding 来加入序列的位置信息,同样在这里也引入了pos_embedding,是用一个可训练的变量。给每一个输入的 (一共 个)都进行了位置编码。
没有采用原版Transformer的 编码,而是直接设置为可学习的Positional Encoding,效果差不多。
Transformer Encoder的前向过程
其中,第1个式子为上面讲到的Positional Encoding的过程。
第2个式子为Transformer Encoder的 的过程。
第3个式子为Transformer Encoder的 的过程,整个过程重复 次。
作者采用的是没有任何改动的transformer Encoder。
最后是一个 的 ,整个的结构只有这些,如下图所示,为了方便读者的理解,我把变量的维度变化过程标注在了图中。
现在得到的向量维度是: 。
Transformer Decoder的前向过程
作者采用的是没有任何改动的transformer Decoder,包含2个 和1个 。
不同之处是:为了区分不同的任务,作者使用了基于特定任务的嵌入向量 Task-specific Embeddings ,这里的 代表3种不同的任务。它们都是可学习的向量,目的是学习到不同任务的不同表达。
整个Decoder的前向过程可以表达为:
其中,第1个式子为上面讲到的Positional Encoding的输出编码特征。
第2个式子为第1个 的 的计算过程,注意Task-specific Embeddings 的添加位置为 和 。
这里不同于DETR,IPT的第1个 的 就已经来自于Encoder的输出编码了。
第3个式子为Decoder的 的过程。
第4个式子为第2个 的 的计算过程,注意Task-specific Embeddings 的添加位置为 。
第5个式子为 Decoder的第2个 的过程。
第6个式子为Transformer Decoder的 的过程,整个过程重复 次。
现在得到的向量维度是: 。再将其reshape成大小为 的输出特征。
最后, 将输出的 映射为 的图片。
整个运行过程如下图所示,为了方便理解原理,我把变量的维度标在了图中。
成功训练一个优秀的IPT模型的关键因素之一是对大规模数据集的使用。与图像分类相比,用于图像处理任务的可用数据数量相对较少(例如,用于图像超分辨率任务的DIV2K数据集上只有2000幅图像),作者建议利用众所周知的ImageNet作为基线数据集来预训练我们的IPT模型。
之前也提过,作者为了得到大量的训练集,进行了不同的退化(Degradation)操作:
:Bicubic Interpolation。
,式中: 是高斯噪声。
,式中: 是人工添加的雨线。
目标函数:
这个式子的意思是所提出的框架是用多个图像处理任务同时训练的。
对于每一个batch的数据, 作者从 个supervised任务中随机选择一个任务进行训练,每个任务将同时使用相应的头、尾和任务嵌入进行处理。
在对IPT模型进行预训练之后,在各个任务对应的小数据集上进行finetune。
除此之外,在设计目标函数时,作者还结合了对比学习的方法:
以图片 作为输入,定义Transformer Decoder输出的Patches Features为 ,式中 。
作者的目标是最小化来自相同图像的Patches Features之间的距离,同时最大化来自不同图像的Patches Features之间的距离。对比学习的损失函数公式如下:
式中, 代表cosine similarity。
这样,总的目标函数可以归结为:
Experiments:
研究表明,在使用大规模数据集解决图像处理问题时,基于Transformer的模型比基于CNN的模型具有更好的性能。
数据集的构建:
如上面的介绍。
训练方法:
由于训练集里面包含了不同任务的数据,在每个iteration中,先随机选取一个任务,再从对应的数据集中抽1个batch的数据。 在对整个合成数据集进行预训练后,作者根据所需任务(例如,3倍单图像超分辨率)对IPT模型进行了30个epochs的finetune。
实验1:超分
比较的模型是近期的CNN模型,比较的数据集有Set5, Set14, BSD100, Urban100,超分类型 ,结果如下,IPT取得了SOTA。
Urban100中的一些可视化结果:
实验2:去噪
去噪实验结果的对比如下图所示,训练和测试数据是通过将 的高斯噪声加到干净的图像上生成的。下图显示了在数据集BSD68和Urban100上的结果,结果显示:无论采用哪种类型的噪声,在这2个数据集上面IPT都能取得SOTA的效果。
下图显示了去噪实验的可视化结果,IPT结果的视觉质量明显优于所有以前的模型。
实验3:去雨
对于去雨任务,作者在合成的Rain100L 数据集上评估模型,它是由BSD200数据集中选取一些图片添加雨线得到的,一共100张。 与最先进的方法相比,IPT获得了最好的性能(41.62dB),提高了1.62dB,结果如下图所示:
下图显示了去雨实验的可视化结果,IPT结果的视觉质量明显优于所有以前的模型。
实验4:泛化能力
一个好的预训练的模型应该有能力很好地适应其他任务,所以作者测试了噪声级别分别为10和70的噪声, 使用图像去噪任务的 和 作为预训练模型。结果如下图所示:
IPT模型优于其他传统方法,这表明预训练模型可以从大规模数据集捕获更多有用的信息和特征。
实验5:数据量的影响
作者也尝试训练了基于CNN的预训练模型,使用的模型是EDSR,在于IPT相同的超分数据集(ImageNet下采样数据集)进行预训练,并控制数据集的量分别为20%,40%,60%,80%,100%,结果如下图所示:
结果显示,在小数据集上进行预训练(60% ImageNet)时,CNN预训练模型性能优于基于Transformer的模型,而数据量变大时,基于Transformer的模型取得了更优的性能。