数学推导
手撕代码
import torch
import torch.nn.functional as F
def binary_cross_entropy(predictions, targets):
loss = -torch.mean(targets * torch.log(predictions) + (1 - targets) * torch.log(1 - predictions))
return loss
def categorical_cross_entropy(predictions, targets):
predictions = F.softmax(predictions, dim=1)
targets = F.one_hot(targets, num_classes=predictions.shape[1]).float()
loss = -torch.mean(torch.sum(targets * torch.log(predictions + 1e-9), dim=1))
return loss
if __name__ == "__main__":
predictions_binary = torch.randn(10, 1, requires_grad=True)
targets_binary = torch.randint(0, 2, (10, 1)).float()
loss_binary = binary_cross_entropy(torch.sigmoid(predictions_binary), targets_binary)
print(f"Binary Cross Entropy Loss: {
loss_binary.item()}")
predictions_categorical = torch.randn(10, 3, requires_grad=True)
targets_categorical = torch.randint(0, 3, (10,))
loss_categorical = categorical_cross_entropy(predictions_categorical, targets_categorical)
print(f"Categorical Cross Entropy Loss: {
loss_categorical.item()}")