pytorch one-hot编码

转载pytorch one-hot编码_GXLiu-CSDN博客

方案一: 使用scatter_将标签转换为one-hot

import torch

num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = torch.zeros((len(label), num_class)).scatter_(1, label.long().reshape(-1, 1), 1)
print(one_hot)
"""
tensor([[1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0.]])
"""

方案二: F.onehot 自动实现

import torch.nn.functional as F
import torch

num_class = 5
label = torch.tensor([0, 2, 1, 4, 1, 3])
one_hot = F.one_hot(label, num_classes=num_class )
print(one_hot)
"""
tensor([[1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0]])
"""

猜你喜欢

转载自blog.csdn.net/zjc910997316/article/details/121437348