torch.narrow()

import torch.tensor 
x=torch.rand(5,6)

tensor([[0.4606, 0.0850, 0.8009, 0.3972, 0.9548, 0.5982],
        [0.4821, 0.9446, 0.5145, 0.8125, 0.3122, 0.9756],
        [0.8747, 0.7186, 0.3945, 0.4090, 0.8398, 0.7494],
        [0.8129, 0.3084, 0.3856, 0.0044, 0.3022, 0.1679],
        [0.3270, 0.7481, 0.7058, 0.7362, 0.4007, 0.3604]])
x.narrow(0,1,3)
tensor([[0.4821, 0.9446, 0.5145, 0.8125, 0.3122, 0.9756],
        [0.8747, 0.7186, 0.3945, 0.4090, 0.8398, 0.7494],
        [0.8129, 0.3084, 0.3856, 0.0044, 0.3022, 0.1679]])
x.narrow(1,2,4)
tensor([[0.8009, 0.3972, 0.9548, 0.5982],
        [0.5145, 0.8125, 0.3122, 0.9756],
        [0.3945, 0.4090, 0.8398, 0.7494],
        [0.3856, 0.0044, 0.3022, 0.1679],
        [0.7058, 0.7362, 0.4007, 0.3604]])
x.narrow(2,2,2) 
Traceback (most recent call last):
  Debug Probe, prompt 177, line 1

总结 x.narrow(0,1,3) 0表示 第0维度就是行 1表示 索引  3表示取的行数 

x.narrow(1,2,4)1 表示 列 2 表示 索引 4表示取得列数 

猜你喜欢

转载自blog.csdn.net/candy134834/article/details/86603939