深度学习大概率用到的Pytorch内容进阶

Pytorch进阶

(1) 拼接与拆分

  • torch.cat(tensors, dim=0)
  • torch.stack(tensors, dim=0)
  • 注意区别 torch.cattorch.stack,前者是不会产生新的维度的,后者是会产生新的维度,而且注意 cat 只有在一个维度上不同时才能合并。比如 (2,3,3) 可以和 (2,3,4) 在 dim=2 上进行合并,但是 (2,3,3) 是不能和 (2,2,4) 进行合并的,stack 只能是 size 相同的之间才能操作。
  • torch.split(tensor, split_size_or_sections, dim=0)
  • torch.chunk(input, chunks, dim=0)
  • 注意采用spilt方法输入时分开后的长度,而chunk方法的输入是分开的个数。

(2) 基本运算

  • 加减乘除,乘方,取对数,取 e指数,进行取整。

  • 矩阵相乘

    • torch.matmul(input, other) (还有一个 torch.mm 只能在二维状态下用,不推荐)
    • 当 input 和 other 都是二维的时候就按正常的矩阵相乘的方式进行相乘。
    • 当 input 和 other 是高维的时候,只有最后两个维度按照矩阵相乘的法则进行相乘,其余维度不变。
    • 注意使用 torch.matmul 会触发 broadcast。
  • 完成clipping

    • torch.clamp(input, min, max)
    • 可以用在当发生梯度爆炸时候的梯度裁剪工作。

(3) 统计属性

  • 取范数

    • xxx.norm(p,dim) 输入一个整数表示取第几类范数。
  • mean, sum, min, max, prod

  • argmin, argmax 取最小/最大的数的索引。

  • 一个重要的操作 keepdim 保持计算完成的 dim。

  • 选取最大的k个或者第k大的

    • torch.topk(input,k,largest) k 表示选出最大的 k 个,将 largest 设为 False 就是选取最小的 k 个。
    • torch.kthvalue(input, k, largest) k 表示选出第 k 大的。
  • ->, >=, <, <=, !=, == 注意有一个 torch.equal(a,b) 可以用来比较a,b两个tensor中的元素是否都相等。

(4) 高阶操作

  • torch.where(condition,x,y)->Tensor 这是一个类似 if-else 的操作,返回一个 tensor,依据 condition 从对应的 x 或者 y 中选出数据,相比于 if-else 结构的好处在于用到 torch ,可以更好的利用 GPU。

  • torch.gather 可以根据输入的index进行元素的筛取。

猜你喜欢

转载自blog.csdn.net/weixin_44618906/article/details/107364382