What does "with torch no_grad" do in PyTorch?

The use of "with" torch.no_grad() is like a loop, where each tensor inside the loop has requires_grad set to False . This means that any tensors with gradients that are currently connected to the current computation graph are now detached from the current graph. We can no longer compute gradients with respect to this tensor.

A tensor is detached from the current graph until it is inside a loop. Once it leaves the loop, if the tensor was defined with gradients, it is appended to the current graph again.

Let's take a few examples to better understand how it works.

Example 1

In this example, we created a tensor x with requires_grad=true . Next, we define a function y of this tensor x and put the function in a with loop. Now x is inside the loop, so its requires_grad is set to False . torch.no_grad()

In a loop, the gradient of y with respect to x cannot be calculated. So, y.requires_grad returns False .

# import torch library
import torch

# define a torch tensor
x = torch.tensor(2., requires_grad = True)
print("x:", x)

# define a function y
with torch.no_grad():
   y = x ** 2
print("y:", y)

# check gradient for Y
print("y.requires_grad:", y.requires_grad)

"""
输出结果
x: tensor(2., requires_grad=True)
y: tensor(4.)
y.requires_grad: False
"""
 
 

Example 2

In this example, we define the function z outside the loop . So, z.requires_grad returns True .

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)

print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# define a function z
with torch.no_grad():
   z = w * x + b

print("z:", z)

# check if requires grad is true or not
print("y.requires_grad:", y.requires_grad)
print("z.requires_grad:", z.requires_grad)

"""
输出结果
x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
z: tensor(7.)
y.requires_grad: True
z.requires_grad: False
"""

Guess you like

Origin blog.csdn.net/m0_52848925/article/details/131179576