Output gradient

metatrain.utils.output_gradient.compute_gradient(target: Tensor, inputs: List[Tensor], is_training: bool) List[Tensor][source]

Calculates the gradient of a target tensor with respect to a list of input tensors.

target must be a single torch.Tensor object. If target contains multiple values, the gradient will be calculated with respect to the sum of all values.

Parameters:
  • target (Tensor) – The tensor for which the gradient is to be computed.

  • inputs (List[Tensor]) – A list of tensors with respect to which the gradient is computed.

  • is_training (bool) – A boolean indicating whether the model is in training mode. If True, the computation graph is retained for further gradient computations. If False, the graph is not retained, which saves memory.

Returns:

A list of tensors representing the gradients of the target with respect to each input

Return type:

List[Tensor]