菜单

半监督分割 in CVPR2022

9月 10, 2022 - 机器学习&大数据, 算法

本文主要记录了BB酱最近调研的一些CVPR2022中的半监督语义分割方法。对其中的内容进行了简要的总结(不一定对),可供参考使用。

Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels

现有的半监督模型往往是先通过已经标注的数据训练网络,之后利用这个网络为没有标注的数据打标签,之后再进一步训练。在打伪标签的过程中,很多方法都用了通过置信度得分来过滤预测结果,防止错误预测影响训练。这就导致有一些置信度低的类别参与不到训练过程中,这也就导致很难为此类像素分配准确的伪标签,导致训练不足和分类不平衡。由于直接使用这些不准确的预测会造成性能下降,因此这篇文章尝试将其转换为负样本来构建方法。这篇文章中利用的一个先验是,对于不可靠的预测,往往只是在几类之间区分不出来,而不是在所有类之间混淆。就像这张图里面表示的一样,对于这张图上的不可靠区域来说,虽然motorbike和person的概率很接近,模型不好区分,但是模型可以很明确的告诉我们这个像素不是car或者train。基于这一点,这篇文章把不可靠的像素视为那些不太可能的类别的负样本。具体来说,在预测的过程中,将像素分为可靠区域和不可靠区域。所有可靠的预测都用于推导伪标签,而预测不可靠的像素被推入负样本的memory bank中,计算对比损失。

这张图展示的是整个网络的pipeline,可以看到其中监督信息为3个损失函数,L_s是有标签的数据产生的cross-entropy loss,L_u是无标签数据通过teacher网络产生的伪标签和student输出之间的cross-entropy loss,还有一个L_c就是之前提到的由不可靠标签产生的对比损失。咱们主要就是看L_c的计算方法。

L_c的计算过程中,首先牵扯到了可靠像素的选取,选取策略为:

\hat{y}_{i j}^u= \begin{cases}\arg \max _c p_{i j}(c),&\text { if } \mathcal{H}\left(\mathbf{p}_{i j}\right)<\gamma_t,\\\text { ignore },&\text { otherwise },\end{cases}

也就是对于网络的输出,考虑其中的每个像素,计算它的熵,小于阈值\gamma_t的视为可靠像素,当做标签。这个\gamma_t是随着训练过程变化的,是一个与\alpha_t相关的分位数,也就是区分前\alpha_t的位置对应的值,\alpha_t的迭代策略是

\alpha_t=\alpha_0 \cdot\left(1-\frac{t}{\text { total epoch }}\right)

\alpha_0的取值为20%。本质上就是一开始取熵小于前20%的像素点看做可靠像素,随着训练过程这个比例逐渐变小,也就是他们所谓的自适应阈值调整策略。

之后就是对于不可靠像素的处理,首先确定anchor,对于有label的数据,将其gt中的像素超过阈值的视为anchor,对于没有label的数据,得到预测结果之后,置信度超过阈值的视为anchor。最终的anchor集合就是两个集合的并集。所有anchor的平均值视为正样本。之后是负样本的构建,对于有标签样本,因们明确知道其所属的类别,因此除真值标签外的所有类别都可以作为该像素的负样本类别;而对于无标签样本,由于伪标签可能存在错误,因此我们并不完全确信标签的正确性,因而我们需要将预测概率最高的几个类别过滤掉,将该像素认作为剩下几个类别的负样本。当然,由于数据集中存在长尾问题,如果只使用一个 batch 的样本作为负样本可能会非常受限,因此采用对比学习中很常用的 MemoryBank 来维护一个逐类别的负样本库,存入的是由 teacher 生成的特征,以先进先出的队列结构维护。之后就是从中采样出负样本,计算对比损失函数。

总的来说,这篇工作就是利用了那些在预测时本身会被置信度卡掉的像素点,把它们视作其他类别的负样本进行对比学习的方法。

ST++: Make Self-training Work Better for Semi-supervised Semantic Segmentation

这篇文章主要是说了两个训练的pipeline,一个是ST,一个是ST++。核心内容就是这两段伪代码,首先是ST的过程,其实就是在重新训练学生模型的时候,对无标签的图像进行了强数据增广。从伪代码看更直观,就是先通过有标签的数据训练Teacher网络,之后根据teacher网络获得无标签数据的伪标签,之后合并有标签的数据和无标签的数据。在训练学生网络的时候,对于有标签的数据正常训练,无标签的数据在通常的数据增强之后,再经过了一次强数据增广。

之后就是ST++,是在ST基础上提出的一种渐进式选取可靠标签的方法。与一般方法里通过阈值选取高置信度的像素的方法不同,它是根据伪标签的稳定性选取图像的。看伪代码可能不太好理解,他具体流程是这样,给一个有标签的数据集和一个无标签的数据集,先从无标签数据中选取出一定数量的最可靠的无标签图像子集,之后合并有标签数据和这个子集,然后用刚说的ST方法来训练student模型。之后用学到的student模型对剩下的部分打上标签,最后再训练出最终的student模型。

其实就是把无标签的图像进行了一次分类,先训可靠的,再训不可靠的。所以关键问题在于怎么选取可靠无标签图像。这篇文章的想法是,对于简单的图像,网络会比较容易学习到正确的结果,对于困难的图像就不容易学到好的结果。由此他们进一步说容易被学到的简单图像在训练的后期伪标签的变化应该很小,而对于比较复杂的图像,模型在不同epoch之间的预测结果一般都有很大的差异。所以他们选择衡量不同epoch生成的伪标签之间的稳定性来衡量这个伪标签的可靠性。具体在算的时候,就是算最后一个epoch的输出和前面K个checkpoint的输出之间的MIoU,miou越大,说明伪标签预测重合度越高,训练就越稳定,质量也就越高。

这篇文章的主要创新点就是通过稳定性而非来筛选可靠标签,对无标注的数据集进行了进一步的划分,多次训练得到最终的结果,每次训练都是使用的这种ST策略来训的,也就是在训练过程中对无标注的数据有一个额外的强数据增广。

Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation

这篇文章从一致性角度出发考虑半监督学习。整篇文章的核心内容大概可以归结为这么三点:

总的来说,他们的训练流程是这样的。首先还是有监督的部分,用强数据增强后的有label的图片通过CE Loss来训练模型。之后是无监督部分,首先喂给两个teacher网络弱数据增强的图片,根据他们的输出生成伪标签。之后用强数据增强的图片喂给student模型,同时通过T-VAT来寻找一个噪音注入到特征图中,最终让decoder来做预测。之后根据confidence CE惩罚预测,完成student模型的参数更新。最后再用EMA的方式,将参数更新到其中一个teacher上,完成teacher模型的迭代。

CPS && n-CPS

从之前的几个工作也能看出来,半监督的工作大致可以总结为两个大类:self-training和一致性学习两种。self-training的过程就是先在有标签的数据上训一个模型,然后针对无标签的数据生成伪标签,最后利用有标注的标签和无标注的伪标签完成训练。

一致性训练的核心思想是,鼓励网络对经过不同变换的同一样本有相似的输出。一般来说,就是对样本进行扰动,改变其特征,但是仍然鼓励网络输出相同的结果。像mean-teacher这种模型就是代表方法。

这张图展示的就是CPS方法,针对两个同样结构但是初始化不同的网络,对于同一个输入图像X,经过网络f_1f_2,得到输出P_1P_2,之后得到对应的伪标签Y_1Y_2。CPS的方法是用Y_1监督P_2,用Y_2监督P_1。这就是CPS所谓的交叉伪监督。这种约束实际上就是希望这两个网络对同一样本的输出是一致的。

之后,在CVPR 22上,有一篇follow他们的工作,叫n-CPS。实际上就是把两个网络之间的CPS推广到了更多网络中,从这张图里可以看出来,基本思路还是一样的,每个网络收到其他n-1个网络的伪标签的约束,唯独不受自己生成的伪标签的约束。

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注