Source code for catalyst.jit

# Copyright 2022-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains classes and decorators for just-in-time and ahead-of-time
compilation of hybrid quantum-classical functions using Catalyst.
"""

import copy
import functools
import inspect
import logging
import os
import warnings

import jax
import jax.numpy as jnp
import pennylane as qml
from jax.interpreters import mlir
from jax.tree_util import tree_flatten, tree_unflatten
from malt.core import config as ag_config

import catalyst
from catalyst.autograph import ag_primitives, run_autograph
from catalyst.compiled_functions import CompilationCache, CompiledFunction
from catalyst.compiler import CompileOptions, Compiler
from catalyst.debug.instruments import instrument
from catalyst.from_plxpr import trace_from_pennylane
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.passes import PipelineNameUniquer, _inject_transform_named_sequence
from catalyst.qfunc import QFunc
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import (
    filter_static_args,
    get_abstract_signature,
    get_type_annotations,
    merge_static_argname_into_argnum,
    merge_static_args,
    promote_arguments,
    verify_static_argnums,
)
from catalyst.utils.c_template import mlir_type_to_numpy_type
from catalyst.utils.callables import CatalystCallable
from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import WorkspaceManager
from catalyst.utils.gen_mlir import inject_functions
from catalyst.utils.patching import Patcher

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

# Required for JAX tracer objects as PennyLane wires.
# pylint: disable=unnecessary-lambda
setattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambda x: id(x))

# This flag cannot be set in ``QJIT.get_mlir()`` because values created before
# that function is called must be consistent with the JAX configuration value.
jax.config.update("jax_enable_x64", True)


## API ##
[docs]@debug_logger def qjit( fn=None, *, autograph=False, autograph_include=(), async_qnodes=False, target="binary", keep_intermediate=False, verbose=False, logfile=None, pipelines=None, static_argnums=None, static_argnames=None, abstracted_axes=None, disable_assertions=False, seed=None, experimental_capture=False, circuit_transform_pipeline=None, ): # pylint: disable=too-many-arguments,unused-argument """A just-in-time decorator for PennyLane and JAX programs using Catalyst. This decorator enables both just-in-time and ahead-of-time compilation, depending on whether function argument type hints are provided. .. note:: Not all PennyLane devices currently work with Catalyst. Supported backend devices include ``lightning.qubit``, ``lightning.kokkos``, ``lightning.gpu``, and ``braket.aws.qubit``. For a full of supported devices, please see :doc:`/dev/devices`. Args: fn (Callable): the quantum or classical function autograph (bool): Experimental support for automatically converting Python control flow statements to Catalyst-compatible control flow. Currently supports Python ``if``, ``elif``, ``else``, and ``for`` statements. Note that this feature requires an available TensorFlow installation. For more details, see the :doc:`AutoGraph guide </dev/autograph>`. autograph_include: A list of (sub)modules to be allow-listed for autograph conversion. async_qnodes (bool): Experimental support for automatically executing QNodes asynchronously, if supported by the device runtime. target (str): the compilation target keep_intermediate (bool): Whether or not to store the intermediate files throughout the compilation. If ``True``, intermediate representations are available via the :attr:`~.QJIT.mlir`, :attr:`~.QJIT.jaxpr`, and :attr:`~.QJIT.qir`, representing different stages in the optimization process. verbosity (bool): If ``True``, the tools and flags used by Catalyst behind the scenes are printed out. logfile (Optional[TextIOWrapper]): File object to write verbose messages to (default - ``sys.stderr``). pipelines (Optional(List[Tuple[str,List[str]]])): A list of pipelines to be executed. The elements of this list are named sequences of MLIR passes to be executed. A ``None`` value (the default) results in the execution of the default pipeline. This option is considered to be used by advanced users for low-level debugging purposes. static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the positions of static arguments. static_argnames(str or Seqence[str]): a string or a sequence of strings that specifies the names of static arguments. abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]): An experimental option to specify dynamic tensor shapes. This option affects the compilation of the annotated function. Function arguments with ``abstracted_axes`` specified will be compiled to ranked tensors with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section below. disable_assertions (bool): If set to ``True``, runtime assertions included in ``fn`` via :func:`~.debug_assert` will be disabled during compilation. seed (Optional[Int]): The seed for circuit readout results when the qjit-compiled function is executed on simulator devices including ``lightning.qubit``, ``lightning.kokkos``, and ``lightning.gpu``. The default value is None, which means no seeding is performed, and all processes are random. A seed is expected to be an unsigned 32-bit integer. Currently, the following measurement processes are seeded: :func:`~.measure`, :func:`qml.sample() <pennylane.sample>`, :func:`qml.counts() <pennylane.counts>`. experimental_capture (bool): If set to ``True``, the qjit decorator will use PennyLane's experimental program capture capabilities to capture the decorated function for compilation. circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]): A dictionary that specifies the quantum circuit transformation pass pipeline order, and optionally arguments for each pass in the pipeline. Keys of this dictionary should correspond to names of passes found in the `catalyst.passes <https://docs. pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes>`_ module, values should either be empty dictionaries (for default pass options) or dictionaries of valid keyword arguments and values for the specific pass. The order of keys in this dictionary will determine the pass pipeline. If not specified, the default pass pipeline will be applied. Returns: QJIT object. Raises: FileExistsError: Unable to create temporary directory PermissionError: Problems creating temporary directory OSError: Problems while creating folder for intermediate files AutoGraphError: Raised if there was an issue converting the given the function(s). ImportError: Raised if AutoGraph is turned on and TensorFlow could not be found. **Example** In just-in-time (JIT) mode, the compilation is triggered at the call site the first time the quantum function is executed. For example, ``circuit`` is compiled as early as the first call. .. code-block:: python @qjit @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(theta): qml.Hadamard(wires=0) qml.RX(theta, wires=1) qml.CNOT(wires=[0,1]) return qml.expval(qml.PauliZ(wires=1)) >>> circuit(0.5) # the first call, compilation occurs here Array(0., dtype=float64) >>> circuit(0.5) # the precompiled quantum function is called Array(0., dtype=float64) Alternatively, if argument type hints are provided, compilation can occur 'ahead of time' when the function is decorated. .. code-block:: python dev = qml.device("lightning.qubit", wires=2) @qjit @qml.qnode(dev) def circuit(x: complex, z: jax.ShapeDtypeStruct((3,), jnp.float64)): theta = jnp.abs(x) qml.RY(theta, wires=0) qml.Rot(z[0], z[1], z[2], wires=0) return qml.state() >>> circuit(0.2j, jnp.array([0.3, 0.6, 0.9])) # calls precompiled function Array([0.75634905-0.52801002j, 0. +0.j , 0.35962678+0.14074839j, 0. +0.j ], dtype=complex128) For more details on compilation and debugging, please see :doc:`/dev/sharp_bits`. .. details:: :title: AutoGraph and Python control flow Catalyst also supports capturing imperative Python control flow in compiled programs. You can enable this feature via the ``autograph=True`` parameter. Note that it does come with some restrictions, in particular whenever global state is involved. Refer to the :doc:`AutoGraph guide </dev/autograph>` for a complete discussion of the supported and unsupported use-cases. .. code-block:: python @qjit(autograph=True) @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(x: int): if x < 5: qml.Hadamard(wires=0) else: qml.T(wires=0) return qml.expval(qml.PauliZ(0)) >>> circuit(3) Array(0., dtype=float64) >>> circuit(5) Array(1., dtype=float64) Note that imperative control flow will still work in Catalyst even when the AutoGraph feature is turned off, it just won't be captured in the compiled program and cannot involve traced values. The example above would then raise a tracing error, as there is no value for ``x`` yet than can be compared in the if statement. A loop like ``for i in range(5)`` would be unrolled during tracing, "copy-pasting" the body 5 times into the program rather than appearing as is. .. details:: :title: In-place JAX array updates with Autograph To update array values when using JAX, the JAX syntax for array modification (which uses methods like ``at``, ``set``, ``multiply``, etc) must be used: .. code-block:: python @qjit(autograph=True) def f(x): first_dim = x.shape[0] result = jnp.empty((first_dim,), dtype=x.dtype) for i in range(first_dim): result = result.at[i].set(x[i]) result = result.at[i].multiply(10) result = result.at[i].add(5) return result However, if updating a single index or slice of the array, Autograph supports conversion of Python's standard arithmatic array assignment operators to the equivalent in-place expressions listed in the JAX documentation for ``jax.numpy.ndarray.at``: .. code-block:: python @qjit(autograph=True) def f(x): first_dim = x.shape[0] result = jnp.empty((first_dim,), dtype=x.dtype) for i in range(first_dim): result[i] = x[i] result[i] *= 10 result[i] += 5 return result Under the hood, Catalyst converts anything coming in the latter notation into the former one. The list of supported operators includes: ``=``, ``+=``, ``-=``, ``*=``, ``/=``, and ``**=``. .. details:: :title: Static arguments - ``static_argnums`` defines which positional arguments should be treated as static. If it takes an integer, it means the argument whose index is equal to the integer is static. If it takes an iterable of integers, arguments whose index is contained in the iterable are static. Changing static arguments will introduce re-compilation. - ``static_argnames`` defines which named function arguments should be treated as static. A valid static argument must be hashable and its ``__hash__`` method must be able to reflect any changes of its attributes. .. code-block:: python @dataclass class MyClass: val: int def __hash__(self): return hash(str(self)) @qjit(static_argnums=1) def f( x: int, y: MyClass, ): return x + y.val f(1, MyClass(5)) f(1, MyClass(6)) # re-compilation f(2, MyClass(5)) # no re-compilation In the example above, ``y`` is static. Note that the second function call triggers re-compilation since the input object is different from the previous one. However, the third function call direcly uses the previous compiled one and does not introduce re-compilation. .. code-block:: python @dataclass class MyClass: val: int def __hash__(self): return hash(str(self)) @qjit(static_argnums=(1, 2)) def f( x: int, y: MyClass, z: MyClass, ): return x + y.val + z.val my_obj_1 = MyClass(5) my_obj_2 = MyClass(6) f(1, my_obj_1, my_obj_2) my_obj_1.val = 7 f(1, my_obj_1, my_obj_2) # re-compilation In the example above, ``y`` and ``z`` are static. The second function should make function ``f`` be re-compiled because ``my_obj_1`` is changed. This requires that the mutation is properly reflected in the hash value. Note that even when ``static_argnums`` is used in conjunction with type hinting, ahead-of-time compilation will not be possible since the static argument values are not yet available. Instead, compilation will be just-in-time. .. details:: :title: Dynamically-shaped arrays There are three ways to use ``abstracted_axes``; by passing a sequence of tuples, a dictionary, or a sequence of dictionaries. Passing a sequence of tuples: .. code-block:: python abstracted_axes=((), ('n',), ('m', 'n')) Each tuple in the sequence corresponds to one of the arguments in the annotated function. Empty tuples can be used and correspond to parameters with statically known shapes. Non-empty tuples correspond to parameters with dynamically known shapes. In this example above, - the first argument will have a statically known shape, - the second argument has its zeroth axis have dynamic shape ``n``, and - the third argument will have its zeroth axis with dynamic shape ``m`` and first axis with dynamic shape ``n``. Passing a dictionary: .. code-block:: python abstracted_axes={0: 'n'} This approach allows a concise expression of the relationships between axes for different function arguments. In this example, it specifies that for all function arguments, the zeroth axis will have dynamic shape ``n``. Passing a sequence of dictionaries: .. code-block:: python abstracted_axes=({}, {0: 'n'}, {1: 'm', 0: 'n'}) The example here is a more verbose version of the tuple example. This convention allows axes to be omitted from the list of abstracted axes. Using ``abstracted_axes`` can help avoid the cost of recompilation. By using ``abstracted_axes``, a more general version of the compiled function will be generated. This more general version is parametrized over the abstracted axes and allows results to be computed over tensors independently of their axes lengths. For example: .. code-block:: python @qjit def sum(arr): return jnp.sum(arr) sum(jnp.array([1])) # Compilation happens here. sum(jnp.array([1, 1])) # And here! The ``sum`` function would recompile each time an array of different size is passed as an argument. .. code-block:: python @qjit(abstracted_axes={0: "n"}) def sum_abstracted(arr): return jnp.sum(arr) sum(jnp.array([1])) # Compilation happens here. sum(jnp.array([1, 1])) # No need to recompile. the ``sum_abstracted`` function would only compile once and its definition would be reused for subsequent function calls. """ kwargs = copy.copy(locals()) kwargs.pop("fn") if fn is None: return functools.partial(qjit, **kwargs) return QJIT(fn, CompileOptions(**kwargs))
## IMPL ## # pylint: disable=too-many-instance-attributes
[docs]class QJIT(CatalystCallable): """Class representing a just-in-time compiled hybrid quantum-classical function. .. note:: ``QJIT`` objects are created by the :func:`~.qjit` decorator. Please see the :func:`~.qjit` documentation for more details. Args: fn (Callable): the quantum or classical function to compile compile_options (CompileOptions): compilation options to use :ivar original_function: This attribute stores `fn`, the quantum or classical function object to compile, as is, without any modifications :ivar jaxpr: This attribute stores the Jaxpr compiled from the function as a string. :ivar mlir: This attribute stores the MLIR compiled from the function as a string. :ivar qir: This attribute stores the QIR in LLVM IR form compiled from the function as a string. """ @debug_logger_init def __init__(self, fn, compile_options): functools.update_wrapper(self, fn) self.original_function = fn self.compile_options = compile_options self.compiler = Compiler(compile_options) self.fn_cache = CompilationCache( compile_options.static_argnums, compile_options.abstracted_axes ) # Active state of the compiler. # TODO: rework ownership of workspace, possibly CompiledFunction self.workspace = None self.c_sig = None self.out_treedef = None self.compiled_function = None self.jaxed_function = None # IRs are only available for the most recently traced function. self.jaxpr = None self.mlir = None # string form (historic presence) self.mlir_module = None self.qir = None self.out_type = None self.overwrite_ir = None self.user_sig = get_type_annotations(fn) self._validate_configuration() # If static_argnames are present, convert them to static_argnums if compile_options.static_argnames is not None: compile_options.static_argnums = merge_static_argname_into_argnum( fn, compile_options.static_argnames, compile_options.static_argnums ) # Patch the conversion rules by adding the included modules before the block list include_convertlist = tuple( ag_config.Convert(rule) for rule in self.compile_options.autograph_include ) self.patched_module_allowlist = include_convertlist + ag_primitives.module_allowlist # Pre-compile with the patched conversion rules with Patcher( (ag_primitives, "module_allowlist", self.patched_module_allowlist), ): self.user_function = self.pre_compilation() # Static arguments require values, so we cannot AOT compile. if self.user_sig is not None and not self.compile_options.static_argnums: self.aot_compile() super().__init__("user_function")
[docs] @debug_logger def __call__(self, *args, **kwargs): # Transparantly call Python function in case of nested QJIT calls. if EvaluationContext.is_tracing(): isQNode = isinstance(self.user_function, qml.QNode) if isQNode and self.compile_options.static_argnums: kwargs = {"static_argnums": self.compile_options.static_argnums, **kwargs} return self.user_function(*args, **kwargs) requires_promotion = self.jit_compile(args, **kwargs) # If we receive tracers as input, dispatch to the JAX integration. if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]): if self.jaxed_function is None: self.jaxed_function = JAX_QJIT(self) # lazy gradient compilation return self.jaxed_function(*args, **kwargs) elif requires_promotion: dynamic_args = filter_static_args(args, self.compile_options.static_argnums) args = promote_arguments(self.c_sig, dynamic_args) return self.run(args, kwargs)
[docs] @debug_logger def aot_compile(self): """Compile Python function on initialization using the type hint signature.""" self.workspace = self._get_workspace() # TODO: awkward, refactor or redesign the target feature if self.compile_options.target in ("jaxpr", "mlir", "binary"): # Capture with the patched conversion rules with Patcher( (ag_primitives, "module_allowlist", self.patched_module_allowlist), ): self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture( self.user_sig or () ) if self.compile_options.target in ("mlir", "binary"): self.mlir_module, self.mlir = self.generate_ir() if self.compile_options.target in ("binary",): self.compiled_function, self.qir = self.compile() self.fn_cache.insert( self.compiled_function, self.user_sig, self.out_treedef, self.workspace )
[docs] @debug_logger def jit_compile(self, args, **kwargs): """Compile Python function on invocation using the provided arguments. Args: args (Iterable): arguments to use for program capture Returns: bool: whether the provided arguments will require promotion to be used with the compiled function """ cached_fn, requires_promotion = self.fn_cache.lookup(args) if cached_fn is None: if self.user_sig and not self.compile_options.static_argnums: msg = "Provided arguments did not match declared signature, recompiling..." warnings.warn(msg, UserWarning) # Cleanup before recompilation: # - recompilation should always happen in new workspace # - compiled functions for jax integration are not yet cached # - close existing shared library self.workspace = self._get_workspace() self.jaxed_function = None if self.compiled_function and self.compiled_function.shared_object: self.compiled_function.shared_object.close() # Capture with the patched conversion rules with Patcher( (ag_primitives, "module_allowlist", self.patched_module_allowlist), ): self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture( args, **kwargs ) self.mlir_module, self.mlir = self.generate_ir() self.compiled_function, self.qir = self.compile() self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace) elif self.compiled_function is not cached_fn.compiled_fn: # Restore active state from cache. self.workspace = cached_fn.workspace self.compiled_function = cached_fn.compiled_fn self.out_treedef = cached_fn.out_treedef self.c_sig = cached_fn.signature self.jaxed_function = None self.compiled_function.shared_object.open() return requires_promotion
# Processing Stages #
[docs] @instrument @debug_logger def pre_compilation(self): """Perform pre-processing tasks on the Python function, such as AST transformations.""" processed_fn = self.original_function if self.compile_options.autograph: processed_fn = run_autograph(self.original_function) return processed_fn
[docs] @instrument(size_from=0) @debug_logger def capture(self, args, **kwargs): """Capture the JAX program representation (JAXPR) of the wrapped function. Args: args (Iterable): arguments to use for program capture Returns: ClosedJaxpr: captured JAXPR PyTreeDef: PyTree metadata of the function output Tuple[Any]: the dynamic argument signature """ verify_static_argnums(args, self.compile_options.static_argnums) static_argnums = self.compile_options.static_argnums abstracted_axes = self.compile_options.abstracted_axes dynamic_args = filter_static_args(args, static_argnums) dynamic_sig = get_abstract_signature(dynamic_args) full_sig = merge_static_args(dynamic_sig, args, static_argnums) def fn_with_transform_named_sequence(*args, **kwargs): """ This function behaves exactly like the user function being jitted, taking in the same arguments and producing the same results, except it injects a transform_named_sequence jax primitive at the beginning of the jaxpr when being traced. Note that we do not overwrite self.original_function and self.user_function; this fn_with_transform_named_sequence is ONLY used here to produce tracing results with a transform_named_sequence primitive at the beginning of the jaxpr. It is never executed or used anywhere, except being traced here. """ _inject_transform_named_sequence() return self.user_function(*args, **kwargs) if self.compile_options.experimental_capture: return trace_from_pennylane( fn_with_transform_named_sequence, static_argnums, abstracted_axes, full_sig, kwargs ) def closure(qnode, *args, **kwargs): params = {} params["static_argnums"] = kwargs.pop("static_argnums", static_argnums) params["_out_tree_expected"] = [] return QFunc.__call__( qnode, pass_pipeline=self.compile_options.circuit_transform_pipeline, *args, **dict(params, **kwargs), ) with Patcher( (qml.QNode, "__call__", closure), ): # TODO: improve PyTree handling jaxpr, out_type, treedef = trace_to_jaxpr( fn_with_transform_named_sequence, static_argnums, abstracted_axes, full_sig, kwargs, ) PipelineNameUniquer.reset() return jaxpr, out_type, treedef, dynamic_sig
[docs] @instrument(size_from=0, has_finegrained=True) @debug_logger def generate_ir(self): """Generate Catalyst's intermediate representation (IR) as an MLIR module. Returns: Tuple[ir.Module, str]: the in-memory MLIR module and its string representation """ mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__) # Inject Runtime Library-specific functions (e.g. setup/teardown). inject_functions(mlir_module, ctx, self.compile_options.seed) # Canonicalize the MLIR since there can be a lot of redundancy coming from JAX. options = copy.deepcopy(self.compile_options) options.pipelines = [("0_canonicalize", ["canonicalize"])] options.lower_to_llvm = False canonicalizer = Compiler(options) # TODO: the in-memory and textual form are different after this, consider unification _, mlir_string = canonicalizer.run(mlir_module, self.workspace) return mlir_module, mlir_string
[docs] @instrument(size_from=1, has_finegrained=True) @debug_logger def compile(self): """Compile an MLIR module to LLVMIR and shared library code. Returns: Tuple[CompiledFunction, str]: the compilation result and LLVMIR """ # WARNING: assumption is that the first function is the entry point to the compiled program. entry_point_func = self.mlir_module.body.operations[0] restype = entry_point_func.type.results for res in restype: baseType = mlir.ir.RankedTensorType(res).element_type # This will make a check before sending it to the compiler that the return type # is actually available in most systems. f16 needs a special symbol and linking # will fail if it is not available. mlir_type_to_numpy_type(baseType) # The function name out of MLIR has quotes around it, which we need to remove. # The MLIR function name is actually a derived type from string which has no # `replace` method, so we need to get a regular Python string out of it. func_name = str(self.mlir_module.body.operations[0].name).replace('"', "") if self.overwrite_ir: shared_object, llvm_ir = self.compiler.run_from_ir( self.overwrite_ir, str(self.mlir_module.operation.attributes["sym_name"]).replace('"', ""), self.workspace, ) else: shared_object, llvm_ir = self.compiler.run(self.mlir_module, self.workspace) compiled_fn = CompiledFunction( shared_object, func_name, restype, self.out_type, self.compile_options ) return compiled_fn, llvm_ir
[docs] @instrument(has_finegrained=True) @debug_logger def run(self, args, kwargs): """Invoke a previously compiled function with the supplied arguments. Args: args (Iterable): the positional arguments to the compiled function kwargs: the keyword arguments to the compiled function Returns: Any: results of the execution arranged into the original function's output PyTrees """ results = self.compiled_function(*args, **kwargs) # TODO: Move this to the compiled function object. return tree_unflatten(self.out_treedef, results)
# Helper Methods # def _validate_configuration(self): """Run validations on the supplied options and parameters.""" if not hasattr(self.original_function, "__name__"): self.__name__ = "unknown" # allow these cases anyways? if not self.compile_options.autograph and len(self.compile_options.autograph_include) > 0: raise CompileError( "In order for 'autograph_include' to work, 'autograph' must be set to True" ) def _get_workspace(self): """Get or create a workspace to use for compilation.""" workspace_name = self.__name__ preferred_workspace_dir = os.getcwd() if self.compile_options.keep_intermediate else None return WorkspaceManager.get_or_create_workspace(workspace_name, preferred_workspace_dir)
class JAX_QJIT: """Wrapper class around :class:`~.QJIT` that enables compatibility with JAX transformations. The primary mechanism through which this is effected is by wrapping the invocation of the QJIT object inside a JAX ``pure_callback``. Additionally, a custom JVP is defined in order to support JAX-based differentiation, which is itself a ``pure_callback`` around a second QJIT object which invokes :func:`~.grad` on the original function. Using this class thus incurs additional compilation time. Args: qjit_function (QJIT): the compiled quantum function object to wrap """ @debug_logger_init def __init__(self, qjit_function): @jax.custom_jvp def jaxed_function(*args, **kwargs): return self.wrap_callback(qjit_function, *args, **kwargs) self.qjit_function = qjit_function self.derivative_functions = {} self.jaxed_function = jaxed_function jaxed_function.defjvp(self.compute_jvp, symbolic_zeros=True) @staticmethod @debug_logger def wrap_callback(qjit_function, *args, **kwargs): """Wrap a QJIT function inside a jax host callback.""" data = jax.pure_callback( qjit_function, qjit_function.jaxpr.out_avals, *args, vectorized=False, **kwargs ) # Unflatten the return value w.r.t. the original PyTree definition if available assert qjit_function.out_treedef is not None, "PyTree shape must not be none." return tree_unflatten(qjit_function.out_treedef, data) @debug_logger def get_derivative_qjit(self, argnums): """Compile a function computing the derivative of the wrapped QJIT for the given argnums.""" argnum_key = "".join(str(idx) for idx in argnums) if argnum_key in self.derivative_functions: return self.derivative_functions[argnum_key] # Here we define the signature for the new QJIT object explicitly, rather than relying on # functools.wrap, in order to guarantee compilation is triggered on instantiation. # The signature of the original QJIT object is guaranteed to be defined by now, located # in QJIT.c_sig, however we don't update the original function with these annotations. annotations = {} updated_params = [] signature = inspect.signature(self.qjit_function) for idx, (arg_name, param) in enumerate(signature.parameters.items()): annotations[arg_name] = self.qjit_function.c_sig[idx] updated_params.append(param.replace(annotation=annotations[arg_name])) def deriv_wrapper(*args, **kwargs): return catalyst.jacobian(self.qjit_function, argnums=argnums)(*args, **kwargs) deriv_wrapper.__name__ = "deriv_" + self.qjit_function.__name__ deriv_wrapper.__annotations__ = annotations deriv_wrapper.__signature__ = signature.replace(parameters=updated_params) self.derivative_functions[argnum_key] = QJIT( deriv_wrapper, self.qjit_function.compile_options ) return self.derivative_functions[argnum_key] @debug_logger def compute_jvp(self, primals, tangents): """Compute the set of results and JVPs for a QJIT function.""" # Assume we have primals of shape `[a,b]` and results of shape `[c,d]`. Derivatives [2] # would get the shape `[c,d,a,b]` and tangents [1] would have the same shape as primals. # Now, In this function we apply tensordot using the pattern `[c,d,a,b]*[a,b] -> [c,d]`. # Optimization: Do not compute Jacobians for arguments which do not participate in # differentiation. argnums = [] for idx, tangent in enumerate(tangents): if not isinstance(tangent, jax.custom_derivatives.SymbolicZero): argnums.append(idx) results = self.wrap_callback(self.qjit_function, *primals) results_data, _results_shape = tree_flatten(results) derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals) derivatives_data, _derivatives_shape = tree_flatten(derivatives) jvps = [jnp.zeros_like(results_data[res_idx]) for res_idx in range(len(results_data))] for diff_arg_idx, arg_idx in enumerate(argnums): tangent = tangents[arg_idx] # [1] taxis = list(range(tangent.ndim)) for res_idx in range(len(results_data)): deriv_idx = diff_arg_idx + res_idx * len(argnums) deriv = derivatives_data[deriv_idx] # [2] daxis = list(range(deriv.ndim - tangent.ndim, deriv.ndim)) jvp = jnp.tensordot(deriv, tangent, axes=(daxis, taxis)) jvps[res_idx] = jvps[res_idx] + jvp # jvps must match the type of primals # due to pytrees, primals are a tuple primal_type = type(primals) jvps = primal_type(jvps) if len(jvps) == 1: jvps = jvps[0] return results, jvps @debug_logger def __call__(self, *args, **kwargs): return self.jaxed_function(*args, **kwargs)