qml.workflow.interfaces.jax_jit¶
This module contains functions for binding JVPs or VJPs to JAX when using JIT.
For information on registering VJPs and JVPs, please see the module documentation for jax.py
.
When using JAX-JIT, we cannot convert arrays to numpy or act on their concrete values without
using jax.pure_callback
.
For example:
>>> def f(x):
... return qml.math.unwrap(x)
>>> x = jax.numpy.array(1.0)
>>> jax.jit(f)(x)
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
>>> def g(x):
... expected_output_shape = jax.ShapeDtypeStruct((), jax.numpy.float64)
... return jax.pure_callback(f, expected_output_shape, x)
>>> jax.jit(g)(x)
Array(1., dtype=float64)
Note that we must provide the expected output shape for the function to use pure callbacks.
Functions
|
Execute a batch of tapes with JAX parameters using JVP derivatives. |
|
Execute a batch of tapes with JAX parameters using VJP derivatives. |
code/api/pennylane.workflow.interfaces.jax_jit
Download Python script
Download Notebook
View on GitHub