the backward function

class Sigmoid(Function):

    @staticmethod

    def forward(ctx, x):

            output = 1/(1 + t.exp(-x))

            ctx.save_for_backward(output)

            return output

    @staticmethond

    def backward(ctx,  grad_output):

        output, = ctx.saved_variables

        grad_x = output * (1 - output) * grad_output

        return grad_x



def f_sigmoid(x):

    y = Sigmoid.apply(x)

    y.backward(t.ones(x.size()))



the backward part in f_sigmoid function has optimized the process of backward

猜你喜欢

转载自blog.csdn.net/doublechenchenchen/article/details/80848986