torch.flatten()和nn.Flatten()的区别

torch.flatten()和之间有什么区别torch.nn.Flatten()

回答

扁平化在 PyTorch 中以三种形式提供

  • 作为一种张量的方法(OOP样式torch.Tensor.flatten直接在一个张量施加:x.flatten()

  • 作为一个函数(函数形式torch.flatten应用为:torch.flatten(x)

  • 作为一个模块(nn.Modulenn.Flatten()。通常用于模型定义。

所有这三个是相同的并且共享相同的实施方式中,唯一的区别是nn.Flatten已经start_dim设置为1默认,以避免平坦化所述第一轴线(通常是分批轴)。而其他两个从axis=0to变平axis=-1-整个张量 - 如果没有给出参数。


以上是torch.flatten()和nn.Flatten()的区别的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>