catalyst.jvp¶
- jvp(f: Callable, params, tangents, *, method=None, h=None, argnums=None)[source]¶
A
qjit()
compatible Jacobian-vector product for PennyLane/Catalyst.This function allows the Jacobian-vector Product of a hybrid quantum-classical function to be computed within the compiled program. Outside of a compiled function, this function will simply dispatch to its JAX counterpart
jax.jvp
. The functionf
can return any pytree-like shape.- Parameters
f (Callable) – Function-like object to calculate JVP for
params (List[Array]) – List (or a tuple) of the function arguments specifying the point to calculate JVP at. A subset of these parameters are declared as differentiable by listing their indices in the
argnums
parameter.tangents (List[Array]) – List (or a tuple) of tangent values to use in JVP. The list size and shapes must match the ones of differentiable params.
method (str) – Differentiation method to use, same as in
grad()
.h (float) – the step-size value for the finite-difference (
"fd"
) methodargnums (Union[int, List[int]]) – the params’ indices to differentiate.
- Returns
Return values of
f
paired with the JVP values.- Return type
Tuple[Any]
- Raises
TypeError – invalid parameter types
ValueError – invalid parameter values
Example 1 (basic usage)
@qjit def jvp(params, tangent): def f(x): y = [jnp.sin(x[0]), x[1] ** 2, x[0] * x[1]] return jnp.stack(y) return catalyst.jvp(f, [params], [tangent])
>>> x = jnp.array([0.1, 0.2]) >>> tangent = jnp.array([0.3, 0.6]) >>> jvp(x, tangent) (Array([0.09983342, 0.04 , 0.02 ], dtype=float64), Array([0.29850125, 0.24 , 0.12 ], dtype=float64))
Example 2 (argnums usage)
Here we show how to use
argnums
to ignore the non-differentiable parametern
of the target function. Note that the length and shapes of tangents must match the length and shape of primal parameters which we mark as differentiable by passing their indices toargnums
.@qjit @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(n, params): qml.RX(params[n, 0], wires=n) qml.RY(params[n, 1], wires=n) return qml.expval(qml.PauliZ(1)) @qjit def workflow(primals, tangents): return catalyst.jvp(circuit, [1, primals], [tangents], argnums=[1])
>>> params = jnp.array([[0.54, 0.3154], [0.654, 0.123]]) >>> dy = jnp.array([[1.0, 1.0], [1.0, 1.0]]) >>> workflow(params, dy) (Array(0.78766064, dtype=float64), Array(-0.7011436, dtype=float64))