让模型实现“终生学习”,佐治亚理工学院提出Data-Free的增量学习

0

写在前面

目前的计算机视觉模型在进行增量学习新的知识的时候,就会出现灾难性遗忘的问题。缓解这种遗忘的最有效的方法需要大量重播(replay)以前训练过的数据;但是,当内存限制或数据合法性问题存在时,这种方法就存在一定的局限性。

在本文中,作者研究了无数据类增量学习(DFCIL)的问题,也就是增量学习能够学习新的知识,而不存储生成器或过去任务的训练数据。目前,DFCIL的一种方法是通过倒置学习分类模型的冻结副本,来合成图像用于训练,使得模型能够不忘记以前任务的知识,也不用replay以前训练过的数据。但是,作者通过实验表明了当使用标准蒸馏策略时,这种方法对于常见的类增量benchmark都是无效的。
因此,在本文中,作者分析了这种方法失败的原因,并提出了一种新的DFCIL增量蒸馏策略,提供了一个改进的交叉熵训练和重要性加权特征蒸馏。最终作者通过实验表明,在类增量benchmark上,与SOTA DFCIL方法相比,本文提出的方法在精度上提高了25.1%,甚至优于几种需要存储图像的基于replay的方法。
01

论文和代码地址

Always Be Dreaming: A New Approach for Data-Free Class-Incremental Learning

论文地址:https://arxiv.org/abs/2106.09701

代码地址:尚未开源

02

Motivation

目前,计算机视觉的一个局限是,它们通常使用一个包含在部署过程中所有可能遇到的数据的大型数据集,进行脱机训练。然而,现实情况是许多应用程序需要在遇到新的情况和数据后不断更新模型。这就是类增量学习的范式,在学习新任务的时候忘记以前学习到的知识的问题被称为在灾难性遗忘 。目前,比较成功的增量学习方法有一个缺点:它们需要大量的内存来replay以前看到过的或建模的数据,以避免灾难性遗忘问题。

这在很多计算机视觉的应用中也是不现实的,因为:

1)许多计算机视觉应用程序都是在设备上的,因此内存有限;

2)在工业界,可能会存在很多不允许被存储的数据(比如用户的隐私信息)。

因此,作者就提出了这样一个问题:计算机视觉系统如何能在不存储数据的情况下增量地学习新信息?作者将这样的设置称为无数据类增量学习(DFCIL)。
DFCIL的一种直观方法是同时训练生成模型进行采样以进行replay,以防止忘记以前的知识。但是与分类模型相比,训练生成模型的计算和内存都更密集。
因此,作者探索了模型反演图像 合成的概念,就是通过反转已经提供的推理网络,来获得网络中与训练数据具有相似激活作用的图像。这样一来,就不需要训练额外的网络(因为它只需要现有的推理网络)。
(上图展示了当使用合成数据进行基于replay类增量学习时,特征嵌入的分布。图a展示了合成数据的直接应用使模型的学习特征更容易区分是真实数据还是合成数据,而不是任务1和2,这也是本文要解决的主要问题;图b展示了修改分类损失和添加正则化可以减轻真实和合成图像之间的特征位移;图c是理想的特性分布,使任务1和任务2更可分离。)
上图展示了DFCIL增量学习失败的原因(图a),用当前任务的真实图像和代表过去任务的合成图像训练模型时,特征提取模型提取的特征会变成:当前真实图像的特征分布与当前真实图像的特征分布(即使他们不属于同一个类)更接近,与合成图像的特征分布更不接近 ,这就导致了预测时候的偏差。
这一现象表明,当训练一个具有两种数据分布的网络时,同时包含语义位移和分布位移,分布位移对特征嵌入有更高的影响。因此,来自以前任务的的测试图像将被识别为新的类,因为模型会更关注于它们的分布,而不是它们的语义内容(这就与分类任务的目标背道而驰了)。
为了解决这个问题,作者提出了一种新的类增量学习方法,该方法学习了具有局部分类损失的新任务特征,依赖于重要性加权特征蒸馏和线性分类head微调来分离新任务和过去任务的特征嵌入。
作者通过实验表明,在类增量benchmark上,与SOTA DFCIL方法相比,本文提出的方法在精度上提高了25.1%,甚至优于几种需要存储图像的基于replay的方法。
03

方法

3.1. 先验知识-类增量学习

在类增量学习中,一个模型需要学习了对应于M个语义对象类、、、的数据,但这些数据是通过N个task依次暴露给模型的,每个任务中子类都不会重合。
我们用来表示任务n中引入的类集,其中表示任务n中对象类的数量。每个类只出现在单个任务中,模型目标就是逐步学习引入的新对象类,并对它们进行分类,同时保留之前学习过的类的知识。
为了描述推理模型,我们将θ,表示在i时刻使用任务n的类训练的模型。

3.2. Baseline Approach

在本节中,作者基于之前工作,提出了一个Data-Free的用于类增量学习的baseline。

3.2.1. Model-Inversion Image Synthesis

大多数模型反演图像合成方法都是通过直接对先验的鉴别模型θ进行优化来合成图像。然而,一次优化一个Batch的图像在计算上是效率低下的。因此作者选择使用卷积网络参数化函数φ用噪声生成合成图像进行近似优化。这就使每个任务只需要训练一次φ,当前任务结束时就可以直接丢弃。

首先,φ需要生成多样性的图片,因此作者优化合成了图像的类预测的多样性,以匹配均匀分布。将θ表示为模型θ对输入x产生的预测类分布,需要使合成样本的平均类预测向量的熵最大化,如下所是(label diversity loss):
其中为信息熵。
除了多样性之外,为了在DFCIL中合成有用的图像,图像还需要校准的类置信度、特征统计数据的一致性和局部平滑的潜在空间。
对于校准的类置信度 ,作者使用了Content Loss,通过对图像张量的类预测一致性最大化,这样θ就能对所有输入做出足够confident的预测了。Content Loss的具体计算表示如下所示:
通过将和相结合,就确保合成的图像将代表过去所有任务类的分布。
对于特征统计数据的一致性 ,先前的工作发现,模型反演的复杂性会导致θ特征的分布大大偏离合成图像的分布。因此,合成图像的Batch统计应该与θ中的Batch Norm层相匹配。基于此,作者进一步提出了stat alignment loss:
其中代表KL散度。
对于局部平滑的潜在空间 ,先验知识告诉我们,自然图像在像素空间中比初始噪声更局部平滑。因此作者又提出了一个损失函数smoothness prior loss,这个函数就是生成图像和高斯模糊版本的生成图像的L2距离:
最后,φ的损失函数为上面提到的损失函数之和:

3.2.2. Distilling Synthetic Data for Class-Incremental Learning

在类增量学习中,对合成图像的知识蒸馏通常被用于θ正则化,迫使它学习,学习的同时,将的知识遗忘减到最小。对于任务,我们从任务期间训练的θ的冻结副本中合成图像。这些合成图像帮助我们将任务中学习的知识提炼到我们当前的模型θ中。

在Baseline方法中,作者采用了DeepInversion中使用的蒸馏方法。具体表示为,给定当前的任务数据和合成的蒸馏数据,我们最小化:
其中是一种知识蒸馏正则化方法:

3.3. Diagnosis: Feature Embedding Prioritizes Domains Over Semantics

为了探究为什么DFCIL的Baseline方法会失败,作者使用度量(MID)分析了嵌入特征之间的表征距离,这种度量用于捕获两个分布样本的平均图像embedding之间的距离。作者将这种度量实例化为Mean Image Distance (MID) score,高分表示不同的特征,低分表示相似的特征。计算如下:

作者计算任务1真实数据的特征embedding与任务2真实数据之间的MID,然后计算任务1真实数据的特征embedding与任务1合成数据之间的MID,结果如上图所示。对于(a)DeepInversion,任务1真实数据与任务1合成数据之间的MID得分明显高于任务1真实数据与任务2真实数据之间的MID得分。
这表明embedding空间对domain有更高的优先级,而不是语义,但这不是模型想要的结果。对于作者提出的方法(b),任务1真实数据和任务1合成数据之间的MID分数明显低得多,这表明特性embedding的语义优先于domain。

3.4. A New Distillation Strategy for DFCIL

基于上面的分析,作者提出了持续的学习应该在以下几个方面保持平衡:(1)针对新任务的学习特征;(2)最小化超过上一个任务的特征偏移;(3)在embedding空间中分离新的类和以前的类之间的类重叠。

对于上面的三个平衡,(1)和(3)可以通过实现。但是作者认为,通过将其分成两种不同的损失,可以在学习新任务的时候,不区分真实图像和合成图像的特征。根据这个想法,作者提出了一种为DFCIL设计的新的类增量学习方法,该方法独立地解决这些目标。

(蓝色箭头表示之前合成的任务数据的计算路径,绿色箭头表示真实的当前任务数据的计算路径,黄色箭头表示真实数据和合成数据的计算路径。)

模型的overview如上图所示

3.4.1. Learning current task features

作者方法背后的intuition是需要学习当前task的特征的同时,绕过偏向最近task真实数据的特征表示。具体实现上,作者通过只计算在新的 线性分类head上的局部交叉熵分类损失来实现这一点。有了这种模式,作者阻止了模型学习通过domain分离新的和过去的类数据,损失函数如下:

3.4.2. Minimizing feature drift over previous task data

与真实的当前任务图像相比,蒸馏图像属于另一个domain,因此作者寻找了另一个损失函数,直接减轻遗忘的损失函数。要实现这个目标,一个选择是特性蒸馏:

虽然强化了过去任务数据的重要组成部分,但它的强正则性抑制了模型的学习新任务的能力。另一方面并不抑制新任务的学习,可能导致真实数据和合成数据的bias。
因此,作者提出了一种重要性加权特征蒸馏,它只强化了过去任务数据中最重要的组成部分,同时允许不那么重要的特性来适应新任务。表示如下:
W为重要性权重矩阵,W权重大的特征更为重要。

3.4.3. Separating Current and Past Decision Boundaries

最后,模型需要分离当前类和过去类的决策边界,而不允许特征空间来区分真实数据和合成数据。作者通过用交叉熵损失函数来fine-tuning线性分类head来实现。除了线性分类head之外,这个损失函数并不会更新θ,:中的任何参数:

3.4.4. Final Objective

最终模型的损失函数为上述损失函数之和,如下所示:

04

实验

4.1. DFCIL (CIFAR-100 )

从上表结果可以看出,本文的方法不仅优于DFCIL方法,甚至优于生成方法。

4.2. CIL with Replay Data (CIFAR-100 )

在上表中,作者将本文的方法(不存储回放数据)与其他存储回放数据的方法进行了比较。可以看出,本文方法的performance可以优于LwF和Rehersal,但是后者需要存储回放数据,这就意味着更高的内存消耗。

4.3. Ablation Study(CIFAR-100 )

从上表可以看出,文中对Data-Free增量学习专门设计的几个损失函数和蒸馏方法,对于整个模型性能的提高,都有着非常重要的作用。

4.4. DFCIL (ImageNet)

作者还使用ImageNet数据集来验证本文的方法在大规模图像数据集上的表现。可以看出,本文的方法在这个大规模图像数据集上的实验结果也没有比基于replay的方法落后太多。
05

总结

在本文中,作者表明现有的类增量学习方法在使用真实训练数据学习新任务和使用合成蒸馏数据保存过去的知识时,performance较差。因此,作者提出了一种新的方法来实现了无数据类增量学习的SOTA性能,并与基于replay的SOTA方法性能相当。

作者提出无数据类增量学习是希望消除在类增量学习中存储回放数据的需要,使计算机视觉的广泛和实际应用成为可能。不存储数据的增量学习解决方案,将对计算机视觉应用产生直接影响,进一步促进计算机视觉任务的落地应用。

▊ 作者简介

厦门大学人工智能系20级硕士

研究领域:FightingCV公众号运营者,研究方向为多模态内容理解,专注于解决视觉模态和语言模态相结合的任务,促进Vision-Language模型的实地应用。

知乎/公众号:FightingCV

END,入群👇备注:CV

(0)

相关推荐