qml.jvp¶
- jvp(f, params, tangents, method=None, h=None, argnums=None, *, argnum=None)[source]¶
A
qjit()compatible Jacobian-vector product of PennyLane programs.This function allows the Jacobian-vector Product of a hybrid quantum-classical function to be computed within the compiled program.
Warning
jvpis intended to be used withqjit()only.Note
When used with
qjit(), this function only supports the Catalyst compiler; seecatalyst.jvp()for more details.Please see the Catalyst quickstart guide, as well as the sharp bits and debugging tips page for an overview of the differences between Catalyst and PennyLane.
Warning
The argument
argnumhas been renamed toargnumsto match Catalyst and JAX. The ability to useargnumwill be removed in v0.45.- Parameters:
f (Callable) – Function-like object to calculate JVP for
params (List[Array]) – List (or a tuple) of the function arguments specifying the point to calculate JVP at. A subset of these parameters are declared as differentiable by listing their indices in the
argnumsparameter.tangents (List[Array]) – List (or a tuple) of tangent values to use in JVP. The list size and shapes must match the ones of differentiable params.
method (str) – Differentiation method to use, same as in
grad().h (float) – the step-size value for the finite-difference (
"fd") methodargnums (Union[int, List[int]]) – the params’ indices to differentiate.
- Returns:
Return values of
fpaired with the JVP values.- Return type:
Tuple[Array]
- Raises:
TypeError – invalid parameter types
ValueError – invalid parameter values
See also
Example 1 (basic usage)
@qml.qjit def jvp(params, tangent): def f(x): y = [jnp.sin(x[0]), x[1] ** 2, x[0] * x[1]] return jnp.stack(y) return qml.jvp(f, [params], [tangent])
>>> x = jnp.array([0.1, 0.2]) >>> tangent = jnp.array([0.3, 0.6]) >>> jvp(x, tangent) (Array([0.09983342, 0.04 , 0.02 ], dtype=float64), Array([0.29850125, 0.24 , 0.12 ], dtype=float64))
Example 2 (argnums usage)
Here we show how to use
argnumsto ignore the non-differentiable parameternof the target function. Note that the length and shapes of tangents must match the length and shape of primal parameters, which we mark as differentiable by passing their indices toargnums.@qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(n, params): qml.RX(params[n, 0], wires=n) qml.RY(params[n, 1], wires=n) return qml.expval(qml.Z(1)) @qml.qjit def workflow(primals, tangents): return qml.jvp(circuit, [1, primals], [tangents], argnums=[1])
>>> params = jnp.array([[0.54, 0.3154], [0.654, 0.123]]) >>> dy = jnp.array([[1.0, 1.0], [1.0, 1.0]]) >>> workflow(params, dy) (Array(0.78766064, dtype=float64), Array(-0.70114352, dtype=float64))