# JAX integration¶

Catalyst allows you to write hybrid quantum-classical functions in Python that are just-in-time
compiled with the `qjit()`

decorator, and ultimately leverages modern compilation tools to
speed up quantum applications.

To support this, the Catalyst frontend leverages PennyLane for representing quantum instructions,
and utilizes JAX for classical processing and program capture, which means you are able to leverage
the many functions accessible in `jax`

and `jax.numpy`

to write code that supports
`@qjit`

and dynamic variables.

Here, we aim to provide an overview of the JAX integration, including the existing support and limitations.

## JAX ‘sharp bits’¶

While leveraging `jax.numpy`

makes it easy to port over NumPy-based
PennyLane workflows to Catalyst, we also inherit various restrictions
and ‘gotchas’ from JAX.
This includes:

**Pure functions**: Compilation is primarily designed to only work on pure functions. That is, functions that do not have any side-effects; the output is purely dependent only on function inputs.**Lack of stateful random number generators**: In JAX, random number generators are stateless, and the key state must be explicitly updated each time you want to compute a random number. For more details, see the JAX documentation.**In-place array updates**: Rather than using in-place array updates, the syntax`new_array = jax_array.at[index].set(value)`

should be used. For more details, see jax.numpy.ndarray.at.Note

Support is being added for automatically capturing native Python in-place array update syntax, and automatically converting it to JAX-compatible syntax via our AutoGraph feature.

For more details, please see the JAX documentation.

## JAX control flow¶

It is recommended to always use Catalyst control flow functions `for_loop()`

, `cond()`

,
and `while_loop()`

(or our experimental AutoGraph feature).

However, JAX control flow functions, such as `jax.lax.cond`

and `jax.lax.fori_loop`

, will work
inside qjit-compiled functions **as long as they are not applied directly to quantum instructions**
and only apply outside of QNodes:

```
dev = qml.device("lightning.qubit", wires=4, shots=10)
@qml.qnode(dev)
def circuit(x):
N = x.shape[0]
@catalyst.for_loop(0, N, 1)
def loop_fn(i):
qml.RX(x[i], wires=i)
loop_fn()
return [qml.expval(qml.PauliZ(i)) for i in range(N)]
@qjit
def fn(x):
def cost(j, x):
return jnp.stack(circuit(x))
return jax.lax.fori_loop(0, 10, cost, x)
```

```
>>> fn(jnp.array([0.1, 0.2, 0.3, 0.5]))
Array([0.6, 0.6, 0.8, 1. ], dtype=float64)
```

## Function support¶

Currently, we are aiming to support as many JAX functions as possible, however there may be cases where there is missing coverage. Known JAX functionality that doesn’t work with Catalyst includes:

`jax.numpy.polyfit`

`jax.numpy.fft`

`jax.numpy.ndarray.at[index]`

when`index`

corresponds to all array indices.

If you come across any other JAX functions that don’t work with Catalyst (and don’t already have a Catalyst equivalent), please let us know by opening a GitHub issue.

Note that there is certain JAX functionality we do not expect to or plan to support in Catalyst qjit-compiled functions. This includes:

`jax.debug`

. Please use instead the Catalyst provided`print()`

,`callback()`

, and`pure_callback()`

functions.JAX device placement. Please use instead the

`accelerate()`

decorator.Certain functions in the jax.lax.debug module which are direct wrappers of XLA functionality with no LLVM/MLIR equivalent.

## Dynamically-shaped arrays¶

One common ‘gotcha’ of JAX jit-compiled functions is that they cannot create or return arrays with
dynamic shape — that is, arrays where their shape is determined by a dynamic variable at runtime.
Typically, workarounds involve rewriting the code to utilize `jnp.where`

where possible.

In Catalyst, however, we have enabled support for dynamically-shaped arrays; qjit-compiled functions can accept, create, and return arrays of dynamic shape without triggering re-compilation:

```
>>> @qjit
... def func(size: int):
... print("Compiling")
... return jax.numpy.ones([size, size], dtype=float)
>>> func(3)
Compiling
Array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=float64)
>>> func(4)
Array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float64)
```

Dynamic arrays can be created using `jnp.ones`

and `jnp.zeros`

. Note that `jnp.arange`

and `jnp.linspace`

do not currently support generating dynamically-shaped arrays (however, unlike
`jnp.arange`

, `jnp.linspace`

*does* support dynamic variables for its `start`

and `stop`

arguments).

For more details, see Dynamically-shaped arrays.

## JAX transforms on QJIT functions¶

Compiled functions remain JAX compatible, and you can call JAX transformations
on them, such as `jax.grad`

and `jax.vmap`

. You can even call `jax.jit`

on functions that call qjit-compiled functions:

```
>>> dev = qml.device("lightning.qubit", wires=2)
>>> @qjit
... @qml.qnode(dev)
... def circuit(x):
... qml.RX(x, wires=0)
... return qml.expval(qml.PauliZ(0))
>>> @jax.jit
... def workflow(y):
... return jax.grad(circuit)(jnp.sin(y))
>>> workflow(0.6)
Array(-0.53511382, dtype=float64, weak_type=True)
>>> jax.vmap(circuit)(jnp.array([0.1, 0.2, 0.3]))
Array([0.99500417, 0.98006658, 0.95533649], dtype=float64)
```

However, a `jax.jit`

function calling a `qjit`

function will always result
in a callback to Python, so will be slower than if the function was purely compiled
using `jax.jit`

or `qjit`

.

Note

Best performance will be seen when the Catalyst
`@qjit`

decorator is used to JIT the entire hybrid workflow. However, there
may be cases where you may want to delegate only the quantum part of your
workflow to Catalyst, and let JAX handle classical components.

## Internal QJIT JAX transformations¶

Inside of a qjit-compiled function, JAX transformations
(`jax.grad`

, `jax.jacobian`

, `jax.vmap`

, etc.)
can be used **as long as they are not applied to quantum processing**.

```
>>> @qjit
... def f(x):
... def g(y):
... return -jnp.sin(y) ** 2
... return jax.grad(g)(x)
>>> f(0.4)
Array(-0.71735609, dtype=float64)
```

If they are applied to quantum processing, an error will occur:

```
>>> @qjit
... def f(x):
... @qml.qnode(dev)
... def g(y):
... qml.RX(y, wires=0)
... return qml.expval(qml.PauliX(0))
... return jax.grad(lambda y: g(y) ** 2)(x)
>>> f(0.4)
NotImplementedError: must override
```

Instead, only Catalyst transformations will work when applied to hybrid quantum-classical processing:

```
>>> @qjit
... def f(x):
... @qml.qnode(dev)
... def g(y):
... qml.RX(y, wires=0)
... return qml.expval(qml.PauliZ(0))
... return grad(lambda y: g(y) ** 2)(x)
>>> f(0.4)
Array(-0.71735609, dtype=float64)
```

Always use the equivalent Catalyst transformation
(`catalyst.grad()`

, `catalyst.jacobian()`

, `catalyst.vjp()`

, `catalyst.jvp()`

)
inside of a qjit-compiled function.