交叉熵的数学推导和手撕代码

交叉熵的数学推导和手撕代码

数学推导

在这里插入图片描述

手撕代码

import torch
import torch.nn.functional as F

# 二元交叉熵损失函数
def binary_cross_entropy(predictions, targets):
    # predictions应为sigmoid函数的输出,即概率值
    # targets应为0或1的二进制标签
    loss = -torch.mean(targets * torch.log(predictions) + (1 - targets) * torch.log(1 - predictions))
    return loss

# 多元交叉熵损失函数(使用softmax处理predictions)
def categorical_cross_entropy(predictions, targets):
    # predictions应为softmax函数的输出,即各类的概率分布
    # targets应为类别的索引(整数),通常使用torch.nn.functional.cross_entropy直接计算更为简便
    # 但为了演示,这里手动实现
    predictions = F.softmax(predictions, dim=1)  # 确保predictions是经过softmax处理的
    targets = F.one_hot(targets, num_classes=predictions.shape[1]).float()  # 将targets转换为one-hot编码
    loss = -torch.mean(torch.sum(targets * torch.log(predictions + 1e-9), dim=1))  # 加上小常数防止log(0)
    return loss

# 示例
if __name__ == "__main__":
    # 假设有10个样本,每个样本预测为二分类问题的概率
    predictions_binary = torch.randn(10, 1, requires_grad=True)
    targets_binary = torch.randint(0, 2, (10, 1)).float()
    
    # 计算二元交叉熵
    loss_binary = binary_cross_entropy(torch.sigmoid(predictions_binary), targets_binary)
    print(f"Binary Cross Entropy Loss: {
      
      loss_binary.item()}")
    
    # 假设有10个样本,每个样本预测为3分类问题的原始分数
    predictions_categorical = torch.randn(10, 3, requires_grad=True)
    targets_categorical = torch.randint(0, 3, (10,))
    
    # 计算多元交叉熵
    loss_categorical = categorical_cross_entropy(predictions_categorical, targets_categorical)
    print(f"Categorical Cross Entropy Loss: {
      
      loss_categorical.item()}")

    # 注意:在实际应用中,通常直接使用torch.nn.CrossEntropyLoss()来计算多元交叉熵,因为它内部已经包含了softmax操作

猜你喜欢

转载自blog.csdn.net/GamBleout/article/details/142731711