catalyst.grad¶
- grad(fn=None, *, method=None, h=None, argnums=None)[source]¶
A
qjit()
compatible gradient transformation for PennyLane/Catalyst.This function allows the gradient of a hybrid quantum-classical function to be computed within the compiled program. Outside of a compiled function, this function will simply dispatch to its JAX counterpart
jax.grad
. The functionf
can return any pytree-like shape.Warning
Currently, higher-order differentiation is only supported by the finite-difference method.
- Parameters
fn (Callable) – a function or a function object to differentiate
method (str) –
The method used for differentiation, which can be any of
["auto", "fd"]
, where:"auto"
represents deferring the quantum differentiation to the method specified by the QNode, while the classical computation is differentiated using traditional auto-diff. Catalyst supports"parameter-shift"
and"adjoint"
on internal QNodes. Notably, QNodes withdiff_method="finite-diff"
is not supported with"auto"
."fd"
represents first-order finite-differences for the entire hybrid function.
h (float) – the step-size value for the finite-difference (
"fd"
) methodargnums (Tuple[int, List[int]]) – the argument indices to differentiate
- Returns
- A callable object that computes the gradient of the wrapped function for the given
arguments.
- Return type
Callable
- Raises
ValueError – Invalid method or step size parameters.
DifferentiableCompilerError – Called on a function that doesn’t return a single scalar.
Note
Any JAX-compatible optimization library, such as Optax, can be used alongside
grad
for JIT-compatible variational workflows. See the Quick Start for examples.See also
Example 1 (Classical preprocessing)
dev = qml.device("lightning.qubit", wires=1) @qjit def workflow(x): @qml.qnode(dev) def circuit(x): qml.RX(jnp.pi * x, wires=0) return qml.expval(qml.PauliY(0)) g = grad(circuit) return g(x)
>>> workflow(2.0) Array(-3.14159265, dtype=float64)
Example 2 (Classical preprocessing and postprocessing)
dev = qml.device("lightning.qubit", wires=1) @qjit def grad_loss(theta): @qml.qnode(dev, diff_method="adjoint") def circuit(theta): qml.RX(jnp.exp(theta ** 2) / jnp.cos(theta / 4), wires=0) return qml.expval(qml.PauliZ(wires=0)) def loss(theta): return jnp.pi / jnp.tanh(circuit(theta)) return catalyst.grad(loss, method="auto")(theta)
>>> grad_loss(1.0) Array(-1.90958669, dtype=float64)
Example 3 (Multiple QNodes with their own differentiation methods)
dev = qml.device("lightning.qubit", wires=1) @qjit def grad_loss(theta): @qml.qnode(dev, diff_method="parameter-shift") def circuit_A(params): qml.RX(jnp.exp(params[0] ** 2) / jnp.cos(params[1] / 4), wires=0) return qml.probs() @qml.qnode(dev, diff_method="adjoint") def circuit_B(params): qml.RX(jnp.exp(params[1] ** 2) / jnp.cos(params[0] / 4), wires=0) return qml.expval(qml.PauliZ(wires=0)) def loss(params): return jnp.prod(circuit_A(params)) + circuit_B(params) return catalyst.grad(loss)(theta)
>>> grad_loss(jnp.array([1.0, 2.0])) Array([ 0.57367285, 44.4911605 ], dtype=float64)
Example 4 (Purely classical functions)
def square(x: float): return x ** 2 @qjit def dsquare(x: float): return catalyst.grad(square)(x)
>>> dsquare(2.3) Array(4.6, dtype=float64)