Pytorch学习笔记--torch.autograd.Function的使用

目录

1--前言

2--代码实例

2-1--e^x函数实现

2-2--linear函数实现

3--参考


1--前言

        构建可连续求导的神经网络时,往往会继承 nn.Module 类,此时只需重写 __init__ 和 forward 函数即可,pytorch会自动求导;

        构建不可连续求导的神经网络时,可以继承 torch.autograd.Function 类,此时需要重写 forward 函数和 backward 函数,其中 backward 函数的作用是返回求导结果。

2--代码实例

2-1--e^x函数实现

import torch
from torch.autograd import gradcheck

# forward 和 backward函数 都必须声明为静态方法
class Exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i): # 必须有ctx,ctx表示上下文管理器,一般用来保存在backward阶段会用到的tensor
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output): # grad_output为上一层累积的梯度(求导结果)
        result, = ctx.saved_tensors # 调用ctx获取forward的tensor
        return grad_output * result # 返回求导结果,本函数实现e^i,对i求导的结果还是e^i,即求导结果为result

if __name__ == "__main__":
    test_input = torch.ones(1, requires_grad = True)
    output = Exp.apply(test_input) # 调用
    print(output) # tensor([2.7183]
    test_grad = gradcheck(Exp.apply, test_input, eps = 1e-3) # 检查求导是否正确,正确则返回true
    print(test_grad) # True

2-2--linear函数实现

import torch
from torch.autograd import gradcheck

# forward 和 backward函数 都必须声明为静态方法
class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias) # 记录需要传递给backward函数的tensor
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output): # grad_output表示上一层累积的梯度
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]: # 当input需要求导
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]: # 当weight需要求导
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]: #当grad_bias需要求导
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias # 返回对input, weight, bias求导的结果

if __name__ == "__main__":
    test_input = torch.ones(1, 1, dtype = torch.double, requires_grad = True)
    weight_input = torch.ones(1, 1, dtype = torch.double,requires_grad = True)
    output = LinearFunction.apply(test_input, weight_input)
    print(output) # 1
    print(output.shape) # [1, 1]
    test_grad = gradcheck(LinearFunction.apply, (test_input, weight_input), eps=1e-6, atol=1e-4)
    print(test_grad) # True
    

3--参考

PyTorch 74.自定义操作torch.autograd.Function

猜你喜欢

转载自blog.csdn.net/weixin_43863869/article/details/134130973