torch.flatten()和nn.Flatten()的区别
torch.flatten()和之间有什么区别torch.nn.Flatten()?
回答
扁平化在 PyTorch 中以三种形式提供
-
作为一种张量的方法(OOP样式)
torch.Tensor.flatten直接在一个张量施加:x.flatten()。 -
作为一个函数(函数形式)
torch.flatten应用为:torch.flatten(x)。 -
作为一个模块(层
nn.Module)nn.Flatten()。通常用于模型定义。
所有这三个是相同的并且共享相同的实施方式中,唯一的区别是nn.Flatten已经start_dim设置为1默认,以避免平坦化所述第一轴线(通常是分批轴)。而其他两个从axis=0to变平axis=-1-即整个张量 - 如果没有给出参数。