使用度量学习进行特征嵌入:交叉熵和监督对比损失的效果对比

分类是机器学习中最简单,最常见的任务之一。 例如,在计算机视觉中,您希望能够微调普通卷积神经网络(CNN)的最后一层,以将样本正确分类为某些类别(类)。 但是,有几种根本不同的方法可以实现这一目标。

Metric learning(度量学习)是其中之一,今天我想与大家分享如何正确使用它。 为了使事情变得实用,我们将研究监督式对比学习(SupCon),它是对比学习的一部分,而后者又是度量学习的一部分,但稍后会介绍更多。

通常如何进行分类

在进行度量学习之前,首先了解通常如何解决分类任务。 卷积神经网络是当今实用计算机视觉最重要的思想之一,它由两部分组成:编码器和头部(在这种情况下为分类器)。

首先-拍摄图像并计算一组特征,这些特征可以捕获该图像的重要信息。这是通过卷积和池化操作完成的(这就是为什么它被称为卷积神经网络)。之后,将这些特征解压缩到单个向量中,并使用常规的全连接神经网络执行分类。在实践中,您采用在大型数据集(例如ImageNet)上预先训练的某种模型(例如ResNet,DenseNet,EfficientNet等),并根据您的任务(仅最后一层或整个模型)进行微调)。

然而,这里有几点需要注意。首先,通常只关心网络FC部分的输出。也就是说,你取它的输出,并把它们提供给损失函数,以保持模型学习。换句话说,您并不真正关心网络中间发生了什么(例如,来自编码器的特性)。其次,通常你用一些基本的损失函数来训练这些东西,比如交叉熵。

为了更好地理解这个2步过程(encoder + FC),你可以这样想:encoder将图像映射到一些高维空间(例如,在ResNet18的情况下,我们讨论的是512维,而对于Resnet101 - 2048)。在此之后,FC的目标是在这些代表样本的点之间画一条线,以便将它们映射到类。这两种东西是同时训练的。因此,你试图优化特征,同时“在高维空间中画线”。

这种方法有什么问题吗?嗯,没什么,真的。它实际上运行得很好。但这并不意味着没有别的办法。

度量学习 Metric learning

现代机器学习中最有趣的想法之一(至少对我来说是这样)叫做度量学习(或深度度量学习)。简单地说:如果我们不去关注FC层的输出,而是更仔细地研究编码器生成的特性会怎样?如果我们设法用一些损耗函数来优化这些特性,而不是使用网络输出进行优化,会怎么样呢?这就是度量学习的意义所在:用编码器生成好的特性(嵌入)。

“好”是什么意思呢?好吧,如果你想一下,在计算机视觉的例子中,你想对相似的图像有相似的特征,而对截然不同的图像有截然不同的特征。

监督对比学习 Supervised Contrastive Learning

好的,假设在度量学习中,我们关心的只是“好”特征。 但是监督式对比学习有什么意义呢? 老实说,这种特定方法没有什么特别之处。 这是最近的一篇论文,提出了一些不错的技巧,以及一个有趣的2步方法

  1. 训练一个好的编码器,该编码器能够为图像生成良好的特征。
  2. 冻结编码器,添加FC层,然后进行训练。

您可能想知道常规分类器训练有什么区别。 不同之处在于,在常规培训中,您需要同时训练编码器和FC。 另一方面,在这里,您首先训练一个不错的编码器,然后将其冻结(不再训练),然后仅训练FC。 这种逻辑背后的想法是,如果我们设法首先为图像生成真正好的特征,则应该很容易优化FC(正如我们前面提到的,其目标是优化分离样本的行)。

训练过程的细节

让我们深入了解SupCon实施的细节。

在查看训练循环之前,您应该了解的一件事是要训练哪种模型。 这非常简单:编码器(例如ResNet,DenseNet,EffNet等),但没有常规的FC层进行分类。

这里不是分类头,而是投影头。投影头是一个由2个FC层组成的序列,它将编码器的特征映射到一个较低的维度空间(通常是128维度,你甚至可以在上面的图片中看到这个值)。使用投影头的原因是,与来自编码器的几千个特征相比,使用128个精心选择的特征更容易让模型学习。

  1. 构造一批N个图像。与其他度量学习方法不同,您不需要太关心这些样本的选择。能拿多少就拿多少,剩下的由损失来处理。
  2. 将这些图像以成对的方式转发给网络,其中一对图像被构造为[augmentation(imagei), augmentation(imagei)],得到embeddings。并进行标准化。
  3. 以某个图像做为锚点。在批处理中找到同一个类的所有图像。把它们作为正样本。找到所有不同类的图像。把他们当作负样本。
  4. 将SupCon损失应用于第二步归一化嵌入,使正样本彼此靠近,同时使负样本更远离。
  5. 第一阶段训练完成后,删除投影头,并在编码器顶部添加FC(就像在常规分类训练中一样)。 开始第二阶段训练的冻结编码器,并微调FC的训练。

这里要记住几件事。首先,在训练完成后,去掉投影头,使用投影头之前的特征是会获得更好的效果。作者解释说,由于我们降低了嵌入的大小,导致信息丢失。其次,增强的选择很重要。作者提出了裁剪和色彩抖动的组合。Supcon一次处理批处理中的所有图像(因此,无需构造对或三元组)。而且批处理中的图像越多,模型学习起来就越容易(因为SupCon具有隐式的正负硬挖掘质量)。第四,你可以在第4步停止。这意味着可以通过嵌入来进行分类,而不需要任何FC层。为了做到这一点,计算所有训练样本的嵌入。然后,在验证时,对每个样本计算一个嵌入,将其与每个训练嵌入进行比较(例如余弦距离),采用其类别。

PyTorch实现

实际上,在PyTorch中有一个SupCon的半官方实现。不幸的是,它包含了非常恼人的隐藏bug。最严重的一个问题是:repo的创造者使用了他自己的resnet实现,由于其中的一些bug,批量大小比普通的torchvision模型低两倍。最重要的是,repo没有验证或可视化,所以你不知道什么时候停止训练。在我的repo中,我修复了所有这些问题,并为稳定的训练增加了更多的技巧。

更准确地说,在我的实现包含了以下功能:

  • 使用albumentations进行扩增
  • Yaml配置
  • t-SNE可视化
  • 使用AMI、NMI、mAP、precisionat1等PyTorch度量学习进行2步验证(用于投影头前后的特性)。
  • 指数移动平均更稳定的训练,随机移动平均更好的泛化和整体性能。
  • 自动混合精度训练,以便能够训练更大的批大小(大约是2的倍数)。
  • 标签平滑损失,LRFinder为第二阶段的训练(FC)。
  • 支持timm模型和jettify优化器
  • 固定种子,使训练具有确定性。
  • 保存基于验证的权重,日志-定期。txt文件,以及TensorBoard日志。

例子是使用Cifar10和Cifar100数据集来进行测试的,但是添加自己的数据集非常简单。为了运行整个数据处理管道,请执行以下操作:

python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage1.yml python swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage1.yml python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.yml python swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage2.yml

之后,你可以检查可视化t-SNE结果。例如,对于Cifar10和Cifar100,大概是下面这样:

Cifar10 t-SNE, SupCon 损失

Cifar10 t-SNE, Cross Entropy 损失

Cifar100 t-SNE, SupCon 损失

Cifar10 t-SNE, Cross Entropy 损失

总结

度量学习是一个非常强大的东西。但是要达到常规CE / LabelSmoothing可以提供的准确性水平非常困难。此外,在训练期间它在计算上也可能是昂贵的并且不稳定的。我在各种任务(分类,超出分布的预测,对新类的泛化等)上测试了SupCon和其他度量指标损失,使用诸如SupCon之类的优势尚不确定。

那有什么意义?我个人认为有两件事。第一,SupCon(和其他度量学习方法)仍然可以提供比CE更结构化的集群,因为它直接优化了该属性。第二,多一个你可以尝试的技能/工具仍然是非常有益的。因此,通过更好的扩展集或不同的数据集(可能使用更细粒度的类),SupCon 可能会产生更好的结果,而不仅仅是与常规分类训练相当。

本文代码:github/ivanpanshin/SupCon-Framework

(0)

相关推荐

  • EEG分类实验block设计的危险与陷阱

           最近的一篇论文声称对观看ImageNet刺激的受试者所诱发的大脑加工采用脑电(EEG)测量进行分类,并利用从这种加工中得到的表征来构造一种新的对象分类器.这篇论文,连同一系列后续论文,声 ...

  • CVPR 2021 | 神经网络如何进行深度估计?

    编者按:与深度神经网络相比,人类的视觉拥有更强的泛化能力,所以能够胜任各项视觉任务.结合人类视觉系统"通过观察结构信息获得感知能力"的特点,微软亚洲研究院的研究员们提出了一种新的深 ...

  • 无监督训练用堆叠自编码器是否落伍?ML博士对比了8个自编码器

    选自krokotsch.eu 作者:Tilman Krokotsch 机器之心编译 编辑:魔王 柏林工业大学深度学习方向博士生 Tilman Krokotsch 在多项任务中对比了 8 种自编码器的性 ...

  • 【图像分类】 标签噪声对分类性能会有什么样的影响?

    不同类型的噪声会对模型的分类性能产生什么样的影响呢,让我们一同进行实验,来探索那暗中作祟的标签噪声! 作者&编辑 | 郭冰洋 1 简介 在数据集制作过程中,由于主观.客观的原因,会导致标签噪声 ...

  • NICE-GAN:新的图像转换网络框架

    DrugAI 3天前 以下文章来源于深度奇点 ,作者AITA|于志勇 深度奇点DeepSingularity致力于将AI与生物.医疗.城市.海洋等各个领域结合,找到深度奇点. 摘要 在传统无监督的图像 ...

  • (3条消息) Learning Robust Low

    Learning Robust Low-Rank Representation (2012) 注释: 本篇主要学习LRR和online LRR理论.本文由RPCA的提出讲起:再叙述论文提出的onlin ...

  • VAE变分自编码器实现

    变分自编码器(VAE)组合了神经网络和贝叶斯推理这两种最好的方法,是最酷的神经网络,已经成为无监督学习的流行方法之一. 变分自编码器是一个扭曲的自编码器.同自编码器的传统编码器和解码器网络一起,具有附 ...

  • 肾小管末端磷酸钙沉积栓的自动检测。

    Healthc Technol Lett. 2019 Dec; 6(6): 271–274. Published online 2019 Dec 6. doi: 10.1049/htl.2019.00 ...

  • 【学术论文】基于先验知识的草莓机器手目标定位算法

    摘要: 研究开发了一种基于视觉识别的草莓机器手,重点是采用深度学习辅助草莓机器手进行目标定位.为了加快目标定位速度,先把草莓目标的特定颜色作为先验知识进行候选目标的分割筛选,再将得到的多个候选区输入预 ...

  • 使用计算机视觉来做异常检测

    作者:Mia Morton 编译:ronghuaiyang 导读 创建异常检测模型,实现生产线上异常检测过程的自动化.在选择数据集来训练和测试模型之后,我们能够成功地检测出86%到90%的异常. 介绍 ...