# Sharp bits and debugging tips¶

Catalyst is designed to allow you to take the tools and code patterns you are familiar with when exploring quantum computing (such as Python, NumPy, JAX, and PennyLane), while unlocking faster execution and the ability to run hybrid quantum-classical workflows on accelerator devices.

Similar to JAX, Catalyst does this via the `@qjit`

decorator, which captures
hybrid programs written in Python, PennyLane, and JAX, and compiles them to
native machine code — preserving control flow like conditional branches and loops.

With Catalyst, we aim to support as many idiomatic PennyLane and JAX
hybrid workflow programs as possible, however there will be **various
restrictions and constraints that should be taken into account**.

Here, we aim to provide an overview of the restrictions and constraints (the ‘sharp bits’), as well as debugging tips and common patterns that are helpful when using Catalyst.

Note

For a more general overview of Catalyst, please see the quick start guide.

## Compile-time vs. runtime¶

An important distinction to make in Catalyst, which we typically don’t have to
worry about with standard PennyLane, is the concept of **compile time**
vs. **runtime**.

Very roughly, the following three processes occur when using the `@qjit`

decorator
with just-in-time (JIT) compilation.

**Program capture or tracing:**When the`@qjit`

decorated function is first called (or, when the`@qjit`

is first applied if using function type hints and ahead-of-time mode), Catalyst will ‘capture’ the entire hybrid workflow with**placeholder variables of unknown value**used as the function arguments (the**runtime arguments**).These symbolic tracer objects represent

**dynamic variables**, and are used to determine how the JIT compiled function transforms its inputs to outputs.**Compilation:**The captured program is then compiled to a parametrized binary using the Catalyst compiler.**Execution:**Finally, the compiled function is executed with the provided numerical function inputs, and the results returned.

Once the function is first compiled, subsequent executions of the function will simply re-use the previous compiled binary, allowing steps (1) and (2) to be skipped. (Note: some cases, such as when the function argument types change, may trigger re-compilation.)

For example, consider the following, where we print out a variable in the middle of
our `@qjit`

compiled function:

```
>>> @qjit
... def f(x):
... print(f"x = {x}")
... return x ** 2
>>> f(2.)
x = Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
array(4.)
>>> f(3.)
array(9.)
```

We can see that on the first execution, program capture/tracing occurs, and we
can see the dynamic variable is printed (tracers capture *type*
and *shape*, but not numeric value). This captured program is compiled, and
then the binary is executed directly to return the function value — the
print statement is never invoked with the numerical value of `x`

.

When we execute the function again, steps (1) and (2) are skipped since we have already compiled a binary; the binary is called directly to get the function result, and again the print statement is never hit.

This allows us to distinguish between computations that happen
at **compile-time** (steps 1 and 2), such as the `print`

statement above,
and those that happen at **runtime** (step 3).

Note

As a general rule of thumb, things that happen at compile-time are slow (or lead to slowdowns), while things that happen at runtime are fast (or lead to speadups).

However, if the same computation is repeated every time the compiled function is run (where the results are the same no matter the inputs), and it is expensive, then it may be worth doing the computation once in Python and use the results statically in the program.

However, computations at compile time cannot depend on the value of
dynamic variable, since this is not known yet. It can only depend
on **static variables**, those whose values are known.

Note

A general guideline when working with JIT compilation and Catalyst:

Python control flow and third party libraries like NumPy and SciPy will be evaluated at compile-time, and can only accept static variables.

JAX functions, such as

`jax.numpy`

, and Catalyst functions like`cond()`

and`for_loop()`

will be evaluated at runtime, and can accept dynamic variables.

Note that if AutoGraph is enabled, Catalyst will attempt to convert Python control flow to its Catalyst equivalent to support dynamic variables.

For example, consider the following:

```
>>> @qjit
... def f(x):
... if x > 5:
... x = x / 2
... return x ** 2
>>> f(2.)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at <ipython-input-15-2aa7bf60efbb>:1 for make_jaxpr.
This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
```

This function will fail, as the Python `if`

statement cannot accept a
dynamic variable (a JAX tracer) as an argument.

Instead, we can use Catalyst control flow `cond()`

here:

```
>>> @qjit
... def f(x):
... @cond(x > 5.)
... def g():
... return x / 2
... @g.otherwise
... def h():
... return x
... return g() ** 2
>>> f(2.)
array(4.)
>>> f(6.)
array(9.)
```

Here, both conditional branches are compiled, and only evaluated at runtime
when the value of `x`

is known.

Note that, if the Python `if`

statement depends only on values that are
static (known at compile time), this is fine — the `if`

statement will
simply be evaluated at compile time rather than runtime:

Let’s consider an example where a for loop is evaluated at compile time:

```
>>> @qjit
... def f(x):
... for i in range(2):
... print(i, x)
... x = x / 2
... return x ** 2
>>> f(2.)
0 Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
1 Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
array(0.25)
```

Here, the for loop is evaluated at compile time (notice the multiple tracers that have been printed out during program capture — one for each loop!), rather than runtime.

Note

AutoGraph is an experimental feature that converts Python control flow that depends on dynamic variables to Catalyst control flow behind the scenes:

```
>>> @qjit(autograph=True)
... def f(x):
... if x > 5.:
... print(x)
... x = x / 2
... return x ** 2
>>> Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
... array(4.)
>>> f(6.)
... array(9.)
```

For more details, see the AutoGraph guide.

## Printing at runtime¶

In the previous section, we saw that the Python `print`

statement will only
be executed during tracing/compilation, and in particular, will not print
out the value of dynamic variables (since their values are only known at *runtime*).

If we wish to print the value of variables at *runtime*, we can instead use the
`catalyst.debug.print()`

function:

```
>>> from catalyst import debug
>>> @qjit
... def g(x):
... debug.print(x)
... return x ** 2
>>> g(2.)
[2.]
array(4.)
```

## Avoiding recompilation¶

In general in Catalyst, recompilation of a QJIT-compiled function will usually
occur when the function is called with different **argument types**
and **shapes**.

For example, consider the following:

```
>>> @qjit
... def f(x, y):
... print("Tracing occuring")
... return x ** 2 + y
>>> f(0.4, 1)
Tracing occuring
array(1.16)
>>> f(0.2, 3)
array(3.04)
```

However, if we change the argument types in a way where Catalyst can’t perform auto-type promotion before passing the argument to the comppiled function (e.g., passing a float instead of an integer), recompilation will occur:

```
>>> f(0.15, 0.65)
Tracing occuring
array(0.6725)
```

However, changing a float to an integer will not cause recompilation:

```
>>> f(2, 4.65)
array(8.65)
```

Similarly, changing the shape of an array will also trigger recompilation:

```
>>> f(jnp.array([0.2]), jnp.array([0.6]))
Tracing occuring
array([0.64])
>>> f(jnp.array([0.8]), jnp.array([1.6]))
array([2.24])
>>> f(jnp.array([0.8, 0.1]), jnp.array([1.6, -2.0]))
Tracing occuring
array([ 2.24, -1.99])
```

This is something to be aware of, especially when porting existing PennyLane code to work with Catalyst. For example, consider the following, where the size of the input argument determines the number of qubits and gates used:

```
dev = qml.device("lightning.qubit", wires=4)
@qjit
@qml.qnode(dev)
def circuit(x):
print("Tracing occurring")
def loop_fn(i):
qml.RX(x[i], wires=i)
for_loop(0, x.shape[0], 1)(loop_fn)()
return qml.expval(qml.PauliZ(0))
```

This will run correctly, but tracing and recompilation will occur with every function execution:

```
>>> circuit(jnp.array([0.1, 0.2]))
Tracing occurring
array(0.99500417)
>>> circuit(jnp.array([0.1, 0.2, 0.3]))
Tracing occurring
array(0.99500417)
```

To be explicitly warned about recompilation, you can use ahead-of-time (AOT) mode, by specifying types and shapes in the function signature directly:

```
>>> @qjit
... @qml.qnode(dev)
... def circuit(x: jax.core.ShapedArray((3,), dtype=np.float64)):
... print("Tracing occurring")
... def loop_fn(i):
... qml.RX(x[i], wires=i)
... for_loop(0, x.shape[0], 1)(loop_fn)()
... return qml.expval(qml.PauliZ(0))
Tracing occurring
```

Note that compilation now happens on **function definition**. We can execute
the compiled function as long as the arguments match the specified shapes and
type:

```
>>> circuit(jnp.array([0.1, 0.2, 0.3]))
array(0.99500417)
>>> circuit(jnp.array([1.4, 1.4, 0.3]))
array(0.16996714)
```

However, deviating from this will result in recompilation and a warning message:

```
>>> circuit(jnp.array([1.4, 1.4, 0.3, 0.1]))
catalyst/compilation_pipelines.py:592:
UserWarning: Provided arguments did not match declared signature, recompiling...
Tracing occurring
array(0.16996714)
```

## Try and compile the full workflow¶

When porting your PennyLane code to work with Catalyst and `@qjit`

, the
biggest performance advantage you will see is if you compile
your *entire* workflow, not just the QNodes. So think about putting
everything inside your JIT-compiled function, including for loops
(including optimization loops), gradient calls, etc.

Consider the following PennyLane example, where we have a parametrized circuit, are measuring an expectation value, and are optimizing the result:

```
dev = qml.device("default.qubit", wires=4)
@qml.qnode(dev)
def cost(weights, data):
qml.AngleEmbedding(data, wires=range(4))
for x in weights:
# each trainable layer
for i in range(4):
# for each wire
if x[i] > 0:
qml.RX(x[i], wires=i)
elif x[i] < 0:
qml.RY(x[i], wires=i)
for i in range(4):
qml.CNOT(wires=[i, (i + 1) % 4])
return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))
weights = jnp.array(2 * np.random.random([5, 4]) - 1)
data = jnp.array(np.random.random([4]))
opt = jaxopt.GradientDescent(cost, stepsize=0.4, jit=False)
params = weights
state = opt.init_state(params)
for i in range(200):
(params, _) = tuple(opt.update(params, state, data))
```

Using PennyLane v0.32 on Google Colab with the Python 3 Google Compute Engine backend, this optimization takes 3min 28s ± 2.05s to complete.

Let’s switch over to Lightning,
our high-performance statevector simulator,
alongside the adjoint differentiation method. To do so, we change the first
two lines of the above code-block to set the device as `"lightning.qubit"`

,
and specify `diff_method="adjoint"`

in the QNode decorator. With this
change, we have reduced the execution time down to 30.7s ± 1.8s.

We can rewrite this QNode to use Catalyst control flow, and compile it using Catalyst:

```
dev = qml.device("lightning.qubit", wires=4)
@qjit
@qml.qnode(dev)
def cost(weights, data):
qml.AngleEmbedding(data, wires=range(4))
def layer_loop(i):
x = weights[i]
def wire_loop(j):
@cond(x[j] > 0)
def trainable_gate():
qml.RX(x[j], wires=j)
@trainable_gate.else_if(x[j] < 0)
def negative_gate():
qml.RY(x[j], wires=j)
trainable_gate.otherwise(lambda: None)
trainable_gate()
def cnot_loop(j):
qml.CNOT(wires=[j, jnp.mod((j + 1), 4)])
for_loop(0, 4, 1)(wire_loop)()
for_loop(0, 4, 1)(cnot_loop)()
for_loop(0, jnp.shape(weights)[0], 1)(layer_loop)()
return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))
opt = jaxopt.GradientDescent(cost, stepsize=0.4)
params = weights
state = opt.init_state(params)
for i in range(200):
(params, _) = tuple(opt.update(params, state, data))
```

With the quantum function qjit-compiled, the optimization loop now takes 16.4s ± 1.51s.

However, while the quantum function is now compiled, and the compiled function is called to compute cost and gradient values, the optimization loop is still occuring in Python.

Instead, we can write the optimization loop itself as a function and decorate
it with `@qjit`

; this will compile the optimization loop, and allow the full
optimization to take place within Catalyst:

```
@qjit
def optimize(init_weights, data, steps):
def loss(x):
dy = grad(cost, argnum=0)(x, data)[0]
return (cost(x, data), dy)
opt = jaxopt.GradientDescent(loss, stepsize=0.4, value_and_grad=True)
update_step = lambda i, *args: tuple(opt.update(*args))
params = init_weights
state = opt.init_state(params)
return for_loop(0, steps, 1)(update_step)(params, state)[0]
```

The optimization now takes 574ms ± 43.1ms to complete when using 200 steps.
Note that, to compute hybrid quantum-classical gradients within a qjit-compiled function,
the `catalyst.grad()`

function must be used.

## JAX support and restrictions¶

Catalyst utilizes JAX for program capture, which means you are able to
leverage the many functions accessible in `jax`

and `jax.numpy`

to write
code that supports `@qjit`

and dynamic variables.

Currently, we are aiming to support as many JAX functions as possible, however there may be cases where there is missing coverage. Known JAX functionality that doesn’t work with Catalyst includes:

`jax.numpy.polyfit`

`jax.numpy.fft`

`jax.debug`

`jax.numpy.ndarray.at[index]`

when`index`

corresponds to all array indices.

If you come across any other JAX functions that don’t work with Catalyst (and don’t already have a Catalyst equivalent), please let us know by opening a GitHub issue.

While leveraging `jax.numpy`

makes it easy to port over NumPy-based
PennyLane workflows to Catalyst, we also inherit various restrictions
and ‘gotchas’ from JAX.
This includes:

**Pure functions**: Compilation is primarily designed to only work on pure functions. That is, functions that do not have any side-effects; the output is purely dependent only on function inputs.**In-place array updates**: Rather than using in-place array updates, the syntax`new_array = jax_array.at[index].set(value)`

should be used. For more details, see jax.numpy.ndarray.at.**Lack of stateful random number generators**: In JAX, random number generators are stateless, and the key state must be explicitly updated each time you want to compute a random number. For more details, see the JAX documentation.**Dynamic-shaped arrays:**Functions that create or return arrays with dynamic shape — that is, arrays where their shape is determined by a dynamic variable at runtime – are currently not supported in JAX nor Catalyst. Typically, workarounds involve rewriting the code to utilize`jnp.where`

where possible.

For more details, please see the JAX documentation.

## JAX integration¶

Compiled functions remain JAX compatible, and you can call JAX transformations
on them, such as `jax.grad`

and `jax.vmap`

. You can even call `jax.jit`

on functions that call qjit-compiled functions:

```
>>> dev = qml.device("lightning.qubit", wires=2)
>>> @qjit
... @qml.qnode(dev)
... def circuit(x):
... qml.RX(x, wires=0)
... return qml.expval(qml.PauliZ(0))
>>> @jax.jit
... def workflow(y):
... return jax.grad(circuit)(jnp.sin(y))
>>> workflow(0.6)
Array(-0.53511382, dtype=float64, weak_type=True)
```

However, a `jax.jit`

function calling a `qjit`

function will always result
in a callback to Python, so will be slower than if the function was purely compiled
using `jax.jit`

or `qjit`

.

If you want to compile some functionality that is not currently Catalyst
compatible, or you want to make use of JAX-supported hardware such as TPUs
for classical processing, mixing `jax.jit`

and `qjit`

will allow this.
However, if possible, try to always use `qjit`

to compile your entire
workflow.

## Internal QJIT transformations¶

Inside of a qjit-compiled function, JAX transformations
(`jax.grad`

, `jax.jacobian`

, `jax.vmap`

, etc.)
can be used **as long as they are not applied to quantum processing**.

```
>>> @qjit
... def f(x):
... def g(y):
... return -jnp.sin(y) ** 2
... return jax.grad(g)(x)
>>> f(0.4)
array(-0.71735609)
```

If they are applied to quantum processing, an error will occur:

```
>>> @qjit
... def f(x):
... @qml.qnode(dev)
... def g(y):
... qml.RX(y, wires=0)
... return qml.expval(qml.PauliX(0))
... return jax.grad(lambda y: g(y) ** 2)(x)
>>> f(0.4)
NotImplementedError: must override
```

Instead, only Catalyst transformations will work when applied to hybrid quantum-classical processing:

```
>>> @qjit
... def f(x):
... @qml.qnode(dev)
... def g(y):
... qml.RX(y, wires=0)
... return qml.expval(qml.PauliZ(0))
... return grad(lambda y: g(y) ** 2)(x)
>>> f(0.4)
array(-0.71735609)
```

Always use the equivalent Catalyst transformation
(`catalyst.grad()`

, `catalyst.jacobian()`

, `catalyst.vjp()`

, `catalyst.jvp()`

)
inside of a qjit-compiled function.

## Inspecting and drawing circuits¶

A useful tool for debugging quantum algorithms is the ability to draw them. Currently,
`@qjit`

compiled QNodes can be used as input to
`qml.draw`

, with the following caveats:

The

`qml.draw()`

function will only accept plain QNodes as input,*or*QNodes that have been qjit-compiled. It will not accept arbitrary hybrid functions (that may contain QNodes).The

`catalyst.measure()`

function is not supported in drawn QNodesCatalyst conditional functions, such as

`cond()`

and`for_loop()`

, will be ‘unrolled’. That is, the drawn circuit will be a straight-line circuit, without any of the control flow represented explicitly.

For example,

```
@qjit
@qml.qnode(dev)
def circuit(x):
def measurement_loop(i, y):
qml.RX(y, wires=0)
qml.RY(y ** 2, wires=1)
qml.CNOT(wires=[0, 1])
@cond(y < 0.5)
def cond_gate():
qml.CRX(y * jnp.exp(- y ** 2), wires=[0, 1])
cond_gate()
return y * 2
for_loop(0, 3, step=1)(measurement_loop)(x)
return qml.expval(qml.PauliZ(0))
```

```
>>> print(qml.draw(circuit)(0.3))
0: ──RX(0.30)─╭●─╭●─────────RX(0.60)─╭●──RX(1.20)─╭●─┤ <Z>
1: ──RY(0.09)─╰X─╰RX(0.27)──RY(0.36)─╰X──RY(1.44)─╰X─┤
```

At the moment, additional PennyLane circuit inspection functions are not supported with Catalyst.

## Conditional debugging¶

Note

See our AutoGraph guide for seamless conversion of native Python control flow to QJIT compatible control flow.

There are various constraints and restrictions that should be kept in mind when working with classical control in Catalyst.

The return values of all branches of

`cond()`

do not have to be the same type; Catalyst will perform automatic type promotion (for example, converting integers) to floats) where possible.>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 # float ... @cond_fn.otherwise ... def else_branch(): ... return 6. # float ... return cond_fn() >>> f(1.5) array(6.)

There may be some cases where automatic type promotion cannot be applied; for example, ommitting a return value in one branch (e.g., which by default in Python is equivalent to returning

`None`

) but not in others. This will result in an error — if other branches do return values, the else branch must be specified.>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 ... return cond_fn() TypeError: Conditional requires consistent return types across all branches, got: - Branch at index 0: [ShapedArray(float64[], weak_type=True)] - Branch at index 1: [] Please specify an else branch if none was specified.

>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 ... @cond_fn.otherwise ... def else_branch(): ... return x ... return cond_fn() >>> f(1.6) array(2.56)

Finally, a reminder that conditional functions provided to

`cond()`

cannot accept any arguments.

## Compatibility with PennyLane transforms¶

PennyLane provides a wide variety of transforms that convert a circuit to one or more circuits.

Currently, most PennyLane transforms will work with Catalyst as long as:

The circuit does not include any Catalyst-specific features, such as Catalyst control flow or measurement,

The QNode returns only lists of measurement processes,

AutoGraph is disabled, and

The transformation does not require or depend on the numeric value of dynamic variables.

This includes transforms that generate many circuits,

```
@qjit
@qml.transforms.split_non_commuting
@qml.qnode(dev)
def circuit(x):
qml.RX(x,wires=0)
return [qml.expval(qml.PauliY(0)), qml.expval(qml.PauliZ(0))]
```

```
>>> circuit(0.4)
[array(-0.51413599), array(0.85770868)]
```

as well as transforms that simply map the circuit to another:

```
@qjit
@qml.transforms.merge_rotations()
@qml.qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
qml.RX(x ** 2, wires=0)
return qml.expval(qml.PauliZ(0))
```

```
>>> circuit(0.5)
array(0.73168887)
```

We can inspect the jaxpr representation of the compiled program, to verify that only a single RX gate is being applied due to the rotation gate merger:

```
>>> circuit.jaxpr
{ lambda ; a:f64[]. let
b:f64[] = func[
call_jaxpr={ lambda ; c:f64[]. let
d:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] c
e:f64[] = integer_pow[y=2] c
f:f64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
g:f64[1] = add d f
h:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] g
i:f64[] = squeeze[dimensions=(0,)] h
= qdevice[
rtd_kwargs={'shots': 0, 'mcmc': False}
rtd_lib=/usr/local/lib/python3.10/dist-packages/catalyst/utils/../lib/librtd_lightning.so
rtd_name=LightningSimulator
]
j:AbstractQreg() = qalloc 2
k:AbstractQbit() = qextract j 0
l:AbstractQbit() = qinst[op=RX qubits_len=1] k i
m:AbstractObs(num_qubits=None,primitive=None) = namedobs[kind=PauliZ] l
n:f64[] = expval[shots=None] m
o:AbstractQreg() = qinsert j 0 l
= qdealloc o
in (n,) }
fn=<QNode: wires=2, device='lightning.qubit', interface='auto', diff_method='best'>
] a
in (b,) }
```

Note that currently PennyLane transforms **cannot** be applied when `autograph=True`

.

## Function argument restrictions¶

Compiled functions can accept arbitrary function arguments, as long as the
inputs can be represented as Pytrees — tree-like
structures built out of Python container objects such as lists, dictionaries,
and tuples — where the *values* (leaf nodes) are compatible types.

Compatible types includes Booleans, Python numeric types, JAX arrays, and PennyLane quantum operators.

Note

Non-numeric types, such as strings, are generally not supported as arguments to compiled functions.

For example, consider the following, where we pass arbitrarily nested lists or dictionaries as input to the compiled function:

```
>>> f = qjit(lambda *args: args)
>>> x = qml.RX(0.4, wires=0)
>>> y = {"apple": (True, jnp.array([0.1, 0.2, 0.3]))}
>>> f(x, y)
(RX(array(0.4), wires=[0]), {'apple': (array(True), array([0.1, 0.2, 0.3]))})
```

Arbitrary objects cannot be passed as function arguments, unless they are registered as Pytrees with compatible data types.

```
>>> class MyObject:
... def __init__(self, x, name):
... self.x = x
... self.name = name
>>> obj = MyObject(jnp.array(0.4), "test")
>>> f(obj)
TypeError: Unsupported argument type: <class '__main__.MyObject'>
```

By registring it as a Pytree (that is, specifying to JAX the dynamic and static compile-time information, we make this object compatible with Catalyst:

```
>>> def flatten_fn(my_object):
... data = (my_object.x,) # Dynamic variables
... aux = {"name": my_object.name} # static compile-time data
... return (data, aux)
>>> def unflatten_fn(aux, data):
... return MyObject(data[0], **aux)
>>> register_pytree_node(MyObject, flatten_fn, unflatten_fn)
>>> f(obj)
<__main__.MyObject at 0x7c061434b820>
```

Note that the function will only be re-compiled if the custom objects static
compile-time data changes (in this case, `MyObject.name`

); **not** if the
dynamic part of the custom object (`MyObject.x`

) changes:

```
>>> @qjit
... def f(my_object):
... print("compiling")
... return my_object.x
>>> f(MyObject(jnp.array(0.1), name="test1"))
Compiling: name=test1
array(0.1)
>>> f(MyObject(jnp.array(0.2), name="test1"))
array(0.2)
>>> f(MyObject(jnp.array(0.2), name="test2"))
Compiling: name=test2
array(0.2)
```

Note

JAX provides a `static_argnums`

argument for the `jax.jit`

function,
which allows you to specify which arguments to the compile function to treat
as static compile-time arguments. Changes to these arguments will trigger
re-compilation.

The Catalyst `@qjit`

decorator doesn’t yet support this functionality.

## Dynamically-shaped arrays¶

Catalyst provides experimental support for compiling functions that accept or contain tensors whose dimensions are not know at compile time, without needing to recompile the function when tensor shapes change.

For example, one might consider a case where a dynamic variable specifies the shape of a tensor created within (or returned by) the compiled function:

```
>>> @qjit
... def func(size: int):
... print("Compiling")
... return jax.numpy.ones([size, size], dtype=float)
>>> func(3)
Compiling
array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
>>> func(4)
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
```

We can also pass tensors of variable shape directly as arguments to compiled
functions, however we need to provide the `abstracted_axes`

argument,
to specify which axes of the tensors should be considered dynamic during compilation.

```
>>> @qjit(abstracted_axes={0: "n"})
... def sum_fn(x):
... print("Compiling")
... return jnp.sum(x)
>>> sum_fn(jnp.array([1., 0.5]))
Compiling
array(1.5)
>>> sum_fn(jnp.array([1., 0.5, 0.6]))
array(2.1)
```

Note that failure to specify this argument will cause re-compilation each time input tensor arguments change shape:

```
>>> @qjit
... def sum_fn(x):
... print("Compiling")
... return jnp.sum(x)
>>> sum_fn(jnp.array([1., 0.5]))
Compiling
array(1.5)
>>> sum_fn(jnp.array([1., 0.5, 0.6]))
Compiling
array(2.1)
```

For more details on using `abstracted_axes`

, please see the `qjit()`

documentation.

Note that indexing of dynamically-shaped arrays is not currently supported:

```
>>> @qjit
... def almost_sum_fn(x):
... print("Compiling")
... return jnp.sum(x[0:-1]). # indexing into dynamic array x
>>> almost_sum_fn(jnp.array([1., 0.5]))
IndexError: Cannot use NumPy slice indexing on an array dimension whose
size is not statically known (Traced<ShapedArray(int64[], weak_type=True)>with<
DynamicJaxprTrace(level=1/0)>). Try using lax.dynamic_slice/dynamic_update_slice
```

Similarly, using dynamically-shaped arrays within for loops, while loops, and conditional statements, is not currently supported:

```
>>> @qjit
... def f(size):
... a = jnp.ones([size], dtype=float)
... for i in range(10):
... a = a
... @for_loop(0, 10, 2)
... def loop(_, a):
... return a
... return loop(a)
KeyError: 137774138140016
```

## Returning multiple measurements¶

A common pattern in PennyLane is to have multiple return statements within a single QNode, allowing the measurement type to alter based on some condition:

```
dev = qml.device("default.qubit", wires=2, shots=10)
@qml.qnode(dev)
def circuit(x, sample=False):
qml.RX(x, wires=0)
if sample:
return qml.sample(wires=0)
return qml.expval(qml.PauliZ(0))
```

This pattern is currently not supported in Catalyst, and will lead to an error:

```
dev = qml.device("lightning.qubit", wires=2, shots=10)
@qjit
@qml.qnode(dev)
def circuit(x, sample=False):
qml.RX(x, wires=0)
@cond(sample)
def measure_fn():
return qml.sample(wires=0)
@measure_fn.otherwise
def expval():
return qml.expval(qml.PauliZ(0))
return measure_fn()
```

```
>>> circuit(3)
TypeError: Value sample(wires=[0]) with type <class 'pennylane.measurements.sample.SampleMP'> is not a valid JAX type
```

It is recommended for now to create separate QNodes if different measurement statistics need to be returned, or alternatively using a single return statement with multiple measurements:

```
>>> @qjit
... @qml.qnode(dev)
... def circuit(x):
... qml.RX(x, wires=0)
... return {"samples": qml.sample(), "expval": qml.expval(qml.PauliZ(0))}
>>> circuit(0.3)
{'expval': array(-0.9899925),
'samples': array([[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.],
[0., 0.],
[1., 0.],
[1., 0.]])}
```

## Recursion¶

Recursion is not currently supported, and will result in errors. For example,

```
@qjit(autograph=True)
def fibonacci(n: int):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)
```

```
>>> fibonacci(10)
RecursionError: maximum recursion depth exceeded in comparison
```

This is due to the fact that during compilation, Catalyst tries to evaluate
both branches of the conditional statement recursively; because there is
`n`

is a dynamic variable, it has no concrete value at compile time, and
tracing can never complete.

Instead, try to write your program without recursion. For example, in this case we can use a while loop:

```
@qjit
def fibonacci(n):
@catalyst.while_loop(lambda count, *args: count < n)
def loop_fn(count, a, b, sum):
a, b = b, sum
sum = a + b
return count + 1, a, b, sum
_, _, _, result = loop_fn(1, 0, 1, 1)
return result
```

```
>>> fibonacci(10)
array(89)
```

## Compatibility with broadcasting¶

Catalyst does not currently support passing multi-dimensional arrays as quantum operator parameters (‘parameter broadcasting’):

```
>>> @qml.qnode(dev)
... def circuit(x):
... qml.RX(x, wires=0)
... qml.RY(0.1, wires=0)
... return qml.expval(qml.PauliZ(0))
>>> circuit(jnp.array([0.1, 0.2]))
Array([0.99003329, 0.97517033], dtype=float64)
>>> qjit(circuit)(jnp.array([0.1, 0.2]))
UnboundLocalError: local variable 'baseType' referenced before assignment
```

While not as flexible as true vectorized quantum operations, as a workaround
`jax.vmap`

can be used to allow for multi-dimensional **function**
arguments:

```
>>> jax.vmap(qjit(circuit))(jnp.array([0.1, 0.2]))
Array([0.99003329, 0.97517033], dtype=float64)
```

Note that `jax.vmap`

cannot be used within a qjit-compiled function:

```
>>> qjit(jax.vmap(circuit))(jnp.array([0.1, 0.2]))
NotImplementedError: Batching rule for 'qinst' not implemented
```

## Functionality differences from PennyLane¶

The ultimate aim with Catalyst will be the ability to prototype quantum algorithms
in Python with PennyLane, and easily scale up prototypes by simply adding `@qjit`

.
This will require that all PennyLane functionality behaves identically whether or not
the `@qjit`

decorator is applied.

Currently, however, this is not the case for measurements.

**Measurement behaviour**.`catalyst.measure()`

currently behaves differently from its PennyLane counterpart`pennylane.measure()`

. In particular:Final measurement statistics occurring after

`pennylane.measure()`

will average over all potential measurements, weighted by their likelihood.Final measurement statistics occurring after

`catalyst.measure()`

will be post-selected on the outcome that was measured. The post-selected measurement will change with every execution.