使用度量学习进行特征嵌入:交叉熵和监督对比损失的效果对比
分类是机器学习中最简单,最常见的任务之一。 例如,在计算机视觉中,您希望能够微调普通卷积神经网络(CNN)的最后一层,以将样本正确分类为某些类别(类)。 但是,有几种根本不同的方法可以实现这一目标。
Metric learning(度量学习)是其中之一,今天我想与大家分享如何正确使用它。 为了使事情变得实用,我们将研究监督式对比学习(SupCon),它是对比学习的一部分,而后者又是度量学习的一部分,但稍后会介绍更多。
通常如何进行分类
在进行度量学习之前,首先了解通常如何解决分类任务。 卷积神经网络是当今实用计算机视觉最重要的思想之一,它由两部分组成:编码器和头部(在这种情况下为分类器)。
首先-拍摄图像并计算一组特征,这些特征可以捕获该图像的重要信息。这是通过卷积和池化操作完成的(这就是为什么它被称为卷积神经网络)。之后,将这些特征解压缩到单个向量中,并使用常规的全连接神经网络执行分类。在实践中,您采用在大型数据集(例如ImageNet)上预先训练的某种模型(例如ResNet,DenseNet,EfficientNet等),并根据您的任务(仅最后一层或整个模型)进行微调)。
然而,这里有几点需要注意。首先,通常只关心网络FC部分的输出。也就是说,你取它的输出,并把它们提供给损失函数,以保持模型学习。换句话说,您并不真正关心网络中间发生了什么(例如,来自编码器的特性)。其次,通常你用一些基本的损失函数来训练这些东西,比如交叉熵。
为了更好地理解这个2步过程(encoder + FC),你可以这样想:encoder将图像映射到一些高维空间(例如,在ResNet18的情况下,我们讨论的是512维,而对于Resnet101 - 2048)。在此之后,FC的目标是在这些代表样本的点之间画一条线,以便将它们映射到类。这两种东西是同时训练的。因此,你试图优化特征,同时“在高维空间中画线”。
这种方法有什么问题吗?嗯,没什么,真的。它实际上运行得很好。但这并不意味着没有别的办法。
度量学习 Metric learning
现代机器学习中最有趣的想法之一(至少对我来说是这样)叫做度量学习(或深度度量学习)。简单地说:如果我们不去关注FC层的输出,而是更仔细地研究编码器生成的特性会怎样?如果我们设法用一些损耗函数来优化这些特性,而不是使用网络输出进行优化,会怎么样呢?这就是度量学习的意义所在:用编码器生成好的特性(嵌入)。
“好”是什么意思呢?好吧,如果你想一下,在计算机视觉的例子中,你想对相似的图像有相似的特征,而对截然不同的图像有截然不同的特征。
监督对比学习 Supervised Contrastive Learning
好的,假设在度量学习中,我们关心的只是“好”特征。 但是监督式对比学习有什么意义呢? 老实说,这种特定方法没有什么特别之处。 这是最近的一篇论文,提出了一些不错的技巧,以及一个有趣的2步方法
- 训练一个好的编码器,该编码器能够为图像生成良好的特征。
- 冻结编码器,添加FC层,然后进行训练。
您可能想知道常规分类器训练有什么区别。 不同之处在于,在常规培训中,您需要同时训练编码器和FC。 另一方面,在这里,您首先训练一个不错的编码器,然后将其冻结(不再训练),然后仅训练FC。 这种逻辑背后的想法是,如果我们设法首先为图像生成真正好的特征,则应该很容易优化FC(正如我们前面提到的,其目标是优化分离样本的行)。
训练过程的细节
让我们深入了解SupCon实施的细节。
在查看训练循环之前,您应该了解的一件事是要训练哪种模型。 这非常简单:编码器(例如ResNet,DenseNet,EffNet等),但没有常规的FC层进行分类。
这里不是分类头,而是投影头。投影头是一个由2个FC层组成的序列,它将编码器的特征映射到一个较低的维度空间(通常是128维度,你甚至可以在上面的图片中看到这个值)。使用投影头的原因是,与来自编码器的几千个特征相比,使用128个精心选择的特征更容易让模型学习。
- 构造一批N个图像。与其他度量学习方法不同,您不需要太关心这些样本的选择。能拿多少就拿多少,剩下的由损失来处理。
- 将这些图像以成对的方式转发给网络,其中一对图像被构造为[augmentation(imagei), augmentation(imagei)],得到embeddings。并进行标准化。
- 以某个图像做为锚点。在批处理中找到同一个类的所有图像。把它们作为正样本。找到所有不同类的图像。把他们当作负样本。
- 将SupCon损失应用于第二步归一化嵌入,使正样本彼此靠近,同时使负样本更远离。
- 第一阶段训练完成后,删除投影头,并在编码器顶部添加FC(就像在常规分类训练中一样)。 开始第二阶段训练的冻结编码器,并微调FC的训练。
这里要记住几件事。首先,在训练完成后,去掉投影头,使用投影头之前的特征是会获得更好的效果。作者解释说,由于我们降低了嵌入的大小,导致信息丢失。其次,增强的选择很重要。作者提出了裁剪和色彩抖动的组合。Supcon一次处理批处理中的所有图像(因此,无需构造对或三元组)。而且批处理中的图像越多,模型学习起来就越容易(因为SupCon具有隐式的正负硬挖掘质量)。第四,你可以在第4步停止。这意味着可以通过嵌入来进行分类,而不需要任何FC层。为了做到这一点,计算所有训练样本的嵌入。然后,在验证时,对每个样本计算一个嵌入,将其与每个训练嵌入进行比较(例如余弦距离),采用其类别。
PyTorch实现
实际上,在PyTorch中有一个SupCon的半官方实现。不幸的是,它包含了非常恼人的隐藏bug。最严重的一个问题是:repo的创造者使用了他自己的resnet实现,由于其中的一些bug,批量大小比普通的torchvision模型低两倍。最重要的是,repo没有验证或可视化,所以你不知道什么时候停止训练。在我的repo中,我修复了所有这些问题,并为稳定的训练增加了更多的技巧。
更准确地说,在我的实现包含了以下功能:
- 使用albumentations进行扩增
- Yaml配置
- t-SNE可视化
- 使用AMI、NMI、mAP、precisionat1等PyTorch度量学习进行2步验证(用于投影头前后的特性)。
- 指数移动平均更稳定的训练,随机移动平均更好的泛化和整体性能。
- 自动混合精度训练,以便能够训练更大的批大小(大约是2的倍数)。
- 标签平滑损失,LRFinder为第二阶段的训练(FC)。
- 支持timm模型和jettify优化器
- 固定种子,使训练具有确定性。
- 保存基于验证的权重,日志-定期。txt文件,以及TensorBoard日志。
例子是使用Cifar10和Cifar100数据集来进行测试的,但是添加自己的数据集非常简单。为了运行整个数据处理管道,请执行以下操作:
python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage1.yml python swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage1.yml python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.yml python swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage2.yml
之后,你可以检查可视化t-SNE结果。例如,对于Cifar10和Cifar100,大概是下面这样:
Cifar10 t-SNE, SupCon 损失
Cifar10 t-SNE, Cross Entropy 损失
Cifar100 t-SNE, SupCon 损失
Cifar10 t-SNE, Cross Entropy 损失
总结
度量学习是一个非常强大的东西。但是要达到常规CE / LabelSmoothing可以提供的准确性水平非常困难。此外,在训练期间它在计算上也可能是昂贵的并且不稳定的。我在各种任务(分类,超出分布的预测,对新类的泛化等)上测试了SupCon和其他度量指标损失,使用诸如SupCon之类的优势尚不确定。
那有什么意义?我个人认为有两件事。第一,SupCon(和其他度量学习方法)仍然可以提供比CE更结构化的集群,因为它直接优化了该属性。第二,多一个你可以尝试的技能/工具仍然是非常有益的。因此,通过更好的扩展集或不同的数据集(可能使用更细粒度的类),SupCon 可能会产生更好的结果,而不仅仅是与常规分类训练相当。
本文代码:github/ivanpanshin/SupCon-Framework