register_full_backward_hook
- UNet.register_full_backward_hook(hook: Callable[[Module, Tuple[Tensor, ...] | Tensor, Tuple[Tensor, ...] | Tensor], None | Tuple[Tensor, ...] | Tensor], prepend: bool = False) RemovableHandle
Registers a backward hook on the module.
The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
The
grad_input
andgrad_output
are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place ofgrad_input
in subsequent computations.grad_input
will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries ingrad_input
andgrad_output
will beNone
for all non-Tensor arguments.For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.
Warning
Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.
- Parameters:
hook (Callable) – The user-defined hook to be registered.
prepend (bool) – If true, the provided
hook
will be fired before all existingbackward
hooks on thistorch.nn.modules.Module
. Otherwise, the providedhook
will be fired after all existingbackward
hooks on thistorch.nn.modules.Module
. Note that globalbackward
hooks registered withregister_module_full_backward_hook()
will fire before all hooks registered by this method.
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()
- Return type:
torch.utils.hooks.RemovableHandle