1
2
3
4
5
6
labels = torch.tensor([[3],
[5],
[4],
[2]])
onehot = torch.zeros(len(labels), max(labels)+1)
onehot
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]])
1
onehot.scatter(dim=1, index=labels, src=1)
tensor([[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 1., 0., 0., 0.]])
대상 tensor(onehot
)에 axis(1
)와 index tensor(labels
)로 지정된 위치에 src의 값(1
)을 채워넣으라는 의미이다.
위의 예시의 경우,
- 각 행에 대하여 (
dim=1
) - 첫 번째 행은
[3]
, 두 번째 행은[5]
, 세 번째 행은[4]
, 네 번째 행은[2]
의 위치에 src=1
의 값을 채워넣어라
이를 통해 onehot encoding을 간단히 구현할 수 있다.
PREVIOUSEtc