目录
1--前言
构建可连续求导的神经网络时,往往会继承 nn.Module 类,此时只需重写 __init__ 和 forward 函数即可,pytorch会自动求导;
构建不可连续求导的神经网络时,可以继承 torch.autograd.Function 类,此时需要重写 forward 函数和 backward 函数,其中 backward 函数的作用是返回求导结果。
2--代码实例
2-1--e^x函数实现
import torch
from torch.autograd import gradcheck
# forward 和 backward函数 都必须声明为静态方法
class Exp(torch.autograd.Function):
@staticmethod
def forward(ctx, i): # 必须有ctx,ctx表示上下文管理器,一般用来保存在backward阶段会用到的tensor
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output): # grad_output为上一层累积的梯度(求导结果)
result, = ctx.saved_tensors # 调用ctx获取forward的tensor
return grad_output * result # 返回求导结果,本函数实现e^i,对i求导的结果还是e^i,即求导结果为result
if __name__ == "__main__":
test_input = torch.ones(1, requires_grad = True)
output = Exp.apply(test_input) # 调用
print(output) # tensor([2.7183]
test_grad = gradcheck(Exp.apply, test_input, eps = 1e-3) # 检查求导是否正确,正确则返回true
print(test_grad) # True
2-2--linear函数实现
import torch
from torch.autograd import gradcheck
# forward 和 backward函数 都必须声明为静态方法
class LinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias) # 记录需要传递给backward函数的tensor
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output): # grad_output表示上一层累积的梯度
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]: # 当input需要求导
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]: # 当weight需要求导
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]: #当grad_bias需要求导
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias # 返回对input, weight, bias求导的结果
if __name__ == "__main__":
test_input = torch.ones(1, 1, dtype = torch.double, requires_grad = True)
weight_input = torch.ones(1, 1, dtype = torch.double,requires_grad = True)
output = LinearFunction.apply(test_input, weight_input)
print(output) # 1
print(output.shape) # [1, 1]
test_grad = gradcheck(LinearFunction.apply, (test_input, weight_input), eps=1e-6, atol=1e-4)
print(test_grad) # True