torch.max() 和 torch.min()

官网:https://pytorch.org/docs/stable/torch.html#torch.max 

torch.max()和torch.min()是比较tensor大小的函数。两者用法相同,所以就总结了一个。

(1)不指定比较维度:torch.max(input)

x = torch.rand(1,3)
print(x)
print(torch.min(x))

y = torch.rand(2,3)
print(y)
print(torch.min(y))

output:
tensor([[0.4094, 0.0262, 0.9132]])
tensor(0.0262)
tensor([[0.4712, 0.3108, 0.3703],
        [0.0609, 0.8676, 0.7341]])
tensor(0.0609)

(2)指定比较维度:torch.max(input,dim)  

output 返回tuple:tuple[0] -> 比较结果; tuple[1] ->所在索引

y = torch.rand(2,3)
print(y)
print(torch.max(y,0))

output:
tensor([[0.7573, 0.4121, 0.0922],
        [0.0562, 0.1346, 0.5164]])
(tensor([0.7573, 0.4121, 0.5164]), tensor([0, 0, 1]))

(3)两个tensor相比较:不一定是相同大小结构,若不是相同大小结构,必须满足可广播

相同结构:比较相同位置的返回结果

x = torch.rand(2,3)
y = torch.rand(2,3)
print(x)
print(y)
print(torch.max(x,y))

output:
tensor([[0.9054, 0.4904, 0.4252],
        [0.5209, 0.8509, 0.7347]])
tensor([[0.2347, 0.4457, 0.4466],
        [0.5157, 0.5463, 0.0814]])
tensor([[0.9054, 0.4904, 0.4466],
        [0.5209, 0.8509, 0.7347]])

不是相同结构的,按照广播原理将维度少的那个做一个数据复制再比较。

x = torch.rand(1,3)
y = torch.rand(2,3)
print(x)
print(y)
print(torch.max(x,y))

output:
tensor([[0.2240, 0.1759, 0.3040]])
tensor([[0.6603, 0.1693, 0.5366],
        [0.4192, 0.4316, 0.0386]])
tensor([[0.6603, 0.1759, 0.5366],
        [0.4192, 0.4316, 0.3040]])
发布了56 篇原创文章 · 获赞 29 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/foneone/article/details/103926319