torch.cat与torch.chunk的使用

本文转载于知乎上Anthony Eden的pytorch专栏,链接点这里。感谢作者Anthony Eden。


torch.cat ( (A, B), dim=0)接受一个由两个(或多个)tensor组成的元组,按行拼接,所以两个(多个)tensor的列数要相同:
在这里插入图片描述
在这里插入图片描述
torch.cat ( (A, B), dim=1)是按列拼接,所以两个tensor的行数要相同:
在这里插入图片描述
在这里插入图片描述
torch.chunk(tensor, chunk_num, dim)与torch.cat()原理相反,它是将tensor按dim(行或列)分割成chunk_num个tensor块,返回的是一个元组。

a = torch.Tensor([[1,2,4]])
b = torch.Tensor([[4,5,7], [3,9,8], [9,6,7]])
c = torch.cat((a,b), dim=0)
print(c)
print(c.size())
print('********************')
d = torch.chunk(c,4,dim=0)
print(d)
print(len(d))

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/u011913417/article/details/111297300