Keras的CategoricalCrossEntropy到底在做什么?

我正在移植一个 keras 模型torch,但我无法'categorical_crossentropy'在 softmax 层之后复制 keras/tensorflow 的确切行为。我对这个问题有一些解决方法,所以我只对了解计算分类交叉熵时 tensorflow究竟计算什么感兴趣。

作为一个玩具问题,我设置了标签和预测向量

>>> import tensorflow as tf
>>> from tensorflow.keras import backend as K
>>> import numpy as np


>>> true = np.array([[0.0, 1.0], [1.0, 0.0]])
>>> pred = np.array([[0.0, 1.0], [0.0, 1.0]])

并计算分类交叉熵:

>>> loss = tf.keras.losses.CategoricalCrossentropy()
>>> print(loss(pred, true).eval(session=K.get_session()))
8.05904769897461

这与分析结果不同

>>> loss_analytical = -1*K.sum(true*K.log(pred))/pred.shape[0]
>>> print(loss_analytical.eval(session=K.get_session()))
nan

我深入研究了 keras/tf 的交叉熵的源代码(参见Tensorflow Github Source Code 中的 Softmax Cross Entropy implementation),并在https://github.com/tensorflow/tensorflow/blob/c903b4607821a03c36c17b0befa25306c6/compilerflow/compiler/找到了 c 函数tf2xla/kernels/softmax_op.cc第 116 行。在该函数中,有一条注释:

// sum(-labels *
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
// along classes
// (The subtraction broadcasts along the batch dimension.)

并实现这一点,我试过:

>>> max_logits = K.max(pred, axis=0)
>>> max_logits = max_logits
>>> xent = K.sum(-true * ((pred - max_logits) - K.log(K.sum(K.exp(pred - max_logits)))))/pred.shape[0]

>>> print(xent.eval(session=K.get_session()))
1.3862943611198906

我还尝试打印 的跟踪xent.eval(session=K.get_session()),但跟踪长约 95000 行。所以它引出了一个问题:keras/tf 在计算时到底在做什么'categorical_crossentropy'?它不返回是有道理的,这nan会导致训练问题,但是 8 来自哪里?

回答

以下是我在您的代码中注意到的一些内容。

首先,您的预测显示两个数据实例,[0.0, 1.0]并且[0.0, 1.0].

pred = np.array([[0.0, 1.0], [0.0, 1.0]])

它们应该表示概率,但 softmax 之后的值通常不完全是 0.0 和 1.0。改为尝试 0.01 和 0.99。

其次,CateogoricalCrossEntropy() 调用的参数应该是true, pred,而不是pred, true

所以这就是我得到的:

import tensorflow as tf
from tensorflow.keras import backend as K
import numpy as np

true = np.array([[0.0, 1.0], [1.0, 0.0]])
pred = np.array([[0.01, 0.99], [0.01, 0.99]])

loss = tf.keras.losses.CategoricalCrossentropy()
print(loss(true, pred).numpy())
# 2.307610273361206

为了完整起见,让我们试试你所做的,使用pred, true

print(loss(pred, true).numpy())
# 8.05904769897461

那就是你神秘的 8.05 的来源。

我的回答2.307610273361206正确吗?让我们手动计算损失。按照StackOverflow 帖子中的解释,我们可以计算两个数据实例中每一个的损失,然后计算它们的平均值。

loss1 = -(0.0 * np.log(0.01) + 1.0 * np.log(0.99))
print(loss1) # 0.01005033585350145

loss2 = -(1.0 * np.log(0.01) + 0.0 * np.log(0.99))
print(loss2) # 4.605170185988091

# Total loss is the average of the per-instance losses.
loss = (loss1 + loss2) / 2
print(loss) # 2.307610260920796

所以看起来CategoricalCrossEntropy()正在产生正确的答案。


回答

问题是您在预测中使用了硬 0 和 1。这导致nan您的计算因为log(0)未定义(或无限)。

没有真正记录在案的是,Keras 交叉熵通过将值剪切到范围内来自动“保护”这一点[eps, 1-eps]。这意味着,在您的示例中,Keras 为您提供了不同的结果,因为它完全用其他值替换了预测。

如果您用软值替换您的预测,您应该能够重现结果。无论如何,这是有道理的,因为您的网络通常会通过 softmax 激活返回这些值;硬 0/1 只发生在数值下溢的情况下。

如果你想自己检查一下,剪裁发生在这里。这个函数最终被CategoricalCrossentropy函数调用。epsilon在别处定义,但似乎是0.0000001- 尝试手动计算,pred = np.clip(pred, 0.0000001, 1-0.0000001)您应该看到结果8.059047875479163


以上是Keras的CategoricalCrossEntropy到底在做什么?的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>