Quick Start

Catalyst enables just-in-time (JIT) and ahead-of-time (AOT) compilation of quantum programs and workflows, while taking into account both classical and quantum code, and ultimately leverages modern compilation tools to speed up quantum applications.

You can imagine compiling a function once in advance and then benefit from faster execution on all subsequent calls of the function, similar to the jax.jit functionality. However, compared to JAX we are also able to compile the quantum code natively without having to rely on callbacks to any Python-based PennyLane devices. We can thus compile/execute entire workflows (such as variational algorithms) as a single program or unit, without having to go back and forth between device execution and the Python interpreter.

Importing Catalyst and PennyLane

The first thing we need to do is import qjit() and QJIT compatible methods in Catalyst, as well as PennyLane and the version of NumPy provided by JAX.

from catalyst import qjit, measure, cond, for_loop, while_loop
import pennylane as qml
from jax import numpy as jnp

Constructing the QNode

You should be able to express your quantum functions in the way you are accustomed to using PennyLane. However, some of PennyLane’s features may not be fully supported yet, such as optimizers.

Warning

The supported backend devices are currently lightning.qubit, lightning.kokkos, braket.local.qubit, and braket.aws.qubit but future plans include the addition of more.

PennyLane tapes are still used internally by Catalyst and you can express your circuits in the way you are used to, as long as you ensure that all operations are added to the main tape.

Let’s start learning more about Catalyst by running a simple circuit.

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
    qml.Hadamard(wires=0)
    qml.RX(theta, wires=1)
    qml.CNOT(wires=[0,1])
    return qml.expval(qml.PauliZ(wires=1))

In PennyLane, the qml.qnode() decorator creates a device specific quantum function. For each quantum function, we can specify the number of wires.

The qjit() decorator can be used to jit a workflow of quantum functions:

jitted_circuit = qjit(circuit)
>>> jitted_circuit(0.7)
array(0.)

In Catalyst, dynamic wire values are fully supported for operations, observables and measurements. For example, the following circuit can be jitted with wires as arguments:

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=5))
def circuit(arg0, arg1, arg2):
    qml.RX(arg0, wires=[arg1 + 1])
    qml.RY(arg0, wires=[arg2])
    qml.CNOT(wires=[arg1, arg2])
    return qml.probs(wires=[arg1 + 1])
>>> circuit(jnp.pi / 3, 1, 2)
array([0.625, 0.375])

Operations

Catalyst allows you to use quantum operations available in PennyLane either via native support by the runtime or PennyLane’s decomposition rules. The qml.adjoint() and qml.ctrl() functions in PennyLane are also supported via the decomposition mechanism in Catalyst. For example,

@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit():
    qml.Rot(0.3, 0.4, 0.5, wires=0)
    qml.adjoint(qml.SingleExcitation(jnp.pi / 3, wires=[0, 1]))
    return qml.state()

In addition, you can jit most of PennyLane templates to easily construct and evaluate more complex quantum circuits; see below for the list of currently supported operations and templates.

Important

Most decomposition logic will be equivalent to PennyLane’s decomposition. However, decomposition logic will differ in the following cases:

  1. All qml.Controlled operations will decompose to qml.QubitUnitary operations.

  2. qml.ControlledQubitUnitary operations will decompose to qml.QubitUnitary operations.

  3. The list of device-supported gates employed by Catalyst is currently different than that of the lightning.qubit device, as defined by the QJITDevice.

Identity

The Identity operator

PauliX

The Pauli X operator

PauliY

The Pauli Y operator

PauliZ

The Pauli Z operator

Hadamard

The Hadamard operator

S

The single-qubit phase gate

T

The single-qubit T gate

PhaseShift

Arbitrary single qubit local phase shift

RX

The single qubit X rotation

RY

The single qubit Y rotation

RZ

The single qubit Z rotation

CNOT

The controlled-NOT operator

CY

The controlled-Y operator

CZ

The controlled-Z operator

SWAP

The swap operator

IsingXX

Ising XX coupling gate

IsingYY

Ising YY coupling gate

IsingXY

Ising (XX + YY) coupling gate

IsingZZ

Ising ZZ coupling gate

ControlledPhaseShift

A qubit controlled phase shift.

CRX

The controlled-RX operator

CRY

The controlled-RY operator

CRZ

The controlled-RZ operator

CRot

The controlled-Rot operator

CSWAP

The controlled-swap operator

MultiRZ

Arbitrary multi Z rotation.

QubitUnitary

Apply an arbitrary unitary matrix with a dimension that is a power of two.

AllSinglesDoubles

Builds a quantum circuit to prepare correlated states of molecules by applying all SingleExcitation and DoubleExcitation operations to the initial Hartree-Fock state.

AmplitudeEmbedding

Encodes \(2^n\) features into the amplitude vector of \(n\) qubits.

AngleEmbedding

Encodes \(N\) features into the rotation angles of \(n\) qubits, where \(N \leq n\).

ApproxTimeEvolution

Applies the Trotterized time-evolution operator for an arbitrary Hamiltonian, expressed in terms of Pauli gates.

ArbitraryStatePreparation

Implements an arbitrary state preparation on the specified wires.

BasicEntanglerLayers

Layers consisting of one-parameter single-qubit rotations on each qubit, followed by a closed chain or ring of CNOT gates.

BasisEmbedding

Encodes \(n\) binary features into a basis state of \(n\) qubits.

BasisStatePreparation

Prepares a basis state on the given wires using a sequence of Pauli-X gates.

broadcast

Applies a unitary multiple times to a specific pattern of wires.

FermionicDoubleExcitation

Circuit to exponentiate the tensor product of Pauli matrices representing the double-excitation operator entering the Unitary Coupled-Cluster Singles and Doubles (UCCSD) ansatz.

FermionicSingleExcitation

Circuit to exponentiate the tensor product of Pauli matrices representing the single-excitation operator entering the Unitary Coupled-Cluster Singles and Doubles (UCCSD) ansatz.

FlipSign

Flips the sign of a given basis state.

GateFabric

Implements a local, expressive, and quantum-number-preserving ansatz proposed by Anselmetti et al. (2021).

GroverOperator

Performs the Grover Diffusion Operator.

IQPEmbedding

Encodes \(n\) features into \(n\) qubits using diagonal gates of an IQP circuit.

kUpCCGSD

Implements the k-Unitary Pair Coupled-Cluster Generalized Singles and Doubles (k-UpCCGSD) ansatz.

MERA

The MERA template broadcasts an input circuit across many wires following the architecture of a multi-scale entanglement renormalization ansatz tensor network.

MottonenStatePreparation

Prepares an arbitrary state on the given wires using a decomposition into gates developed by Möttönen et al. (2004).

MPS

The MPS template broadcasts an input circuit across many wires following the architecture of a Matrix Product State tensor network.

Permute

Applies a permutation to a set of wires.

QAOAEmbedding

Encodes \(N\) features into \(n>N\) qubits, using a layered, trainable quantum circuit that is inspired by the QAOA ansatz proposed by Killoran et al. (2020).

QFT

Apply a quantum Fourier transform (QFT).

QuantumMonteCarlo

Performs the quantum Monte Carlo estimation algorithm.

QuantumPhaseEstimation

Performs the quantum phase estimation circuit.

RandomLayers

Layers of randomly chosen single qubit rotations and 2-qubit entangling gates, acting on randomly chosen qubits.

SimplifiedTwoDesign

Layers consisting of a simplified 2-design architecture of Pauli-Y rotations and controlled-Z entanglers proposed in Cerezo et al. (2021).

StronglyEntanglingLayers

Layers consisting of single qubit rotations and entanglers, inspired by the circuit-centric classifier design arXiv:1804.00633.

TTN

The TTN template broadcasts an input circuit across many wires following the architecture of a tree tensor network.

UCCSD

Implements the Unitary Coupled-Cluster Singles and Doubles (UCCSD) ansatz.

Observables

The Catalyst has support for PennyLane observables.

For example, the following circuit is a QJIT compatible function that calculates the expectation value of a tensor product of a qml.PauliX, qml.Hadamard and qml.Hermitian observables.

@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, y):
    qml.RX(x, 0)
    qml.RX(y, 1)
    qml.CNOT([0, 2])
    qml.CNOT([1, 2])
    h_matrix = jnp.array(
        [[complex(1.0, 0.0), complex(2.0, 0.0)],
        [complex(2.0, 0.0), complex(-1.0, 0.0)]]
    )
    return qml.expval(qml.PauliX(0) @ qml.Hadamard(1) @ qml.Hermitian(h_matrix, 2))

Identity

The Identity operator

PauliX

The Pauli X operator

PauliY

The Pauli Y operator

PauliZ

The Pauli Z operator

Hadamard

The Hadamard operator

Hermitian

An arbitrary Hermitian observable.

Hamiltonian

Operator representing a Hamiltonian.

Measurements

Most PennyLane measurement processes are supported in Catalyst, although not all features are supported for all measurement types.

qml.expval()

The expectation value of observables is supported analytically as well as with finite-shots.

qml.var()

The variance of observables is supported analytically as well as with finite-shots.

qml.sample()

Samples in the computational basis only are supported.

qml.counts()

Sample counts in the computational basis only are supported.

qml.probs()

The probabilities is supported in the computational basis as well as with finite-shots.

qml.state()

The state in the computational basis only is supported.

measure()

The projective mid-circuit measurement is supported via its own operation in Catalyst.

For both qml.sample() and qml.counts() omitting the wires parameters produces samples on all declared qubits in the same format as in PennyLane.

Counts are returned a bit differently, namely as a pair of arrays representing a dictionary from basis states to the number of observed samples. We thus have to do a bit of extra work to display them nicely. Note that the basis states are represented in their equivalent binary integer representation, inside of a float data type. This way they are compatible with eigenvalue sampling, but this may change in the future.

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000))
def counts():
    qml.Rot(0.1, 0.2, 0.3, wires=[0])
    return qml.counts(wires=[0])
basis_states, counts = counts()
>>> {format(int(state), '01b'): count for state, count in zip(basis_states, counts)}
{'0': 985, '1': 15}

You can specify the number of shots to be used in sample-based measurements when you create a device. qml.sample() and qml.counts() will automatically use the device’s shots parameter when performing measurements. In the following example, the number of shots is set to \(500\) in the device instantiation.

Note

You can return any combination of measurement processes as a tuple from quantum functions. In addition, Catalyst allows you to return any classical values computed inside quantum functions as well.

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=3, shots=500))
def circuit(params):
    qml.RX(params[0], wires=0)
    qml.RX(params[1], wires=1)
    qml.RZ(params[2], wires=2)
    return (
        qml.sample(),
        qml.counts(),
        qml.expval(qml.PauliZ(0)),
        qml.var(qml.PauliZ(0)),
        qml.probs(wires=[0, 1]),
        qml.state(),
    )
>>> circuit([0.3, 0.5, 0.7])
[array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        ...,
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
array([0., 1., 2., 3., 4., 5., 6., 7.]),
array([458,   7,  35,   0,   0,   0,   0,   0]),
array(0.95533649),
array(0.08733219),
array([0.91782642, 0.05984182, 0.02096486, 0.0013669 ]),
array([ 0.89994966-0.32850727j,  0.        +0.j        ,
        -0.08388168-0.22979488j,  0.        +0.j        ,
        -0.04964902-0.13601409j,  0.        +0.j        ,
        -0.0347301 +0.01267748j,  0.        +0.j        ])]

The PennyLane projective mid-circuit measurement is also supported in Catalyst. measure() is a QJIT compatible mid-circuit measurement for Catalyst that only requires a list of wires that the measurement process acts on.

Important

The qml.measure() function is not QJIT compatible and measure() from Catalyst should be used instead:

from catalyst import measure

In the following example, m will be equal to True if wire \(0\) is rotated by \(180\) degrees.

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x):
    qml.RX(x, wires=0)
    m = measure(wires=0)
    return m
>>> circuit(jnp.pi)
True
>>> circuit(0.0)
False

Compilation Modes

In Catalyst, there are two ways of compiling quantum functions depending on when the compilation is triggered.

Just-in-time

In just-in-time (JIT), the compilation is triggered at the call site the first time the quantum function is executed. For example, circuit is compiled as early as the first call.

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(theta):
    qml.Hadamard(wires=0)
    qml.RX(theta, wires=1)
    qml.CNOT(wires=[0,1])
    return qml.expval(qml.PauliZ(wires=1))
>>> circuit(0.5)  # the first call, compilation occurs here
array(0.)
>>> circuit(0.5)  # the precompiled quantum function is called
array(0.)

Ahead-of-time

An alternative is to trigger the compilation without specifying any concrete values for the function parameters. This works by specifying the argument signature right in the function definition, which will trigger compilation “ahead-of-time” (AOT) before the program is executed. We can use both builtin Python scalar types, as well as the special ShapedArray type that JAX uses to represent the shape and data type of a tensor:

from jax.core import ShapedArray

@qjit  # compilation happens at definition
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)):
    theta = jnp.abs(x)
    qml.RY(theta, wires=0)
    qml.Rot(z[0], z[1], z[2], wires=0)
    return qml.state()
>>> circuit(0.2j, jnp.array([0.3, 0.6, 0.9]))  # calls precompiled function
array([0.75634905-0.52801002j, 0. +0.j,
   0.35962678+0.14074839j, 0. +0.j])

At this stage the compilation already happened, so the execution of circuit calls the compiled function directly on the first call, resulting in faster initial execution. Note that implicit type promotion for most datatypes are allowed in the compilation as long as it doesn’t lead to a loss of data.

Compiling with Control Flow

Catalyst has support for natively compiled control flow as “first-class” components of any quantum program, providing a much smaller representation and compilation time for large circuits, and also enabling the compilation of arbitrarily parametrized circuits.

Catalyst-provided control flow operations:

cond

A qjit() compatible decorator for if-else conditionals in PennyLane/Catalyst.

for_loop

A qjit() compatible for-loop decorator for PennyLane/Catalyst.

while_loop

A qjit() compatible while-loop decorator for PennyLane/Catalyst.

Note

Catalyst supports automatic conversion of native Python control flow to the Catalyst control flow operations. For more details, see the AutoGraph guide.

Conditionals

cond() is a functional version of the traditional if-else conditional for Catalyst. This means that each execution path, a True branch and a False branch, is provided as a separate function. Both functions will be traced during compilation, but only one of them the will be executed at runtime, depending of the value of a Boolean predicate. The JAX equivalent is the jax.lax.cond function, but this version is optimized to work with quantum programs in PennyLane.

Note that cond() can also be used outside of the qjit() context for better interoperability with PennyLane.

Values produced inside the scope of a conditional can be returned to the outside context, but the return type signature of each branch must be identical. If no values are returned, the False branch is optional. Refer to the example below to learn more about the syntax of this decorator.

@cond(predicate: bool)
def conditional_fn():
    # do something when the predicate is true
    return "optionally return some value"

@conditional_fn.otherwise
def conditional_fn():
    # optionally define an alternative execution path
    return "if provided, return types need to be identical in both branches"

ret_val = conditional_fn()  # must invoke the defined function

Warning

The conditional functions can only return JAX compatible data types.

Loops

for_loop() and while_loop() are functional versions of the traditional for- and while-loop for Catalyst. That is, any variables that are modified across iterations need to be provided as inputs and outputs to the loop body function. Input arguments contain the value of a variable at the start of an iteration, while output arguments contain the value at the end of the iteration. The outputs are then fed back as inputs to the next iteration. The final iteration values are also returned from the transformed function.

for_loop() and while_loop() can also be interpreted without needing to compile its surrounding context.

The for-loop statement:

The for_loop() executes a fixed number of iterations as indicated via the values specified in its header: a lower_bound, an upper_bound, and a step size.

The loop body function must always have the iteration index (in the below example i) as its first argument and its value can be used arbitrarily inside the loop body. As the value of the index across iterations is handled automatically by the provided loop bounds, it must not be returned from the body function.

@for_loop(lower_bnd, upper_bnd, step)
def loop_body(i, *args):
    # code to be executed over index i starting
    # from lower_bnd to upper_bnd - 1 by step
    return args

final_args = loop_body(init_args)

The semantics of for_loop() are given by the following Python implementation:

for i in range(lower_bnd, upper_bnd, step):
    args = body_fn(i, *args)

The while-loop statement:

The while_loop(), on the other hand, is able to execute an arbitrary number of iterations, until the condition function specified in its header returns False.

The loop condition is evaluated every iteration and can be any callable with an identical signature as the loop body function. The return type of the condition function must be a Boolean.

@while_loop(lambda *args: "some condition")
def loop_body(*args):
    # perform some work and update (some of) the arguments
    return args

final_args = loop_body(init_args)

Calculating Quantum Gradients

Catalyst-provided gradient operations:

grad

A qjit() compatible gradient transformation for PennyLane/Catalyst.

jacobian

A qjit() compatible Jacobian transformation for PennyLane/Catalyst.

vjp

A qjit() compatible Vector-Jacobian product for PennyLane/Catalyst.

jvp

A qjit() compatible Jacobian-vector product for PennyLane/Catalyst.

grad() is a QJIT compatible grad decorator in Catalyst that can differentiate a hybrid quantum function using finite-difference, parameter-shift, or adjoint-jacobian methods. See the documentation for more details.

This decorator accepts the function to differentiate, a differentiation strategy, and the argument indices of the function with which to differentiate:

@qjit
def workflow(x):
    @qml.qnode(qml.device("lightning.qubit", wires=1))
    def circuit(x):
        qml.RX(jnp.pi * x, wires=0)
        return qml.expval(qml.PauliY(0))

    g = grad(circuit)
    return g(x)
>>> workflow(2.0)
array(-3.14159265)

To specify the differentiation strategy, the method argument can be passed to the grad function:

  • method="auto": Quantum components of the hybrid function are differentiated according to the corresponding QNode diff_method, while the classical computation is differentiated using traditional autodiff.

    With this strategy, Catalyst only currently supports QNodes with diff_method="parameter-shift" and diff_method="adjoint".

  • method="fd": First-order finite-differences for the entire hybrid function. The diff_method argument for each QNode is ignored.

Currently, higher-order differentiation is only supported by the finite-difference method. The gradient of circuits with QJIT compatible control flow is supported for all methods in Catalyst.

You can further provide the step size (h-value) of finite-difference in the grad() method. For example, the gradient call to differentiate circuit with respect to its second argument using finite-difference and h-value \(0.1\) should be:

g_fd = grad(circuit, method="fd", argnum=1, h=0.1)

Gradients of quantum functions can be calculated for a range or tensor of parameters. For example, grad(circuit, argnum=[0, 1]) would calculate the gradient of circuit using the finite-difference method for the first and second parameters. In addition, the gradient of the following circuit with a tensor of parameters is also feasible.

@qjit
def workflow(params):
    @qml.qnode(qml.device("lightning.qubit", wires=1))
    def circuit(params):
        qml.RX(params[0] * params[1], wires=0)
        return qml.expval(qml.PauliY(0))

    return grad(circuit, argnum=0)(params)
>>> workflow(jnp.array([2.0, 3.0]))
array([-2.88051099, -1.92034063])

The grad() decorator works for functions that return a scalar value. You can also use the jacobian() decorator to compute Jacobian matrices of general hybrid functions with multiple or multivariate results.

@qjit
def workflow(x):
    @qml.qnode(qml.device("lightning.qubit", wires=1))
    def circuit(x):
        qml.RX(jnp.pi * x[0], wires=0)
        qml.RY(x[1], wires=0)
        return qml.probs()

    g = jacobian(circuit, method="auto")
    return g(x)
>>> workflow(jnp.array([2.0, 1.0]))
array([[ 3.48786850e-16 -4.20735492e-01]
       [-8.71967125e-17  4.20735492e-01]])

This decorator has the same methods and API as grad. See the documentation for more details.

Optimizers

You can develop your own optimization algorithm using the grad() method, control-flow operators that are compatible with QJIT, or by utilizing differentiable optimizers in JAXopt.

Warning

Catalyst currently does not provide any optimization tools and does not support the optimizers offered by PennyLane. However, this feature is planned for future implementation.

For example, you can use jaxopt.GradientDescent in a QJIT workflow to calculate the gradient descent optimizer. The following example shows a simple use case of this feature in Catalyst.

The jaxopt.GradientDescent gets a smooth function of the form gd_fun(params, *args, **kwargs) and calculates either just the value or both the value and gradient of the function depending on the value of value_and_grad argument. To optimize params iteratively, you later need to use jax.lax.fori_loop to loop over the gradient descent steps.

import jaxopt
from jax.lax import fori_loop

dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(param):
    qml.Hadamard(0)
    qml.RY(param, wires=0)
    return qml.expval(qml.PauliZ(0))

@qjit
def workflow():
    def gd_fun(param):
        diff = grad(circuit, argnum=0)
        return circuit(param), diff(param)

    opt = jaxopt.GradientDescent(gd_fun, stepsize=0.4, value_and_grad=True)

    def gd_update(i, args):
        (param, state) = opt.update(*args)
        return (param, state)

    param = 0.1
    state = opt.init_state(param)
    (param, _) = jax.lax.fori_loop(0, 100, gd_update, (param, state))
    return param
>>> workflow()
array(1.57079633)

JAX Integration

Catalyst programs can also be used inside of a larger JAX workflow which uses JIT compilation, automatic differentiation, and other JAX transforms.

Note

Note that, in general, best performance will be seen when the Catalyst @qjit decorator is used to JIT the entire hybrid workflow. However, there may be cases where you may want to delegate only the quantum part of your workflow to Catalyst, and let JAX handle classical components (for example, due to missing a feature or compatibility issue in Catalyst).

For example, call a Catalyst qjit-compiled function from within a JAX jit-compiled function:

dev = qml.device("lightning.qubit", wires=1)

@qjit
@qml.qnode(dev)
def circuit(x):
  qml.RX(jnp.pi * x[0], wires=0)
  qml.RY(x[1] ** 2, wires=0)
  qml.RX(x[1] * x[2], wires=0)
  return qml.probs(wires=0)

@jax.jit
def cost_fn(weights):
  x = jnp.sin(weights)
  return jnp.sum(jnp.cos(circuit(x)) ** 2)
>>> cost_fn(jnp.array([0.1, 0.2, 0.3]))
Array(1.32269195, dtype=float64)

Catalyst-compiled functions can now also be automatically differentiated via JAX, both in forward and reverse mode to first-order,

>>> jax.grad(cost_fn)(jnp.array([0.1, 0.2, 0.3]))
Array([0.49249037, 0.05197949, 0.02991883], dtype=float64)

as well as vectorized using jax.vmap:

>>> jax.vmap(cost_fn)(jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]))
Array([1.32269195, 1.53905377], dtype=float64)

In particular, this allows for a reduction in boilerplate when using JAX-compatible optimizers such as jaxopt:

>>> opt = jaxopt.GradientDescent(cost_fn)
>>> params = jnp.array([0.1, 0.2, 0.3])
>>> (final_params, _) = jax.jit(opt.run)(params)
>>> final_params
Array([-0.00320799,  0.03475223,  0.29362844], dtype=float64)