pytorch学习笔记torch.mul() 和 torch.mm() 的区别

pytorch中torch.mul() 和 torch.mm() 的区别

torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵

torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

import torch

a = torch.rand(1, 2)
b = torch.rand(1, 2)
c = torch.rand(2, 3)

print(torch.mul(a, b))  # 返回 1*2 的tensor
print(torch.mm(a, c))   # 返回 1*3 的tensor
print(torch.mul(a, c))  # 由于a、b维度不同,报错

猜你喜欢

转载自blog.csdn.net/m0_45388819/article/details/109907583