Pytorch 基础
文章目录
(1) 基本数据类型
1.Pytorch 用的比较多 tensor
torch.FloatTensor
(可以简写为torch.Tensor
),类型为torch.float32
。精度更高一点就到了torch.DoubleTensor
,类型为torch.float64
。torch.IntTensor
,类型为torch.int16
。 变为长整型就是torch.LongTensor
,类型为torch.int64
。torch.ByteTensor
一般用来作为布尔变量的,类型为torch.int8
。
2.类型的检测
- 一共有三种方式
- 直接采用
xxx.type()
的方式。 - 采用
type(xxx)
的方法 - 采用
isinstance(xxx,一个tensor的type不能是dtype)
- 直接采用
3.不同维度的数据及用途
-
Dimension 0
- 就是一个标量,它的shape,size()是空的。
- 一般在计算loss的时候会用标量作为loss的计算结果。
-
Dimension 1
- 就是一个一维向量(行向量),它的shape,size()仅有一个数。
- 对线性层 batch_size=1 的时候输入和bias是一个一维向量。
-
Dimension 2
- 是一个二维向量(可以理解成矩阵),它的shape,size()有两个数。
- 对线性层 batch_size 不为 1 的时候输入和bias是一个二维向量。
-
Dimension 3
- 是一个三维向量,它的shape,size()有三个数。
- 一般RNN使用到的是 dim=3 的 tensor,其中 (seq_len,batch_size,vocab_size) 分别表示一段话的单词个数,截取的段数,字典长度。
-
Dimension 4
- 是一个四维向量,它的shape,size()有四个数。
- 一般CNN使用到的是 dim=4 的 tensor,其中 (b,c,h,w) 分别表示 batch_size ,channel(RGB的话是3个channel),长,宽
(2) 创建Tensor
1.从numpy中创建Tensor
-
使用
torch.from_numpy(xxx)
。
2.从list中创建Tensor
- 首先要区分
torch.Tensor(xxx)
和torch.tensor(xxx)
,前者等价于torch.FloatTensor(xxx)
不能自定义dtype
。 - 其它的只要将一个 list 的数据输到里面就可以转换了。
3.设定默认type
- 一般在不指定type的情况下采用
torch.tensor
创建的 tensor会直接采用默认的type,注意这个只能在double 和 float这两种数据类型之间转换。
4.随机产生进行初始化
torch.rand(*size)
这是一个在 [0,1] 上的均匀分布。torch.rand_like(xxx)
相当于torch.rand(xxx.size())
torch.randint(low=0,high,size)
这个是在 [min,max) 中的均匀分布。torch.normal(mean,std,size)
torch.randn(size)
相当于torch.normal(0.0,1.0,size)
5.采用特殊生成进行初始化
-
torch.full(size,num)
-
torch.arange()
或者torch.range()
-
torch.linspace(start, end, steps=100)
或者torch.logspace(start, end, steps=100, base=10.0)
这里base表示从 b a s e s t a r t base^{start} basestart 到 b a s e e n d base^{end} baseend 取值。 -
torch.ones()
和torch.zeros()
和torch.eye()
-
对0维进行一个shuffle操作,首先采用
torch.randperm(n)
,然后返回之前的矩阵的取样。
(3) 索引与切片
-
采用索引方法(默认从第0维开始),没有写东西的维度相当于在那个维度上采用
:
。 -
通过
:
进行切片,选取前面k个或者选取后面k个。 -
通过step进行选取。
-
选取特殊的index,采用
torch.index_select(dim,tensor)
-
先把原来的tensor拉直之后再选取对应的index,采用
torch.take(input, index)
(4) 维度变换
- 常用的有四种方式
-
view
,缺陷在于使用view
会导致维度信息的丢失。 -
Squeeze/unsqueeze
torch.squeeze(input, dim=None)
当维度 i 可以被压缩的时候就会发生压缩,否则维持原状。
torch.unsqueeze(input, dim)
-
transpost/t/permute
torch.transpose(input, dim0, dim1)
用来交换 dim0 和 dim1上的数据。xxx.t
注意这个只能用在二维状态下。xxx.permute(*dims)
可以一次交换很多个维度。- 要注意一件事,就是使用这样的交换函数之后会导致数字在存储上变得不连续,因此在后面不能直接加view,而是要先用contiguous把存储整的连续起来。
-
expand/repeat
两者的区别在于,repeat
会进行数据的拷贝,而expand
仅是进行了一个 broadcast,注意在进行扩张操作的时候要保证这个维度上为1才能在这个维度上扩张,比如 (2,3,3) 就扩不成 (2,3,3,3)。
-