Source code for pennylane.transforms.core.transform_dispatcher

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module defines the data structure that encapsulates a quantum transform.
"""
from __future__ import annotations

import os
import warnings
from collections.abc import Callable, Sequence
from copy import copy
from functools import lru_cache, partial, singledispatch, update_wrapper, wraps

from pennylane import capture, math
from pennylane.capture import autograph
from pennylane.exceptions import TransformError
from pennylane.measurements import MeasurementProcess
from pennylane.operation import Operator
from pennylane.pytrees import flatten
from pennylane.queuing import AnnotatedQueue, QueuingManager, apply
from pennylane.tape import QuantumScript
from pennylane.typing import ResultBatch


@lru_cache
def _create_transform_primitive():
    try:
        # pylint: disable=import-outside-toplevel
        from pennylane.capture.custom_primitives import QmlPrimitive
    except ImportError:
        return None

    transform_prim = QmlPrimitive("transform")
    transform_prim.multiple_results = True
    transform_prim.prim_type = "transform"

    # pylint: disable=too-many-arguments, disable=unused-argument
    @transform_prim.def_impl
    def _impl(*all_args, inner_jaxpr, args_slice, consts_slice, **_):
        args = all_args[slice(*args_slice)]
        consts = all_args[slice(*consts_slice)]
        return capture.eval_jaxpr(inner_jaxpr, consts, *args)

    @transform_prim.def_abstract_eval
    def _abstract_eval(*_, inner_jaxpr, **__):
        return [out.aval for out in inner_jaxpr.outvars]

    return transform_prim


def _create_plxpr_fallback_transform(tape_transform):
    # pylint: disable=import-outside-toplevel
    try:
        import jax

        from pennylane.tape import plxpr_to_tape
    except ImportError:
        return None

    def plxpr_fallback_transform(jaxpr, consts, targs, tkwargs, *args):
        # Restore tkwargs from hashable tuple to dict
        tkwargs = dict(tkwargs)

        def wrapper(*inner_args):
            tape = plxpr_to_tape(jaxpr, consts, *inner_args)
            with capture.pause():
                tapes, _ = tape_transform(tape, *targs, **tkwargs)

            if len(tapes) > 1:
                raise TransformError(
                    f"Cannot apply {tape_transform.__name__} transform with program "
                    "capture enabled. Only transforms that return a single QuantumTape "
                    "and null processing function are usable with program capture."
                )

            for op in tapes[0].operations:
                data, struct = jax.tree_util.tree_flatten(op)
                jax.tree_util.tree_unflatten(struct, data)

            out = []
            for mp in tapes[0].measurements:
                data, struct = jax.tree_util.tree_flatten(mp)
                out.append(jax.tree_util.tree_unflatten(struct, data))

            return tuple(out)

        abstracted_axes, abstract_shapes = capture.determine_abstracted_axes(args)
        return jax.make_jaxpr(wrapper, abstracted_axes=abstracted_axes)(*abstract_shapes, *args)

    return plxpr_fallback_transform


def specific_apply_transform(transform, obj, *targs, **tkwargs):
    """The default behavior for Transform._apply_transform. By default, it dispatches to the
    generic registration."""
    return transform.generic_apply_transform(obj, *targs, **tkwargs)


@singledispatch
def generic_apply_transform(obj, transform, *targs, **tkwargs):
    """Apply an generic transform to a specific type of object. A singledispatch function
    used by ``TransformDipsatcher.generic_apply_transform``, but with a different order of arguments
    to allow is to be used by singledispatch.

    When called with an object that is not a valid dispatch target (e.g., not a QNode, tape, etc.),
    this returns a BoundTransform with the supplied args and kwargs. This enables patterns like:

        decompose(gate_set=gate_set) + merge_rotations(1e-6)

    where transforms are called with just configuration parameters and combined into a CompilePipeline
    """
    # If the first argument is not a valid dispatch target, return a BoundTransform
    # with the first argument and any additional args/kwargs stored as transform parameters.
    return BoundTransform(transform, args=(obj, *targs), kwargs=tkwargs)


# pragma: no cover
def _dummy_register(obj):  # just used for sphinx
    if isinstance(obj, type):  # pragma: no cover
        return lambda arg: arg  # pragma: no cover
    return obj  # pragma: no cover


[docs] class Transform: # pylint: disable=too-many-instance-attributes r"""Generalizes a function that transforms tapes to work with additional circuit-like objects such as a :class:`~.QNode`. ``transform`` should be applied to a function that transforms tapes. Once validated, the result will be an object that is able to transform PennyLane's range of circuit-like objects: :class:`~.QuantumTape`, quantum function and :class:`~.QNode`. A circuit-like object can be transformed either via decoration or by passing it functionally through the created transform. Args: tape_transform (Callable | None): The input quantum transform must be a function that satisfies the following requirements: * Accepts a :class:`~.QuantumScript` as its first input and returns a sequence of :class:`~.QuantumScript` and a processing function. * The transform must have the following structure (type hinting is optional): ``my_tape_transform(tape: qml.tape.QuantumScript, ...) -> tuple[qml.tape.QuantumScriptBatch, qml.typing.PostprocessingFn]`` pass_name (str | None): the name of the associated MLIR pass to be applied when Catalyst is used. See Usage Details for more information. Keyword Args: expand_transform=None (Optional[Callable]): An optional transform that is applied directly before the input transform. It must be a function that satisfies the same requirements as ``tape_transform``. classical_cotransform=None (Optional[Callable]): A classical co-transform is a function to post-process the classical jacobian and the quantum jacobian and has the signature: ``my_cotransform(qjac, cjac, tape) -> tensor_like`` is_informative=False (bool): Whether or not a transform is informative. If true, the transform is queued at the end of the compile pipeline and the tapes or qnode aren't executed. final_transform=False (bool): Whether or not the transform is terminal. If true, the transform is queued at the end of the compile pipeline. ``is_informative`` supersedes ``final_transform``. use_argnum_in_expand=False (bool): Whether to use ``argnum`` of the tape to determine trainable parameters during the expansion transform process. plxpr_transform=None (Optional[Callable]): Function for transforming plxpr. **Experimental** **Example** First define an input tape transform with the necessary structure defined above. In this example, we copy the tape and sum the results of the execution of the two tapes. .. code-block:: python from pennylane.tape import QuantumScript, QuantumScriptBatch from pennylane.typing import PostprocessingFn def my_quantum_transform(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]: tape1 = tape tape2 = tape.copy() def post_processing_fn(results): return qml.math.sum(results) return [tape1, tape2], post_processing_fn We want to be able to apply this transform on both a ``qfunc`` and a :class:`pennylane.QNode` and will use ``transform`` to achieve this. ``transform`` validates the signature of your input quantum transform and makes it capable of transforming ``qfunc`` and :class:`pennylane.QNode` in addition to quantum tapes. Let's define a circuit as a :class:`pennylane.QNode`: .. code-block:: python dev = qml.device("default.qubit") @qml.qnode(device=dev) def qnode_circuit(a): qml.Hadamard(wires=0) qml.CNOT(wires=[0, 1]) qml.X(0) qml.RZ(a, wires=1) return qml.expval(qml.Z(0)) We first apply ``transform`` to ``my_quantum_transform``: >>> dispatched_transform = qml.transform(my_quantum_transform) Now you can use the dispatched transform directly on a :class:`pennylane.QNode`. For :class:`pennylane.QNode`, the dispatched transform populates the ``CompilePipeline`` of your QNode. The transform and its processing function are applied in the execution. >>> transformed_qnode = dispatched_transform(qnode_circuit) >>> transformed_qnode <QNode: device='<default.qubit device at ...>', interface='auto', diff_method='best', shots='Shots(total=None)'> >>> transformed_qnode.transform_program CompilePipeline(my_quantum_transform) If we apply ``dispatched_transform`` a second time to the :class:`pennylane.QNode`, we would add it to the compile pipeline again and therefore the transform would be applied twice before execution. >>> transformed_qnode = dispatched_transform(transformed_qnode) >>> transformed_qnode.transform_program CompilePipeline(my_quantum_transform, my_quantum_transform) When a transformed QNode is executed, the QNode's compile pipeline is applied to the generated tape and creates a sequence of tapes to be executed. The execution results are then post-processed in the reverse order of the compile pipeline to obtain the final results. .. details:: :title: Dispatch a transform onto a batch of tapes We can compose multiple transforms when working in the tape paradigm and apply them to more than one tape. The following example demonstrates how to apply a transform to a batch of tapes. **Example** In this example, we apply sequentially a transform to a tape and another one to a batch of tapes. We then execute the transformed tapes on a device and post-process the results. .. code-block:: python import pennylane as qml H = qml.PauliY(2) @ qml.PauliZ(1) + 0.5 * qml.PauliZ(2) + qml.PauliZ(1) measurement = [qml.expval(H)] operations = [qml.Hadamard(0), qml.RX(0.2, 0), qml.RX(0.6, 0), qml.CNOT((0, 1))] tape = qml.tape.QuantumTape(operations, measurement) batch1, function1 = qml.transforms.split_non_commuting(tape) batch2, function2 = qml.transforms.merge_rotations(batch1) dev = qml.device("default.qubit", wires=3) result = dev.execute(batch2) The first ``split_non_commuting`` transform splits the original tape, returning a batch of tapes ``batch1`` and a processing function ``function1``. The second ``merge_rotations`` transform is applied to the batch of tapes returned by the first transform. It returns a new batch of tapes ``batch2``, each of which has been transformed by the second transform, and a processing function ``function2``. >>> batch2 (<QuantumTape: wires=[0, 1, 2], params=1>, <QuantumTape: wires=[0, 1, 2], params=1>) >>> type(function2) <class 'function'> We can combine the processing functions to post-process the results of the execution. >>> function1(function2(result)) np.float64(0.499...) .. details:: :title: Signature of a transform A dispatched transform is able to handle several PennyLane circuit-like objects: - :class:`pennylane.QNode` - a quantum function (callable) - :class:`pennylane.tape.QuantumScript` - a batch of :class:`pennylane.tape.QuantumScript` - :class:`pennylane.devices.Device`. For each object, the transform will be applied in a different way, but it always preserves the underlying tape-based quantum transform behaviour. The return of a dispatched transform depends upon which of the above objects is passed as an input: - For a :class:`~.QNode` input, the underlying transform is added to the QNode's :class:`~.CompilePipeline` and the return is the transformed :class:`~.QNode`. For each execution of the :class:`pennylane.QNode`, it first applies the compile pipeline on the original captured circuit. Then the transformed circuits are executed by a device and finally the post-processing function is applied on the results. When experimental program capture is enabled, transforming a :class:`~.QNode` returns a new function to which the transform has been added as a higher-order primitive. - For a quantum function (callable) input, the transform builds the tape when the quantum function is executed and then applies itself to the tape. The resulting tape is then converted back to a quantum function (callable). It therefore returns a transformed quantum function (Callable). The limitation is that the underlying transform can only return a sequence containing a single tape, because quantum functions only support a single circuit. When experimental program capture is enabled, transforming a function (callable) returns a new function to which the transform has been added as a higher-order primitive. - For a :class:`~.QuantumScript, the underlying quantum transform is directly applied on the :class:`~.QuantumScript`. It returns a sequence of :class:`~.QuantumScript` and a processing function to be applied after execution. - For a batch of :class:`pennylane.tape.QuantumScript`, the quantum transform is mapped across all the tapes. It returns a sequence of :class:`~.QuantumScript` and a processing function to be applied after execution. Each tape in the sequence is transformed by the transform. - For a :class:`~.devices.Device`, the transform is added to the device's compile pipeline and a transformed :class:`pennylane.devices.Device` is returned. The transform is added to the end of the device program and will be last in the overall compile pipeline. .. details:: :title: Transforms with Catalyst If a compilation pass is written in MLIR, using it in a ``qjit``'d workflow requires that it have a transform with a matching ``pass_name``. This ensures that the transform is properly applied as part of the lower-level compilation. For example, we can create a transform that will apply the ``cancel-inverses`` pass, like the in-built ``qml.transforms.cancel_inverses`` transform. .. code-block:: python my_transform = qml.transform(pass_name="cancel-inverses") @qml.qjit @my_transform @qml.qnode(qml.device('lightning.qubit', wires=4)) def circuit(): qml.X(0) qml.X(0) return qml.expval(qml.Z(0)) We can see that the instruction to apply ``"cancel-inverses"`` is present in the initial MLIR. >>> circuit() Array(1., dtype=float64) >>> print(circuit.mlir[200:600]) tensor<f64> } module @module_circuit { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { %0 = transform.apply_registered_pass "cancel-inverses" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> transform.yield } } func.func public @circui Transforms can have both tape-based and ``pass_name``-based definitions. For example, the transform below called ``my_transform`` has both definitions. In this case, the MLIR pass will take precedence when being ``qjit``'d if only MLIR passes can occur after. .. code-block:: python from functools import partial @partial(qml.transform, pass_name="my-pass-name") def my_transform(tape): return (tape, ), lambda res: res[0] Note that any transform with only a ``pass_name`` definition *must* occur after any purely tape-based transform, as tape transforms occur prior to lowering to MLIR. >>> @qml.qjit ... @qml.defer_measurements ... @qml.transform(pass_name="cancel-inverses") ... @qml.qnode(qml.device('lightning.qubit', wires=4)) ... def c(): ... qml.X(0) ... qml.X(0) ... return qml.expval(qml.Z(0)) ... Traceback (most recent call last): ... ValueError: <cancel-inverses((), {})> without a tape definition occurs before tape transform <defer_measurements((), {})>. .. details:: :title: Transforms with experimental program capture To define a transform that can be applied directly to plxpr without the need to create ``QuantumScript``\ s, users must provide the ``plxpr_transform`` argument. If this argument is not provided, executing transformed functions is not guaranteed to work. More details about this are provided below. The ``plxpr_transform`` argument should be a function that applies the respective transform to ``jax.extend.core.Jaxpr`` and returns a transformed ``jax.extend.core.ClosedJaxpr``. ``plxpr_transform`` can assume that no transform primitives are present in the input plxpr, and its implementation does not need to account for these primitives. The exact expected signature of ``plxpr_transform`` is shown in the example below: .. code-block:: python def dummy_plxpr_transform( jaxpr: jax.extend.core.Jaxpr, consts: list, targs: list, tkwargs: dict, *args ) -> jax.extend.core.ClosedJaxpr: ... Once the ``plxpr_transform`` argument is provided, the transform can be easily used with program capture enabled! To do so, apply the transform as you normally would: .. code-block:: python qml.capture.enable() @qml.transforms.cancel_inverses def circuit(): qml.X(0) qml.S(1) qml.X(0) qml.adjoint(qml.S(1)) return qml.expval(qml.Z(1)) >>> jax.make_jaxpr(circuit)() { lambda ; . let a:AbstractMeasurement(n_wires=None) = transform[ args_slice=(0, 0, None) consts_slice=(0, 0, None) inner_jaxpr={ lambda ; . let _:AbstractOperator() = PauliX[n_wires=1] 0:i...[] _:AbstractOperator() = S[n_wires=1] 1:i...[] _:AbstractOperator() = PauliX[n_wires=1] 0:i...[] b:AbstractOperator() = S[n_wires=1] 1:i...[] _:AbstractOperator() = Adjoint b c:AbstractOperator() = PauliZ[n_wires=1] 1:i...[] d:AbstractMeasurement(n_wires=None) = expval_obs c in (d,) } targs_slice=(0, None, None) tkwargs=() transform=<transform: cancel_inverses> ] in (a,) } As shown, the transform gets applied as a higher-order primitive, with the jaxpr representation of the function being transformed stored in the ``inner_jaxpr`` parameter of the transform's primitive. **Fallback implementation of plxpr transforms:** If a transform that does not define a ``plxpr_transform`` is applied to a function, a fallback implementation of the transform is used. This fallback implementation converts the function into a :func:`~pennylane.tape.QuantumScript`, which is then transformed as a traditional tape. However, because of the constraints of program capture, many transforms will not be compatible with this fallback implementation: * Transforms that return multiple tapes are not compatible. * Transforms that require non-trivial post-processing of results are not compatible. * Dynamically shaped arrays are not compatible. * Functions that are being transformed that contain control flow dependent on dynamic parameters are not compatible. This includes: * :func:`pennylane.cond` with dynamic parameters as predicates. * :func:`pennylane.for_loop` with dynamic parameters for ``start``, ``stop``, or ``step``. * :func:`pennylane.while_loop` does not work. .. warning:: Currently, executing a function to which a transform has been applied will raise a ``NotImplementedError``. See below for details on how to use functions that are transformed. To perform the transform, the :func:`pennylane.capture.expand_plxpr_transforms` function should be used. This function accepts a function to which transforms have been applied as an input, and returns a new function that has been transformed: >>> transformed_circuit = qml.capture.expand_plxpr_transforms(circuit) >>> jax.make_jaxpr(transformed_circuit)() { lambda ; . let a:AbstractOperator() = PauliZ[n_wires=1] 1:i...[] b:AbstractMeasurement(n_wires=None) = expval_obs a in (b,) } """ def __new__( # pylint: disable=too-many-arguments cls, tape_transform: Callable | None = None, pass_name: None | str = None, *, expand_transform: Callable | None = None, classical_cotransform: Callable | None = None, is_informative: bool = False, final_transform: bool = False, use_argnum_in_expand: bool = False, plxpr_transform=None, ) -> Transform: if os.environ.get("SPHINX_BUILD") == "1": # If called during a Sphinx documentation build, # simply return the original function rather than # instantiating the object. This allows the signature to # be correctly displayed in the documentation. warnings.warn( "Transforms have been disabled, as a Sphinx " "build has been detected via SPHINX_BUILD='1'. If this is not the " "case, please set the environment variable SPHINX_BUILD='0'.", UserWarning, ) tape_transform.custom_qnode_transform = lambda x: x tape_transform.register = _dummy_register return tape_transform return super().__new__(cls) # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, tape_transform: Callable | None = None, pass_name: None | str = None, *, expand_transform: Callable | None = None, classical_cotransform: Callable | None = None, is_informative: bool = False, final_transform: bool = False, use_argnum_in_expand: bool = False, plxpr_transform=None, ): if tape_transform is not None and not callable(tape_transform): raise TransformError( f"The function to register, {tape_transform}, does " "not appear to be a valid Python function or callable." ) if expand_transform is not None and not callable(expand_transform): raise TransformError("The expand function must be a valid Python function.") if classical_cotransform is not None and not callable(classical_cotransform): raise TransformError("The classical co-transform must be a valid Python function.") if tape_transform is None and pass_name is None: raise ValueError("Transforms must define either a tape transform or a pass_name") self._tape_transform = tape_transform self._expand_transform = expand_transform self._classical_cotransform = classical_cotransform self._is_informative = is_informative # is_informative supersedes is_final_transform self._is_final_transform = is_informative or final_transform self._custom_qnode_transform = None self._pass_name = pass_name self._use_argnum_in_expand = use_argnum_in_expand if tape_transform: update_wrapper(self, tape_transform) self._apply_transform = singledispatch(partial(specific_apply_transform, self)) self._plxpr_transform = plxpr_transform or _create_plxpr_fallback_transform(tape_transform) @property def pass_name(self) -> None | str: """The name of the equivalent MLIR pass.""" return self._pass_name @property def register(self): """Returns a decorator for registering a specific application behavior for a given transform and a new class. .. code-block:: python @qml.transform def printer(tape): print("I have a tape: ", tape) return (tape, ), lambda x: x[0] @printer.register def _(obj: qml.operation.Operator, *targs, **tkwargs): print("I have an operator:", obj) return obj >>> printer(qml.X(0)) I have an operator: X(0) X(0) """ return self._apply_transform.register
[docs] def generic_apply_transform(self, obj, *targs, **tkwargs): """generic_apply_transform(obj, *targs, **tkwargs) Generic application of a transform that forms the default for all transforms. Args: obj: The object we want to transform *targs: The arguments for the transform **tkwargs: The keyword arguments for the transform. """ return generic_apply_transform(obj, self, *targs, **tkwargs)
[docs] @staticmethod def generic_register(arg): """Returns a decorator for registering a default application behavior for a transform for a new class. Given a special new class, we can register how transforms should apply to them via: .. code-block:: python class Subroutine: def __repr__(self): return f"<Subroutine: {self.ops}>" def __init__(self, ops): self.ops = ops from pennylane.transforms.core import Transform @Transform.generic_register def apply_to_subroutine(obj: Subroutine, transform, *targs, **tkwargs): tape = qml.tape.QuantumScript(obj.ops) batch, _ = transform(tape, *targs, **tkwargs) return Subroutine(batch[0].operations) >>> qml.transforms.cancel_inverses(Subroutine([qml.Y(0), qml.X(0), qml.X(0)])) <Subroutine: [Y(0)]> The type can also be explicitly provided like: .. code-block:: python @Transform.generic_register(Subroutine) def apply_to_subroutine(obj: Subroutine, transform, *targs, **tkwargs): tape = qml.tape.QuantumScript(obj.ops) batch, _ = transform(tape, *targs, **tkwargs) return Subroutine(batch[0].operations) to more explicitly force registration for a given type. """ return generic_apply_transform.register(arg)
def __call__(self, *args, **kwargs): if not args and not kwargs: raise TypeError( f"{self!r} requires at least one argument. " "Provide a tape, qfunc, QNode, or device to transform, " "or provide keyword arguments to create a BoundTransform for composition." ) if not args and kwargs: return BoundTransform(self, kwargs=kwargs) return self._apply_transform(*args, **kwargs) def __repr__(self): name = self._tape_transform.__name__ if self._tape_transform else self.pass_name return f"<transform: {name}>" def __add__(self, other): """Add two transforms to create a CompilePipeline.""" if not isinstance(other, (Transform, BoundTransform)): return NotImplemented # Technically this is checked in the CompilePipeline dunders but we still # do it here to raise a more informative error message. if self.is_final_transform and other.is_final_transform: raise TransformError( f"Both {self} and {other} are final transforms and cannot be combined." ) if self.expand_transform: # pylint: disable=import-outside-toplevel from .compile_pipeline import CompilePipeline return CompilePipeline(self, other) # Convert this transform to a BoundTransform (no args/kwargs) and delegate return BoundTransform(self) + other def __mul__(self, n): """Multiply by an integer to create a compile pipeline with this transform repeated.""" if self.expand_transform: # pylint: disable=import-outside-toplevel from .compile_pipeline import CompilePipeline return CompilePipeline(self) * n # Convert to container (no args/kwargs) and delegate return BoundTransform(self) * n __rmul__ = __mul__ @property def tape_transform(self): """The tape transform.""" return self._tape_transform @property def expand_transform(self): """The expand transform.""" return self._expand_transform @property def classical_cotransform(self): """The classical co-transform.""" return self._classical_cotransform @property def plxpr_transform(self): """Function for transforming plxpr.""" return self._plxpr_transform @property def is_informative(self): """``True`` if the transform is informative.""" return self._is_informative @property def is_final_transform(self): """``True`` if the transformed tapes must be executed.""" return self._is_final_transform
[docs] def custom_qnode_transform(self, fn): """Register a custom QNode execution wrapper function for the batch transform. **Example** .. code-block:: python3 @transform def my_transform(tape, *targs, **tkwargs): ... return tapes, processing_fn @my_transform.custom_qnode_transform def my_custom_qnode_wrapper(self, qnode, targs, tkwargs): new_tkwargs = dict(tkwargs) new_tkwargs['shots'] = 100 return self.generic_apply_transform(qnode, *targs, **new_tkwargs) The custom QNode execution wrapper must have arguments ``self`` (the batch transform object), ``qnode`` (the input QNode to transform and execute), ``targs`` and ``tkwargs`` (the transform arguments and keyword arguments respectively). It should return a QNode that accepts the *same* arguments as the input QNode with the transform applied. The default :meth:`~.generic_apply_transform` method may be called if only pre- or post-processing dependent on QNode arguments is required. """ # unfortunately, we don't have access to qml.QNode here, or in the places where # transforms are defining custom qnode transforms, so we still need to have this # "hold onto until later" approach # potentially can remove this patch by moving source code self._custom_qnode_transform = fn
[docs] def default_qnode_transform(self, qnode, targs, tkwargs): """ The default method that takes in a QNode and returns another QNode with the transform applied. """ # same comment as custom_qnode_transform :( qnode = copy(qnode) qnode.transform_program.append(BoundTransform(self, args=targs, kwargs=tkwargs)) return qnode
[docs] class BoundTransform: # pylint: disable=too-many-instance-attributes """A transform with bound inputs. Args: transform: Any transform. args (Sequence[Any]): The positional arguments to use with the transform. kwargs (Dict | None): The keyword arguments for use with the transform. Keyword Args: use_argnum (bool): An advanced option used in conjunction with calculating classical cotransforms of jax workflows. .. seealso:: :func:`~.pennylane.transform` >>> bound_t = BoundTransform(qml.transforms.merge_rotations, (), {"atol": 1e-4}) >>> bound_t <merge_rotations((), {'atol': 0.0001})> The class can also be created by directly calling the transform with its inputs: >>> qml.transforms.merge_rotations(atol=1e-4) <merge_rotations((), {'atol': 0.0001})> These objects can now directly applied to anything individual transforms can apply to: .. code-block:: python @bound_t @qml.qnode(qml.device('null.qubit', wires=2)) def c(x): qml.RX(x, 0) qml.RX(-x + 1e-6, 0) qml.RY(x, 1) qml.RY(-x + 1e-2, 1) return qml.probs(wires=(0,1)) If we draw this circuit, we can see that the ``merge_rotations`` transforms was applied with a tolerance of ``1e-4``. The ``RX`` gates sufficiently close to zero disappear, while the ``RY`` gates that are further from zero remain. >>> print(qml.draw(c)(1.0)) 0: ───────────┤ ╭Probs 1: ──RY(0.01)─┤ ╰Probs Repeated versions of the bound transform can be created with multiplication: >>> bound_t * 3 CompilePipeline(merge_rotations, merge_rotations, merge_rotations) And it can be used in conjunction with both individual transforms, bound transforms, and compile pipelines. >>> bound_t + qml.transforms.cancel_inverses CompilePipeline(merge_rotations, cancel_inverses) >>> bound_t + qml.transforms.cancel_inverses + bound_t CompilePipeline(merge_rotations, cancel_inverses, merge_rotations) """ def __hash__(self): hashable_dict = tuple((key, value) for key, value in self.kwargs.items()) return hash((self.tape_transform, self.pass_name, self.args, hashable_dict)) def __init__( self, transform: Transform, args: tuple | list = (), kwargs: None | dict = None, *, use_argnum: bool = False, **transform_config, ): if not isinstance(transform, Transform): transform = Transform(transform, **transform_config) elif transform_config: raise ValueError( f"transform_config kwargs {transform_config} cannot be passed if a transform is provided." ) self._transform = transform self._args = tuple(args) self._kwargs = kwargs or {} self._use_argnum = use_argnum def __repr__(self): name = self.tape_transform.__name__ if self.tape_transform else self.pass_name return f"<{name}({self._args}, {self._kwargs})>" def __call__(self, obj): return self._transform(obj, *self.args, **self.kwargs) def __iter__(self): return iter( ( self._transform.tape_transform, self._args, self._kwargs, self._transform.classical_cotransform, self._transform.plxpr_transform, self._transform.is_informative, self._transform.is_final_transform, ) ) def __eq__(self, other: object) -> bool: if not isinstance(other, BoundTransform): return False return ( self.args == other.args and self.tape_transform == other.tape_transform and self.pass_name == other.pass_name and self.kwargs == other.kwargs and self.classical_cotransform == other.classical_cotransform and self.is_informative == other.is_informative and self.is_final_transform == other.is_final_transform ) @property def tape_transform(self) -> Callable | None: """The raw tape transform definition for the transform.""" return self._transform.tape_transform @property def transform(self) -> Callable | None: """The raw tape transform definition of the transform.""" # TODO: deprecate this in the next version return self.tape_transform @property def expand_transform(self) -> BoundTransform | None: """The expand_transform associated with this transform.""" if not self._transform.expand_transform: return None return BoundTransform( self._transform.expand_transform, args=self.args, kwargs=self.kwargs, use_argnum=self._transform._use_argnum_in_expand, # pylint:disable=protected-access ) @property def pass_name(self) -> None | str: """The name of the corresponding Catalyst pass, if it exists.""" return self._transform.pass_name @property def args(self) -> tuple: """The stored quantum transform's ``args``.""" return self._args @property def kwargs(self) -> dict: """The stored quantum transform's ``kwargs``.""" return self._kwargs @property def classical_cotransform(self) -> None | Callable: """The stored quantum transform's classical co-transform.""" return self._transform.classical_cotransform @property def plxpr_transform(self) -> None | Callable: """The stored quantum transform's PLxPR transform. **UNMAINTAINED AND EXPERIMENTAL** """ return self._transform.plxpr_transform @property def is_informative(self) -> bool: """Whether or not a transform is informative. If true the transform is queued at the end of the transform program and the tapes or qnode aren't executed. This property is rare, but used by such transforms as ``qml.transforms.commutation_dag``. """ return self._transform.is_informative @property def is_final_transform(self) -> bool: """Whether or not the transform must be the last one to be executed in a ``CompilePipeline``. This property is ``True`` for most gradient transforms. """ return self._transform.is_final_transform def __add__(self, other): """Add two transforms to create a CompilePipeline.""" if not isinstance(other, (Transform, BoundTransform)): return NotImplemented # Import here to avoid circular import # pylint: disable=import-outside-toplevel from .compile_pipeline import CompilePipeline if self.is_final_transform and other.is_final_transform: raise TransformError( f"Both {self} and {other} are final transforms and cannot be combined." ) return CompilePipeline(self, other) def __mul__(self, n): """Multiply by an integer to create a pipeline with this transform repeated.""" # Import here to avoid circular import from .compile_pipeline import CompilePipeline # pylint: disable=import-outside-toplevel if not isinstance(n, int): return NotImplemented if n < 0: raise ValueError("Cannot multiply transform container by negative integer") if self.is_final_transform and n > 1: raise TransformError( f"{self} is a final transform and cannot be applied more than once." ) return CompilePipeline(self) * n __rmul__ = __mul__
@Transform.generic_register def _apply_to_tape(obj: QuantumScript, transform, *targs, **tkwargs): if transform.tape_transform is None: raise NotImplementedError(f"transform {transform} has no defined tape transform.") if transform.expand_transform: expanded_tapes, expand_processing = transform.expand_transform(obj, *targs, **tkwargs) transformed_tapes = [] processing_and_slices = [] start = 0 for tape in expanded_tapes: intermediate_tapes, post_processing_fn = transform.tape_transform( tape, *targs, **tkwargs ) transformed_tapes.extend(intermediate_tapes) end = start + len(intermediate_tapes) processing_and_slices.append(tuple([post_processing_fn, slice(start, end)])) start = end def processing_fn(results): processed_results = [fn(results[slice]) for fn, slice in processing_and_slices] return expand_processing(processed_results) else: transformed_tapes, processing_fn = transform.tape_transform(obj, *targs, **tkwargs) if transform.is_informative: return processing_fn(transformed_tapes) return transformed_tapes, processing_fn def _capture_apply(obj, transform, *targs, **tkwargs): @autograph.wraps(obj) def qfunc_transformed(*args, **kwargs): import jax # pylint: disable=import-outside-toplevel flat_qfunc = capture.flatfn.FlatFn(obj) jaxpr = jax.make_jaxpr(flat_qfunc)(*args, **kwargs) flat_args = jax.tree_util.tree_leaves(args) n_args = len(flat_args) n_consts = len(jaxpr.consts) args_slice = slice(0, n_args) consts_slice = slice(n_args, n_args + n_consts) targs_slice = slice(n_args + n_consts, None) results = _create_transform_primitive().bind( # pylint: disable=protected-access *flat_args, *jaxpr.consts, *targs, inner_jaxpr=jaxpr.jaxpr, args_slice=args_slice, consts_slice=consts_slice, targs_slice=targs_slice, tkwargs=tkwargs, transform=transform, ) assert flat_qfunc.out_tree is not None return jax.tree_util.tree_unflatten(flat_qfunc.out_tree, results) return qfunc_transformed @Transform.generic_register def apply_to_callable(obj: Callable, transform, *targs, **tkwargs): """Apply a transform to a Callable object.""" if obj.__class__.__name__ == "QJIT": raise TransformError( "Functions that are wrapped / decorated with qjit cannot subsequently be" f" transformed with a PennyLane transform (attempted {transform})." f" For the desired affect, ensure that qjit is applied after {transform}." ) @wraps(obj) def qfunc_transformed(*args, **kwargs): if capture.enabled(): return _capture_apply(obj, transform, *targs, **tkwargs)(*args, **kwargs) # removes the argument to the qfuncs from the active queuing context. leaves, _ = flatten((args, kwargs), lambda obj: isinstance(obj, Operator)) for l in leaves: if isinstance(l, Operator): QueuingManager.remove(l) with AnnotatedQueue() as q: qfunc_output = obj(*args, **kwargs) tape = QuantumScript.from_queue(q) with QueuingManager.stop_recording(): if transform.is_informative: transformed_tapes, processing_fn = transform.tape_transform(tape, *targs, **tkwargs) else: transformed_tapes, processing_fn = transform(tape, *targs, **tkwargs) if len(transformed_tapes) != 1: raise TransformError( "Impossible to dispatch your transform on quantum function, because more than " "one tape is returned" ) transformed_tape = transformed_tapes[0] if transform.is_informative: return processing_fn(transformed_tapes) for op in transformed_tape.operations: apply(op) mps = [apply(mp) for mp in transformed_tape.measurements] if not mps: return qfunc_output if isinstance(qfunc_output, MeasurementProcess): return tuple(mps) if len(mps) > 1 else mps[0] if isinstance(qfunc_output, (tuple, list)): return type(qfunc_output)(mps) interface = math.get_interface(qfunc_output) return math.asarray(mps, like=interface) return qfunc_transformed @Transform.generic_register def _apply_to_sequence(obj: Sequence, transform, *targs, **tkwargs): if not all(isinstance(t, QuantumScript) for t in obj): raise TransformError( f"Transforms can only apply to sequences of QuantumScript, not {type(obj[0])}" ) execution_tapes = [] batch_fns = [] tape_counts = [] for t in obj: # Preprocess the tapes by applying transforms # to each tape, and storing corresponding tapes # for execution, processing functions, and list of tape lengths. new_tapes, fn = transform(t, *targs, **tkwargs) execution_tapes.extend(new_tapes) batch_fns.append(fn) tape_counts.append(len(new_tapes)) def processing_fn(res: ResultBatch) -> ResultBatch: """Applies a batch of post-processing functions to results. Args: res (ResultBatch): the results of executing a batch of circuits. Returns: ResultBatch: results that have undergone classical post processing. Closure variables: tape_counts: the number of tapes outputted from each application of the transform. batch_fns: the post processing functions to apply to each sub-batch. """ count = 0 final_results = [] for f, s in zip(batch_fns, tape_counts): # apply any batch transform post-processing new_res = f(res[count : count + s]) final_results.append(new_res) count += s return tuple(final_results) return tuple(execution_tapes), processing_fn TransformContainer = BoundTransform TransformDispatcher = Transform