torch_scatter的scatter函数是什么意思

直接上例子

from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.tensor([0, 2, 0], dtype=torch.int64)
scatter(src, index, dim=0, reduce='mean')

以上整个操作的意思是,把src中的第0行和第2行做平均后,放在新tensor的第0行,
把src的第1行放在新tensor的第2行,最后,第一行用0补充空缺,最终输出tensor有三行。

例子二

from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.tensor([1, 0, 1], dtype=torch.int64)
scatter(src, index, dim=0, reduce='max')

以上整个操作的意思是,把src中的第0行和第2行取最大后,放在新tensor的第1行,
把src的第1行放在新tensor的第0行,最终输出tensor只有两行。

猜你喜欢

转载自blog.csdn.net/zmhzmhzm/article/details/131333252
今日推荐