pytorch中data和detach()的区别

"However, .data can be unsafe in some cases. Any changes on x.data wouldn’t be tracked by autograd, and the computed gradients would be incorrect if x is needed in a backward pass. A safer alternative is to use x.detach(), which also returns a Tensor that shares data with requires_grad=False, but will have its in-place changes reported by autograd if x is needed in backward."

Any in-place change on x.detach() will cause errors when x is needed in backward, so .detach() is a safer way for the exclusion of subgraphs from gradient computation. https://github.com/pytorch/pytorch/issues/6990

Here's an example. If you use detach() instead of .data, gradient computation is guaranteed to be correct..

>>> a = torch.tensor([1,2,3.], requires_grad = True)
>>> out = a.sigmoid()
>>> c = out.detach()
>>> c.zero_()  
tensor([ 0.,  0.,  0.])

>>> out  # modified by c.zero_() !!
tensor([ 0.,  0.,  0.])

>>> out.sum().backward()  # Requires the original value of out, but that was overwritten by c.zero_()
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

As opposed to using .data:

>>> a = torch.tensor([1,2,3.], requires_grad = True)
>>> out = a.sigmoid()
>>> c = out.data
>>> c.zero_()
tensor([ 0.,  0.,  0.])

>>> out  # out  was modified by c.zero_()
tensor([ 0.,  0.,  0.])

>>> out.sum().backward()
>>> a.grad  # The result is very, very wrong because `out` changed!
tensor([ 0.,  0.,  0.])

I'll leave this issue open: we should add an example to the migration guide and clarify that section.

猜你喜欢

转载自blog.csdn.net/xiaojiajia007/article/details/82183444