随机打乱函数 torch.randperm的使用(类似tensorflow中的tf.random_shuffle)

例:一维 

x = torch.tensor([2,5,16,10,0,5618,81,8,18])
indices = torch.randperm(x.numel())
shuffled_x = x[indices]
print(shuffled_x)
# tensor([   2,   18, 5618,   81,    5,   10,    8,    0,   16])

indices = torch.randperm(x.numel())
shuffled_x = x.view(-1)[indices].view(x.size())

在这里,torch.randperm(x.numel()) 函数生成从 0 到 x.numel()-1 的整数序列随机排列,这个序列的长度等于 x 的元素个数。随机排列后的序列可以用来索引 x.view(-1) (将 x 张量展平为一个一维张量),以此来打乱 x 中的元素顺序。最后,通过 view(x.size()) 将张量恢复到原来的形状。

如果要打乱一个多维张量的元素顺序,需要在 view 函数中使用多个参数,如下所示:

indices = torch.randperm(x.numel())
shuffled_x = x.view(x.size()[0], -1)[indices].view(x.size())

猜你喜欢

转载自blog.csdn.net/djdjdhch/article/details/130633915