"""Contains tools and decorators for registering batch transforms."""
# pylint: disable=too-few-public-methods
import copy
import functools
import inspect
import os
import types
import warnings

from typing import Callable, Tuple

import pennylane as qml
from pennylane.typing import ResultBatch

PostprocessingFn = Callable[[ResultBatch], ResultBatch]
QuantumTapeBatch = Tuple[qml.tape.QuantumScript]

[docs]class batch_transform: r"""Class for registering a tape transform that takes a tape, and outputs a batch of tapes to be independently executed on a quantum device. .. warning:: Use of ``batch_transform`` to create a custom transform is deprecated. Instead switch to using the new :func:`transform` function. Follow the instructions `here <>`_ for further details Examples of such transforms include quantum gradient shift rules (such as finite-differences and the parameter-shift rule) and metrics such as the quantum Fisher information matrix. Args: transform_fn (function): The function to register as the batch tape transform. It can have an arbitrary number of arguments, but the first argument **must** be the input tape. expand_fn (function): An expansion function (if required) to be applied to the input tape before the transformation takes place. It **must** take the same input arguments as ``transform_fn``. differentiable (bool): Specifies whether the transform is differentiable or not. A transform may be non-differentiable for several reasons: - It does not use an autodiff framework for its tensor manipulations; - It returns a non-differentiable or non-numeric quantity, such as a boolean, string, or integer. In such a case, setting ``differentiable=False`` instructs the decorator to mark the output as 'constant', reducing potential overhead. **Example** A valid batch tape transform is a function that satisfies the following: - The first argument must be a tape. - Depending on the structure of this input tape, various quantum operations, functions, and templates may be called. - Any internal classical processing should use the ``qml.math`` module to ensure the transform is differentiable. - The transform should return a tuple containing: * Multiple transformed tapes to be executed on a device. * A classical processing function for post-processing the executed tape results. This processing function should have the signature ``f(list[tensor_like]) → Any``. If ``None``, no classical processing is applied to the results. For example: .. code-block:: python @qml.batch_transform def my_transform(tape, a, b): '''Generates two tapes, one with all RX replaced with RY, and the other with all RX replaced with RZ.''' ops1 = [] ops2 = [] # loop through all operations on the input tape for op in tape.operations: if == "RX": wires = op.wires param = op.parameters[0] ops1.append(qml.RY(a * qml.math.abs(param), wires=wires)) ops2.append(qml.RZ(b * qml.math.abs(param), wires=wires)) else: ops1.append(op) ops2.append(op) tape1 = qml.tape.QuantumTape(ops1, tape.measurements) tape2 = qml.tape.QuantumTape(ops2, tape.measurements) def processing_fn(results): return qml.math.sum(qml.math.stack(results)) return [tape1, tape2], processing_fn We can apply this transform to a quantum tape: >>> ops = [qml.Hadamard(wires=0), qml.RX(-0.5, wires=0)] >>> tape = qml.tape.QuantumTape(ops, [qml.expval(qml.X(0))]) >>> tapes, fn = my_transform(tape, 0.65, 2.5) >>> print(qml.drawer.tape_text(tapes[0], decimals=2)) 0: ──H──RY(0.33)─┤ <X> >>> print(qml.drawer.tape_text(tapes[1], decimals=2)) 0: ──H──RZ(1.25)─┤ <X> We can execute these tapes manually: >>> dev = qml.device("default.qubit", wires=1) >>> res = qml.execute(tapes, dev, interface="autograd", gradient_fn=qml.gradients.param_shift) >>> print(res) [0.9476507264148154, 0.31532236239526856] Applying the processing function, we retrieve the end result of the transform: >>> print(fn(res)) 1.2629730888100839 Alternatively, we may also transform a QNode directly, using either decorator syntax: >>> @my_transform(0.65, 2.5) ... @qml.qnode(dev) ... def circuit(x): ... qml.Hadamard(wires=0) ... qml.RX(x, wires=0) ... return qml.expval(qml.X(0)) >>> print(circuit(-0.5)) 1.2629730888100839 or by transforming an existing QNode: >>> @qml.qnode(dev) ... def circuit(x): ... qml.Hadamard(wires=0) ... qml.RX(x, wires=0) ... return qml.expval(qml.X(0)) >>> circuit = my_transform(circuit, 0.65, 2.5) >>> print(circuit(-0.5)) 1.2629730888100839 Batch tape transforms are fully differentiable: >>> x = np.array(-0.5, requires_grad=True) >>> gradient = qml.grad(circuit)(x) >>> print(gradient) 2.5800122591960153 .. details:: :title: Usage Details **Expansion functions** Tape expansion, decomposition, or manipulation may always be performed within the custom batch transform. However, by specifying a separate expansion function, PennyLane will be possible to access this separate expansion function where needed via >>> my_transform.expand_fn The provided ``expand_fn`` must have the same input arguments as ``transform_fn`` and return a ``tape``. Following the example above: .. code-block:: python def expand_fn(tape, a, b): stopping_crit = lambda obj:!="PhaseShift" return tape.expand(depth=10, stop_at=stopping_crit) my_transform = batch_transform(my_transform, expand_fn) Note that: - the transform arguments ``a`` and ``b`` must be passed to the expansion function, and - the expansion function must return a single tape. """ def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument if os.environ.get("SPHINX_BUILD") == "1": # If called during a Sphinx documentation build, # simply return the original function rather than # instantiating the object. This allows the signature to # be correctly displayed in the documentation. warnings.warn( "Batch transformations have been disabled, as a Sphinx " "build has been detected via SPHINX_BUILD='1'. If this is not the " "case, please set the environment variable SPHINX_BUILD='0'.", UserWarning, ) args[0].custom_qnode_wrapper = lambda x: x return args[0] return super().__new__(cls) def __init__(self, transform_fn, expand_fn=None, differentiable=True): if not callable(transform_fn): raise ValueError( f"The batch transform function to register, {transform_fn}, " "does not appear to be a valid Python function or callable." ) warnings.warn( "Use of `batch_transform` to create a custom transform is deprecated. Instead " "switch to using the new qml.transform function. Follow the instructions here for " "further details:", qml.PennyLaneDeprecationWarning, ) self.transform_fn = transform_fn self.expand_fn = expand_fn self.differentiable = differentiable self.qnode_wrapper = self.default_qnode_wrapper functools.update_wrapper(self, transform_fn)
[docs] def custom_qnode_wrapper(self, fn): """Register a custom QNode execution wrapper function for the batch transform. **Example** .. code-block:: python def my_transform(tape, *targs, **tkwargs): ... return tapes, processing_fn @my_transform.custom_qnode_wrapper def my_custom_qnode_wrapper(self, qnode, targs, tkwargs): def wrapper_fn(*args, **kwargs): # construct QNode qnode.construct(args, kwargs) # apply transform to QNode's tapes tapes, processing_fn = self.construct(qnode.qtape, *targs, **tkwargs) # execute tapes and return processed result ... return processing_fn(results) return wrapper_fn The custom QNode execution wrapper must have arguments ``self`` (the batch transform object), ``qnode`` (the input QNode to transform and execute), ``targs`` and ``tkwargs`` (the transform arguments and keyword arguments respectively). It should return a callable object that accepts the *same* arguments as the QNode, and returns the transformed numerical result. The default :meth:`~.default_qnode_wrapper` method may be called if only pre- or post-processing dependent on QNode arguments is required: .. code-block:: python @my_transform.custom_qnode_wrapper def my_custom_qnode_wrapper(self, qnode, targs, tkwargs): transformed_qnode = self.default_qnode_wrapper(qnode) def wrapper_fn(*args, **kwargs): args, kwargs = pre_process(args, kwargs) res = transformed_qnode(*args, **kwargs) ... return ... return wrapper_fn """ self.qnode_wrapper = types.MethodType(fn, self)
[docs] def default_qnode_wrapper(self, qnode, targs, tkwargs): """A wrapper method that takes a QNode and transform arguments, and returns a function that 'wraps' the QNode execution. The returned function should accept the same keyword arguments as the QNode, and return the output of applying the tape transform to the QNode's constructed tape. """ transform_max_diff = tkwargs.pop("max_diff", None) if "shots" in inspect.signature(qnode.func).parameters: raise ValueError( "Detected 'shots' as an argument of the quantum function to transform. " "The 'shots' argument name is reserved for overriding the number of shots " "taken by the device." ) def _wrapper(*args, **kwargs): shots = kwargs.pop("shots", False) argnums = kwargs.pop("argnums", None) if argnums: tkwargs["argnums"] = argnums # pragma: no cover old_interface = qnode.interface if old_interface == "auto": qnode.interface = qml.math.get_interface(*args, *list(kwargs.values())) qnode.construct(args, kwargs) tapes, processing_fn = self.construct(qnode.qtape, *targs, **tkwargs) interface = qnode.interface execute_kwargs = getattr(qnode, "execute_kwargs", {}).copy() max_diff = execute_kwargs.pop("max_diff", 2) max_diff = transform_max_diff or max_diff gradient_fn = getattr(qnode, "gradient_fn", qnode.diff_method) gradient_kwargs = getattr(qnode, "gradient_kwargs", {}) if interface is None or not self.differentiable: gradient_fn = None if old_interface == "auto": qnode.interface = "auto" res = qml.execute( tapes, device=qnode.device, gradient_fn=gradient_fn, interface=interface, max_diff=max_diff, override_shots=shots, gradient_kwargs=gradient_kwargs, **execute_kwargs, ) return processing_fn(res) return _wrapper
def __call__(self, *targs, **tkwargs): qnode = None if targs: qnode, *targs = targs if isinstance(qnode, qml.Device): # Input is a quantum device. # dev = some_transform(dev, *transform_args) return self._device_wrapper(*targs, **tkwargs)(qnode) if isinstance(qnode, qml.tape.QuantumScript): # Input is a quantum tape. # tapes, fn = some_transform(tape, *transform_args) return self._tape_wrapper(*targs, **tkwargs)(qnode) if isinstance(qnode, qml.QNode): # Input is a QNode: # result = some_transform(qnode, *transform_args)(*qnode_args) wrapper = self.qnode_wrapper(qnode, targs, tkwargs) wrapper = functools.wraps(qnode)(wrapper) def _construct(args, kwargs): qnode.construct(args, kwargs) return self.construct(qnode.qtape, *targs, **tkwargs) wrapper.construct = _construct else: # Input is not a QNode nor a quantum tape nor a device. # Assume Python decorator syntax: # # result = some_transform(*transform_args)(qnode)(*qnode_args) # # or # # @some_transform(*transform_args) # @qml.qnode(dev) # def circuit(...): # ... # result = circuit(*qnode_args) # Prepend the input to the transform args, # and create a wrapper function. if qnode is not None: targs = (qnode,) + tuple(targs) def wrapper(qnode): if isinstance(qnode, qml.Device): return self._device_wrapper(*targs, **tkwargs)(qnode) if isinstance(qnode, qml.tape.QuantumScript): return self._tape_wrapper(*targs, **tkwargs)(qnode) _wrapper = self.qnode_wrapper(qnode, targs, tkwargs) _wrapper = functools.wraps(qnode)(_wrapper) def _construct(args, kwargs): qnode.construct(args, kwargs) return self.construct(qnode.qtape, *targs, **tkwargs) _wrapper.construct = _construct return _wrapper wrapper.tape_fn = functools.partial(self.transform_fn, *targs, **tkwargs) wrapper.expand_fn = self.expand_fn wrapper.differentiable = self.differentiable return wrapper
[docs] def construct(self, tape, *targs, **tkwargs): """Applies the batch tape transform to an input tape. Args: tape (.QuantumTape): the tape to be transformed *args: positional arguments to pass to the tape transform **kwargs: keyword arguments to pass to the tape transform Returns: tuple[list[tapes], callable]: list of transformed tapes to execute and a post-processing function. """ expand = tkwargs.pop("_expand", True) argnums = tkwargs.pop("argnums", None) if expand and self.expand_fn is not None: tape = self.expand_fn(tape, *targs, **tkwargs) if argnums is not None: tape.trainable_params = argnums # pragma: no cover tapes, processing_fn = self.transform_fn(tape, *targs, **tkwargs) if processing_fn is None: def processing_fn(x): return x return tapes, processing_fn
def _device_wrapper(self, *targs, **tkwargs): def _wrapper(dev): new_dev = copy.deepcopy(dev) new_dev.batch_transform = lambda tape: self.construct(tape, *targs, **tkwargs) return new_dev return _wrapper def _tape_wrapper(self, *targs, **tkwargs): return lambda tape: self.construct(tape, *targs, **tkwargs)
[docs]def map_batch_transform( transform: Callable, tapes: QuantumTapeBatch ) -> Tuple[QuantumTapeBatch, PostprocessingFn]: """Map a transform over multiple tapes. Args: transform (Callable): the transform to be mapped tapes (Sequence[QuantumTape]): The sequence of tapes the transform should be applied to. Each tape in the sequence is transformed by the transform. **Example** Consider the following tapes: .. code-block:: python H = qml.Z(0) @ qml.Z(1) - qml.X(0) ops1 = [ qml.RX(0.5, wires=0), qml.RY(0.1, wires=1), qml.CNOT(wires=(0,1)) ] measurements1 = [qml.expval(H)] tape1 = qml.tape.QuantumTape(ops1, measurements1) ops2 = [qml.Hadamard(0), qml.CRX(0.5, wires=(0,1)), qml.CNOT((0,1))] measurements2 = [qml.expval(H + 0.5 * qml.Y(0))] tape2 = qml.tape.QuantumTape(ops2, measurements2) We can use ``map_batch_transform`` to map a single transform across both of the these tapes in such a way that allows us to submit a single job for execution: >>> tapes, fn = map_batch_transform(qml.transforms.hamiltonian_expand, [tape1, tape2]) >>> dev = qml.device("default.qubit", wires=2) >>> fn(qml.execute(tapes, dev, qml.gradients.param_shift)) [array(0.99500417), array(0.8150893)] """ execution_tapes = [] batch_fns = [] tape_counts = [] for t in tapes: # Preprocess the tapes by applying transforms # to each tape, and storing corresponding tapes # for execution, processing functions, and list of tape lengths. new_tapes, fn = transform(t) execution_tapes.extend(new_tapes) batch_fns.append(fn) tape_counts.append(len(new_tapes)) def processing_fn(res: ResultBatch) -> ResultBatch: """Applies a batch of post-processing functions to results. Args: res (ResultBatch): the results of executing a batch of circuits Returns: ResultBatch : results that have undergone classical post processing Closure variables: tape_counts: the number of tapes outputted from each application of the transform batch_fns: the post processing functions to apply to each sub-batch """ count = 0 final_results = [] for idx, s in enumerate(tape_counts): # apply any transform post-processing new_res = batch_fns[idx](res[count : count + s]) final_results.append(new_res) count += s return final_results return execution_tapes, processing_fn