PyTorch中 detach() 、detach_()和 data 的区别

1.介绍

    在使用PyTorch的过程中,我们经常会遇到detach() 、detach_()和 data这三种类别,如果你不详细分析它们的使用场所,的确是很容易让人懵逼。

    1)detach()与detach_()

    在x->y->z传播中,如果我们对y进行detach(),梯度还是能正常传播的,但如果我们对y进行detach_(),就把x->y->z切成两部分:x和y->z,x就无法接受到后面传过来的梯度

    2)detach()和data

    共同点:x.data(x.detach()) 返回和 x 的相同数据 tensor, 这个新的tensor和原来的tensor(即x)是共用数据的,一者改变,另一者也会跟着改变,且require s_grad = False

    不同点: x.data 不能被 autograd 追踪求微分,即使被改了也能错误求导,而x.detach()则不行

2.代码

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a)
out = a.tanh()
print(out)
c = out.data  # 需要走注意的是,通过.data “分离”得到的的变量会和原来的变量共用同样的数据,而且新分离得到的张量是不可求导的,c发生了变化,原来的张量也会发生变化
c.zero_() # 改变c的值,原来的out也会改变
print(c.requires_grad)
print(c)
print(out.requires_grad)
print(out)
print("----------------------------------------------")

out.sum().backward()  # 对原来的out求导,
print(a.grad)  # 不会报错,但是结果却并不正确

#输出
tensor([1., 2., 3.], requires_grad=True)
tensor([0.7616, 0.9640, 0.9951], grad_fn=<TanhBackward>)
False
tensor([0., 0., 0.])
True
tensor([0., 0., 0.], grad_fn=<TanhBackward>)
----------------------------------------------
tensor([1., 1., 1.])
import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a)
out = a.tanh()
print(out)
c = out.detach()  # 需要走注意的是,通过.detach() “分离”得到的的变量会和原来的变量共用同样的数据,而且新分离得到的张量是不可求导的,c发生了变化,原来的张量也会发生变化
c.zero_()  # 改变c的值,原来的out也会改变
print(c.requires_grad)
print(c)
print(out.requires_grad)
print(out)
print("----------------------------------------------")

out.sum().backward()  # 对原来的out求导,
print(a.grad)  # 此时会报错,错误结果参考下面,显示梯度计算所需要的张量已经被“原位操作inplace”所更改了。

# 输出
tensor([1., 2., 3.], requires_grad=True)
tensor([0.7616, 0.9640, 0.9951], grad_fn=<TanhBackward>)
False
tensor([0., 0., 0.])
True
tensor([0., 0., 0.], grad_fn=<TanhBackward>)
----------------------------------------------
Traceback (most recent call last):
  File "E:/python/TCL/entropy_coding_project/test_code/test27.py", line 15, in <module>
    out.sum().backward()  # 对原来的out求导,
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of TanhBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
发布了138 篇原创文章 · 获赞 141 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/u013289254/article/details/102557070