如何正确使用交叉熵损失vsSoftmax进行分类?

我想使用 Pytorch 训练多类分类器。

按照官方 Pytorch 文档展示了如何nn.CrossEntropyLoss()在 type 的最后一层之后使用 a nn.Linear(84, 10)

但是,我记得这是 Softmax 所做的。

这让我很困惑。


  1. 如何以最佳方式训练“标准”分类网络?
  2. 如果网络有最后的线性层,如何推断每个类的概率?
  3. 如果网络有最后的 softmax 层,如何训练网络(哪个损失,以及如何训练)?

我在 Pytorch 论坛上找到了这个线程,它可能回答了所有这些问题,但我无法将它编译成工作和可读的 Pytorch 代码。


我假设的答案:

  1. 就像医生说的那样。
  2. 线性层输出的幂,这实际上是对数(对数概率)。
  3. 我不明白。

回答

我认为理解 softmax 和交叉熵很重要,至少从实践的角度来看是这样。一旦您掌握了这两个概念,就应该清楚如何在 ML 的上下文中“正确”使用它们。

交叉熵 H(p, q)

交叉熵是比较两个概率分布的函数。从实践的角度来看,可能不值得深入研究交叉熵的正式动机,但如果您有兴趣,我会推荐Cover 和 Thomas 的Elements of Information Theory作为介绍性文本。这个概念很早就被引入了(我相信是第 2 章)。这是我在研究生院使用的介绍文字,我认为它做得很好(当然我也有一位很棒的导师)。

要注意的关键是交叉熵是一个函数,它以两个概率分布作为输入:q 和 p,并在 q 和 p 相等时返回一个最小的值。q 表示估计分布,p 表示真实分布。

在 ML 分类的上下文中,我们知道训练数据的实际标签,因此真实/目标分布 p 对于真实标签的概率为 1,其他地方为 0,即 p 是一个单热向量。

另一方面,估计分布(模型的输出)q 通常包含一些不确定性,因此 q 中任何类别的概率将在 0 和 1 之间。 通过训练系统以最小化交叉熵,我们告诉系统我们希望它尝试使估计分布尽可能接近真实分布。因此,你的模型认为最有可能的类就是q的最大值对应的类。

软最大

同样,有一些复杂的统计方法来解释 softmax,我们不会在这里讨论。从实用的角度来看,关键是 softmax 是一个函数,它以无界值列表作为输入,并输出一个有效的概率质量函数,并保持相对顺序。重要的是要强调关于相对顺序的第二点。这意味着 softmax 输入中的最大元素对应于 softmax 输出中的最大元素。

考虑一个经过训练以最小化交叉熵的 softmax 激活模型。在这种情况下,在 softmax 之前,模型的目标是为正确的标签产生可能的最高值,为不正确的标签产生可能的最低值。

PyTorch 中的交叉熵损失

PyTorch中 CrossEntropyLoss的定义是 softmax 和交叉熵的结合。具体来说

CrossEntropyLoss(x, y) := H(one_hot(y), softmax(x))

请注意,one_hot 是一个函数,它采用索引 y,并将其扩展为 one-hot 向量。

同样,您可以将 CrossEntropyLoss 公式化为LogSoftmax和负对数似然损失(即PyTorch中的 NLLLoss)的组合

LogSoftmax(x) := ln(softmax(x))

CrossEntropyLoss(x, y) := NLLLoss(LogSoftmax(x), y)

由于 softmax 中的求幂,有一些计算“技巧”可以使直接使用 CrossEntropyLoss 比分阶段计算更稳定(更准确,不太可能得到 NaN)。

结论

基于以上讨论,您的问题的答案是

1. 如何以最好的方式训练一个“标准”的分类网络?

就像医生说的那样。

2. 如果网络有最后的线性层,如何推断每个类的概率?

将 softmax 应用于网络的输出以推断每个类别的概率。如果目标只是找到相对排序或最高概率类,那么只需将 argsort 或 argmax 直接应用于输出(因为 softmax 保持相对排序)。

3. 如果网络有最后的 softmax 层,如何训练网络(哪个损失,以及如何训练)?

通常,出于上述稳定性原因,您不想训练输出 softmaxed 输出的网络。

也就是说,如果您出于某种原因绝对需要,您可以获取输出日志并将它们提供给 NLLLoss

criterion = nn.NLLLoss()
...
x = model(data)    # assuming the output of the model is softmax activated
loss = criterion(torch.log(x), y)

这在数学上等同于将 CrossEntropyLoss 与使用 softmax 激活的模型一起使用。

criterion = nn.CrossEntropyLoss()
...
x = model(data)    # assuming the output of the model is NOT softmax activated
loss = criterion(x, y)


以上是如何正确使用交叉熵损失vsSoftmax进行分类?的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>