JAX interface¶
Born out of the autograd package, JAX is the
next generation of differentiable functional computation, adding support for powerful hardware
accelerators like GPUs and TPUs via XLA. To use
PennyLane in combination with JAX, we have to generate JAX-compatible quantum nodes. A basic
QNode
can be translated into a quantum node that interfaces with JAX by using the
interface='jax'
flag in the QNode decorator.
Note
When using diff_method="parameter-shift"
, diff_method="finite-diff"
or diff_method="adjoint"
with the JAX interface some restrictions apply to
the measurements in the QNode:
Sample and probability measurements cannot be mixed with other measurement types in QNodes;
Multiple probability measurements need to have the same number of wires specified;
However, when using diff_method="backprop"
, all QNode measurement statistics
are supported.
Note
To use the JAX interface in PennyLane, you must first
install jax
and jaxlib
. You can then import PennyLane and JAX as follows:
import pennylane as qml
import jax
import jax.numpy as jnp
Note
JAX supports the single-precision numbers by default. To enable double-precision, add the following code on startup:
jax.config.update("jax_enable_x64", True)
Construction via the decorator¶
The QNode decorator is the recommended way for creating
a JAX-capable QNode in PennyLane. Simply specify the interface='jax'
keyword argument:
dev = qml.device('default.qubit', wires=2)
@qml.qnode(dev, interface='jax')
def circuit1(phi, theta):
qml.RX(phi[0], wires=0)
qml.RY(phi[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.PhaseShift(theta, wires=0)
return qml.expval(qml.PauliZ(0)), qml.expval(qml.Hadamard(1))
The QNode circuit1()
is now a JAX-capable QNode, accepting jax.Array
objects
as input, and returning jax.Array
objects. It can now be used like any other JAX function:
>>> phi = jnp.array([0.5, 0.1])
>>> theta = jnp.array(0.2)
>>> circuit1(phi, theta)
(Array(0.87758256, dtype=float64), Array(0.68803733, dtype=float64))
The interface can also be automatically determined when the QNode
is called. You do not need to pass the interface
if you provide parameters.
Quantum gradients using JAX¶
Since a JAX-interfacing QNode acts like any other JAX interfacing python function, the standard method used to calculate gradients with JAX can be used.
For example:
dev = qml.device('default.qubit', wires=2)
@qml.qnode(dev, interface='jax')
def circuit3(phi, theta):
qml.RX(phi[0], wires=0)
qml.RY(phi[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.PhaseShift(theta, wires=0)
return qml.expval(qml.PauliZ(0))
phi = jnp.array([0.5, 0.1])
theta = jnp.array(0.2)
grads = jax.grad(circuit3, argnums=(0, 1))
phi_grad, theta_grad = grads(phi, theta)
This has output:
>>> phi_grad
Array([-0.47942555, 0. ], dtype=float32)
>>> theta_grad
Array(-3.4332792e-10, dtype=float32)
Using jax.jit on QNodes¶
To fully utilize the power and speed of JAX, you’ll need to just-in-time compile your functions - a
process called “jitting”. If only expectation values or variances are returned,
the @jax.jit
decorator can be directly applied to the QNode.
dev = qml.device('default.qubit', wires=2)
@jax.jit # QNode calls will now be jitted, and should run faster.
@qml.qnode(dev, interface='jax')
def circuit4(phi, theta):
qml.RX(phi[0], wires=0)
qml.RZ(phi[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.RX(theta, wires=0)
return qml.expval(qml.PauliZ(0))
Note
For differentiation methods other than backprop
, when
interface='jax'
is specified, PennyLane will attempt to determine if
the computation was just-in-time compiled. This is done by checking if any
of the input parameters were subject to a JAX transformation. If so, a
variant of the interface that supports the just-in-time compilation of
QNodes will be used. This is equivalent to passing interface='jax-jit'
.
Computing the jacobian of vector-valued QNodes is not supported with the JAX JIT interface. The output of vector-valued QNodes can, however, be used in the definition of scalar-valued cost functions whose gradients can be computed.
Specify interface='jax-python'
to enforce support for computing the
backward pass of vector-valued QNodes (e.g., QNodes with probability, state
or multiple expectation value measurements). This option does not support
just-in-time compilation.
Randomness: Shots and Samples¶
In JAX, there is no such thing as statefull randomness, meaning all random number generators must be explicitly seeded. (See the JAX random package documentation for more details).
When simulations include randomness (i.e., if the device has a finite shots
value, or the qnode
returns qml.sample()
), the JAX device requires a jax.random.PRNGKey
. Usually, PennyLane
automatically handles this for you. However, if you wish to use jitting with randomness, both the
qnode and the device need to be created in the context of the jax.jit
decorator. This can be
achieved by wrapping device and qnode creation into a function decorated by @jax.jit
:
Example:
import jax
import pennylane as qml
@jax.jit
def sample_circuit(phi, theta, key):
# Device construction should happen inside a `jax.jit` decorated
# method when using a PRNGKey.
dev = qml.device('default.qubit', wires=2, prng_key=key, shots=100)
@qml.qnode(dev, interface='jax', diff_method=None)
def circuit(phi, theta):
qml.RX(phi[0], wires=0)
qml.RZ(phi[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.RX(theta, wires=0)
return qml.sample() # Here, we take samples instead.
return circuit(phi, theta)
# Get the samples from the jitted method.
samples = sample_circuit([0.2, 1.0], 5.2, jax.random.PRNGKey(0))
Note
If you don’t pass a PRNGKey when sampling with a jax.jit
, every call to the sample function
will return the same result.
Optimization using JAXopt and Optax¶
To optimize your hybrid classical-quantum model using the JAX interface, you must make use of a package meant for optimizing JAX code (such as JAXopt or Optax) or your own custom JAX optimizer. The PennyLane optimizers cannot be used with the JAX interface.
As an example of using JAXopt
, the GradientDescent
optimizer may be
used to optimize a QNode that is transformed by jax.jit
:
import pennylane as qml
import jax
import jaxopt
jax.config.update("jax_enable_x64", True)
dev = qml.device("default.qubit", wires=1, shots=None)
@jax.jit
@qml.qnode(dev, interface="jax")
def energy(a):
qml.RX(a, wires=0)
return qml.expval(qml.PauliZ(0))
gd = jaxopt.GradientDescent(energy, maxiter=5)
res = gd.run(0.5)
optimized_params = res.params
>>> optimized_params
Array(3.1415861, dtype=float64, weak_type=True)
Alternatively, optimizers from Optax
may also be used to optimize the same
QNode:
import pennylane as qml
from jax import numpy as jnp
import jax
import optax
learning_rate = 0.15
dev = qml.device("default.qubit", wires=1, shots=None)
@jax.jit
@qml.qnode(dev, interface="jax")
def energy(a):
qml.RX(a, wires=0)
return qml.expval(qml.PauliZ(0))
optimizer = optax.adam(learning_rate)
params = jnp.array(0.5)
opt_state = optimizer.init(params)
for _ in range(200):
grads = jax.grad(energy)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
>>> params
Array(3.14159111, dtype=float64)