torch.sum 维度参数解析

a = torch.ones(2, 3)
print(a)
print(a.shape)
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=-1))

#res
tensor([[1., 1., 1.],
        [1., 1., 1.]])

torch.Size([2, 3])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor([3., 3.])

个人理解:对于二维数组,dim=0,就是固定行,把列相加;dim=1,固定列,把行相加。

值得注意:dim=-1和dim=2 结果相同。个人推测:dim取值为(0,1)就像数组一样ls=[0,1],python 语法中有ls[-1]=ls[1]=1。思想相同。

a = torch.ones((2,2,3))
print(a)
print(a.shape)
print(a.sum(dim=-1))
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=2))

result:

a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=0)
print(b)
print(b.shape)
----------------------------------
res:
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 4.,  6.],
        [ 8., 10.]])
torch.Size([2, 2])
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=1)
print(b)
print(b.shape)

--------------------------------------
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 2.,  4.],
        [10., 12.]])
torch.Size([2, 2])

三维数组:思路同上,dim=(0,1,2),dim=1固定行,列相加;dim=2,固定列,行相加。dim=-1和dim=2一样。至于dim=0,不清楚怎么相加。

猜你喜欢

转载自blog.csdn.net/weixin_40823740/article/details/114988513