Tensor.scatter()

 
1
2
3
4
5
6
labels = torch.tensor([[3],
                       [5],
                       [4],
                       [2]])
onehot = torch.zeros(len(labels), max(labels)+1)
onehot
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
1
onehot.scatter(dim=1, index=labels, src=1)
tensor([[0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0.]])

대상 tensor(onehot)에 axis(1)와 index tensor(labels)로 지정된 위치에 src의 값(1)을 채워넣으라는 의미이다.

위의 예시의 경우,

  1. 각 행에 대하여 (dim=1)
  2. 첫 번째 행은 [3], 두 번째 행은 [5], 세 번째 행은 [4], 네 번째 행은 [2]의 위치에
  3. src=1의 값을 채워넣어라

이를 통해 onehot encoding을 간단히 구현할 수 있다.