qml.math.grad

grad(f, argnums=0)[source]

Compute the gradient in a jax-like manner for any interface.

Parameters:
  • f (Callable) – a function with a single 0-D scalar output

  • argnums (Sequence[int] | int) – which arguments to differentiate

Returns:

a function with the same signature as f that returns the gradient.

Return type:

Callable

Note that this function follows the same design as jax. By default, the function will return the gradient of the first argument, whether or not other arguments are trainable.

>>> import jax, torch, tensorflow as tf
>>> def f(x, y):
...     return  x * y
>>> qml.math.grad(f)(qml.numpy.array(2.0), qml.numpy.array(3.0))
tensor(3., requires_grad=True)
>>> qml.math.grad(f)(jax.numpy.array(2.0), jax.numpy.array(3.0))
Array(3., dtype=float32, weak_type=True)
>>> qml.math.grad(f)(torch.tensor(2.0, requires_grad=True), torch.tensor(3.0, requires_grad=True))
tensor(3.)
>>> qml.math.grad(f)(tf.Variable(2.0), tf.Variable(3.0))
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>

argnums can be provided to differentiate multiple arguments.

>>> qml.math.grad(f, argnums=(0,1))(torch.tensor(2.0, requires_grad=True), torch.tensor(3.0, requires_grad=True))
(tensor(3.), tensor(2.))

Note that the selected arguments must be of an appropriately trainable datatype, or an error may occur.

>>> qml.math.grad(f)(torch.tensor(1.0), torch.tensor(2.))
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn