AIM2020-ESR冠军方案解读:引入注意力模块ESA,实现高效轻量的超分网络(附代码实现)

作者丨Happy
编辑丨极市平台

极市导读

该文是南京大学提出的一种轻量&高效图像超分网络,它获得了AIM20-ESR竞赛的冠军。它在IMDN的基础上提出了两点改进,并引入RFANet的一种ESA注意力模块。如果从结果出发来看RFDN看上去很简单,但每一步的改进却能看到内在的一些思考与尝试。值得初入图像复原领域的同学仔细研究一下该文。

paper: https://arxiv.org/abs/2009.11551
code:  https://github.com/njulj/RFDN(预训练模型未开源)

Abstract

受益于CNN强大的拟合能力,图像超分取得了极大的进展。尽管基于CNN的方法取得了极好的性能与视觉效果,但是这些模型难以部署到端侧设备(高计算量、高内存占用)。
为解决上述问题,已有各种不同的快速而轻量型的CNN模型提出,IDN(Information Distillation Network, IDN)是其中的佼佼者,它采用通道分离的方式提取蒸馏特征。然而,我们并不是很清晰的知道这个操作是如何有益于高效SISR的。
该文提出一种等价于通道分离操作的特征蒸馏连接操作(Feature Distillation Connection, FDC),它更轻量且更灵活。基于FDC,作者对IMDN(Information Multi Distillation Network, IMDN)进行了重思考并提出了一种称之为RFDN(Residual Feature Distillation Network, RFDN)的轻量型图像超分模型,RFDN采用多个FDC学习更具判别能力的特征。与此同时,作者还提出一种浅层残差模块SRB(Shallow Residual Block, SRB)作为RFDB的构件模块,SRB即可得益于残差学习,又足够轻量。
最后作者通过实验表明:所提方法在性能与模型复杂度方面取得了更多的均衡。更进一步,增强型的RFDN(Enhanced RFDN, E-RFDN)获得了AIM2020 Efficient Super Resolution竞赛的冠军。
该文的主要贡献包含以下几点:
  • 提出一种轻量型残差特征蒸馏网络用于图像超分并取得了SOTA性能,同时具有更少的参数量;

  • 系统的分析了IDM并对IMDN进行重思考,基于思考发现提出了FDC;

  • 提出了浅层残差模块,它无需引入额外参数即可提升超分性能。

Method

上图a给出了IMDN的核心模块IMDB的网路架构图,它是一种渐进式模块(Progressive Refinement Module),PRM部分(图中灰色背景区域)采用卷积从输入特征通过多个蒸馏步骤提取特征。在每个步骤,采用通道分离操作将特征分成两部分:一部分保留,一个融入到下一阶段的蒸馏步骤。假设输入特征表示为,该过程可以描述为:
其中表示第j个卷积模块(包含激活单元),表示第j个通道分离操作。最后所有的蒸馏特征通过Concat进行融合得到输出:

Rethinking the IMDB

尽管PRM获得显著的提升,但不够高效且因为通道分离操作引入了某些不灵活性。通过卷积生成的特征存在许多冗余参数;而且特征蒸馏是通过通道分离达成,导致其无法有效利用恒等连接。作者对通道分离操作进行了重思考并提出了一种新的等价架构以避免上述问题。
以Fig2b为例,卷积后接通道分离可以拆解成两个卷积DL和RL,此时改进的架构可以描述如下:
也就是说每一次的通道分离操作可以视作两个卷积的协同作用,我们将这种改进称之为IMDB-R,它比IMDB更为灵活,且具有更好的解释性。

Residual Feature Distillation Block

基于前述思考,作者引入该文的核心RFDB(见Fig2c),一种比IMDB更轻量更强的模块。从Fig2可以看到:信息蒸馏操作是通过卷积以一定比例压缩特征通道实现。在诸多CNN模型中,采用卷积进行进行通道降维更为高效,故得到Fig2c中的卷积设计,采用这种替换还可以进一步降低参数量。
除了前面提到的改进外,作者还引入一种更细粒度的残差学习到RFDB。作者设计了一种浅层残差学习模块SRB(见Fig2d),它仅包含一个卷积核一个恒等连接分支以及一个激活单元。SRB在不引入额外参数的前提下,还可以从残差学习中受益。
原始的IMDN仅仅包含一个粗粒度的残差连接,网络从残差连接的受益比较有限;而SRB可以耕细粒度的残差连接,可以更好的利用参数学习的能力。

Framework

上图给出了RFDN的全局网络架构图,很明显这是一种RDN的网络结构。它包含四个部分:特征提取、堆叠RFDB,特征融合,重建上采样。
特征提取目前基本都是采用卷积提取初始特征,该过程可以描述如下:
而堆叠RFDB则是以链式方式逐渐更新提取特征,该过程可以描述如下:
再完成逐级特征计算后,最后通过卷积对前述所有中间特征进行集成融合,该过程可以描述:
最后,超分图像通过重建模块生成:
注:R一般采用卷积和PixelShuffle操作组合。在损失函数方面,RFDN采用了损失。

Experiments

训练数据:DIV2K;测试数据:Set5、Set14、BSD100、Urban100、Manga109。度量指标PSNR、SSIM。
优化器Adam,初始学习率,每200000次迭代折半,Batch=64,随机水平、随机90度旋转。x2模型从头开始训练,其他尺度模型则以x2模型参数进行初始化。
作者实现了两个尺寸的模型:RFDN和RFDN-L。RFDN的通道数为48,模块数为6;而RFDN-L的通道数为52,模块数为6。先来看所提方法与其他SOTA方法的对比,见下表。
接下来,我们看一下消融实验部分的一些对比。为更好的说明RFDB的优势,作者设计了三组对标模块,见下图。其中FDC则是前述在IMDB之处的改进的一个版本。
下表给出了上述四个模块构成的模型的性能与参数量对比。可以看到:(1)相比标准卷积,SRB可以提升0.12dB模型性能,且不会引入额外参数;(2)在FDC与RFDB的对比中也可以看到类似的性能提升;(3)FDC模块可以提升0.19dB模型性能;(4)FDC与SRB的组合得到了0.27dB的性能提升。
与此同时,作者探讨了蒸馏比例的影响,见下表。总而言之:从参数量与模型性能角度来看,蒸馏比例为0.5是一种比较的均衡。而这也是RFDN中采用的蒸馏比例,笔者在复现RFDN的过程中也曾疑惑过这个参数的设置,因为按照IMDN中的0.25设置的话无论如何都得不到竞赛中的那个参数量、FLOPs。
最后,作者在RFDN的基础上进行了又一次改进,引入了ESA模块,称之为Enhanced RFDN(E-RFDN)。该模型获得了AIM2020 Efficient Super Resolution竞赛的冠军,见下表。需要注意的是:E-RFDN训练数据集为DF2K,模块数为4。从表中数据可以看到:所提方法以较大的优势超越其他参赛方案。
全文到此结束,对该文感兴趣的同学建议去查看原文。

后记

事实上,在这篇论文放出之前,笔者已经在尝试进行RFDN的复现工作。当然实现方面有一点点的出入,见下面的笔者实现code,注:ESA模块是源自作者RFANet一文的代码。笔者参照E-RFDN的网络结构进行的复现,在DIV2K-val上训练200000次迭代达到了PSNR:30.47dB(YCbCr)。也就是说,这个方法看起来简单,复现起来也非常简单,关键是轻量&高效。为什么不去尝试一把呢?
class RFDB(nn.Module): """ A little difference with the official code. """ def __init__(self, in_channels, distillation_rate=0.25): super(RFDB, self).__init__() distilled_channels = int(in_channels * distillation_rate) remaining_channels = in_channels - distilled_channels * 3
self.d1 = nn.Conv2d(in_channels, distilled_channels, 1) self.c1 = SRB(in_channels)
self.d2 = nn.Conv2d(in_channels, distilled_channels, 1) self.c2 = SRB(in_channels)
self.d3 = nn.Conv2d(in_channels, distilled_channels, 1) self.c3 = SRB(in_channels)
self.d4 = nn.Conv2d(in_channels, remaining_channels, 3, 1, 1)
self.act = nn.LeakyReLU(negative_slope=0.05, inplace=True) self.fusion = nn.Conv2d(in_channels, in_channels, 1) self.esa = ESA(in_channels)
def forward(self, inputs): distilled_c1 = self.act(self.d1(inputs)) remaining_c1 = self.act(self.c1(inputs))
distilled_c2 = self.act(self.d2(remaining_c1)) remaining_c2 = self.act(self.c2(remaining_c1))
distilled_c3 = self.act(self.d3(remaining_c2)) remaining_c3 = self.act(self.c3(remaining_c2))
distilled_c4 = self.act(self.d4(remaining_c3))
out = torch.cat([distilled_c1, distilled_c2, distilled_c3, distilled_c4], dim=1) out_fused = self.esa(self.fusion(out)) + inputs return out_fused
◎作者档案
Happy,一个爱“胡思乱想”的AI行者
个人公众号:AIWalker
欢迎大家联系极市小编(微信ID:fengcall19)加入极市原创作者行列
(0)

相关推荐