Объясните простым языком функцию Pytorch - torch.no_grad

Категория: Общий каталог "Простые функции Pytorch"


Менеджер контекста, который отключает вычисление градиента. Tensor.backward()Отключение вычисления градиента полезно для логического вывода, когда мы уверены, что оно не будет вызываться. Это уменьшит потребление памяти вычислениями, иначе нам нужно установить requires_grad=True. В этом режиме результатом каждого вычисления будет , даже если requires_gradна входе . Этот диспетчер контекста является локальным для потока и не влияет на вычисления в других потоках. В то же время этот класс может выступать и в роли декоратора.Truerequires_grad=False

грамматика

torch.no_grad()

пример

x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
    y = x * 2
y.requires_grad
.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad

реализация функции

class no_grad(_DecoratorContextManager):
    r"""Context-manager that disabled gradient calculation.

    Disabling gradient calculation is useful for inference, when you are sure
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
    consumption for computations that would otherwise have `requires_grad=True`.

    In this mode, the result of every computation will have
    `requires_grad=False`, even when the inputs have `requires_grad=True`.

    This context manager is thread local; it will not affect computation
    in other threads.

    Also functions as a decorator. (Make sure to instantiate with parenthesis.)

    .. note::
        No-grad is one of several mechanisms that can enable or
        disable gradients locally see :ref:`locally-disable-grad-doc` for
        more information on how they compare.

    .. note::
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
        If you want to disable forward AD for a computation, you can unpack
        your dual tensors.

    Example::
        >>> # xdoctest: +SKIP
        >>> x = torch.tensor([1.], requires_grad=True)
        >>> with torch.no_grad():
        ...     y = x * 2
        >>> y.requires_grad
        False
        >>> @torch.no_grad()
        ... def doubler(x):
        ...     return x * 2
        >>> z = doubler(x)
        >>> z.requires_grad
        False
    """
    def __init__(self) -> None:
        if not torch._jit_internal.is_scripting():
            super().__init__()
        self.prev = False

    def __enter__(self) -> None:
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

Guess you like

Origin blog.csdn.net/hy592070616/article/details/132029988