topk()/eq( ) / gt( ) / lt( ) / t( )的用法

topk()/eq( ) / gt( ) / lt( ) / t( )的用法

eq( ) / gt( ) / lt( ) / t( )

import torch

x1 = torch.Tensor([0.2,0.8])
x2 = torch.Tensor([0,3])

print('x1等于x2:',x1.eq(x2))
print('x1大于x2:',x1.gt(x2))
print('x1小于x2:',x1.lt(x2))
# x1等于x2: tensor([False, False])
# x1大于x2: tensor([ True, False])
# x1小于x2: tensor([False,  True])
--------------------------------------------------------
x3 = torch.Tensor([[2,1],[3,4]])
print('c:',x3)
print('c转置:',x3.t())

# c: tensor([[2., 1.],
#         [3., 4.]])
# c转置: tensor([[2., 3.],
#         [1., 4.]])

topk()

import torch

a = torch.randn((4, 8))
print(a)

# tensor([[-0.6378,  0.4055, -1.1109, -0.2804, -0.5933, -0.8631, -0.7764, -0.2232],
#         [-2.1446,  0.4058,  1.1801,  1.5446,  0.7786,  0.0172, -2.2552,  0.2385],
#         [ 0.7129, -0.8664, -1.2198, -0.1463,  0.0565, -0.0409, -0.4247,  0.8256],
#         [-0.3058, -0.5409,  0.1872, -1.4345,  0.1649,  0.7080,  1.5167,  1.2903]])

max()

格式:

torch.max(input, dim)
#max:取最大值
maxk = max((1, 3))
print(maxk)  #3

设置keepdim=True,以防降维

_, indices_max = a.max(dim=1, keepdim=True)
print(_)
print(indices_max)  #对应索引
#tensor([[0.4055],
#         [1.5446],
#         [0.8256],
#         [1.5167]])
# tensor([[1],
#         [3],
#         [7],
#         [6]])

torch.topk()

格式:

  torch.max(input, k, dim, largest=True)   

input:tensor数据
k:得到前k个数据以及其index
dim: 指定在某个维度上排序, 默认为最后一个维度
largest:为True:按照从大到小排序; 为False:按照从小到大排序

#_返回的是前maxk个最大值,pred返回对应index
#是指定维度dim=0,按行取,dim=1,按列取。

_, pred = a.topk(maxk, 1, True, True)
print(_)
# tensor([[ 0.4055, -0.2232, -0.2804],
#         [ 1.5446,  1.1801,  0.7786],
#         [ 0.8256,  0.7129,  0.0565],
#         [ 1.5167,  1.2903,  0.7080]])
print(pred)
# tensor([[1, 7, 3],
#         [3, 2, 4],
#         [7, 0, 4],
#         [6, 7, 5]])
_, pred = a.topk(1, 1, True, True)
print(_)
# tensor([[0.4055],
#         [1.5446],
#         [0.8256],
#         [1.5167]])
print(pred)
# tensor([[1],
#         [3],
#         [7],
#         [6]])
_, pred = a.topk(1, 0, True, True)
print(_)
print(pred)
# tensor([[0.7129, 0.4058, 1.1801, 1.5446, 0.7786, 0.7080, 1.5167, 1.2903]])
# tensor([[2, 1, 1, 1, 1, 3, 3, 3]])

猜你喜欢

转载自blog.csdn.net/wahahaha116/article/details/126156140