Pytorch学习笔记【3】 --tensor切片

Pytorch学习笔记【3】 --tensor切片

Pytorch笔记目录:点位进入

1. indexing 索引

类似于list的索引操作,tensor也可以使用类似的方法获取tensor中的数值

# create a 4-dim  tensor
a = torch.rand(4,3,28,28)
print(a[0].shape)
out:
torch.Size([3, 28, 28])
print(a[0][0].shape)
out:
torch.Size([28, 28])
print(a[0,0,2,4])
out:
tensor(0.2295)

2. 切片

可以通过切片的方式获取tensor中的某一段数据

# select first/last N
a[:2].shape
out:
torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape
out:
torch.Size([2, 1, 28, 28])
a.shape
torch.Size([4, 3, 28, 28])
# select by steps
a[:,:,0:28:2,0:28:2].shape
torch.Size([4, 3, 14, 14])

3. …

…可以表示获取当前位置的所有,下面看一个事例

# ... all
a[...].shape
out:
torch.Size([4, 3, 28, 28])
a[0,...].shape
out:
torch.Size([3, 28, 28])

4. 通过掩码来处理

要实现mask操作首先我们需要创建一个规格和你要处理的tensor相同的tensor,然后再进行处理, 程序会输出所有mask值为1的地方

# select by mask
x = torch.randn(3,4)
mask = x.ge(0.5)
print(torch.masked_select(x,mask))
out:
tensor([1.0916, 0.6544, 0.9824, 1.4880, 0.5094])

flatten index

通过散列的index来获取数值,这种方法的规则就是把多维的向量打平进行计算,当然不会经常用到,3维以上就很难计算了

# select by flatten index
src = torch.tensor([[4,3,5],
[6,7,8]])
print(torch.take(src,torch.tensor([0,2,5])))
out:
tensor([4, 5, 8])
原创文章 113 获赞 80 访问量 3万+

猜你喜欢

转载自blog.csdn.net/python_LC_nohtyp/article/details/106016469