对于numpy中的函数的参数dim的一点理解

对于numpy中的函数的参数dim的一点理解

经常被dim参数搞混。试着总结了一下。记忆瞬间清晰了

以.max(dim)方法为例:

>>> import numpy as np
>>> a = np.random.randint(1, 100, [2, 3, 4])
>>> a
array([[[26, 36, 31, 21],
        [74, 59, 79, 32],
        [77, 94, 81, 32]],

       [[72, 76, 85, 93],
        [66, 34, 80, 12],
        [99, 17, 98, 23]]])
>>> for i in range(3):
...     print(a.max(i))
... 
[[72 76 85 93]
 [74 59 80 32]
 [99 94 98 32]]
[[77 94 81 32]
 [99 76 98 93]]
[[36 79 94]
 [93 80 99]]

可以见得:
a是一个2x3x4的三维矩阵。
当a.max(0)时,max则在维度大小为2的方向上进行操作,所以
a.max(0)就是:
[[72 76 85 93]
[74 59 80 32]
[99 94 98 32]]
一个 1x3x4的矩阵。
以此类推,a.max(1)就是在维度大小为3的方向上进行操作
a.max(i)就是:
[[77 94 81 32]
[99 76 98 93]]
一个 1x2x4的矩阵。

由此很容易发现。
.max(dim)中的dim,并不是a上的维度。而是指a的shape上的顺序(可以这么理解),a的shape是2x3x4,也就是[2, 3, 4]。故可以这样一一对应以来。
而不用死记硬背那些0是对列操作还是对行操作了

猜你喜欢

转载自www.cnblogs.com/orangestar/p/12892410.html