catalyst.vjp¶
- vjp(f: Callable, params, cotangents, *, method=None, h=None, argnums=None)[source]¶
A
qjit()
compatible Vector-Jacobian product for PennyLane/Catalyst.This function allows the Vector-Jacobian Product of a hybrid quantum-classical function to be computed within the compiled program. Outside of a compiled function, this function will simply dispatch to its JAX counterpart
jax.vjp
. The functionf
can return any pytree-like shape.- Parameters
f (Callable) – Function-like object to calculate JVP for
params (List[Array]) – List (or a tuple) of f’s arguments specifying the point to calculate VJP at. A subset of these parameters are declared as differentiable by listing their indices in the
argnums
parameter.cotangents (List[Array]) – List (or a tuple) of tangent values to use in JVP. The list size and shapes must match the size and shape of
f
outputs.method (str) – Differentiation method to use, same as in
grad
.h (float) – the step-size value for the finite-difference (
"fd"
) methodargnums (Union[int, List[int]]) – the params’ indices to differentiate.
- Returns
Return values of
f
paired with the VJP values.- Return type
Tuple[Any])
- Raises
TypeError – invalid parameter types
ValueError – invalid parameter values
Example
@qjit def vjp(params, cotangent): def f(x): y = [jnp.sin(x[0]), x[1] ** 2, x[0] * x[1]] return jnp.stack(y) return catalyst.vjp(f, [params], [cotangent])
>>> x = jnp.array([0.1, 0.2]) >>> dy = jnp.array([-0.5, 0.1, 0.3]) >>> vjp(x, dy) (Array([0.09983342, 0.04 , 0.02 ], dtype=float64), (Array([-0.43750208, 0.07 ], dtype=float64),))