Quick Start

Catalyst enables just-in-time (JIT) and ahead-of-time (AOT) compilation of quantum programs and workflows, while taking into account both classical and quantum code, and ultimately leverages modern compilation tools to speed up quantum applications.

You can imagine compiling a function once in advance and then benefit from faster execution on all subsequent calls of the function, similar to the jax.jit functionality. However, compared to JAX we are also able to compile the quantum code natively without having to rely on callbacks to any Python-based PennyLane devices. We can thus compile/execute entire workflows (such as variational algorithms) as a single program or unit, without having to go back and forth between device execution and the Python interpreter.

Importing Catalyst and PennyLane

The first thing we need to do is import qjit() and QJIT compatible methods in Catalyst, as well as PennyLane and the version of NumPy provided by JAX.

from catalyst import qjit, measure, cond, for_loop, while_loop
import pennylane as qml
from jax import numpy as jnp

Constructing the QNode

You should be able to express your quantum functions in the way you are accustomed to using PennyLane. However, some of PennyLane’s features may not be fully supported yet, such as optimizers.


The supported backend devices are currently lightning.qubit, lightning.kokkos, braket.local.qubit, braket.aws.qubit, and oqc.cloud but future plans include the addition of more.

PennyLane tapes are still used internally by Catalyst and you can express your circuits in the way you are used to, as long as you ensure that all operations are added to the main tape.

Let’s start learning more about Catalyst by running a simple circuit.

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
    qml.RX(theta, wires=1)
    return qml.expval(qml.PauliZ(wires=1))

In PennyLane, the qml.qnode() decorator creates a device specific quantum function. For each quantum function, we can specify the number of wires.

The qjit() decorator can be used to jit a workflow of quantum functions:

jitted_circuit = qjit(circuit)
>>> jitted_circuit(0.7)

In Catalyst, dynamic wire values are fully supported for operations, observables and measurements. For example, the following circuit can be jitted with wires as arguments:

@qml.qnode(qml.device("lightning.qubit", wires=5))
def circuit(arg0, arg1, arg2):
    qml.RX(arg0, wires=[arg1 + 1])
    qml.RY(arg0, wires=[arg2])
    qml.CNOT(wires=[arg1, arg2])
    return qml.probs(wires=[arg1 + 1])
>>> circuit(jnp.pi / 3, 1, 2)
array([0.625, 0.375])


Catalyst allows you to use quantum operations available in PennyLane either via native support by the runtime or PennyLane’s decomposition rules. The qml.adjoint() and qml.ctrl() functions in PennyLane are also supported via the decomposition mechanism in Catalyst. For example,

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit():
    qml.Rot(0.3, 0.4, 0.5, wires=0)
    qml.adjoint(qml.SingleExcitation(jnp.pi / 3, wires=[0, 1]))
    return qml.state()

In addition, you can jit most of PennyLane templates to easily construct and evaluate more complex quantum circuits; see below for the list of currently supported operations and templates.


Most decomposition logic will be equivalent to PennyLane’s decomposition. However, decomposition logic will differ in the following cases:

  1. All qml.Controlled operations will decompose to qml.QubitUnitary operations.

  2. qml.ControlledQubitUnitary operations will decompose to qml.QubitUnitary operations.

  3. The list of device-supported gates employed by Catalyst is currently different than that of the lightning.qubit device, as defined by the QJITDevice.


The Identity operator


The Pauli X operator


The Pauli Y operator


The Pauli Z operator


The Hadamard operator


The single-qubit phase gate


The single-qubit T gate


Arbitrary single qubit local phase shift


The single qubit X rotation


The single qubit Y rotation


The single qubit Z rotation


The controlled-NOT operator


The controlled-Y operator


The controlled-Z operator


The swap operator


Ising XX coupling gate


Ising YY coupling gate


Ising (XX + YY) coupling gate


Ising ZZ coupling gate


A qubit controlled phase shift.


The controlled-RX operator


The controlled-RY operator


The controlled-RZ operator


The controlled-Rot operator


The controlled-swap operator


Arbitrary multi Z rotation.


Apply an arbitrary unitary matrix with a dimension that is a power of two.


Builds a quantum circuit to prepare correlated states of molecules by applying all SingleExcitation and DoubleExcitation operations to the initial Hartree-Fock state.


Encodes \(2^n\) features into the amplitude vector of \(n\) qubits.


Encodes \(N\) features into the rotation angles of \(n\) qubits, where \(N \leq n\).


Applies the Trotterized time-evolution operator for an arbitrary Hamiltonian, expressed in terms of Pauli gates.


Implements an arbitrary state preparation on the specified wires.


Layers consisting of one-parameter single-qubit rotations on each qubit, followed by a closed chain or ring of CNOT gates.


Encodes \(n\) binary features into a basis state of \(n\) qubits.


Prepares a basis state on the given wires using a sequence of Pauli-X gates.


Applies a unitary multiple times to a specific pattern of wires.


Circuit to exponentiate the tensor product of Pauli matrices representing the double-excitation operator entering the Unitary Coupled-Cluster Singles and Doubles (UCCSD) ansatz.


Circuit to exponentiate the tensor product of Pauli matrices representing the single-excitation operator entering the Unitary Coupled-Cluster Singles and Doubles (UCCSD) ansatz.


Flips the sign of a given basis state.


Implements a local, expressive, and quantum-number-preserving ansatz proposed by Anselmetti et al. (2021).


Performs the Grover Diffusion Operator.


Encodes \(n\) features into \(n\) qubits using diagonal gates of an IQP circuit.


Implements the k-Unitary Pair Coupled-Cluster Generalized Singles and Doubles (k-UpCCGSD) ansatz.


The MERA template broadcasts an input circuit across many wires following the architecture of a multi-scale entanglement renormalization ansatz tensor network.


Prepares an arbitrary state on the given wires using a decomposition into gates developed by Möttönen et al. (2004).


The MPS template broadcasts an input circuit across many wires following the architecture of a Matrix Product State tensor network.


Applies a permutation to a set of wires.


Encodes \(N\) features into \(n>N\) qubits, using a layered, trainable quantum circuit that is inspired by the QAOA ansatz proposed by Killoran et al. (2020).


Apply a quantum Fourier transform (QFT).


Performs the quantum Monte Carlo estimation algorithm.


Performs the quantum phase estimation circuit.


Layers of randomly chosen single qubit rotations and 2-qubit entangling gates, acting on randomly chosen qubits.


Layers consisting of a simplified 2-design architecture of Pauli-Y rotations and controlled-Z entanglers proposed in Cerezo et al. (2021).


Layers consisting of single qubit rotations and entanglers, inspired by the circuit-centric classifier design arXiv:1804.00633.


The TTN template broadcasts an input circuit across many wires following the architecture of a tree tensor network.


Implements the Unitary Coupled-Cluster Singles and Doubles (UCCSD) ansatz.


The Catalyst has support for PennyLane observables.

For example, the following circuit is a QJIT compatible function that calculates the expectation value of a tensor product of a qml.PauliX, qml.Hadamard and qml.Hermitian observables.

@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, y):
    qml.RX(x, 0)
    qml.RX(y, 1)
    qml.CNOT([0, 2])
    qml.CNOT([1, 2])
    h_matrix = jnp.array(
        [[complex(1.0, 0.0), complex(2.0, 0.0)],
        [complex(2.0, 0.0), complex(-1.0, 0.0)]]
    return qml.expval(qml.PauliX(0) @ qml.Hadamard(1) @ qml.Hermitian(h_matrix, 2))


The Identity operator


The Pauli X operator


The Pauli Y operator


The Pauli Z operator


The Hadamard operator


An arbitrary Hermitian observable.


alias of pennylane.ops.op_math.linear_combination.LinearCombination


Most PennyLane measurement processes are supported in Catalyst, although not all features are supported for all measurement types.


The expectation value of observables is supported analytically as well as with finite-shots.


The variance of observables is supported analytically as well as with finite-shots.


Samples in the computational basis only are supported.


Sample counts in the computational basis only are supported.


The probabilities is supported in the computational basis as well as with finite-shots.


The state in the computational basis only is supported.


The projective mid-circuit measurement is supported via its own operation in Catalyst.

For both qml.sample() and qml.counts() omitting the wires parameters produces samples on all declared qubits in the same format as in PennyLane.

Counts are returned a bit differently, namely as a pair of arrays representing a dictionary from basis states to the number of observed samples. We thus have to do a bit of extra work to display them nicely. Note that the basis states are represented in their equivalent binary integer representation, inside of a float data type. This way they are compatible with eigenvalue sampling, but this may change in the future.

@qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000))
def counts():
    qml.Rot(0.1, 0.2, 0.3, wires=[0])
    return qml.counts(wires=[0])
basis_states, counts = counts()
>>> {format(int(state), '01b'): count for state, count in zip(basis_states, counts)}
{'0': 985, '1': 15}

You can specify the number of shots to be used in sample-based measurements when you create a device. qml.sample() and qml.counts() will automatically use the device’s shots parameter when performing measurements. In the following example, the number of shots is set to \(500\) in the device instantiation.


You can return any combination of measurement processes as a tuple from quantum functions. In addition, Catalyst allows you to return any classical values computed inside quantum functions as well.

@qml.qnode(qml.device("lightning.qubit", wires=3, shots=500))
def circuit(params):
    qml.RX(params[0], wires=0)
    qml.RX(params[1], wires=1)
    qml.RZ(params[2], wires=2)
    return (
        qml.probs(wires=[0, 1]),
>>> circuit([0.3, 0.5, 0.7])
[array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
array([0., 1., 2., 3., 4., 5., 6., 7.]),
array([458,   7,  35,   0,   0,   0,   0,   0]),
array([0.91782642, 0.05984182, 0.02096486, 0.0013669 ]),
array([ 0.89994966-0.32850727j,  0.        +0.j        ,
        -0.08388168-0.22979488j,  0.        +0.j        ,
        -0.04964902-0.13601409j,  0.        +0.j        ,
        -0.0347301 +0.01267748j,  0.        +0.j        ])]

The PennyLane projective mid-circuit measurement is also supported in Catalyst. measure() is a QJIT compatible mid-circuit measurement for Catalyst that only requires a list of wires that the measurement process acts on.


The qml.measure() function is not QJIT compatible and measure() from Catalyst should be used instead:

from catalyst import measure

In the following example, m will be equal to True if wire \(0\) is rotated by \(180\) degrees.

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x):
    qml.RX(x, wires=0)
    m = measure(wires=0)
    return m
>>> circuit(jnp.pi)
>>> circuit(0.0)

Compilation Modes

In Catalyst, there are two ways of compiling quantum functions depending on when the compilation is triggered.


In just-in-time (JIT), the compilation is triggered at the call site the first time the quantum function is executed. For example, circuit is compiled as early as the first call.

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
    qml.RX(theta, wires=1)
    return qml.expval(qml.PauliZ(wires=1))
>>> circuit(0.5)  # the first call, compilation occurs here
>>> circuit(0.5)  # the precompiled quantum function is called


An alternative is to trigger the compilation without specifying any concrete values for the function parameters. This works by specifying the argument signature right in the function definition, which will trigger compilation “ahead-of-time” (AOT) before the program is executed. We can use both builtin Python scalar types, as well as the special ShapedArray type that JAX uses to represent the shape and data type of a tensor:

from jax.core import ShapedArray

@qjit  # compilation happens at definition
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)):
    theta = jnp.abs(x)
    qml.RY(theta, wires=0)
    qml.Rot(z[0], z[1], z[2], wires=0)
    return qml.state()
>>> circuit(0.2j, jnp.array([0.3, 0.6, 0.9]))  # calls precompiled function
array([0.75634905-0.52801002j, 0. +0.j,
   0.35962678+0.14074839j, 0. +0.j])

At this stage the compilation already happened, so the execution of circuit calls the compiled function directly on the first call, resulting in faster initial execution. Note that implicit type promotion for most datatypes are allowed in the compilation as long as it doesn’t lead to a loss of data.

Compiling with Control Flow

Catalyst has support for natively compiled control flow as “first-class” components of any quantum program, providing a much smaller representation and compilation time for large circuits, and also enabling the compilation of arbitrarily parametrized circuits.

Catalyst-provided control flow operations:


A qjit() compatible decorator for if-else conditionals in PennyLane/Catalyst.


A qjit() compatible for-loop decorator for PennyLane/Catalyst.


A qjit() compatible while-loop decorator for PennyLane/Catalyst.


Catalyst supports automatic conversion of native Python control flow to the Catalyst control flow operations. For more details, see the AutoGraph guide.


cond() is a functional version of the traditional if-else conditional for Catalyst. This means that each execution path, a True branch and a False branch, is provided as a separate function. Both functions will be traced during compilation, but only one of them the will be executed at runtime, depending of the value of a Boolean predicate. The JAX equivalent is the jax.lax.cond function, but this version is optimized to work with quantum programs in PennyLane.

Note that cond() can also be used outside of the qjit() context for better interoperability with PennyLane.

Values produced inside the scope of a conditional can be returned to the outside context, but the return type signature of each branch must be identical. If no values are returned, the False branch is optional. Refer to the example below to learn more about the syntax of this decorator.

@cond(predicate: bool)
def conditional_fn():
    # do something when the predicate is true
    return "optionally return some value"

def conditional_fn():
    # optionally define an alternative execution path
    return "if provided, return types need to be identical in both branches"

ret_val = conditional_fn()  # must invoke the defined function


The conditional functions can only return JAX compatible data types.


for_loop() and while_loop() are functional versions of the traditional for- and while-loop for Catalyst. That is, any variables that are modified across iterations need to be provided as inputs and outputs to the loop body function. Input arguments contain the value of a variable at the start of an iteration, while output arguments contain the value at the end of the iteration. The outputs are then fed back as inputs to the next iteration. The final iteration values are also returned from the transformed function.

for_loop() and while_loop() can also be interpreted without needing to compile its surrounding context.

The for-loop statement:

The for_loop() executes a fixed number of iterations as indicated via the values specified in its header: a lower_bound, an upper_bound, and a step size.

The loop body function must always have the iteration index (in the below example i) as its first argument and its value can be used arbitrarily inside the loop body. As the value of the index across iterations is handled automatically by the provided loop bounds, it must not be returned from the body function.

@for_loop(lower_bnd, upper_bnd, step)
def loop_body(i, *args):
    # code to be executed over index i starting
    # from lower_bnd to upper_bnd - 1 by step
    return args

final_args = loop_body(init_args)

The semantics of for_loop() are given by the following Python implementation:

for i in range(lower_bnd, upper_bnd, step):
    args = body_fn(i, *args)

The while-loop statement:

The while_loop(), on the other hand, is able to execute an arbitrary number of iterations, until the condition function specified in its header returns False.

The loop condition is evaluated every iteration and can be any callable with an identical signature as the loop body function. The return type of the condition function must be a Boolean.

@while_loop(lambda *args: "some condition")
def loop_body(*args):
    # perform some work and update (some of) the arguments
    return args

final_args = loop_body(init_args)

Calculating Quantum Gradients

Catalyst-provided gradient operations:


A qjit() compatible gradient transformation for PennyLane/Catalyst.


A qjit() compatible Jacobian transformation for PennyLane/Catalyst.


A qjit() compatible Vector-Jacobian product for PennyLane/Catalyst.


A qjit() compatible Jacobian-vector product for PennyLane/Catalyst.

grad() is a QJIT compatible grad decorator in Catalyst that can differentiate a hybrid quantum function using finite-difference, parameter-shift, or adjoint-jacobian methods. See the documentation for more details.

This decorator accepts the function to differentiate, a differentiation strategy, and the argument indices of the function with which to differentiate:

def workflow(x):
    @qml.qnode(qml.device("lightning.qubit", wires=1))
    def circuit(x):
        qml.RX(jnp.pi * x, wires=0)
        return qml.expval(qml.PauliY(0))

    g = grad(circuit)
    return g(x)
>>> workflow(2.0)

To specify the differentiation strategy, the method argument can be passed to the grad function:

  • method="auto": Quantum components of the hybrid function are differentiated according to the corresponding QNode diff_method, while the classical computation is differentiated using traditional autodiff.

    With this strategy, Catalyst only currently supports QNodes with diff_method="parameter-shift" and diff_method="adjoint".

  • method="fd": First-order finite-differences for the entire hybrid function. The diff_method argument for each QNode is ignored.

Currently, higher-order differentiation is only supported by the finite-difference method. The gradient of circuits with QJIT compatible control flow is supported for all methods in Catalyst.

You can further provide the step size (h-value) of finite-difference in the grad() method. For example, the gradient call to differentiate circuit with respect to its second argument using finite-difference and h-value \(0.1\) should be:

g_fd = grad(circuit, method="fd", argnum=1, h=0.1)

Gradients of quantum functions can be calculated for a range or tensor of parameters. For example, grad(circuit, argnum=[0, 1]) would calculate the gradient of circuit using the finite-difference method for the first and second parameters. In addition, the gradient of the following circuit with a tensor of parameters is also feasible.

def workflow(params):
    @qml.qnode(qml.device("lightning.qubit", wires=1))
    def circuit(params):
        qml.RX(params[0] * params[1], wires=0)
        return qml.expval(qml.PauliY(0))

    return grad(circuit, argnum=0)(params)
>>> workflow(jnp.array([2.0, 3.0]))
array([-2.88051099, -1.92034063])

The grad() decorator works for functions that return a scalar value. You can also use the jacobian() decorator to compute Jacobian matrices of general hybrid functions with multiple or multivariate results.

def workflow(x):
    @qml.qnode(qml.device("lightning.qubit", wires=1))
    def circuit(x):
        qml.RX(jnp.pi * x[0], wires=0)
        qml.RY(x[1], wires=0)
        return qml.probs()

    g = jacobian(circuit, method="auto")
    return g(x)
>>> workflow(jnp.array([2.0, 1.0]))
array([[ 3.48786850e-16 -4.20735492e-01]
       [-8.71967125e-17  4.20735492e-01]])

This decorator has the same methods and API as grad. See the documentation for more details.


You can develop your own optimization algorithm using the grad() method, control-flow operators that are compatible with QJIT, or by utilizing differentiable optimizers in JAXopt.


Catalyst currently does not provide any optimization tools and does not support the optimizers offered by PennyLane. However, this feature is planned for future implementation.

For example, you can use jaxopt.GradientDescent in a QJIT workflow to calculate the gradient descent optimizer. The following example shows a simple use case of this feature in Catalyst.

The jaxopt.GradientDescent gets a smooth function of the form gd_fun(params, *args, **kwargs) and calculates either just the value or both the value and gradient of the function depending on the value of value_and_grad argument. To optimize params iteratively, you later need to use jax.lax.fori_loop to loop over the gradient descent steps.

import jaxopt
from jax.lax import fori_loop

dev = qml.device("lightning.qubit", wires=1)

def circuit(param):
    qml.RY(param, wires=0)
    return qml.expval(qml.PauliZ(0))

def workflow():
    def gd_fun(param):
        diff = grad(circuit, argnum=0)
        return circuit(param), diff(param)

    opt = jaxopt.GradientDescent(gd_fun, stepsize=0.4, value_and_grad=True)

    def gd_update(i, args):
        (param, state) = opt.update(*args)
        return (param, state)

    param = 0.1
    state = opt.init_state(param)
    (param, _) = jax.lax.fori_loop(0, 100, gd_update, (param, state))
    return param
>>> workflow()

JAX Integration

Catalyst programs can also be used inside of a larger JAX workflow which uses JIT compilation, automatic differentiation, and other JAX transforms.


Note that, in general, best performance will be seen when the Catalyst @qjit decorator is used to JIT the entire hybrid workflow. However, there may be cases where you may want to delegate only the quantum part of your workflow to Catalyst, and let JAX handle classical components (for example, due to missing a feature or compatibility issue in Catalyst).

For example, call a Catalyst qjit-compiled function from within a JAX jit-compiled function:

dev = qml.device("lightning.qubit", wires=1)

def circuit(x):
  qml.RX(jnp.pi * x[0], wires=0)
  qml.RY(x[1] ** 2, wires=0)
  qml.RX(x[1] * x[2], wires=0)
  return qml.probs(wires=0)

def cost_fn(weights):
  x = jnp.sin(weights)
  return jnp.sum(jnp.cos(circuit(x)) ** 2)
>>> cost_fn(jnp.array([0.1, 0.2, 0.3]))
Array(1.32269195, dtype=float64)

Catalyst-compiled functions can now also be automatically differentiated via JAX, both in forward and reverse mode to first-order,

>>> jax.grad(cost_fn)(jnp.array([0.1, 0.2, 0.3]))
Array([0.49249037, 0.05197949, 0.02991883], dtype=float64)

as well as vectorized using jax.vmap:

>>> jax.vmap(cost_fn)(jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]))
Array([1.32269195, 1.53905377], dtype=float64)

In particular, this allows for a reduction in boilerplate when using JAX-compatible optimizers such as jaxopt:

>>> opt = jaxopt.GradientDescent(cost_fn)
>>> params = jnp.array([0.1, 0.2, 0.3])
>>> (final_params, _) = jax.jit(opt.run)(params)
>>> final_params
Array([-0.00320799,  0.03475223,  0.29362844], dtype=float64)