【百战GAN】如何使用GAN拯救你的低分辨率老照片

大家好,欢迎来到专栏《百战GAN》,在这个专栏里,我们会进行算法的核心思想讲解,代码的详解,模型的训练和测试等内容。

作者&编辑 | 言有三

本文资源与生成结果展示

本文篇幅:5200字

背景要求:会使用Python和Pytorch

附带资料:项目推荐,版本包括Pytorch+Tensorflow

同步平台:有三AI知识星球(一周内)

1 项目背景

了解GAN的同学都知道,GAN擅长于捕捉概率分布,因此非常适合图像生成类任务。我们在图片视频拍摄以及传输过程中,经常会进行图像的压缩,导致图像分辨率过低,另外早些年的设备拍摄出来的照片也存在分辨率过低的问题,比如10年前的320*240分辨率。要解决此问题,需要使用到图像超分辩技术。

本次我们使用GAN来完成图像超分辩任务,需要做的准备工作包括:

(1) Linux系统或者windows系统,使用Linux效率更高。

(2) 安装好的Pytorch,需要GPU进行训练。

2 原理简介

图像超分辩任务输入是一张低分辨率的图像,输出是一张对它进行分辨率增大的图片,下面是一个常用的框架示意图[1]:

该框架首先对输入图使用插值方法进行上采样,然后使用卷积层对输入进行学习,这种框架的劣势是计算代价比较大,因为整个网络是对高分辨率图操作。

随后研究者提出在网络的后端进行分辨率放大,通过扩充通道数,然后将其重新分布来获得高分辨率图,这套操作被称为(PixShuffle)[2],这样整个网络大部分计算量是对低分辨率图操作,如下图:

以上构成了图像超分辨的基本思路,之后研究者将GAN带入超分辩框架[3],实际上就是增加了对抗损失,同时使用了我们常说的感知损失替代了重建用的MSE损失。

关于各类超分辨率框架的具体原理,大家可以移步有三AI知识星球,或者自行学习。由于我们这是实战专栏,不对原理做完整的介绍。

3 模型训练

大多数超分重建任务的数据集都是通过从高分辨率图像进行下采样获得,论文中往往选择ImageNet数据集,由于我们这里打算专门对人脸进行清晰度恢复,因此选择了一个常用的高清人脸数据集,CelebA-HQ,它发布于 2019 年,包含30000张不同属性的高清人脸图,其中图像大小均为1024×1024,预览如下。

接下来我们对代码进行解读:

3.1 数据预处理

图像超分辨数据集往往都是从高分辨率图进行采样得到低分辨率图,然后组成训练用的图像对,下面是对训练集和验证集中数据处理的核心代码:

## 训练集高分辨率图预处理函数

def train_hr_transform(crop_size):

return Compose([

RandomCrop(crop_size),

ToTensor(),

])

## 训练集低分辨率图预处理函数

def train_lr_transform(crop_size, upscale_factor):

return Compose([

ToPILImage(),

Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),

ToTensor()

])

## 训练数据集类

class TrainDatasetFromFolder(Dataset):

def __init__(self, dataset_dir, crop_size, upscale_factor):

super(TrainDatasetFromFolder, self).__init__()

self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] ##获得所有图像

crop_size = calculate_valid_crop_size(crop_size, upscale_factor)##获得裁剪尺寸

self.hr_transform = train_hr_transform(crop_size) ##高分辨率图预处理函数

self.lr_transform = train_lr_transform(crop_size, upscale_factor) ##低分辨率图预处理函数

##数据集迭代指针

def __getitem__(self, index):

hr_image = self.hr_transform(Image.open(self.image_filenames[index])) ##随机裁剪获得高分辨率图

lr_image = self.lr_transform(hr_image) ##获得低分辨率图

return lr_image, hr_image

def __len__(self):

return len(self.image_filenames)

## 验证数据集类

class ValDatasetFromFolder(Dataset):

def __init__(self, dataset_dir, upscale_factor):

super(ValDatasetFromFolder, self).__init__()

self.upscale_factor = upscale_factor

self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

(0)

相关推荐