catalyst.qjit¶
- qjit(fn=None, *, autograph=False, autograph_include=(), async_qnodes=False, target='binary', keep_intermediate=False, verbose=False, logfile=None, pipelines=None, static_argnums=None, static_argnames=None, abstracted_axes=None, disable_assertions=False, seed=None, experimental_capture=False, circuit_transform_pipeline=None, pass_plugins=None, dialect_plugins=None)[source]¶
A just-in-time decorator for PennyLane and JAX programs using Catalyst.
This decorator enables both just-in-time and ahead-of-time compilation, depending on whether function argument type hints are provided.
Note
Not all PennyLane devices currently work with Catalyst. Supported backend devices include
lightning.qubit
,lightning.kokkos
,lightning.gpu
, andbraket.aws.qubit
. For a full of supported devices, please see Supported devices.- Parameters
fn (Callable) – the quantum or classical function
autograph (bool) – Experimental support for automatically converting Python control flow statements to Catalyst-compatible control flow. Currently supports Python
if
,elif
,else
, andfor
statements. Note that this feature requires an available TensorFlow installation. For more details, see the AutoGraph guide.autograph_include – A list of (sub)modules to be allow-listed for autograph conversion.
async_qnodes (bool) – Experimental support for automatically executing QNodes asynchronously, if supported by the device runtime.
target (str) – the compilation target
keep_intermediate (bool) – Whether or not to store the intermediate files throughout the compilation. If
True
, intermediate representations are available via themlir
,jaxpr
, andqir
, representing different stages in the optimization process.verbose (bool) – If
True
, the tools and flags used by Catalyst behind the scenes are printed out.logfile (Optional[TextIOWrapper]) – File object to write verbose messages to (default -
sys.stderr
).pipelines (Optional(List[Tuple[str,List[str]]])) – A list of pipelines to be executed. The elements of this list are named sequences of MLIR passes to be executed. A
None
value (the default) results in the execution of the default pipeline. This option is considered to be used by advanced users for low-level debugging purposes.static_argnums (int or Seqence[Int]) – an index or a sequence of indices that specifies the positions of static arguments.
static_argnames (str or Seqence[str]) – a string or a sequence of strings that specifies the names of static arguments.
abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]) – An experimental option to specify dynamic tensor shapes. This option affects the compilation of the annotated function. Function arguments with
abstracted_axes
specified will be compiled to ranked tensors with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section below.disable_assertions (bool) – If set to
True
, runtime assertions included infn
viadebug_assert()
will be disabled during compilation.seed (Optional[Int]) – The seed for circuit readout results when the qjit-compiled function is executed on simulator devices including
lightning.qubit
,lightning.kokkos
, andlightning.gpu
. The default value is None, which means no seeding is performed, and all processes are random. A seed is expected to be an unsigned 32-bit integer. Currently, the following measurement processes are seeded:measure()
,qml.sample()
,qml.counts()
,qml.probs()
,qml.expval()
,qml.var()
.experimental_capture (bool) – If set to
True
, the qjit decorator will use PennyLane’s experimental program capture capabilities to capture the decorated function for compilation.circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]) – A dictionary that specifies the quantum circuit transformation pass pipeline order, and optionally arguments for each pass in the pipeline. Keys of this dictionary should correspond to names of passes found in the catalyst.passes module, values should either be empty dictionaries (for default pass options) or dictionaries of valid keyword arguments and values for the specific pass. The order of keys in this dictionary will determine the pass pipeline. If not specified, the default pass pipeline will be applied.
pass_plugins (Optional[List[Path]]) – List of paths to pass plugins.
dialect_plugins (Optional[List[Path]]) – List of paths to dialect plugins.
- Returns
QJIT object.
- Raises
FileExistsError – Unable to create temporary directory
PermissionError – Problems creating temporary directory
OSError – Problems while creating folder for intermediate files
AutoGraphError – Raised if there was an issue converting the given the function(s).
ImportError – Raised if AutoGraph is turned on and TensorFlow could not be found.
Example
In just-in-time (JIT) mode, the compilation is triggered at the call site the first time the quantum function is executed. For example,
circuit
is compiled as early as the first call.@qjit @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(theta): qml.Hadamard(wires=0) qml.RX(theta, wires=1) qml.CNOT(wires=[0,1]) return qml.expval(qml.PauliZ(wires=1))
>>> circuit(0.5) # the first call, compilation occurs here Array(0., dtype=float64) >>> circuit(0.5) # the precompiled quantum function is called Array(0., dtype=float64)
Alternatively, if argument type hints are provided, compilation can occur ‘ahead of time’ when the function is decorated.
dev = qml.device("lightning.qubit", wires=2) @qjit @qml.qnode(dev) def circuit(x: complex, z: jax.ShapeDtypeStruct((3,), jnp.float64)): theta = jnp.abs(x) qml.RY(theta, wires=0) qml.Rot(z[0], z[1], z[2], wires=0) return qml.state()
>>> circuit(0.2j, jnp.array([0.3, 0.6, 0.9])) # calls precompiled function Array([0.75634905-0.52801002j, 0. +0.j , 0.35962678+0.14074839j, 0. +0.j ], dtype=complex128)
For more details on compilation and debugging, please see Sharp bits and debugging tips.
AutoGraph and Python control flow
Catalyst also supports capturing imperative Python control flow in compiled programs. You can enable this feature via the
autograph=True
parameter. Note that it does come with some restrictions, in particular whenever global state is involved. Refer to the AutoGraph guide for a complete discussion of the supported and unsupported use-cases.@qjit(autograph=True) @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(x: int): if x < 5: qml.Hadamard(wires=0) else: qml.T(wires=0) return qml.expval(qml.PauliZ(0))
>>> circuit(3) Array(0., dtype=float64)
>>> circuit(5) Array(1., dtype=float64)
Note that imperative control flow will still work in Catalyst even when the AutoGraph feature is turned off, it just won’t be captured in the compiled program and cannot involve traced values. The example above would then raise a tracing error, as there is no value for
x
yet than can be compared in the if statement. A loop likefor i in range(5)
would be unrolled during tracing, “copy-pasting” the body 5 times into the program rather than appearing as is.In-place JAX array updates with Autograph
To update array values when using JAX, the JAX syntax for array modification (which uses methods like
at
,set
,multiply
, etc) must be used:@qjit(autograph=True) def f(x): first_dim = x.shape[0] result = jnp.empty((first_dim,), dtype=x.dtype) for i in range(first_dim): result = result.at[i].set(x[i]) result = result.at[i].multiply(10) result = result.at[i].add(5) return result
However, if updating a single index or slice of the array, Autograph supports conversion of Python’s standard arithmatic array assignment operators to the equivalent in-place expressions listed in the JAX documentation for
jax.numpy.ndarray.at
:@qjit(autograph=True) def f(x): first_dim = x.shape[0] result = jnp.empty((first_dim,), dtype=x.dtype) for i in range(first_dim): result[i] = x[i] result[i] *= 10 result[i] += 5 return result
Under the hood, Catalyst converts anything coming in the latter notation into the former one.
The list of supported operators includes:
=
,+=
,-=
,*=
,/=
, and**=
.Static arguments
static_argnums
defines which positional arguments should be treated as static. If it takes an integer, it means the argument whose index is equal to the integer is static. If it takes an iterable of integers, arguments whose index is contained in the iterable are static. Changing static arguments will introduce re-compilation.static_argnames
defines which named function arguments should be treated as static.
A valid static argument must be hashable and its
__hash__
method must be able to reflect any changes of its attributes.@dataclass class MyClass: val: int def __hash__(self): return hash(str(self)) @qjit(static_argnums=1) def f( x: int, y: MyClass, ): return x + y.val f(1, MyClass(5)) f(1, MyClass(6)) # re-compilation f(2, MyClass(5)) # no re-compilation
In the example above,
y
is static. Note that the second function call triggers re-compilation since the input object is different from the previous one. However, the third function call direcly uses the previous compiled one and does not introduce re-compilation.@dataclass class MyClass: val: int def __hash__(self): return hash(str(self)) @qjit(static_argnums=(1, 2)) def f( x: int, y: MyClass, z: MyClass, ): return x + y.val + z.val my_obj_1 = MyClass(5) my_obj_2 = MyClass(6) f(1, my_obj_1, my_obj_2) my_obj_1.val = 7 f(1, my_obj_1, my_obj_2) # re-compilation
In the example above,
y
andz
are static. The second function should make functionf
be re-compiled becausemy_obj_1
is changed. This requires that the mutation is properly reflected in the hash value.Note that even when
static_argnums
is used in conjunction with type hinting, ahead-of-time compilation will not be possible since the static argument values are not yet available. Instead, compilation will be just-in-time.Dynamically-shaped arrays
There are three ways to use
abstracted_axes
; by passing a sequence of tuples, a dictionary, or a sequence of dictionaries. Passing a sequence of tuples:abstracted_axes=((), ('n',), ('m', 'n'))
Each tuple in the sequence corresponds to one of the arguments in the annotated function. Empty tuples can be used and correspond to parameters with statically known shapes. Non-empty tuples correspond to parameters with dynamically known shapes.
In this example above,
the first argument will have a statically known shape,
the second argument has its zeroth axis have dynamic shape
n
, andthe third argument will have its zeroth axis with dynamic shape
m
and first axis with dynamic shapen
.
Passing a dictionary:
abstracted_axes={0: 'n'}
This approach allows a concise expression of the relationships between axes for different function arguments. In this example, it specifies that for all function arguments, the zeroth axis will have dynamic shape
n
.Passing a sequence of dictionaries:
abstracted_axes=({}, {0: 'n'}, {1: 'm', 0: 'n'})
The example here is a more verbose version of the tuple example. This convention allows axes to be omitted from the list of abstracted axes.
Using
abstracted_axes
can help avoid the cost of recompilation. By usingabstracted_axes
, a more general version of the compiled function will be generated. This more general version is parametrized over the abstracted axes and allows results to be computed over tensors independently of their axes lengths.For example:
@qjit def sum(arr): return jnp.sum(arr) sum(jnp.array([1])) # Compilation happens here. sum(jnp.array([1, 1])) # And here!
The
sum
function would recompile each time an array of different size is passed as an argument.@qjit(abstracted_axes={0: "n"}) def sum_abstracted(arr): return jnp.sum(arr) sum(jnp.array([1])) # Compilation happens here. sum(jnp.array([1, 1])) # No need to recompile.
the
sum_abstracted
function would only compile once and its definition would be reused for subsequent function calls.