qml.math.grad¶
- grad(f, argnums=0)[source]¶
Compute the gradient in a jax-like manner for any interface.
- Parameters
f (Callable) – a function with a single 0-D scalar output
argnums (Sequence[int] | int) – which arguments to differentiate
- Returns
a function with the same signature as
f
that returns the gradient.- Return type
Callable
See also
Note that this function follows the same design as jax. By default, the function will return the gradient of the first argument, whether or not other arguments are trainable.
>>> import jax, torch, tensorflow as tf >>> def f(x, y): ... return x * y >>> qml.math.grad(f)(qml.numpy.array(2.0), qml.numpy.array(3.0)) tensor(3., requires_grad=True) >>> qml.math.grad(f)(jax.numpy.array(2.0), jax.numpy.array(3.0)) Array(3., dtype=float32, weak_type=True) >>> qml.math.grad(f)(torch.tensor(2.0, requires_grad=True), torch.tensor(3.0, requires_grad=True)) tensor(3.) >>> qml.math.grad(f)(tf.Variable(2.0), tf.Variable(3.0)) <tf.Tensor: shape=(), dtype=float32, numpy=3.0>
argnums
can be provided to differentiate multiple arguments.>>> qml.math.grad(f, argnums=(0,1))(torch.tensor(2.0, requires_grad=True), torch.tensor(3.0, requires_grad=True)) (tensor(3.), tensor(2.))
Note that the selected arguments must be of an appropriately trainable datatype, or an error may occur.
>>> qml.math.grad(f)(torch.tensor(1.0), torch.tensor(2.)) RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
code/api/pennylane.math.grad
Download Python script
Download Notebook
View on GitHub