[Pytorch] 前向传播和反向传播示例

目录

简介

神经网络训练基本步骤

1. 计算图

2. 前向传播 Forward

3. 计算损失Loss 【损失函数】

4. 反向传播 Backward

5. 使用学习率更新权重【优化器】

样例代码

样例结果

样例图解


简介

PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序。Pytorch提供了两个高级功能:

  1. 具有强大的GPU加速的张量计算(Numpy的替代品)
  2. 包含自动求导系统的深度神经网络

神经网络训练基本步骤

1. 计算图

组成:由节点和边组成,节点分为Tensor和Function(运算)

  • Tensor分为叶子节点和非叶节点
  • Pytorch计算图是动态图

2. 前向传播 Forward

操作:根据输入数据进行推测。创建Fucntion后可以立即执行,不需要等到计算图定义好之后再执行。

3. 计算损失Loss 【损失函数】

操作:计算前向推测结果与真实值之间的误差

4. 反向传播 Backward

操作:将Loss向输入侧进行反向传播,对所有需要进行梯度计算的所有变量 leaf node Tensor x(requires_grad=True),计算梯度 dLoss/dx,并将其积累到梯度x.grad中备用, 即:x.grad = x.grad + dLoss/dx

5. 使用学习率更新权重【优化器】

操作:使用优化器对x的值进行更新。优化器会根据用户设置的学习率以及x.grad来更新x。

如:随机梯度下降SGD,x = x - learning_rate * x.grad

样例代码

def test_training_pipeline():    
    # ============================================================================ 1.创建计算图
    # ============================================================================ 2.前向传播(即时计算)
    input_data = [[4, 4, 4, 4],
                  [9, 9, 9, 9]]  # 2x4
    input = torch.tensor(input_data, dtype=torch.float32, requires_grad=True)
    output = torch.sqrt(input)
    print("\n### 前向传播推测结果:\n", output)

    # ============================================================================ 3.计算Loss
    target_data = [1, 2, 3, 4]
    target = torch.tensor(target_data, dtype=torch.float32)
    
    loss_fn = torch.nn.MSELoss()
    loss = loss_fn(input=output, target=target)
    print("\n### loss:\n", loss)
       
    # ============================================================================ 4.反向传播
    loss.backward()
    print("\n### input_grad:\n", input.grad)
    
    # ============================================================================ 5.更新input
    optim = torch.optim.SGD([input], lr=0.001)
    print("\n### input before optim.step():\n", input)
    optim.step()
    print("\n### input after optim.step():\n", input)

样例结果

样例图解

图解和手动计算前向传播和反向传播。

参考

理解Pytorch的loss.backward()和optimizer.step() - 知乎

猜你喜欢

转载自blog.csdn.net/zmj1582188592/article/details/129321426