PULSE:一种基于隐式空间的图像超分辨率算法

分享一篇 CVPR 2020 录用论文:PULSE: Self-Supervised Photo Upsampling via Latent Space Exploration of Generative Models,作者提出了一种新的图像超分辨率方法,区别于有监督的PSNR-based和GANs-based方法,该方法是一种无监督的方法,即只需要低分辨率的图片就可以恢复高质量、高分辨率的图片。

目前代码已经开源:

https://github.com/adamian98/pulse

论文信息:

作者均来自于杜克大学。

1. Motivation

图像超分辨率任务的基本目标就是把一张低分辨率的图像超分成其对应的高分辨率图像。无论是基于PNSR还是GAN的监督学习方法,或多或少都会用到pixel-wise误差损失函数,而这往往会导致生成的图像比较平滑,一些细节效果不是很好。于是作者换了一个思路:**以往的方法都是从LR,逐渐恢复和生成HR;如果能找到一个高分辨率图像HR的Manifold,并从该Manifold中搜寻到一张高分辨率的图像使其下采样能恢复到LR,那么搜寻到的那张图像就是LR超分辨率后的结果。**所以本篇文章主要解决了以下的两个问题:

  • 如何找到一个高分辨率图像的Manifold?
  • 如何在高分辨率图像的Manifold上搜寻到一张图片使其下采样能恢复LR?

2. Method

假设高分辨率图像的Manifold是,是M上的一个高分辨率图片,给定一个低分辨率图像,如果可以通过下采样操作DS恢复LR,那么就可以认为是LR的超分辨率结果,该问题定义如下:

即当两者的差值小于某个阈值时。令,那么本文任务其实就是找到一个如下图所示:

所以本文一个重要的损失就是下采样损失Downscaling loss(p=2)

于是,如何构建这样一个高分辨率图像的Manifold()成为了问题的关键。假如我们有一个参数可微的Manifold(),那么我们就可以利用Downscapling loss去指导这个搜寻过程。而且得到的结果也是在Manifold()上的,结果肯定是高分辨率的,且通过下采样之后也能恢复LR图片。

所以作者想到了用一个带latent space的生成模型(e.g. VAE, GANs)来近似这个Manifold()。这些方法可以通过latent code去生成一张图片,通过引入一个pre-trained的生成模型来近似Manifold(),且latent code是可微分的,可以利用Downscapling loss去指导搜寻。所以问题又转化为找到一个生成模型G,令latent space为,那么对于一个latent code 有:

但是,简单的要求并不能保证,所以大多数的生成模型需要对施加先验,让在基于此先验信息的区域有更大的概率被取到。一个简单的方法就是基于此先验添加一个负对数似然损失项。假如使用的Gaussian先验,那么就可以采用L2正则项,但是L2正则会使latent code vector趋近于0,这显然不能满足要求。但是高维高斯分布有个性质是大部分质量都分布在半径为的球面上,d是高斯先验的维度。于是就可以将高维高斯分布转化为球面均匀分布。于是问题就简化为在一个球面空间中执行梯度下降,而不是在整个latent space。

以上就是本篇文章的核心内容,下面我们结合代码来看一下具体是怎么实现的。

首先我们需要一个生成模型来近似高分辨率的Manifold,在本文中,作者采用的是StyleGAN的预训练模型:

StyleGAN的生成器网络中有两个部分,一个是Mapping Network用于将latent code映射为style code,一个Synthesis Network用于将映射后得到的style code用于指导图像的生成。这里需要注意的是,本篇文章只是使用了StyleGAN的预训练模型,并不训练更新其参数。加载两个部分的参数之后,随机构造100000个随机latent code,经过Mapping Network,用得到新的latent code计算均值与方差:

latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda")
latent_out = torch.nn.LeakyReLU(5)(mapping(latent))
self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)}

这个均值与方差就可以用来映射新的latent code。接下就是随机初始化latent code和noise(StyleGAN需要):

# 初始化latent code
latent = torch.randn((batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda')
# 初始化noise
for i in range(18): # [?, 1, 4, 4] -> [?, 1, 1024, 1024]
    res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))
    new_noise = torch.randn(res, dtype=torch.float, device='cuda')
    if (i < num_trainable_noise_layers):  # num_trainable_noise_layers
        new_noise.requires_grad = True
    noise_vars.append(new_noise)
                   
noise.append(new_noise)

**从这里我们可以看出,模型优化的其实是latent code与noise的前5层,而不是模型参数。**初始化完成了之后就可以执行前向了:

# 根据之前的求得的均值和方差,映射latent code
latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"])
# 加载Synthesis Network用于生产图像
# 把图像结果从[-1, 1]修改到[0, 1]
gen_im = (self.synthesis(latent_in, noise)+1)/2

根据原始的低分辨率图像和生成的高分辨率图像计算loss。在代码中,loss由两个部分组成:

其中L2损失是将生成的高分辨率图像gen_im通过bicubic下采样恢复LR,并与输入的LR计算pixel-wise误差,GEOCROSS是测地线距离。


最后优化器选择的是球面优化器:

# opt = SphericalOptimizer(torch.optim.Adam, [x], lr=0.01)

class SphericalOptimizer(Optimizer):
    def __init__(self, optimizer, params, **kwargs):
        self.opt = optimizer(params, **kwargs)
        self.params = params
        with torch.no_grad():
            self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params}

@torch.no_grad()
    def step(self, closure=None):
        loss = self.opt.step(closure)
        for param in self.params:
            param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt())
            param.mul_(self.radii[param])

return loss

3. Result

从结果可以看出,PULSE生成的图像细节更丰富,包括头发丝、眼睛和牙齿这些比较细微的地方都能生成的很好。而且对于有噪声的LR,也能生成的很好,说明该算法有很强的鲁棒性:

最终的比较指标采用的是MOS:

4. Questions

PULSE是一个无监督的图像超分辨率模型,其图像的质量其实很大程度上取决于所选取的生成模型的好坏。另一方面,由于PULSE的基础原理就是找到一个高分辨率的图像,使其下采样之后能恢复LR,那么意味着结果不唯一,可能生成的图像很清楚,但是已经失去了身份信息:

5. Resource

  • Paper

    PULSE:https://arxiv.org/pdf/2003.03808.pdf

    StyleGAN: https://arxiv.org/abs/1812.04948

    Random Vectors in High Dimen- sions: https://www.sci-hub.ren/10.1017/9781108231596.006

  • Github: https://github.com/adamian98/pulse.git

备注:超分辨率

超分辨率交流群

图像视频超分辨率,可见光、红外、遥感超分辨率等技术,

若已为CV君其他账号好友请直接私信。

我爱计算机视觉

微信号:aicvml

QQ群:805388940

微博知乎:@我爱计算机视觉

投稿:amos@52cv.net

网站:www.52cv.net

在看,让更多人看到

(0)

相关推荐