a = torch.arange(6) * 1.
# print(a)
a = a.reshape(2, 1, 3)
print(a)
print(a.shape)
b = a.mean(dim=0)
print(b)
print(b.shape)
--------------------------------------------
res:
tensor([[[0., 1., 2.]],
[[3., 4., 5.]]])
torch.Size([2, 1, 3])
tensor([[1.5000, 2.5000, 3.5000]])
torch.Size([1, 3])
三维(m,n,q):
ls = [
[ [1,2] , [3,4] ],
[ [5,6] , [7,8] ]
]
ls.shape= 2*2*2
dim = 0,
固定行,列相加:
(1+5)/2 = 3,
(2+6)/2 = 4,
(3+7)/2 =5,
(4+8)/2 = 6,
ls.mean =
[
[3,4],[5,6]
]
ls.shape = (2*2)
---------------------------------------------
dim =1
固定列,行相加
(1+ 3)/2 = 2,
(2+4)/2 = 3,
(5+7)/2 = 6,
(6+8)/2 =7,
ls.mean(dim=1)=
[
[2,3],
[6,7]
]
ls.shape =(2*2)
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.mean(dim=1)
print(b)
print(b.shape)
-------------------------------
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
torch.Size([2, 2, 2])
tensor([[1., 2.],
[5., 6.]])
torch.Size([2, 2])