vjp

vjp(f: Union[catalyst.jax_tracer.Function, pennylane.workflow.qnode.QNode, Callable, catalyst.jit.QJIT], params, cotangents, *, method=None, h=None, argnum=None)[source]

A qjit() compatible Vector-Jacobian product for PennyLane/Catalyst.

This function allows the Vector-Jacobian Product of a hybrid quantum-classical function to be computed within the compiled program. Outside of a compiled function, this function will simply dispatch to its JAX counterpart jax.vjp. The function f can return any pytree-like shape.

Parameters
  • f (Callable) – Function-like object to calculate JVP for

  • params (List[Array]) – List (or a tuple) of f’s arguments specifying the point to calculate VJP at. A subset of these parameters are declared as differentiable by listing their indices in the argnum parameter.

  • cotangents (List[Array]) – List (or a tuple) of tangent values to use in JVP. The list size and shapes must match the size and shape of f outputs.

  • method (str) – Differentiation method to use, same as in grad.

  • h (float) – the step-size value for the finite-difference ("fd") method

  • argnum (Union[int, List[int]]) – the params’ indices to differentiate.

Returns

Return values of f paired with the VJP values.

Return type

Tuple[Any])

Raises
  • TypeError – invalid parameter types

  • ValueError – invalid parameter values

Example

@qjit
def vjp(params, cotangent):
  def f(x):
      y = [jnp.sin(x[0]), x[1] ** 2, x[0] * x[1]]
      return jnp.stack(y)

  return catalyst.vjp(f, [params], [cotangent])
>>> x = jnp.array([0.1, 0.2])
>>> dy = jnp.array([-0.5, 0.1, 0.3])
>>> vjp(x, dy)
(array([0.09983342, 0.04      , 0.02      ]), (array([-0.43750208,  0.07      ]),))