catalyst.grad

grad(fn=None, *, method=None, h=None, argnums=None)[source]

A qjit() compatible gradient transformation for PennyLane/Catalyst.

This function allows the gradient of a hybrid quantum-classical function to be computed within the compiled program. Outside of a compiled function, this function will simply dispatch to its JAX counterpart jax.grad. The function f can return any pytree-like shape.

Warning

Currently, higher-order differentiation is only supported by the finite-difference method.

Parameters
  • fn (Callable) – a function or a function object to differentiate

  • method (str) –

    The method used for differentiation, which can be any of ["auto", "fd"], where:

    • "auto" represents deferring the quantum differentiation to the method specified by the QNode, while the classical computation is differentiated using traditional auto-diff. Catalyst supports "parameter-shift" and "adjoint" on internal QNodes. Notably, QNodes with diff_method="finite-diff" is not supported with "auto".

    • "fd" represents first-order finite-differences for the entire hybrid function.

  • h (float) – the step-size value for the finite-difference ("fd") method

  • argnums (Tuple[int, List[int]]) – the argument indices to differentiate

Returns

A callable object that computes the gradient of the wrapped function for the given

arguments.

Return type

Callable

Raises
  • ValueError – Invalid method or step size parameters.

  • DifferentiableCompilerError – Called on a function that doesn’t return a single scalar.

Note

Any JAX-compatible optimization library, such as Optax, can be used alongside grad for JIT-compatible variational workflows. See the Quick Start for examples.

Example 1 (Classical preprocessing)

dev = qml.device("lightning.qubit", wires=1)

@qjit
def workflow(x):
    @qml.qnode(dev)
    def circuit(x):
        qml.RX(jnp.pi * x, wires=0)
        return qml.expval(qml.PauliY(0))

    g = grad(circuit)
    return g(x)
>>> workflow(2.0)
Array(-3.14159265, dtype=float64)

Example 2 (Classical preprocessing and postprocessing)

dev = qml.device("lightning.qubit", wires=1)

@qjit
def grad_loss(theta):
    @qml.qnode(dev, diff_method="adjoint")
    def circuit(theta):
        qml.RX(jnp.exp(theta ** 2) / jnp.cos(theta / 4), wires=0)
        return qml.expval(qml.PauliZ(wires=0))

    def loss(theta):
        return jnp.pi / jnp.tanh(circuit(theta))

    return catalyst.grad(loss, method="auto")(theta)
>>> grad_loss(1.0)
Array(-1.90958669, dtype=float64)

Example 3 (Multiple QNodes with their own differentiation methods)

dev = qml.device("lightning.qubit", wires=1)

@qjit
def grad_loss(theta):
    @qml.qnode(dev, diff_method="parameter-shift")
    def circuit_A(params):
        qml.RX(jnp.exp(params[0] ** 2) / jnp.cos(params[1] / 4), wires=0)
        return qml.probs()

    @qml.qnode(dev, diff_method="adjoint")
    def circuit_B(params):
        qml.RX(jnp.exp(params[1] ** 2) / jnp.cos(params[0] / 4), wires=0)
        return qml.expval(qml.PauliZ(wires=0))

    def loss(params):
        return jnp.prod(circuit_A(params)) + circuit_B(params)

    return catalyst.grad(loss)(theta)
>>> grad_loss(jnp.array([1.0, 2.0]))
Array([ 0.57367285, 44.4911605 ], dtype=float64)

Example 4 (Purely classical functions)

def square(x: float):
    return x ** 2

@qjit
def dsquare(x: float):
    return catalyst.grad(square)(x)
>>> dsquare(2.3)
Array(4.6, dtype=float64)