符号*,torch.max 和 torch.sum, item()方法

*的作用可以参考https://www.cnblogs.com/jony7/p/8035376.html

torch.max可以参考https://blog.csdn.net/Z_lbj/article/details/79766690

a.size()
# Out[134]: torch.Size([6, 4, 3])
torch.max(a, 0)[1].size()
# Out[135]: torch.Size([4, 3])
torch.max(a, 1)[1].size()
# Out[136]: torch.Size([6, 3])
torch.max(a, 2)[1].size()
# Out[137]: torch.Size([6, 4])

 具体怎么比较的可以看下面

b

tensor([[[  0.,   1.,   2.,   3.],
         [  4.,   5.,   6.,   7.],
         [  8.,   9.,  10.,  11.]],
        [[ 12.,  13.,  14.,  15.],
         [ 16.,  17.,  18.,  19.],
         [ 20.,  21.,  22.,  23.]]])

torch.max(b,0)[0]
 
tensor([[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.]])

torch.max(b,1)[0]

tensor([[  8.,   9.,  10.,  11.],
        [ 20.,  21.,  22.,  23.]])

torch.max(b,2)[0]

tensor([[  3.,   7.,  11.],
        [ 15.,  19.,  23.]])

相应的下标可以得到

b

tensor([[[  0.,   1.,   2.,   3.],
         [  4.,   5.,   6.,   7.],
         [  8.,   9.,  10.,  11.]],
        [[ 12.,  13.,  14.,  15.],
         [ 16.,  17.,  18.,  19.],
         [ 20.,  21.,  22.,  23.]]])

torch.max(b,0)[1]

tensor([[ 1,  1,  1,  1],
        [ 1,  1,  1,  1],
        [ 1,  1,  1,  1]])

torch.max(b,1)[1]

tensor([[ 2,  2,  2,  2],
        [ 2,  2,  2,  2]])

torch.max(b,2)[1]

tensor([[ 3,  3,  3],
        [ 3,  3,  3]])

 torch.sum:

torch.sum(input) → Tensor
torch.sum(input, dim, out=None) → Tensor
参数:

    input (Tensor) – 输入张量
    dim (int) – 缩减的维度
    out (Tensor, optional) – 结果张量

函数的输出是一个tensor

match
out:
tensor([[[ 0,  0,  2,  0],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  0]],
        [[ 0,  0,  0,  0],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  0]]], dtype=torch.uint8)

torch.sum(match)
Out: 
tensor(2)

torch.sum(match,0)
Out: 
tensor([[ 0,  0,  2,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]])
torch.sum(match,1)
Out: 
tensor([[ 0,  0,  2,  0],
        [ 0,  0,  0,  0]])
torch.sum(match,2)
Out: 
tensor([[ 2,  0,  0],
        [ 0,  0,  0]])

还要补充一点的就是item方法的使用:如果tensor只有一个元素那么调用item方法的时候就是将tensor转换成python的scalars;如果tensor不是单个元素的话那就会引发ValueError,如下面

b.item()
Traceback (most recent call last):
    b.item()
ValueError: only one element tensors can be converted to Python scalars

torch.sum(b)
Out: tensor(276.)
torch.sum(b).item()
Out: 276.0

那么在python中的item方法一般是怎么样的呢?可参见https://blog.csdn.net/qq_34941023/article/details/78431376

猜你喜欢

转载自blog.csdn.net/zz2230633069/article/details/83092376