训练大规模对比学习的一些小笔记

在笔者对于对比学习的认识中,主要有2个维度的事情需要考虑:

如何选取合适的负样本如何选取合适的损失函数以下结合一些训练经验,简要笔记下。

如何构造合适的负样本之前在[1]中简单介绍过一些构造负样本的方法,总体来说,基于用户行为数据我们可以通过batch negative和无点数据进行负样本构建。

batch negative在搜索过程中,用户行为存在很大的随机性,比如有展现但没有点击的数据并不一定就是负样本,为了获取更可靠的用户数据,我们可以选择在用户点击过的Doc之间组成负样本。没错,我们认为用户点击过的行为是更为可靠的,虽然即便是点击行为也可能只是因为用户的好奇行为或者误操作等等,但是对比于无点行为总归是更为可靠的。假设用户的query i ii和点击过的Doc组成二元组,其中的C \mathcal{C}C表示所有有点行为的集合,那么我们认为其负样本就是。当数据足够庞大是,有点数据的规模也会非常庞大,我们无法一次将所有负样本都列举出来(同时,也没有必要),我们通常会在一个batch内对所有用户点击二元组进行组合。也即是将的规模限制在一个batch内,如Fig 1.1所示,其中的对角线都是二元组正样本,而其他元素都是负样本。通过一个矩阵乘法,我们就可以实现这个操作。如式子(1.1)所示。

Fig 1.1 Batch Negative的方式从一个batch中构造负样本。

无点击样本无点数据也不是一无是处,在某些搜索产品中,如果排序到前面的结果本身就不够好,那么用户的点击数据和无点击数据就具有足够的区分度,无点数据拿来视为负样本就是合理的,这个和产品具体的设计,或者呈现UI形式等等有关,需要在实践中才能实验出来。

在实践中,通常还会去进行batch negative和无点击数据的混合以达到获取足够多的负样本的目的。

使用何种损失函数常用在对比学习中的损失函数主要有两种,hinge loss[2]和交叉熵损失。其中的hinge loss形式如(2.1)所示:

hinge loss和SVM一样[3],存在一个margin,一旦正样本和负样本打分的差距超过这个margin,那么损失就变为0,通过这种手段可以让hinge loss学习到正样本和负样本之间的表征区别,而且又可以更好地控制训练过程。而交叉熵损失是我们的老朋友了,如式子(2.2)所示

其中的N为样本数量,M为分类类别数量,而则是预测的logit经过softmax之后的概率分布。注意到,正如Fig 1.1所示,对于每个而言,其每一行都有个负样本;对于每个 而言,其每一列都有个负样本,那么就可以组织双向的损失函数计算。这种方式对于双塔模型结构来说特别地“划算”,因为对于双塔模型而言只需要计算一次矩阵计算就可以得到的打分矩阵,然后通过双向计算损失,可以实现更高效地对模型的训练。

在hinge loss计算过程中,还可以通过在这每一行(或者每一列)的N − 1个负样本中选择一个最难的负样本,也就是打分最高的负样本。这一点很容易理解,负样本的打分如果打得很高,那么就可以认为模型很大程度地将其误认为正样本了,如果能将其分开,那么模型的表征能力应该是更上一层楼的,因此将最难负样本作为式子(2.1)的进行训练。

训练过程在对比学习训练过程中,我们暂时只考虑双塔模型(因为交互式模型的负样本选取策略不同),虽然理论上hinge loss这种基于pairwise样本选取策略的损失,可以很好地对比正负样本的表征区别,但是如果模型并没有进行很好地训练就拿去用hinge loss进行训练,有可能因为负样本太难导致训练出现“损失坍缩”(loss collapse)的现象,此时模型对正样本和负样本没有区分能力,因此对两者的打分都极为相似,有,此时loss坍缩到margin并且恒等于margin不再变化,如Fig 3.1所示。我们可以认为模型陷入了平凡解。

Fig 3.1 采用hinge loss导致损失坍缩的现象。图省事就直接ipad上画了,有点丑见谅:_)这个现象也不一定就会出现,如果采用的模型已经进行过合适的初始化,就不一定会出现这个问题。另外,采用交叉熵损失进行一开始的训练是一种比较稳定的方法。在CLIP模型中[4],作者采用了batch size=32,768的配置,在进行过allgather机制,对所有特征进行汇聚后[5],甚至可以实现32768 × 32768 大小的打分矩阵,这意味着有着海量的负样本可供学习,这也同时意味着对模型学习的巨大挑战。因此CLIP文章的作者没有采用hinge loss训练,而是采用了双向的交叉熵损失进行训练。

然而在巨大的batch size中训练是有着非常大的诱惑的,在[1]中我们就曾经讨论过对于对比学习中,负样本增多意味着表征词典的词表的增大,有着巨大的效果增益。那么要如何去训练这种超大规模的batch size下的对比学习任务呢?笔者个人认为需要进行阶段式地训练,一步步提高batch size大小。笔者曾经试验过,如果一开始就采用很大的batch size进行训练,在hinge loss的情况下,将会非常不稳定,很容易出现损失坍塌的现象。而如果循序渐进则不会出现这个问题,那么是否可以通过这种方法将batch size增加到很大呢(不考虑硬件的约束情况),这个笔者也还在实验,希望后续能有个比较正向的结论。

同时,在超大规模的对比学习过程中,如何结合交叉熵损失和hinge loss损失也是一个值得思考的问题。交叉熵损失稳定,但是学习速度较慢(笔者实验发现,不一定准确),hinge loss不稳定,但是学习速度更快,如何进行平衡是一个值得尝试的方向。对比学习在大规模数据上的训练的确还有很多值得探索的呢,公开的论文提供的细节也不多。

Reference

[1]. https://fesian.blog.csdn.net/article/details/119515146

[2]. https://blog.csdn.net/LoseInVain/article/details/103995962

[3]. https://blog.csdn.net/LoseInVain/article/details/78636176

[4]. https://fesian.blog.csdn.net/article/details/119516894

[5]. https://medium.com/@cresclux/example-on-torch-distributed-gather-7b5921092

声明:本内容为作者独立观点,不代表电子星球立场。未经允许不得转载。授权事宜与稿件投诉,请联系:editor@netbroad.com
觉得内容不错的朋友,别忘了一键三连哦!
赞 2
收藏 3
关注 49
成为作者 赚取收益
全部留言
0/200
成为第一个和作者交流的人吧