将“softmax+交叉熵”推广到多标签分类问题

栏目: IT技术 · 发布时间: 4年前

内容简介:©PaperWeekly 原创 · 作者|苏剑林单位|追一科技

将“softmax+交叉熵”推广到多标签分类问题

©PaperWeekly 原创 · 作者|苏剑林

单位|追一科技

研究方向|NLP、神经网络

一般来说,在处理常规的多分类问题时,我们会在模型的最后用一个全连接层输出每个类的分数,然后用 softmax 激活并用交叉熵作为损失函数。在这篇文章里,我们尝试将 “softmax+交叉熵”方案推广到多标签分类场景,希望能得到用于多标签分类任务的、不需要特别调整类权重和阈值的 loss。

将“softmax+交叉熵”推广到多标签分类问题

▲ 类别不平衡

将“softmax+交叉熵”推广到多标签分类问题

单标签到多标签

一般来说,多分类问题指的就是单标签分类问题,即从 n 个候选类别中选 1 个目标类别。假设各个类的得分分别为,目标类为,那么所用的 loss 为:

将“softmax+交叉熵”推广到多标签分类问题

这个 loss 的优化方向是让目标类的得分变为中的最大值。关于 softmax 的相关内容,还可以参考《寻求一个光滑的最大值函数》 [1] 、《函数光滑化杂谈:不可导函数的可导逼近》 [2] 等文章。

现在我们转到多标签分类问题,即从 n 个候选类别中选 k 个目标类别。这种情况下我们一种朴素的做法是用 sigmoid 激活,然后变成 n 个二分类问题,用二分类的交叉熵之和作为 loss。

显然,当时,这种做法会面临着严重的类别不均衡问题,这时候需要一些平衡策略,比如手动调整正负样本的权重、focal loss [3] 等。训练完成之后,还需要根据验证集来进一步确定最优的阈值。

这时候,一个很自然的困惑就是: 为什么“n选k”要比“n选1”多做那么多工作?

笔者认为这是很不科学的事情,毕竟直觉上 n 选 k 应该只是 n 选 1 自然延伸,所以不应该要比 n 要多做那么多事情,就算 n 选 k 要复杂一些,难度也应该是慢慢过渡的,但如果变成多个二分类的话,n 选 1 反而是最难的,因为这时候类别最不均衡。

而从形式上来看,单标签分类比多标签分类要容易,就是因为单标签有 “softmax + 交叉熵”可以用,它不会存在类别不平衡的问题,而多标签分类中的 “sigmoid + 交叉熵”就存在不平衡的问题。

所以,理想的解决办法应该就是将 “softmax + 交叉熵”推广到多标签分类上去。

将“softmax+交叉熵”推广到多标签分类问题

众里寻她千百度

为了考虑这个推广,笔者进行了多次尝试,也否定了很多结果,最后确定了一个相对来说比较优雅的方案:构建组合形式的 softmax 来作为单标签 softmax 的推广。

在这部分内容中,我们会先假设 k 是一个固定的常数,然后再讨论一般情况下 k 的自动确定方案,最后确实能得到一种有效的推广形式。

2.1 组合softmax

首先,我们考虑 k 是一个固定常数的情景,这意味着预测的时候,我们直接输出得分最高的 k 个类别即可。那训练的时候呢?作为 softmax 的自然推广,我们可以考虑用下式作为 loss:

将“softmax+交叉熵”推广到多标签分类问题

其中是 k 个目标标签,是配分函数。

很显然,上式是以任何 k 个类别总得分为基本单位所构造的 softmax,所以它算是单标签 softmax 的合理推广。又或者理解为还是一个单标签分类问题,只不过这是选 1 问题。

在这个方案之中,比较困难的地方是的计算,它是项总得分的指数和。不过,我们可以利用牛顿恒等式 [4] 来帮助我们递归计算。设,那么:

将“softmax+交叉熵”推广到多标签分类问题

所以为了计算,我们只需要递归计算 k 步,这可以在合理的时间内计算出来。预测阶段,则直接输出分数最高的 k 个类就行。

2.2 自动确定阈值

上述讨论的是输出数目固定的多标签分类问题,但一般的多标签分类的目标标签数是不确定的。为此,我们确定一个最大目标标签数,并添加一个 0 标签作为填充标签,此时 loss 变为:

将“softmax+交叉熵”推广到多标签分类问题

而:

将“softmax+交叉熵”推广到多标签分类问题

看上去很复杂,其实很简单,还是以 K 个类别总得分为基本单位,但是允许且仅允许 0 类重复出现。预测的时候,仍然是输出分数最大的 K 个类,但允许重复输出 0 类,等价的效果是 为阈值,只输出得分大于 的类 。最后的式子显示也可以通过递归来计算,所以实现上是没有困难的。

将“softmax+交叉熵”推广到多标签分类问题

暮然回首阑珊处

看上去“众里寻她千百度”终究是有了结果:理论有了,实现也不困难,接下来似乎就应该做实验看效果了吧?效果好的话,甚至可以考虑发 paper 了吧?看似一片光明前景呢!然而~

幸运或者不幸,在验证了它的有效性的同时,笔者请教了一些前辈大神,在他们的提示下翻看了之前没细看的 Circle Loss [5] ,看到了它里边统一的 loss 形式(原论文的公式 (1)),然后意识到了这个统一形式蕴含了一个更简明的推广方案。

所以,不幸的地方在于,已经有这么一个现成的更简明的方案了,所以不管如何“众里寻她千百度”,都已经没有太大意义了;而幸运的地方在于,还好找到了这个更好的方案,要不然屁颠屁颠地把前述方案写成文章发出来,还不如现成的方案简单有效,那时候丢人就丢大发了。

3.1 统一的loss形式

让我们换一种形式看单标签分类的交叉熵 (1):

将“softmax+交叉熵”推广到多标签分类问题

为什么这个 loss 会有效呢?在文章《寻求一个光滑的最大值函数》 [1] 、《函数光滑化杂谈:不可导函数的可导逼近》 [2] 中我们都可以知道,实际上就是的光滑近似,所以我们有:

将“softmax+交叉熵”推广到多标签分类问题

这个 loss 的特点是,所有的非目标类得分跟目标类得分两两作差比较,它们的差的最大值都要尽可能小于零,所以实现了“目标类得分都大于每个非目标类的得分”的效果。

所以,假如是有多个目标类的多标签分类场景,我们也希望 “每个目标类得分都不小于每个非目标类的得分” ,所以下述形式的 loss 就呼之欲出了:

将“softmax+交叉熵”推广到多标签分类问题

其中分别是正负样本的类别集合。这个 loss 的形式很容易理解,就是我们希望,就往里边加入这么一项。如果补上缩放因子和间隔 m,就得到了 Circle Loss 论文里边的统一形式:

将“softmax+交叉熵”推广到多标签分类问题

说个题外话,上式就是 Circle Loss 论文的公式 (1),但原论文的公式 (1) 不叫 Circle Loss,原论文的公式 (4) 才叫 Circle Loss,所以不能把上式叫做 Circle Loss。但笔者认为,整篇论文之中最有意思的部分还数公式 (1)。

3.2 用于多标签分类

和 m 一般都是度量学习中才会考虑的,所以这里我们还是只关心式  (8) 。如果 n 选 k 的多标签分类中 k 是固定的话,那么直接用式  (8)  作为 loss 就行了,然后预测时候直接输出得分最大的 k 个类别。

对于 k 不固定的多标签分类来说,我们就需要一个阈值来确定输出哪些类。为此,我们同样引入一个额外的 0 类,希望目标类的分数都大于,非目标类的分数都小于,而前面已经已经提到过,“希望就往里边加入”,所以现在式  (8)  变成:

将“softmax+交叉熵”推广到多标签分类问题

如果指定阈值为 0,那么就简化为:

将“softmax+交叉熵”推广到多标签分类问题

这便是我们最终得到的 Loss 形式了——“softmax + 交叉熵”在多标签分类任务中的自然、简明的推广,它没有类别不均衡现象,因为它不是将多标签分类变成多个二分类问题,而是变成目标类别得分与非目标类别得分的两两比较,并且借助于的良好性质,自动平衡了每一项的权重。

这里给出 Keras 下的参考实现:

def multilabel_categorical_crossentropy(y_true, y_pred):
    """多标签分类的交叉熵
    说明:y_true和y_pred的shape一致,y_true的元素非0即1,
         1表示对应的类为目标类,0表示对应的类为非目标类。
    """
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = K.zeros_like(y_pred[..., :1])
    y_pred_neg = K.concatenate([y_pred_neg, zeros], axis=-1)
    y_pred_pos = K.concatenate([y_pred_pos, zeros], axis=-1)
    neg_loss = K.logsumexp(y_pred_neg, axis=-1)
    pos_loss = K.logsumexp(y_pred_pos, axis=-1)
    return neg_loss + pos_loss

将“softmax+交叉熵”推广到多标签分类问题

所以,结论就是

所以,最终结论就是式 (11),它就是本文要寻求的多标签分类问题的统一 loss,欢迎大家测试并报告效果。笔者也实验过几个多标签分类任务,均能媲美精调权重下的二分类方案。

要提示的是,除了标准的多标签分类问题外,还有一些常见的任务形式也可以认为是多标签分类,比如基于 0/1 标注的序列标注,典型的例子是笔者的 “半指针-半标注”标注设计

因此,从这个角度看,能被视为多标签分类来测试式  (11)  的任务就有很多了,笔者也确实在之前的三元组抽取例子 task_relation_extraction.py  [6] 中尝试了  (11) ,最终能取得跟这里 [7] 一致的效果。

当然,最后还是要说明一下,虽然理论上式  (11)  作为多标签分类的损失函数能自动地解决很多问题,但终究是不存在绝对完美、保证有提升的方案。

所以当你用它替换掉你原来多标签分类方案时,也不能保证一定会有提升,尤其是当你原来已经通过精调权重等方式处理好类别不平衡问题的情况下,式  (11)  的收益是非常有限的。毕竟式  (11)  的初衷,只是让我们在不需要过多调参的的情况下达到大部分的效果。

参考链接

[1] https://kexue.fm/archives/3290

[2] https://kexue.fm/archives/6620

[3] https://kexue.fm/archives/4733

[4] https://en.wikipedia.org/wiki/Newton's_identities

[5] https://arxiv.org/abs/2002.10857

[6] https://github.com/bojone/bert4keras/blob/master/examples/task_relation_extraction.py

[ 7 ] https://kexue.fm/archives/7161# 类别失衡

将“softmax+交叉熵”推广到多标签分类问题

点击以下标题查看更多往期内容:

将“softmax+交叉熵”推广到多标签分类问题

# 投 稿 通 道 #

让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢? 答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是 最新论文解读 ,也可以是 学习心得技术干货 。我们的目的只有一个,让知识真正流动起来。

:memo:  来稿标准:

• 稿件确系个人 原创作品 ,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

:mailbox_with_mail:  投稿邮箱:

• 投稿邮箱: hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

:mag:

现在,在 「知乎」 也能找到我们了

进入知乎首页搜索 「PaperWeekly」

点击 「关注」 订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击 「交流群」 ,小助手将把你带入 PaperWeekly 的交流群里。

将“softmax+交叉熵”推广到多标签分类问题

将“softmax+交叉熵”推广到多标签分类问题


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Data Structures and Algorithm Analysis in Java

Data Structures and Algorithm Analysis in Java

Mark A. Weiss / Pearson / 2011-11-18 / GBP 129.99

Data Structures and Algorithm Analysis in Java is an “advanced algorithms” book that fits between traditional CS2 and Algorithms Analysis courses. In the old ACM Curriculum Guidelines, this course wa......一起来看看 《Data Structures and Algorithm Analysis in Java》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

随机密码生成器
随机密码生成器

多种字符组合密码

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器