qml.workflow.interfaces.jax_jit

This module contains functions for binding JVPs or VJPs to JAX when using JIT.

For information on registering VJPs and JVPs, please see the module documentation for jax.py.

When using JAX-JIT, we cannot convert arrays to numpy or act on their concrete values without using jax.pure_callback.

For example:

>>> def f(x):
...     return qml.math.unwrap(x)
>>> x = jax.numpy.array(1.0)
>>> jax.jit(f)(x)
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
>>> def g(x):
...     expected_output_shape = jax.ShapeDtypeStruct((), jax.numpy.float64)
...     return jax.pure_callback(f, expected_output_shape, x)
>>> jax.jit(g)(x)
Array(1., dtype=float64)

Note that we must provide the expected output shape for the function to use pure callbacks.

Functions

jax_jit_jvp_execute(tapes, execute_fn, jpc, ...)

Execute a batch of tapes with JAX parameters using JVP derivatives.

jax_jit_vjp_execute(tapes, execute_fn, jpc)

Execute a batch of tapes with JAX parameters using VJP derivatives.