Quick Start¶
Catalyst enables justintime (JIT) and aheadoftime (AOT) compilation of quantum programs and workflows, while taking into account both classical and quantum code, and ultimately leverages modern compilation tools to speed up quantum applications.
You can imagine compiling a function once in advance and then benefit from faster
execution on all subsequent calls of the function, similar to the jax.jit
functionality.
However, compared to JAX we are also able to compile the quantum code natively without having
to rely on callbacks to any Pythonbased PennyLane devices. We can thus compile/execute entire workflows
(such as variational algorithms) as a single program or unit, without having to go back and forth between
device execution and the Python interpreter.
Importing Catalyst and PennyLane¶
The first thing we need to do is import qjit()
and QJIT compatible methods in Catalyst,
as well as PennyLane and the version of NumPy
provided by JAX.
from catalyst import qjit, measure, cond, for_loop, while_loop, grad
import pennylane as qml
from jax import numpy as jnp
Constructing the QNode¶
You should be able to express your quantum functions in the way you are accustomed to using PennyLane. However, some of PennyLane’s features may not be fully supported yet, such as optimizers.
Warning
Not all PennyLane devices currently work with Catalyst. Supported backend devices include
lightning.qubit
, lightning.kokkos
, and braket.aws.qubit
. For
a full of supported devices, please see Supported devices.
PennyLane tapes are still used internally by Catalyst and you can express your circuits in the way you are used to, as long as you ensure that all operations are added to the main tape.
Let’s start learning more about Catalyst by running a simple circuit.
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
qml.Hadamard(wires=0)
qml.RX(theta, wires=1)
qml.CNOT(wires=[0,1])
return qml.expval(qml.PauliZ(wires=1))
In PennyLane, the qml.qnode()
decorator creates a device specific quantum function. For each quantum
function, we can specify the number of wires.
The qjit()
decorator can be used to jit a workflow of quantum functions:
jitted_circuit = qjit(circuit)
>>> jitted_circuit(0.7)
Array(0., dtype=float64)
In Catalyst, dynamic wire values are fully supported for operations, observables and measurements. For example, the following circuit can be jitted with wires as arguments:
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=5))
def circuit(arg0, arg1, arg2):
qml.RX(arg0, wires=[arg1 + 1])
qml.RY(arg0, wires=[arg2])
qml.CNOT(wires=[arg1, arg2])
return qml.probs(wires=[arg1 + 1])
>>> circuit(jnp.pi / 3, 1, 2)
Array([0.625, 0.375], dtype=float64)
Operations¶
Catalyst allows you to use quantum operations
available in PennyLane either via native support by the runtime or PennyLane’s decomposition rules.
The qml.adjoint()
and qml.ctrl()
functions in
PennyLane are also supported via the decomposition mechanism in Catalyst. For example,
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit():
qml.Rot(0.3, 0.4, 0.5, wires=0)
qml.adjoint(qml.SingleExcitation(jnp.pi / 3, wires=[0, 1]))
return qml.state()
In addition, you can qjit most PennyLane templates to easily construct and evaluate more complex quantum circuits.
Important
Decomposition will generally happen in accordance with the specification provided by devices,
which can vary from device to device (e.g., default.qubit
and lightning.qubit
might
decompose quite differently).
However, Catalyst’s decomposition logic will differ in the following cases:
All
qml.Controlled
operations will decompose toqml.QubitUnitary
operations.The set of operations supported by Catalyst itself can in some instances lead to additional decompositions compared to the device itself.
Observables¶
The Catalyst has support for PennyLane observables.
For example, the following circuit is a QJIT compatible function that calculates the expectation value of
a tensor product of a qml.PauliX
, qml.Hadamard
and qml.Hermitian
observables.
@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, y):
qml.RX(x, 0)
qml.RX(y, 1)
qml.CNOT([0, 2])
qml.CNOT([1, 2])
h_matrix = jnp.array(
[[complex(1.0, 0.0), complex(2.0, 0.0)],
[complex(2.0, 0.0), complex(1.0, 0.0)]]
)
return qml.expval(qml.PauliX(0) @ qml.Hadamard(1) @ qml.Hermitian(h_matrix, 2))
Measurements¶
Most PennyLane measurement processes are supported in Catalyst, although not all features are supported for all measurement types.
The expectation value of observables is supported analytically as well as with finiteshots. 

The variance of observables is supported analytically as well as with finiteshots. 

Samples in the computational basis only are supported. 

Sample counts in the computational basis only are supported. 

The probabilities is supported in the computational basis as well as with finiteshots. 

The state in the computational basis only is supported. 

The projective midcircuit measurement is supported via its own operation in Catalyst. 
For both qml.sample()
and qml.counts()
omitting the wires
parameters produces samples on all declared qubits in the same format as in PennyLane.
Counts are returned a bit differently, namely as a pair of arrays representing a dictionary from basis states to the number of observed samples. We thus have to do a bit of extra work to display them nicely. Note that the basis states are represented in their equivalent binary integer representation, inside of a float data type. This way they are compatible with eigenvalue sampling, but this may change in the future.
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000))
def counts():
qml.Rot(0.1, 0.2, 0.3, wires=[0])
return qml.counts(wires=[0])
basis_states, counts = counts()
>>> {format(int(state), '01b'): count for state, count in zip(basis_states, counts)}
{'0': 985, '1': 15}
You can specify the number of shots to be used in samplebased measurements when you create a device.
qml.sample()
and qml.counts()
will
automatically use the device’s shots
parameter when performing measurements.
In the following example, the number of shots is set to \(500\) in the device instantiation.
Note
You can return any combination of measurement processes as a tuple from quantum functions. In addition, Catalyst allows you to return any classical values computed inside quantum functions as well.
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=3, shots=500))
def circuit(params):
qml.RX(params[0], wires=0)
qml.RX(params[1], wires=1)
qml.RZ(params[2], wires=2)
return (
qml.sample(),
qml.counts(),
qml.expval(qml.PauliZ(0)),
qml.var(qml.PauliZ(0)),
qml.probs(wires=[0, 1]),
qml.state(),
)
>>> circuit([0.3, 0.5, 0.7])
(Array([[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]], dtype=int64),
(Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int64),
Array([453, 0, 31, 0, 16, 0, 0, 0], dtype=int64)),
Array(0.936, dtype=float64),
Array(0.138816, dtype=float64),
Array([0.926, 0.048, 0.026, 0. ], dtype=float64),
Array([ 0.899949660.32850727j, 0. +0.j ,
0.083881680.22979488j, 0. +0.j ,
0.049649020.13601409j, 0. +0.j ,
0.0347301 +0.01267748j, 0. +0.j ], dtype=complex128))
The PennyLane projective midcircuit measurement is also supported in Catalyst.
measure()
is a QJIT compatible midcircuit measurement for Catalyst that only
requires a list of wires that the measurement process acts on.
Important
The qml.measure()
function is not QJIT compatible and measure()
from Catalyst should be used instead:
from catalyst import measure
In the following example, m
will be equal to True
if wire \(0\) is rotated by \(180\) degrees.
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x):
qml.RX(x, wires=0)
m = measure(wires=0)
return m
>>> circuit(jnp.pi)
Array(True, dtype=bool)
>>> circuit(0.0)
Array(False, dtype=bool)
Compilation Modes¶
In Catalyst, there are two ways of compiling quantum functions depending on when the compilation is triggered.
Justintime¶
In justintime (JIT), the compilation is triggered at the call site the first time
the quantum function is executed. For example, circuit
is compiled as early as the first call.
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
qml.Hadamard(wires=0)
qml.RX(theta, wires=1)
qml.CNOT(wires=[0,1])
return qml.expval(qml.PauliZ(wires=1))
>>> circuit(0.5) # the first call, compilation occurs here
Array(0., dtype=float64)
>>> circuit(0.5) # the precompiled quantum function is called
Array(0., dtype=float64)
Aheadoftime¶
An alternative is to trigger the compilation without specifying any concrete values for the function
parameters. This works by specifying the argument signature right in the function definition, which
will trigger compilation “aheadoftime” (AOT) before the program is executed. We can use both builtin
Python scalar types, as well as the special ShapedArray
type that JAX uses to represent the shape
and data type of a tensor:
from jax.core import ShapedArray
@qjit # compilation happens at definition
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)):
theta = jnp.abs(x)
qml.RY(theta, wires=0)
qml.Rot(z[0], z[1], z[2], wires=0)
return qml.state()
>>> circuit(0.2j, jnp.array([0.3, 0.6, 0.9])) # calls precompiled function
Array([0.756349050.52801002j, 0. +0.j,
0.35962678+0.14074839j, 0. +0.j], dtype=complex128)
At this stage the compilation already happened, so the execution of circuit
calls the compiled function directly on
the first call, resulting in faster initial execution. Note that implicit type promotion for most datatypes are allowed
in the compilation as long as it doesn’t lead to a loss of data.
Compiling with Control Flow¶
Catalyst has support for natively compiled control flow as “firstclass” components of any quantum program, providing a much smaller representation and compilation time for large circuits, and also enabling the compilation of arbitrarily parametrized circuits.
Catalystprovided control flow operations:
A 

A 

A 
Note
Catalyst supports automatic conversion of native Python control flow to the Catalyst control flow operations. For more details, see the AutoGraph guide.
Conditionals¶
cond()
is a functional version of the traditional ifelse conditional for Catalyst.
This means that each execution path, a True
branch and a False
branch, is provided as a
separate function. Both functions will be traced during compilation, but only one of them the will be
executed at runtime, depending of the value of a Boolean predicate. The JAX equivalent is the
jax.lax.cond
function, but this version is optimized to work with quantum programs in PennyLane.
Note that cond()
can also be used outside of the qjit()
context for better interoperability with PennyLane.
Values produced inside the scope of a conditional can be returned to the outside context, but
the return type signature of each branch must be identical. If no values are returned, the
False
branch is optional. Refer to the example below to learn more about the syntax of this
decorator.
@cond(predicate: bool)
def conditional_fn():
# do something when the predicate is true
return "optionally return some value"
@conditional_fn.otherwise
def conditional_fn():
# optionally define an alternative execution path
return "if provided, return types need to be identical in both branches"
ret_val = conditional_fn() # must invoke the defined function
Warning
The conditional functions can only return JAX compatible data types.
Loops¶
for_loop()
and while_loop()
are functional versions of the traditional for and
whileloop for Catalyst. That is, any variables that are modified across iterations need to be
provided as inputs and outputs to the loop body function. Input arguments contain the value of a
variable at the start of an iteration, while output arguments contain the value at the end of the
iteration. The outputs are then fed back as inputs to the next iteration. The final iteration values
are also returned from the transformed function.
for_loop()
and while_loop()
can also be interpreted without needing to compile its surrounding context.
The forloop statement:
The for_loop()
executes a fixed number of iterations as indicated via the values specified
in its header: a lower_bound
, an upper_bound
, and a step
size.
The loop body function must always have the iteration index (in the below example i
) as its
first argument and its value can be used arbitrarily inside the loop body. As the value of the index
across iterations is handled automatically by the provided loop bounds, it must not be returned from
the body function.
@for_loop(lower_bnd, upper_bnd, step)
def loop_body(i, *args):
# code to be executed over index i starting
# from lower_bnd to upper_bnd  1 by step
return args
final_args = loop_body(init_args)
The semantics of for_loop()
are given by the following Python implementation:
for i in range(lower_bnd, upper_bnd, step):
args = body_fn(i, *args)
The whileloop statement:
The while_loop()
, on the other hand, is able to execute an arbitrary number of iterations,
until the condition function specified in its header returns False
.
The loop condition is evaluated every iteration and can be any callable with an identical signature as the loop body function. The return type of the condition function must be a Boolean.
@while_loop(lambda *args: "some condition")
def loop_body(*args):
# perform some work and update (some of) the arguments
return args
final_args = loop_body(init_args)
Calculating Quantum Gradients¶
Catalystprovided gradient operations:
A 

A 

A 

A 

A 
grad()
is a QJIT compatible grad decorator in Catalyst that can differentiate a hybrid quantum function
using finitedifference, parametershift, or adjointjacobian methods. See the documentation for more details.
This decorator accepts the function to differentiate, a differentiation strategy, and the argument indices of the function with which to differentiate:
@qjit
def workflow(x):
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x):
qml.RX(jnp.pi * x, wires=0)
return qml.expval(qml.PauliY(0))
g = grad(circuit)
return g(x)
>>> workflow(2.0)
Array(3.14159265, dtype=float64)
To specify the differentiation strategy, the method
argument can be passed
to the grad
function:
method="auto"
: Quantum components of the hybrid function are differentiated according to the corresponding QNodediff_method
, while the classical computation is differentiated using traditional autodiff.With this strategy, Catalyst only currently supports QNodes with
diff_method="parametershift"
anddiff_method="adjoint"
.method="fd"
: Firstorder finitedifferences for the entire hybrid function. Thediff_method
argument for each QNode is ignored.
Currently, higherorder differentiation is only supported by the finitedifference method. The gradient of circuits with QJIT compatible control flow is supported for all methods in Catalyst.
You can further provide the step size (h
value) of finitedifference in the grad()
method.
For example, the gradient call to differentiate circuit
with respect to its second argument using
finitedifference and h
value \(0.1\) should be:
g_fd = grad(circuit, method="fd", argnums=1, h=0.1)
Gradients of quantum functions can be calculated for a range or tensor of parameters.
For example, grad(circuit, argnums=[0, 1])
would calculate the gradient of
circuit
using the finitedifference method for the first and second parameters.
In addition, the gradient of the following circuit with a tensor of parameters is
also feasible.
@qjit
def workflow(params):
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(params):
qml.RX(params[0] * params[1], wires=0)
return qml.expval(qml.PauliY(0))
return grad(circuit, argnums=0)(params)
>>> workflow(jnp.array([2.0, 3.0]))
Array([2.88051099, 1.92034063], dtype=float64)
The grad()
decorator works for functions that return a scalar value. You can also use the jacobian()
decorator to compute Jacobian matrices of general hybrid functions with multiple or multivariate results.
@qjit
def workflow(x):
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x):
qml.RX(jnp.pi * x[0], wires=0)
qml.RY(x[1], wires=0)
return qml.probs()
g = jacobian(circuit, method="auto")
return g(x)
>>> workflow(jnp.array([2.0, 1.0]))
Array([[ 3.48786850e16 4.20735492e01]
[8.71967125e17 4.20735492e01]], dtype=float64)
This decorator has the same methods and API as grad
. See the documentation for more details.
Optimizers¶
You can develop your own optimization algorithm using the grad()
method, controlflow operators that are
compatible with QJIT, or by utilizing differentiable optimizers in Optax.
Warning
Catalyst currently does not provide any optimization tools and does not support the optimizers offered by PennyLane. However, this feature is planned for future implementation.
For example, you can use optax.sgd
in a QJIT workflow to calculate
the gradient descent optimizer. The following example shows a simple use case of this
feature in Catalyst.
The optax.sgd
gets a smooth function of the form gd_fun(params, *args, **kwargs)
and calculates either just the value or both the value and gradient of the function depending on
the value of value_and_grad
argument. To optimize params iteratively, you later need to use
jax.lax.fori_loop
to loop over the gradient descent steps.
import optax
from jax.lax import fori_loop
dev = qml.device("lightning.qubit", wires=1)
@qml.qnode(dev)
def circuit(param):
qml.Hadamard(0)
qml.RY(param, wires=0)
return qml.expval(qml.PauliZ(0))
@qjit
def workflow():
def gd_fun(param):
diff = grad(circuit, argnums=0)
return circuit(param), diff(param)
opt = optax.sgd(learning_rate=0.4)
def gd_update(i, args):
param, state = args
_, gradient = gd_fun(param)
(updates, state) = opt.update(gradient, state)
param = optax.apply_updates(param, updates)
return (param, state)
param = 0.1
state = opt.init(param)
(param, _) = fori_loop(0, 100, gd_update, (param, state))
return param
>>> workflow()
Array(1.57079633, dtype=float64)
JAX Integration¶
Catalyst programs can also be used inside of a larger JAX workflow which uses JIT compilation, automatic differentiation, and other JAX transforms.
Note
Note that, in general, best performance will be seen when the Catalyst
@qjit
decorator is used to JIT the entire hybrid workflow. However, there
may be cases where you may want to delegate only the quantum part of your
workflow to Catalyst, and let JAX handle classical components (for example,
due to missing a feature or compatibility issue in Catalyst).
For example, call a Catalyst qjitcompiled function from within a JAX jitcompiled function:
dev = qml.device("lightning.qubit", wires=1)
@qjit
@qml.qnode(dev)
def circuit(x):
qml.RX(jnp.pi * x[0], wires=0)
qml.RY(x[1] ** 2, wires=0)
qml.RX(x[1] * x[2], wires=0)
return qml.probs(wires=0)
@jax.jit
def cost_fn(weights):
x = jnp.sin(weights)
return jnp.sum(jnp.cos(circuit(x)) ** 2)
>>> cost_fn(jnp.array([0.1, 0.2, 0.3]))
Array(1.32269195, dtype=float64)
Catalystcompiled functions can now also be automatically differentiated via JAX, both in forward and reverse mode to firstorder,
>>> jax.grad(cost_fn)(jnp.array([0.1, 0.2, 0.3]))
Array([0.49249037, 0.05197949, 0.02991883], dtype=float64)
as well as vectorized using jax.vmap
:
>>> jax.vmap(cost_fn)(jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]))
Array([1.32269195, 1.53905377], dtype=float64)