
This module contains functions for binding JVP’s or VJP’s to the JAX interface.

See JAX documentation on this process here .

Basic examples:

def f(x):
    return x**2

def f_and_jvp(primals, tangents):
    x = primals[0]
    dx = tangents[0]
    print("in custom jvp function: ", x, dx)
    return x**2, 2*x*dx

registered_f_jvp = jax.custom_jvp(f)

>>> jax.grad(registered_f_jvp)(jax.numpy.array(2.0))
in custom jvp function:  2.0 Traced<ShapedArray(float64[], weak_type=True):JaxprTrace(level=1/0)>
Array(4., dtype=float64, weak_type=True)

We can do something similar for the VJP as well:

def f_fwd(x):
    print("in forward pass: ", x)
    return f(x), x

def f_bwd(residual, dy):
    print("in backward pass: ", residual, dy)
    return (dy*2*residual,)

registered_f_vjp = jax.custom_vjp(f)
registered_f_vjp.defvjp(f_fwd, f_bwd)
>>> jax.grad(registered_f_vjp)(jax.numpy.array(2.0))
in forward pass:  2.0
in backward pass:  2.0 1.0
Array(4., dtype=float64, weak_type=True)

JVP versus VJP:

When JAX can trace the product between the Jacobian and the cotangents, it can turn the JVP calculation into a VJP calculation. Through this process, JAX can support both JVP and VJP calculations by registering only the JVP.

Unfortunately, compute_jvp() uses pure numpy to perform the Jacobian product and cannot be traced by JAX.

For example, if we replace the definition of f_and_jvp from above with one that breaks tracing,

def f_and_jvp(primals, tangents):
    x = primals[0]
    dx = qml.math.unwrap(tangents[0]) # This line breaks tracing
    return x**2, 2*x*dx
>>> jax.grad(registered_f_jvp)(jax.numpy.array(2.0))
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

Note that the comment about JIT is generally a comment about not being able to trace code.

But if we used the VJP instead:

def f_bwd(residual, dy):
    dy = qml.math.unwrap(dy)
    return (dy*2*residual,)

We would be able to calculate the gradient without error.

Since the VJP calculation offers access to jax.grad and jax.jacobian, we register the VJP when we have to choose between either the VJP or the JVP.

Pytrees and Non-diff argnums:

The trainable arguments for the registered functions can be any valid pytree.

def f(x):
    return x['a']**2

def f_and_jvp(primals, tangents):
    x = primals[0]
    dx = tangents[0]
    print("in custom jvp function: ", x, dx)
    return x['a']**2, 2*x['a']*dx['a']

registered_f_jvp = jax.custom_jvp(f)

>>> jax.grad(registered_f_jvp)({'a': jax.numpy.array(2.0)})
in custom jvp function:  {'a': Array(2., dtype=float64, weak_type=True)} {'a': Traced<ShapedArray(float64[], weak_type=True):JaxprTrace(level=1/0)>}
{'a': Array(4., dtype=float64, weak_type=True)}

As we can see here, the tangents are packed into the same pytree structure as the trainable arguments.

Currently, QuantumScript is a valid pytree most of the time. Once it is a valid pytree all of the time and can store tangents in place of the variables, we can use a batch of tapes as our trainable argument. Until then, the tapes must be a non-pytree non-differenatible argument that accompanies the tree leaves.



Check all parameters in each tape and output the name of the suitable JAX interface.

jax_jvp_execute(tapes, execute_fn, jpc[, device])

Execute a batch of tapes with JAX parameters using JVP derivatives.

set_parameters_on_copy_and_unwrap(tapes, params)

Copy a set of tapes with operations and set parameters