qml.jvp

jvp(f, params, tangents, method=None, h=None, argnum=None)[source]

A qjit() compatible Jacobian-vector product of PennyLane programs.

This function allows the Jacobian-vector Product of a hybrid quantum-classical function to be computed within the compiled program.

Warning

jvp is intended to be used with qjit() only.

Note

When used with qjit(), this function only supports the Catalyst compiler; see catalyst.jvp() for more details.

Please see the Catalyst quickstart guide, as well as the sharp bits and debugging tips page for an overview of the differences between Catalyst and PennyLane.

Parameters
  • f (Callable) – Function-like object to calculate JVP for

  • params (List[Array]) – List (or a tuple) of the function arguments specifying the point to calculate JVP at. A subset of these parameters are declared as differentiable by listing their indices in the argnum parameter.

  • tangents (List[Array]) – List (or a tuple) of tangent values to use in JVP. The list size and shapes must match the ones of differentiable params.

  • method (str) – Differentiation method to use, same as in grad().

  • h (float) – the step-size value for the finite-difference ("fd") method

  • argnum (Union[int, List[int]]) – the params’ indices to differentiate.

Returns

Return values of f paired with the JVP values.

Return type

Tuple[Array]

Raises
  • TypeError – invalid parameter types

  • ValueError – invalid parameter values

See also

grad(), vjp(), jacobian()

Example 1 (basic usage)

@qml.qjit
def jvp(params, tangent):
  def f(x):
      y = [jnp.sin(x[0]), x[1] ** 2, x[0] * x[1]]
      return jnp.stack(y)

  return qml.jvp(f, [params], [tangent])
>>> x = jnp.array([0.1, 0.2])
>>> tangent = jnp.array([0.3, 0.6])
>>> jvp(x, tangent)
[array([0.09983342, 0.04      , 0.02      ]),
array([0.29850125, 0.24000006, 0.12      ])]

Example 2 (argnum usage)

Here we show how to use argnum to ignore the non-differentiable parameter n of the target function. Note that the length and shapes of tangents must match the length and shape of primal parameters, which we mark as differentiable by passing their indices to argnum.

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(n, params):
    qml.RX(params[n, 0], wires=n)
    qml.RY(params[n, 1], wires=n)
    return qml.expval(qml.PauliZ(1))

@qml.qjit
def workflow(primals, tangents):
    return qml.jvp(circuit, [1, primals], [tangents], argnum=[1])
>>> params = jnp.array([[0.54, 0.3154], [0.654, 0.123]])
>>> dy = jnp.array([[1.0, 1.0], [1.0, 1.0]])
>>> workflow(params, dy)
[array(0.78766064), array(-0.7011436)]