字节跳动提出面向GAN压缩的在线多粒度蒸馏算法,算力降至1/46

机器之心专栏

字节跳动-智能创作团队

字节跳动 - 智能创作团队提出了一种用于学习轻量级 GAN 的在线多粒度蒸馏算法 OMGD。该算法能够把 GAN 模型的计算量减少到最低 1/46、参数量减少到最低 1/82 的程度,并保持原来的图像生成质量。

近年来,生成对抗网络(GAN)在图像生成、图像翻译等多种视觉应用中取得了显著成果。尽管 GAN 模型给图像生成带来了不同程度的提升,但大部分模型的部署都涉及巨大的计算资源和内存消耗。这成为在资源受限的移动设备或其他轻量级物联网设备上部署 GAN 的一个关键瓶颈。

GAN 压缩方向已经成为业界的挑战之一,不少高校和科技公司对此投入研究力量。但当前的 GAN 压缩算法主要存在两个方面的问题:一方面,当前研究倾向于直接采用成熟的模型压缩技术来进行压缩,而这些技术不是面向 GAN 定制的,缺乏对 GAN 复杂特性和结构的探索;另一方面,GAN 压缩通常被规划为一个多阶段的任务,多阶段设置中对时间和计算资源的要求较高。

为了解决上述问题,字节跳动 - 智能创作团队提出了一种面向 GAN 压缩的在线多粒度蒸馏算法(Online Multi-Granularity Distillation,简称 OMGD)。该算法能够把 GAN 模型的计算量减少到最低 1/46、参数量减少到最低 1/82 的程度,并保持原来的图像生成质量。这为在资源受限的设备上部署实时图像翻译的 GAN 模型提供了一个可行的解决方案。

OMGD 研究论文已入选 ICCV2021,相关代码也已开源。

论文链接:https://arxiv.org/abs/2108.06908

GitHub 链接:https://github.com/bytedance/OMGD

引言

论文提出了一种新的在线多粒度蒸馏(OMGD)方案来获得轻量级的 GAN,以较低的计算成本生成高保真图像。OMGD 放弃了复杂的多级压缩过程,设计了一种面向 GAN 的在线蒸馏策略,可以一步获得压缩模型。OMGD 还从多个层次和粒度挖掘潜在的图像信息,以帮助优化压缩模型。这些概念可以看作是辅助监督线索,这对于突破低计算成本模型的容量瓶颈至关重要。

方法

1. 在线蒸馏

论文提出了一种面向 GAN 的在线蒸馏算法来解决离线蒸馏中的三个关键问题:

第一,传统离线蒸馏方法中的学生生成器应保持一定的容量,以保持与鉴别器的动态平衡,避免模型崩溃和消失梯度。然而,OMGD 的学生生成器仅利用教师网络的输出信息来进行优化,并且在无判别器的设定中进行训练。学生生成器不再与鉴别器紧密绑定,它可以更灵活地训练并获得进一步的压缩。具体来说,在每个迭代步骤中反向传播教师生成器和学生生成器之间的蒸馏损失。这样学生生成器可以模仿教师生成器的训练过程以逐步学习。其中蒸馏的损失函数由结构相似化损失函数和感知损失函数构成。

第二,预先训练的教师生成器无法引导学生逐步学习信息,并且容易导致在训练阶段过度拟合。而 OMGD 的教师生成器有助于渐进地引导学生生成器的优化方向。

第三,对于 GAN 任务来说,评估指标是主观的。因此选择合适的预训练的教师生成器并非易事。而在线策略不需要一个预先训练好的教师生成器,可以避免这个问题;同时 OMGD 的学生生成器在优化过程中不需要使用真实标签,而仅学习教师生成器的输出,这大大降低了直接拟合真实标签的难度。

2. 多粒度蒸馏

OMGD 进一步从两个角度将在线蒸馏策略扩展为多粒度方案。

一方面,其采用基于不同结构的教师生成器来捕获更多的互补的信息,从更多样化的维度提高视觉逼真度。具体来说,从深度和宽度两个互补维度将学生模型扩展为教师模型。给定一个学生生成器,通过扩展学生生成器的通道来获得更宽的教师生成器。接着在学生生成器每个下采样层和上采样层之后插入几个 Resnet block 构建一个更深的教师生成器。该研究直接将互补教师生成器的两个蒸馏损失合并为多教师设置中的知识蒸馏损失。

另一方面,除了输出层的概念外,OMGD 还将中间通道的粒度信息作为辅助的监督信号进行蒸馏优化。具体地说,OMGD 计算在通道维度上的注意力权重来衡量特征图中每个通道的重要性,并将注意力信息迁移到学生模型中。通道蒸馏损失的定义如下所示:

综上所述,整个在线多粒度蒸馏的损失函数定义:

实验

在 Pix2Pix 的结果表明,OMGD 对 U-Net 和 Resnet 两种类型的 GAN 都显著优于当前最先进的方法,且计算成本减少许多。

在 CycleGAN 的结果表明,OMGD 以更小的计算成本获得了更好的性能表现。

可视化结果表示,OMGD 能够在压缩过程中保持图像的细节信息。

总结

论文提出了一种用于学习轻量级 GAN 的在线多粒度蒸馏算法 OMGD。大量实验表明,OMGD 能够将 Pix2Pix 和 CycleGAN 压缩到极低的计算成本,而不会造成明显的视觉保真度损失,这为在资源受限的设备上部署实时 GAN 提供了一个可行的解决方案。

(0)

相关推荐