import torch
基本的数据类型如图所示
1.创建Tensor的基本操作
# Import from numpy
torch.from_numpy
# uninitialized
torch.empty
torch.Tensor
torch.IntTensor
torch.FloatTensor
# 生成矩阵
torch.rand(3,3) # 0 - 1均匀分布
torch.randn(3,3) # -1 - 1正太分布
torch.rand_like(a)
torch.randint(1,10,(3,3))
# 填充
torch.full([2,3],7)
torch.full([1],7)
torch.full([],7)
# 生成连续数或步长一致数组
torch.arange(0,10,2)
torch.range(0,10)
# 按步长切割数组
torch.linspace(0,10,steps=4)
torch.logspace(0,1,steps=10)
# 生成0,1矩阵
torch.ones
torch.zeros
torch.eye
# 生成随机矩阵
torch.randperm(4) # 输出为tensor([3, 1, 2, 0])
2.索引与切片操作
切片:符号操作的总结:
1):表示全部
2):n表示从开始直到n但不包含n,n:表示从n开始到结束
3)【a:b:c】表示从a到b,步长为c
4)::a表示直接步长为a
# 切片操作
a.shape # torch.Size([4, 3, 28, 28])
a[:2].shape # torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape # torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape # torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape # torch.Size([2, 1, 28, 28])
a[:,:,0:28:2,0:28:2].shape # torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape # torch.Size([4, 3, 14, 14])
# '...'的操作
a[...].shape # torch.Size([4, 3, 28, 28])
a[0,...].shape # torch.Size([3, 28, 28])
a[:,1,...].shape # torch.Size([4, 28, 28])
a[...,:2].shape # torch.Size([4, 3, 28, 2])
# 自主挑选采样
# index_select的使用方法
a.shape # torch.Size([4, 3, 28, 28])
a.index_select(0,torch.tensor([0,2])).shape # torch.Size([2, 3, 28, 28])
a.index_select(1,torch.tensor([1,2])).shape # torch.Size([4, 2, 28, 28])
a.index_select(2,torch.arange(28)).shape # torch.Size([4, 3, 28, 28])
a.index_select(2,torch.arange(8)).shape # torch.Size([4, 3, 8, 28])
# take的使用方法,缺点是会将原数组打平操作
src = torch.tensor([[4,3,5],[6,7,8]])
torch.take(src,torch.tensor([0,2,5])) # 输出为tensor([4, 5, 8])
3.维度变换操作
# view与reshape用法一致
a = torch.rand(4,1,28,28)
a.view(4,28*28)
a.view(4*28,28)
a.view(4,28,28,1) # 破坏了数据
# 增加维度操作unsqueeze,在前进的方向后增加一个维度
a.shape # torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape # torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(4).shape # torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(-1).shape # torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(-4).shape # torch.Size([4, 1, 1, 28, 28])
a.unsqueeze(-5).shape # torch.Size([1, 4, 1, 28, 28])
a = torch.tensor([1.2,2.3])
a.unsqueeze(-1) #tensor([[1.2000],[2.3000]])
# 去除维度squeeze
b.shape # torch.Size([1, 32, 1, 1])
b.squeeze().shape # torch.Size([32])
b.squeeze(0).shape # torch.Size([32, 1, 1])
b.squeeze(-4).shape # torch.Size([32, 1, 1])
# 增加维度
# expand
b.shape # torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape # torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape # torch.Size([1, 32, 1, 1])
# repeat
b.shape # torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape # torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape # torch.Size([4, 32, 1, 1])
# 维度的交换操作transpose(1,3) 表示交换1,3两个维度
a = torch.randn(4,3,32,32)
b = a.transpose(1,3) # torch.Size([4, 32, 32, 3])
a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
torch.all(torch.eq(a,a1)) # tensor(True)
# 按指定位置改变数据的维度permute
a = torch.rand(4,3,28,28)
b1 = a.transpose(1,3).transpose(1,2).shape # torch.Size([4, 28, 28, 3])
b2 = a.permute(0,2,3,1).shape # 与上式的变化是一直的,可以直接的指定位置,torch.Size([4, 28, 28, 3])
4.Broadcast广播机制
# 对于以下例子是不能自动扩展去相加的
# 第一种情况
a = torch.randn(4,3,3,4)
b = torch.randn(2,3,3,4)
c = a + b # The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0
# 需要配对
# 第二种情况
a = torch.randn(4,3,3,4)
b = torch.randn(1,3,3,4)
c = a + b # No Error
# 第三种情况,自行觉得内容,但是还是要配对
a = torch.randn(3,28,8)
b = torch.tensor([0,0,0,5,0,0,0,0])
c = a + b
# 对于第一种情况来说如果需要相加,则可以进行手动的相加
a = torch.randn(4,3,3,4)
b = torch.randn(2,3,3,4)
c1 = a + b[0]
c2 = a + b[1]
# 总结:从最小维度开始匹配便可
5.拼接与拆分
# cat的使用例子
# 第一个例子:
# 示例意义为两个班级的成绩表相加变成了一张全年级的成绩表
a1 = torch.rand(4,3,32,32)
a2 = torch.rand(5,3,32,32)
torch.cat([a1,a2],dim = 0).shape # torch.Size([9, 3, 32, 32]),其中的dim = 0指定了在第一个维度上面相加
# 第二个例子:
# 示例意义为两组不同通道数的照片进行通道的叠加
a1 = torch.rand(4,3,32,32)
a2 = torch.rand(4,1,32,32)
torch.cat([a1,a2],dim = 1).shape # torch.Size([4, 4, 32, 32]),其中的dim = 1指定了在第二个维度上面相加
# 第三个例子:
# 行的相加,示例意义为两组半张的照片叠加成一张完整的图片
a1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
torch.cat([a1,a2],dim = 2).shape # torch.Size([4, 3, 32, 32]),其中的dim = 2指定了在第三个维度上面相加
# stack的使用例子
# 第一个例子:
# 示例的意义为,两个老师为自己的班级[16,32]创建了两张表
# 去使用stack增加一个维度2来表示这两个班级,因为不可能让这两个班级的人去合并成一个64人的班级
a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim = 0).shape # torch.Size([2, 32, 8])
# 类似的对于拥有更加高的维度
a1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
torch.stack([a1,a2],dim = 2).shape # torch.Size([4, 3, 2, 16, 32])
# split的使用例子----按长度来拆分
# 将一个班级分成3个班级[2,2,1]的比例来分,最后一个是重点班
cls = torch.randn(5,32,8)
a,b,c = cls.split([2,2,1],dim=0)
a.shape,b.shape,c.shape # (torch.Size([2, 32, 8]), torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))
# 分成两块也是可以的
cls = torch.randn(5,32,8)
a,b = cls.split([4,1],dim=0)
a.shape,b.shape # (torch.Size([4, 32, 8]), torch.Size([1, 32, 8]))
# Chunk的使用例子----按数量来拆分
# 第一个例子:
a = torch.rand(32,8)
b = torch.rand(32,8)
c = torch.stack([a,b],dim = 0) # torch.Size([2, 32, 8])
aa,bb = c.chunk(2,dim = 0)
aa.shape,bb.shape # (torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))
# 第二个例子:
cls = torch.randn(6,32,8)
a,b,c = cls.chunk(3,dim=0)
a.shape,b.shape, c.shape # (torch.Size([2, 32, 8]), torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))
6.基本运算
# 四则运算
torch.all(torch.eq(a-b,torch.sub(a,b))) # tensor(True)
torch.all(torch.eq(a/b,torch.div(a,b))) # tensor(True)
torch.all(torch.eq(a*b,torch.mul(a,b))) # tensor(True)
# 矩阵乘法,torch.mm & @ & torch.matmul 一般使用后者,因为前者只使用与2D
# 第一个例子:高位的矩阵相乘
a = torch.rand(4,3,28,64)
b = torch.rand(4,3,64,32)
torch.matmul(a,b).shape # torch.Size([4, 3, 28, 32])
# 第二个例子:Broadcast与矩阵相结合
a = torch.rand(4,3,28,64)
b = torch.rand(4,1,64,32)
torch.matmul(a,b).shape # torch.Size([4, 3, 28, 32])
# **2 & pow(2) 为平方
# sqrt 为开方, rsqrt为开立方
a.pow(2) = a**2
a.sqrt() = a**0.5
a.floor() # 向下取整
a.ceil() # 向上取整
a.trunc() # 取整数部分
a.frac() # 取小数部分
a.round() # 四舍五入
# 常用的一个函数clamp,限制函数
grad.clamp(10) # 表示小于10的数据设置成 10
grad.clamp(5,15) # 表示小于5的数据设置为 5,而大于15的数据设置为 15
# 例子
a = torch.arange(0,20,2) # tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18])
a.clamp(10) # tensor([10, 10, 10, 10, 10, 10, 12, 14, 16, 18])
a.clamp(5,15) # tensor([ 5, 5, 5, 6, 8, 10, 12, 14, 15, 15])
7.统计属性
# norm范数的相关介绍:https://blog.csdn.net/a493823882/article/details/80569888
a1 = torch.full([8],1.).view(2,4)
a2 = torch.full([8],1.).view(2,2,2)
a1.norm(1),a1.norm(2),a1.norm(3) # (tensor(8.), tensor(2.8284), tensor(2.))
a2.norm(1),a2.norm(2),a2.norm(3) # (tensor(8.), tensor(2.8284), tensor(2.))
a1.norm(1,dim=1),a1.norm(2,dim=1) # (tensor([4., 4.]), tensor([2., 2.]))
a2.norm(1,dim=1),a2.norm(2,dim=1) # (tensor([[2., 2.],
# [2., 2.]]),
# tensor([[1.4142, 1.4142],
# [1.4142, 1.4142]]))
# 其他的统计属性
a.min() # 最小值
a.max() # 最大值
a.mean() # 均值
a.prod() # 累乘
a.sum() # 求和
a.argmax() # 最大值的索引
a.argmin() # 最小值的索引
a.argmax(dim = 1) # 每一行最大值的索引
a.max(dim = 1,keepdim = True) # 其中的keepdim是函数max中的一个参数,设置为True可以保持输出的结果维度与原维度一致
# topk的使用:可以返回前几大或者是前几小的数值,并返回索引
a = torch.rand(4,10)
a.topk(3,dim = 1) # 对第二个维度找出前3大的数值,也就是对每一列的元素进行判断
# 输出为:
# torch.return_types.topk(
# values=tensor([[0.8866, 0.8177, 0.7332],
# [0.9514, 0.8833, 0.6571],
# [0.9797, 0.7790, 0.7702],
# [0.8143, 0.8087, 0.8024]]),
# indices=tensor([[0, 8, 6],
# [4, 6, 0],
# [9, 4, 5],
# [5, 9, 3]]))
a.topk(3,dim = 1,largest = False) # 设置为找最小的前3个数据与索引
# 输出为:
# torch.return_types.topk(
# values=tensor([[0.0204, 0.1908, 0.2236],
# [0.1053, 0.2193, 0.3290],
# [0.0178, 0.0706, 0.3166],
# [0.1955, 0.1956, 0.2229]]),
# indices=tensor([[1, 3, 7],
# [0, 5, 2],
# [7, 9, 1],
# [2, 1, 6]]))
# kthvalue的使用:可以返回第几小的数据与索引
a = torch.rand(4,6)
# a的数据为:
# tensor([[0.7352, 0.7930, 0.3943, 0.3258, 0.6368, 0.5204],
# [0.6069, 0.0058, 0.2128, 0.5723, 0.0749, 0.8241],
# [0.0063, 0.5670, 0.2957, 0.2370, 0.8618, 0.0283],
# [0.1310, 0.4823, 0.2898, 0.9954, 0.3406, 0.1469]])
a.kthvalue(3) # 输出每一行第三小的数据
# 输出为:
# torch.return_types.kthvalue(
# values=tensor([0.5204, 0.2128, 0.2370, 0.2898]),
# indices=tensor([5, 2, 3, 2]))
a.kthvalue(1,dim=0) # 输出第1个维度,也就是每一行的第1小数据,也就是最小数据
# 输出为:
# torch.return_types.kthvalue(
# values=tensor([0.0063, 0.0058, 0.2128, 0.2370, 0.0749, 0.0283]),
# indices=tensor([2, 1, 1, 2, 1, 2]))
8.高级用法
## where的用法
# torch.where(condition,x,y) ---> Tensor
# 意思是:如果满足condition,则来源于x;如果是不符合这个condition,则来源于y。
a = torch.ones(2,2)
b = torch.zeros(2,2)
c = torch.randn(2,2)
# tensor([[ 0.8717, -0.5480],
# [ 1.6908, -0.0894]])
torch.where(c<0,a,b)
# tensor([[0., 1.],
# [0., 1.]])
## gather用法
# gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor
# 注意:使用gather函数的时候,参数dim不能写出来否则会报错
# 正确示范torch.gather(a, 1, torch.tensor([[0,0,1],[1,1,0],[2,2,0]]))
# 错误示范torch.gather(a, dim = 1, torch.tensor([[0,0,1],[1,1,0],[2,2,0]]))
# 使用help(torch.gather)可以查看教程
# 例子
t = torch.arange(1,10).view(3,3)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
index = torch.randint(0,3,(3,3)) # 输出一个取值为0-2的3x3的矩阵
# tensor([[2, 2, 1],
# [0, 1, 2],
# [1, 1, 1]])
torch.gather(t,1,index) # 对行数据进行索引
# 输出为:
# tensor([[3, 3, 2],
# [4, 5, 6],
# [8, 8, 8]])
torch.gather(t,0,index) # 对列数据进行索引
# 输出为:
# tensor([[7, 8, 6],
# [1, 5, 9],
# [4, 5, 6]])