torch.stack(), torch.cat()用法详解

torch.stack(), torch.cat()用法详解

if __name__ == '__main__':
    import torch
    x_dat = torch.tensor([[1, 2], [3,4], [5,6]], dtype=torch.float)
    y_dat = torch.tensor([[10, 20], [30,40], [50,60]], dtype=torch.float)

    res=torch.stack((x_dat,y_dat),0)

    print(res)
    res = torch.stack((x_dat, y_dat), 1)

    print(res)
    res = torch.stack((x_dat, y_dat),2)

    print(res)

    res = torch.cat((x_dat, y_dat), 0)

    print(res)
    res = torch.cat((x_dat, y_dat), 1)

    print(res)
    res = torch.cat((x_dat, y_dat), 2)

    print(res)

stack 是合并,但是内容单元不变。

cat是追加,内容尺寸会变化

发布了2853 篇原创文章 · 获赞 1112 · 访问量 581万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/105557559