torch.mean()维度解释

a = torch.arange(6) * 1.
# print(a)
a = a.reshape(2, 1, 3)
print(a)
print(a.shape)
b = a.mean(dim=0)
print(b)
print(b.shape)

--------------------------------------------
res:
tensor([[[0., 1., 2.]],

        [[3., 4., 5.]]])
torch.Size([2, 1, 3])
tensor([[1.5000, 2.5000, 3.5000]])
torch.Size([1, 3])

三维(m,n,q):

ls = [

[  [1,2] , [3,4]  ],

[  [5,6] , [7,8]  ]

]

ls.shape= 2*2*2

dim = 0,

固定行,列相加:

(1+5)/2 = 3,

(2+6)/2 = 4,

(3+7)/2 =5,

(4+8)/2 = 6,

ls.mean =

[

[3,4],[5,6]

]  

ls.shape = (2*2)

---------------------------------------------

dim =1

固定列,行相加

(1+ 3)/2 = 2,

(2+4)/2 = 3,

(5+7)/2 = 6,

(6+8)/2 =7,

ls.mean(dim=1)=

[

[2,3],

[6,7]

]

ls.shape =(2*2)  

a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.mean(dim=1)
print(b)
print(b.shape)
-------------------------------

tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[1., 2.],
        [5., 6.]])
torch.Size([2, 2])

猜你喜欢

转载自blog.csdn.net/weixin_40823740/article/details/115252326