关于交叉熵的易混淆概念

0. 前言

交叉熵,二值交叉熵,softmax,sigmoid等概念经常一起出现,这里结合pytorch的具体实现来捋清楚这几个概念。

1. 概念

1.1 交叉熵

交叉熵用于描述两个概率分布的差异,也就是说处理的都是0-1之间的数值。

在实际应用中,一般用于计算分类损失。并且,为了避免对0取对数,会将GT的one-hot向量作为$p$,预测结果作为$q$。

1.2 二值交叉熵

二值交叉熵就是上述交叉熵公式在$x$只有两个取值时的特例。

1.3 sigmoid函数

一般模型直接输出的结果向量不保证其中的标量元素的数值(pytorch等官方术语称logits)在0-1之间,因此要用到sigmoid这种值域在0-1之间的激活函数将每一个logit单独作映射,然后再进行交叉熵计算。

1.4 softmax函数

跟sigmoid一样,也是一个logit归一化的工具,但是需要所有logits一起做联合映射,而不是独立映射。

2. pytorch实现

pytorch中的nn模块和nn.funtional模块中都有交叉熵损失的实现,这里只介绍nn下的3个模块。

2.1 nn.CrossEntropyLoss

主要的入参有input(N, C)和target(N,)两个,N是结果向量有几个,C是类别数量。

input的每一行都是一个结果向量,包含C个类别的各自的logit,这是因为这个模块内部集成了softmax函数,在进行损失计算之前会先对input做0-1的映射。

target每一行是一个真实标签序号,值在[0, C-1]之间。在进行算是计算之前会将target的每一行拓展成一个one-hot向量,hot的位置就是这个真实标签序号。

然后用1.1中的公式逐行计算这N个结果向量与N个one-hot向量的交叉熵,最后求和或求算数平均值(默认求均值)。

2.2 nn.BCEWithLogitsLoss

主要的入参还是input(N, C)和target(N, C)两个,但是target的形状不同,其每一行都直接是GT的one-hot向量。

也同样集成了sigmoid函数,计算损失之前会先对input做0-1的映射。

然后逐行选中一个C维的结果行向量和一个C维的GT one-hot行向量,按照1.2中的公式计算得到C个标量,求均值即可得到这一行的loss。

最后求各行的loss均值即可得到最终的loss。

2.3 nn.BCELoss

未集成sigmoid的二值交叉熵计算模块,有要求入参input的标量元素值在0-1之间,所以一般不使用,直接使用2.2。

3. 附录


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
import torch

logits = torch.tensor([[2, -3, 0.8, 1.2], [1.5, -0.1, 0.9, -1.3]])

sigmoid = torch.sigmoid(logits)
softmax = torch.softmax(logits, 1)

GT4bce = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0]], dtype=torch.float)
GT4ce = torch.tensor([2, 1])

ce = nn.CrossEntropyLoss(reduction='sum')
ce_loss = ce(logits, GT4ce)

bce = nn.BCEWithLogitsLoss(reduction='sum')
bce_loss = bce(logits, GT4bce)

print("CE loss is {} \nBCE loss is {}".format(ce_loss, bce_loss))