函数介绍
torch.bmm(input, mat2, *, out=None) → Tensor
-
输入:
-
函数在
input
和mat2
之间进行 batch 矩阵乘法 -
input
和mat2
都必须是 3-D tensors,他们包含的矩阵数量相同-
如果
input
是 shape 为[b, n, m]
的 tensor -
mat2
是 shape 为[b, m, p]
的 tensor -
那么函数的结果就是
shape
为[b, n, p]
的 tensor
-
例子
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])