Batchsize不够大,如何发挥BN性能?探讨神经网络在小Batch下的训练方法

作者丨皮特潘
编辑丨极市平台

极市导读

由于算力的限制,有时我们无法使用足够大的batchsize,此时该如何使用BN呢?本文将介绍两种在小batchsize也可以发挥BN性能的方法。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

前言

BN(Batch Normalization)几乎是目前神经网络的必选组件,但是使用BN有两个前提要求:

  1. batchsize不能太小;
  2. 每一个minibatch和整体数据集同分布。

不然的话,非但不能发挥BN的优势,甚至会适得其反。但是由于算力的限制,有时我们无法使用足够大的batchsize,此时该如何使用BN呢?本文介绍两篇在小batchsize也可以发挥BN性能的方法。解决思路为:既然batchsize太小的情况下,无法保证当前minibatch收集到的数据和整体数据同分布。那么能否多收集几个batch的数据进行统计呢?这两篇工作分别分别是:

  • BRN:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
  • CBN:Cross-Iteration Batch Normalization

另外,本文也会给出代码解析,帮助大家理解。

batchsize过小的场景

通常情况下,大家对CNN任务的研究一般为公开的数据集指标负责。分类任务为ImageNet数据集负责,其尺度为224X224。检测任务为coco数据集负责,其尺度为640X640左右。分割任务一般为coco或PASCAL VOC数据集负责,后者的尺度大概在500X500左右。再加上例如resize的前处理操作,真正送入网络的图片的分辨率都不算太大。一般性能的GPU也很容易实现大的batchsize(例如大于32)的支持。

但是实际的项目中,经常遇到需要处理的图片尺度过大的场景,例如我们使用500w像素甚至2000w像素的工业相机进行数据采集,500w的相机采集的图片尺度就是2500X2000左右。而对于微小的缺陷检测、高精度的关键点检测或小物体的目标检测等任务,我们一般不太想粗暴降低输入图片的分辨率,这样违背了我们使用高分辨率相机的初衷,也可能导致丢失有用特征。在算力有限的情况下,我们的batchsize就无法设置太大,甚至只能为1或2。小的batchsize会带来很多训练上的问题,其中BN问题就是最突出的。虽然大batchsize训练是一个共识,但是现实中可能无法具有充足的资源,因此我们需要一些处理手段。

BN回顾

首先Batch Normalization 中的Normalization被称为标准化,通过将数据进行平和缩放拉到一个特定的分布。BN就是在batch维度上进行数据的标准化。BN的引入是用来解决 internal covariate shift 问题,即训练迭代中网络激活的分布的变化对网络训练带来的破坏。BN通过在每次训练迭代的时候,利用minibatch计算出的当前batch的均值和方差,进行标准化来缓解这个问题。虽然How Does Batch Normalization Help Optimization 这篇文章探究了BN其实和Internal Covariate Shift (ICS)问题关系不大,本文不深入讨论,这个会在以后的文章中细说。

一般来说,BN有两个优点:

  • 降低对初始化、学习率等超参的敏感程度,因为每层的输入被BN拉成相对稳定的分布,也能加速收敛过程。
  • 应对梯度饱和和梯度弥散,主要是对于使用sigmoid和tanh的激活函数的网络。

当然,BN的使用也有两个前提:

  • minibatch和全部数据同分布。因为训练过程每个minibatch从整体数据中均匀采样,不同分布的话minibatch的均值和方差和训练样本整体的均值和方差是会存在较大差异的,在测试的时候会严重影响精度。
  • batchsize不能太小,否则效果会较差,论文给的一般性下限是32。

再来回顾一下BN的具体做法:

  • 训练的时候:使用当前batch统计的均值和方差对数据进行标准化,同时优化优化gamma和beta两个参数。另外利用指数滑动平均收集全局的均值和方差。
  • 测试的时候:使用训练时收集全局均值和方差以及优化好的gamma和beta进行推理。

可以看出,要想BN真正work,就要保证训练时当前batch的均值和方差逼近全部数据的均值和方差。

BRN

论文题目:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models

论文地址: https://arxiv.org/pdf/1702.03275.pdf

代码地址: https://github.com/ludvb/batchrenorm

核心解析

本文的核心思想就是:训练过程中,由于batchsize较小,当前minibatch统计到的均值和方差与全部数据有差异,那么就对当前的均值和方差进行修正。修正的方法主要是利用到通过滑动平均收集到的全局均值和标准差。看公式:

上面公式中,i表示网络的第i层。μ和σ表示网络推理时的均值和标准差,也就是训练过程通过滑动平均收集的到均值和方差。μB和σb表示当前训练迭代过程中的实际统计到的均值和标准差。BN在小batch不work的根本原因就是这两组参数存在较大的差异。通过r和d对训练过程中数据进行线性变换,在该变化下,上公式左右两端就严格相等了。其实标准的BN就是r=1,d=0的一种情况。对于某一个特定的minibatch,其中r和d可以看成是固定的,是直接计算出来的,不需要梯度优化的。

具体流程

  • 统计当前batch数据的均值和标注差,和标准BN做法一致。

  • 根据当前batch的均值和标准差结合全局的均值和标准差利用上面的公式计算r和d;注意该运算是不参与梯度反向传播的。另外,r和d需要增加一个限制,直接clip操作就好。

  • 利用当前的均值和标准差对当前数据执行Normalization操作,利用上面计算得到的r和d对当前batch进行线性变换。

  • 滑动平均收集全局均值和标注差。

测试过程和标准BN一样。其实本质上,就是训练的过程中使用全局的信息进行更新当前batch的数据。间接利用了全局的信息,而非当前这一个batch的信息。

实验效果

在较大的batchsize(32)的时候,与标准BN相比,不会丢失效果,训练过程一如既往稳定高效。如下:

在小的batchsize(4)下, 本文做法依然接近batchsize为32的时候,可见在小batchsize下是work的。

代码解析

def forward(self, x): if x.dim() > 2: x = x.transpose(1, -1) if self.training: # 训练过程 dims = [i for i in range(x.dim() - 1) batch_mean = x.mean(dims) # 计算均值 batch_std = x.std(dims, unbiased=False) + self.eps # 计算标准差 # 按照公式计算r和d r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(1 / self.rmax, self.rmax) d = ((batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std)).clamp_(-self.dmax, self.dmax) # 对当前数据进行标准化和线性变换 x = (x - batch_mean) / batch_std * r + d # 滑动平均收集全局均值和标注差 self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean) self.running_std += self.momentum * (batch_std.detach() - self.running_std) self.num_batches_tracked += 1 else: # 测试过程 x = (x - self.running_mean) / self.running_std return x


CBN

论文题目:Cross-Iteration Batch Normalization

论文地址:https://arxiv.org/abs/2002.05712

代码地址:https://github.com/Howal/Cross-iterationBatchNorm

本文认为BRN的问题在于它使用的全局均值和标准差不是当前网络权重下获取的,因此不是exactly正确的,所以batchsize再小一点,例如为1或2时就不太work了。本文使用泰勒多项式逼近原理来修正当前的均值和标准差,同样也是间接利用了全局的均值和方差信息。简述就是:当前batch的均值和方差来自之前的K次迭代均值和方差的平均,由于网络权重一直在更新,所以不能直接粗暴求平均。本文而是利用泰勒公式估计前面的迭代在当前权重下的数值。

泰勒公式

泰勒公式是个用函数在某点的信息描述其附近取值的公式。如果函数满足定的条件,泰勒公式可以用函数在某点的各阶导数值做系数构建个多项式来近似表达这个函数。教科书介绍如下:

核心解析:

本文做法,由于网络一般使用SGD更新权重,因此网络权重的变化是平滑的,所以适用泰勒公式。如下,t为训练过程中当前迭代时刻,t-τ为t时刻向前τ时刻。θ为网络权重,权重下标代表该权重的时刻。μ为当前minibatch均值,v为当强minibatch平方的均值,是为了计算标准差。因此直接套用泰勒公式得到:

上面这两个公式就是为了估计在t-τ时刻,t时刻的权重下的均值和方差的参数估计。BRN可以看作没有进行该方法估计,使用的依然是t-τ时刻权重的参数估计。其中O为高阶项,因为该式主要由一阶项控制,因此高阶项目可以忽略。上面的公式还要进一步简化,主要是偏导项的求法。假设当前层为l,实际上∂μ/ ∂θ 和 ∂ν/∂θ依赖与所有l层之前层的权重,求导计算量极大。不过经验观察到,l层之前层的偏数下降很快,因此可以忽略掉,仅仅计算当前层的权重偏导。

因此化简为如下,可以看出,求偏导的部分,只考虑对当前层的偏导数,注意上标l表示网络层的意思。至此,之前时刻在当前权重下的均值和方差已经估计出来了。

下面穿插代码解析整个计算过程。

首先是统计计算当前batch的数据,和标准BN没有差别。代码为:

cur_mu = y.mean(dim=1) # 当前层的均值cur_meanx2 = torch.pow(y, 2).mean(dim=1) # 当前值平方的均值,计算标准差使用cur_sigma2 = y.var(dim=1) # 当前值的方差

对当前网络层求偏导,直接使用torch的内置函数。代码:

# 注意 grad_outputs = self.ones : 不同值的梯度对结果影响程度不同,类似torch.sum()的作用。dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]

使用公式(7)和(8)继续下面的计算,也就是向前累计K次估计数值,更新到当前batch的均值和方差的计算上,这里引入了一个超参就是k的大小,它表示当前的迭代向后回溯到多长的步长的迭代。实验探究k=8是一个比较折中的选择。k=1的时候,RBN退化成了原始的BN:

代码如下,其中这里的self.pre_mu, self.pre_dmudw, self.pre_weight是前面每次迭代收集到了窗口k大小的数值,分别代表均值、均值对权重的偏导、权重。self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight同理,是对应平方均值的。

# 利用泰勒公式估计mu_all = torch.stack \ ([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])
meanx2_all = torch.stack \  ([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])

上面所说的变量收集迭代过程如下:

# 动态维护buffer_num长度的均值、均值平方、偏导、权重self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]tmp_weight = torch.zeros_like(weight.data)tmp_weight.copy_(weight.data)self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]

计算获取当前batch的均值和方差,取修正后的K次迭代数据的平均即可。

# 利用收集到的一定窗口长度的均值和平方均值,计算当前均值和方差sigma2_all = meanx2_all - torch.pow(mu_all, 2)re_mu_all = mu_all.clone()re_meanx2_all = meanx2_all.clone()re_mu_all[sigma2_all < 0] = 0re_meanx2_all[sigma2_all < 0] = 0count = (sigma2_all >= 0).sum(dim=0).float()mu = re_mu_all.sum(dim=0) / count # 平均操作sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)

均值和方差使用过程,和标准BN没有区别。

# 标准化过程,和原始BN没有区别y = y - mu.view(-1, 1)if self.out_p: # 仅仅控制开平方的位置 y = y / (sigma2.view(-1, 1) + self.eps) ** .5else:  y = y / (sigma2.view(-1, 1) ** .5 + self.eps)

最后再理解一下

mu_0是当前batch统计获取的均值,mu_1是上一batch统计获取的均值。当前batch计算BN的时候也想利用到mu_1,但是统计mu_1的时候利用到网络的权重也是上一次的,直接使用肯定有问题,所以本文使用泰勒公式估计出mu_1在当前权重下应该是什么样子。方差估计同理。

实验效果:

这里的Naive CBN 是上一篇论文BRN的做法,可以认为是CBN不使用泰勒估计的一种特例。在batchsize下降的过程中,CBN指标依然坚挺,甚至超过了GN(不过也侧面反应了GN确实厉害)。而原始BN和其改进版BRN在batchsize更小的时候都不太work了。

◎作者档案
皮特潘,致力于AI落地而上下求索
欢迎大家联系极市小编(微信ID:fengcall19)加入极市原创作者行列
(0)

相关推荐