catalyst.jvp

jvp(f: Callable, params, tangents, *, method=None, h=None, argnums=None)[source]

A qjit() compatible Jacobian-vector product for PennyLane/Catalyst.

This function allows the Jacobian-vector Product of a hybrid quantum-classical function to be computed within the compiled program. Outside of a compiled function, this function will simply dispatch to its JAX counterpart jax.jvp. The function f can return any pytree-like shape.

Parameters
  • f (Callable) – Function-like object to calculate JVP for

  • params (List[Array]) – List (or a tuple) of the function arguments specifying the point to calculate JVP at. A subset of these parameters are declared as differentiable by listing their indices in the argnums parameter.

  • tangents (List[Array]) – List (or a tuple) of tangent values to use in JVP. The list size and shapes must match the ones of differentiable params.

  • method (str) – Differentiation method to use, same as in grad().

  • h (float) – the step-size value for the finite-difference ("fd") method

  • argnums (Union[int, List[int]]) – the params’ indices to differentiate.

Returns

Return values of f paired with the JVP values.

Return type

Tuple[Any]

Raises
  • TypeError – invalid parameter types

  • ValueError – invalid parameter values

Example 1 (basic usage)

@qjit
def jvp(params, tangent):
  def f(x):
      y = [jnp.sin(x[0]), x[1] ** 2, x[0] * x[1]]
      return jnp.stack(y)

  return catalyst.jvp(f, [params], [tangent])
>>> x = jnp.array([0.1, 0.2])
>>> tangent = jnp.array([0.3, 0.6])
>>> jvp(x, tangent)
(Array([0.09983342, 0.04      , 0.02      ], dtype=float64),
 Array([0.29850125, 0.24      , 0.12      ], dtype=float64))

Example 2 (argnums usage)

Here we show how to use argnums to ignore the non-differentiable parameter n of the target function. Note that the length and shapes of tangents must match the length and shape of primal parameters which we mark as differentiable by passing their indices to argnums.

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(n, params):
    qml.RX(params[n, 0], wires=n)
    qml.RY(params[n, 1], wires=n)
    return qml.expval(qml.PauliZ(1))

@qjit
def workflow(primals, tangents):
    return catalyst.jvp(circuit, [1, primals], [tangents], argnums=[1])
>>> params = jnp.array([[0.54, 0.3154], [0.654, 0.123]])
>>> dy = jnp.array([[1.0, 1.0], [1.0, 1.0]])
>>> workflow(params, dy)
(Array(0.78766064, dtype=float64), Array(-0.7011436, dtype=float64))