qml.workflow.interfaces.jax¶
This module contains functions for binding JVP’s or VJP’s to the JAX interface.
See JAX documentation on this process here .
Basic examples:
def f(x):
return x**2
def f_and_jvp(primals, tangents):
x = primals[0]
dx = tangents[0]
print("in custom jvp function: ", x, dx)
return x**2, 2*x*dx
registered_f_jvp = jax.custom_jvp(f)
registered_f_jvp.defjvp(f_and_jvp)
>>> jax.grad(registered_f_jvp)(jax.numpy.array(2.0))
in custom jvp function: 2.0 Traced<ShapedArray(float64[], weak_type=True):JaxprTrace(level=1/0)>
Array(4., dtype=float64, weak_type=True)
We can do something similar for the VJP as well:
def f_fwd(x):
print("in forward pass: ", x)
return f(x), x
def f_bwd(residual, dy):
print("in backward pass: ", residual, dy)
return (dy*2*residual,)
registered_f_vjp = jax.custom_vjp(f)
registered_f_vjp.defvjp(f_fwd, f_bwd)
>>> jax.grad(registered_f_vjp)(jax.numpy.array(2.0))
in forward pass: 2.0
in backward pass: 2.0 1.0
Array(4., dtype=float64, weak_type=True)
JVP versus VJP:
When JAX can trace the product between the Jacobian and the cotangents, it can turn the JVP calculation into a VJP calculation. Through this process, JAX can support both JVP and VJP calculations by registering only the JVP.
Unfortunately, compute_jvp()
uses pure numpy to perform the Jacobian product and cannot
be traced by JAX.
For example, if we replace the definition of f_and_jvp
from above with one that breaks tracing,
def f_and_jvp(primals, tangents):
x = primals[0]
dx = qml.math.unwrap(tangents[0]) # This line breaks tracing
return x**2, 2*x*dx
>>> jax.grad(registered_f_jvp)(jax.numpy.array(2.0))
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
Note that the comment about JIT
is generally a comment about not being able to trace code.
But if we used the VJP instead:
def f_bwd(residual, dy):
dy = qml.math.unwrap(dy)
return (dy*2*residual,)
We would be able to calculate the gradient without error.
Since the VJP calculation offers access to jax.grad
and jax.jacobian
, we register the VJP when we have to choose
between either the VJP or the JVP.
Pytrees and Non-diff argnums:
The trainable arguments for the registered functions can be any valid pytree.
def f(x):
return x['a']**2
def f_and_jvp(primals, tangents):
x = primals[0]
dx = tangents[0]
print("in custom jvp function: ", x, dx)
return x['a']**2, 2*x['a']*dx['a']
registered_f_jvp = jax.custom_jvp(f)
registered_f_jvp.defjvp(f_and_jvp)
>>> jax.grad(registered_f_jvp)({'a': jax.numpy.array(2.0)})
in custom jvp function: {'a': Array(2., dtype=float64, weak_type=True)} {'a': Traced<ShapedArray(float64[], weak_type=True):JaxprTrace(level=1/0)>}
{'a': Array(4., dtype=float64, weak_type=True)}
As we can see here, the tangents are packed into the same pytree structure as the trainable arguments.
Currently, QuantumScript
is a valid pytree most of the time. Once it is a valid pytree all of the
time and can store tangents in place of the variables, we can use a batch of tapes as our trainable argument. Until then, the tapes
must be a non-pytree non-differentiable argument that accompanies the tree leaves.
Functions
|
Check all parameters in each tape and output the name of the suitable JAX interface. |
|
Execute a batch of tapes with JAX parameters using JVP derivatives. |
|
Copy a set of tapes with operations and set parameters |