PyTorch【5】-Tensor 运算

Tensor API 较多,所以把 运算 单独列出来,方便查看

乘法

t.mul(input, other, out=None):矩阵乘以一个数

t.matmul(mat, mat, out=None):矩阵相乘

t.mm(mat, mat, out=None):基本上等同于 matmul

a=torch.randn(2,3)
b=torch.randn(3,2)
### 等价操作
print(torch.mm(a,b))        # mat x mat
print(torch.matmul(a,b))    # mat x mat
### 等价操作
print(torch.mul(a,3))       # mat 乘以 一个数
print(a * 3)

注意,乘法可以直接作用于单个数字

乘法需要符合 向量乘法 的规则,即尺寸匹配

a=torch.randn(2,3)
c = torch.randn(2, 3)
# print(torch.matmul(a, c))   # 尺寸不符合向量乘法,(2,3)x(2,3)
print(torch.matmul(a, c.t())) # t() 转置,正确 (2,3)x(3,2)

加法

加法有 3 种方式:+,add,add_

import torch as t
y = t.rand(2, 3)        ### 使用[0,1]均匀分布构建矩阵
z = t.ones(2, 3)        ### 2x3 的全 1 矩阵

#### 3 中加法操作等价
print(y + z)            ### 加法1
t.add(y, z)             ### 加法2
### 加法的第三种写法
result = t.Tensor(2, 3) ### 预先分配空间
t.add(y, z, out=result) ### 指定加法结果的输出目标
print(result)

add_ 与 add 的区别在于,add 不会改变原来的 tensor,而 add_会改变原来的 tensor;

在 pytorch 中,方法后面加  _ 都会改变原来的对象,相当于 in-place 的作用

print(y)
# tensor([[0.4083, 0.3017, 0.9511],
#         [0.4642, 0.5981, 0.1866]])
y.add(z)
print(y)                ### y 不变
# tensor([[0.4083, 0.3017, 0.9511],
#         [0.4642, 0.5981, 0.1866]])
y.add_(z)
print(y)                ### y 变了,相当于 inplace
# tensor([[1.4083, 1.3017, 1.9511],
#         [1.4642, 1.5981, 1.1866]])

可以作用于单个数字或者 尺寸为 (1,1) 的 Tensor

a = t.ones(3, 3)
print(a + 1)        ### 可以直接作用于单个数字

b = t.ones(1, 1)
print(a + b)

c = t.ones(2, 1)
# print(a + c)        ### 报错,如果尺寸不匹配,c 的尺寸只能是 (1, 1)

减法 

和加法一样,三种:-、sub、sub_

a = t.randn(2, 1)
b = t.randn(2, 1)
print(a)
### 等价操作
print(a - b)
print(t.sub(a, b))
print(a)        ### sub 后 a 没有变化

a.sub_(b)
print(a)        ### sub_ 后 a 也变了

c = 1
print(a - c)    ### 直接作用于单个数字

其他运算

t.div(input, other, out=None):除法

t.pow(input, other, out=None):指数

t.sqrt(input, out=None):开方

t.round(input, out=None):四舍五入到整数

t.abs(input, out=None):绝对值

t.ceil(input, out=None):向上取整

t.clamp(input, min, max, out=None):把 input 规范在 min 到 max 之间,超出用 min 和 max 代替,可理解为削尖函数

t.argmax(input, dim=None, keepdim=False):返回指定维度最大值的索引

t.sigmoid(input, out=None)

t.tanh(input, out=None)

参考资料:

猜你喜欢

转载自www.cnblogs.com/yanshw/p/12206849.html