Pytorch에서 분산() 및 분산_() 함수의 사용법과 차이점

Tensor.scatter_(dim, index, src, Reduce=None) → Tensor
index에 따라 src에 있는 값을 self에 쓰는 기능이고, im이 차원을 결정합니다.
여기서 주의할 점은 self의 dtype이 다음과 같아야 한다는 것입니다. src의 dtype과 동일 !!!예:

torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)

여기서 self의 dtype은 src의 dtype과 동일해야 합니다.
함수의 역할은 3D 텐서의 예입니다.

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

구체적인 예를 들면 다음과 같습니다.

src = torch.arange(1, 11).reshape((2, 5))
# tensor([[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10]])
index = torch.tensor([[0, 1, 2, 0],
					  [1, 0, 1, 2]])
# tensor([[0, 1, 2, 0],
#		  [1, 0, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 7, 0, 4, 0],
#         [6, 2, 8, 0, 0],
#         [0, 0, 3, 9, 0]])

# 分析:index的i取值为0-1,j的取值从0-3都可以
# self[index[0][0]][0] = self[0][0] = src[0][0] = 1
# self[index[0][1]][1] = self[1][1] = src[0][1] = 2
# self[index[0][2]][2] = self[2][2] = src[0][2] = 3
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# self[index[1][0]][0] = self[1][0] = src[1][0] = 6
# self[index[1][1]][1] = self[0][1] = src[1][1] = 7
# self[index[1][2]][2] = self[1][2] = src[1][2] = 8
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9

여기서 또 재미있는 점은 위 상황은 겹치지 않는 경우인데, 인덱스의 위쪽 행과 아래쪽 행에 다음과 같이 겹치는 요소가 있다고 가정해 보겠습니다.

index = torch.tensor([[0, 1, 2, 0],
					  [1, 0, 1, 0]])

첫 번째 행의 마지막 요소는 두 번째 행의 마지막 요소와 동일하며 둘 다 0입니다. (앞 두 번째 줄의 마지막 요소는 2입니다.)
이 경우 위의 값은

# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9
变为了
# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[0][3] = src[1][3] = 9

self[0][3]에는 2개의 할당이 있고, 하나는 i=0, j=3에 따라 4개가 할당되고, 다른 하나는 i=1, j=3에 따라 9개가 할당된 것을 알 수 있습니다. , 9는 4를 포함하므로 최종 결과는 다음과 같습니다.

tensor([[1, 7, 0, 9, 0],
        [6, 2, 8, 0, 0],
        [0, 0, 3, 0, 0]])

Scatter()와 Scatter_()의 차이점은 Scatter_()가 내부에서 작동한다는 것입니다.
예를 들어 b = a.scatter(dim, index, src) 이후에는 a의 값이 변경되지 않습니다
. 상대적으로 b = a.scatter_(dim, index, src) 이후에는 a의 값이 변경되어 b와 같아집니다.

추천

출처blog.csdn.net/qq_43666068/article/details/130860504