Pytorch进阶
(1) 拼接与拆分
torch.cat(tensors, dim=0)
torch.stack(tensors, dim=0)
- 注意区别
torch.cat
和torch.stack
,前者是不会产生新的维度的,后者是会产生新的维度,而且注意cat
只有在一个维度上不同时才能合并。比如 (2,3,3) 可以和 (2,3,4) 在 dim=2 上进行合并,但是 (2,3,3) 是不能和 (2,2,4) 进行合并的,stack
只能是 size 相同的之间才能操作。


torch.split(tensor, split_size_or_sections, dim=0)
torch.chunk(input, chunks, dim=0)

- 注意采用spilt方法输入时分开后的长度,而chunk方法的输入是分开的个数。
(2) 基本运算
-
加减乘除,乘方,取对数,取 e指数,进行取整。
-
矩阵相乘
torch.matmul(input, other)
(还有一个torch.mm
只能在二维状态下用,不推荐)- 当 input 和 other 都是二维的时候就按正常的矩阵相乘的方式进行相乘。
- 当 input 和 other 是高维的时候,只有最后两个维度按照矩阵相乘的法则进行相乘,其余维度不变。
- 注意使用
torch.matmul
会触发 broadcast。
-
完成clipping
torch.clamp(input, min, max)
- 可以用在当发生梯度爆炸时候的梯度裁剪工作。
(3) 统计属性
-
取范数
xxx.norm(p,dim)
输入一个整数表示取第几类范数。
-
mean, sum, min, max, prod
-
argmin, argmax
取最小/最大的数的索引。 -
一个重要的操作
keepdim
保持计算完成的 dim。 -
选取最大的k个或者第k大的
torch.topk(input,k,largest)
k 表示选出最大的 k 个,将 largest 设为 False 就是选取最小的 k 个。torch.kthvalue(input, k, largest)
k 表示选出第 k 大的。
-
->, >=, <, <=, !=, ==
注意有一个torch.equal(a,b)
可以用来比较a,b两个tensor中的元素是否都相等。
(4) 高阶操作
-
torch.where(condition,x,y)->Tensor
这是一个类似 if-else 的操作,返回一个 tensor,依据 condition 从对应的 x 或者 y 中选出数据,相比于 if-else 结构的好处在于用到 torch ,可以更好的利用 GPU。 -
torch.gather
可以根据输入的index进行元素的筛取。