torch.split()方法

一、方法详解
含义:将一个张量分为几个chunks

torch.split(tensor, split_size_or_sections, dim=0)
1
tensor:要分的张量
split_size_or_sections:
如果该项参数的值为一个int类型的value值,那么该方法会将tensor划分为同等数量的张量;如果tensor的size沿着给定的不能整除split_size,那么最后一个chunk相较于其它chunk小;
如果是一个list列表,该方法会将tensor划分为len(split_size_or_sections)的张量。
dim:划分张量所依据的维度
return:返回的是一个tuple
这样纯靠文字解释,是很抽象的,我们直接通过以下的案例来对这个方法进行彻底的掌握。

二、案例
案例1
    import torch
    # 创建一个张量
    x = torch.arange(10).reshape(5, 2)
    print(x)

y = torch.split(x, 2)
print(y, type(y))
 


案例2
x = torch.arange(10).reshape(5, 2)
print(x)
y = torch.split(x, [2, 3])
print(y)
 


案例3
x = torch.arange(10).reshape(5, 2)
print(x)
y = torch.split(x, [1, 4])
print(y)
 


原文链接:https://blog.csdn.net/dongjinkun/article/details/115375847

torch.split,用来划分tensor,可以从数量上划分,还有维度上划分。
torch.split(tensor,split_szie,dim),split_size有整数,也有列表,dim默认为0,自己也可以修改。
代码示例:

import torch
a=torch.tensor([[[1,2,3],[4,5,6]],
                [[7,8,9],[10,11,12]]])
print("a的shape:",a.shape)
#在第0维上进行split
b=torch.split(a,1)
print("b:",b)
#在第1维上进行split
c=torch.split(a,[1,1],1)
print("c:",c)


输出:

a的shape: torch.Size([2, 2, 3])
b: (tensor([[[1, 2, 3],
         [4, 5, 6]]]),
          tensor([[[ 7,  8,  9],
         [10, 11, 12]]]))
c: (tensor([[[1, 2, 3]],

        [[7, 8, 9]]]), 
        tensor([[[ 4,  5,  6]],

        [[10, 11, 12]]]))
————————————————
版权声明:本文为CSDN博主「江南汪」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_47156261/article/details/116599161

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/125361745