Source code for pennylane.transforms.core.transform_dispatcher

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the transform dispatcher and the transform container.
"""
import functools
import os
import copy
import warnings
import types

import pennylane as qml


class TransformError(Exception):
    """Raised when there is an error with the transform logic."""


[docs]class TransformDispatcher: r"""Converts a transform that has the signature ``(tape -> Sequence(tape), fn)`` to a transform dispatcher that can act on :class:`pennylane.tape.QuantumTape`, quantum function, :class:`pennylane.QNode`, :class:`pennylane.devices.Device`. .. warning:: This class is developer-facing and should not be used directly. Instead, use :func:`qml.transform <pennylane.transform>` if you would like to make a custom transform. .. seealso:: :func:`~.pennylane.transform` """ def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument if os.environ.get("SPHINX_BUILD") == "1": # If called during a Sphinx documentation build, # simply return the original function rather than # instantiating the object. This allows the signature to # be correctly displayed in the documentation. warnings.warn( "Transforms have been disabled, as a Sphinx " "build has been detected via SPHINX_BUILD='1'. If this is not the " "case, please set the environment variable SPHINX_BUILD='0'.", UserWarning, ) args[0].custom_qnode_transform = lambda x: x return args[0] return super().__new__(cls) # pylint: disable=too-many-arguments def __init__( self, transform, expand_transform=None, classical_cotransform=None, is_informative=False, final_transform=False, ): # pylint:disable=redefined-outer-name self._transform = transform self._expand_transform = expand_transform self._classical_cotransform = classical_cotransform self._is_informative = is_informative # is_informative supersedes final_transform self._final_transform = is_informative or final_transform self._qnode_transform = self.default_qnode_transform functools.update_wrapper(self, transform) def __call__(self, *targs, **tkwargs): # pylint: disable=too-many-return-statements obj = None if targs: # assume the first argument passed to the transform # is the object we wish to transform obj, *targs = targs if isinstance(obj, qml.tape.QuantumScript): if self._expand_transform: expanded_tapes, expand_processing = self._expand_transform(obj, *targs, **tkwargs) transformed_tapes = [] processing_and_sclices = [] start = 0 for tape in expanded_tapes: intermediate_tapes, post_processing_fn = self._transform( tape, *targs, **tkwargs ) transformed_tapes.extend(intermediate_tapes) end = start + len(intermediate_tapes) processing_and_sclices.append(tuple([post_processing_fn, slice(start, end)])) start = end def processing_fn(results): processed_results = [fn(results[slice]) for fn, slice in processing_and_sclices] return expand_processing(processed_results) else: transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs) if self.is_informative: return processing_fn(transformed_tapes) return transformed_tapes, processing_fn if isinstance(obj, qml.QNode): return self._qnode_transform(obj, targs, tkwargs) # TODO: Remove with the previous device generation if isinstance(obj, qml.Device): return self._old_device_transform(obj, targs, tkwargs) if isinstance(obj, qml.devices.Device): return self._device_transform(obj, targs, tkwargs) if callable(obj): return self._qfunc_transform(obj, targs, tkwargs) # Input is not a QNode nor a quantum tape nor a device. # Assume Python decorator syntax: # # result = some_transform(*transform_args)(qnode)(*qnode_args) warnings.warn( "Decorating a QNode with @transform_fn(**transform_kwargs) has been " "deprecated and will be removed in a future version. Please decorate " "with @functools.partial(transform_fn, **transform_kwargs) instead, " "or call the transform directly using qnode = transform_fn(qnode, " "**transform_kwargs). Visit the deprecations page for more details: " "https://docs.pennylane.ai/en/stable/development/deprecations.html", UserWarning, ) if obj is not None: targs = (obj, *targs) def wrapper(obj): return self(obj, *targs, **tkwargs) wrapper.__doc__ = ( f"Partial of transform {self._transform} with bound arguments and keyword arguments." ) return wrapper def __repr__(self): return f"<transform: {self._transform.__name__}>" @property def transform(self): """The quantum transform.""" return self._transform @property def expand_transform(self): """The expand transform.""" return self._expand_transform @property def classical_cotransform(self): """The classical co-transform.""" return self._classical_cotransform @property def is_informative(self): """``True`` if the transform is informative.""" return self._is_informative @property def final_transform(self): """``True`` if the transformed tapes must be executed.""" return self._final_transform
[docs] def custom_qnode_transform(self, fn): """Register a custom QNode execution wrapper function for the batch transform. **Example** .. code-block:: python @transform def my_transform(tape, *targs, **tkwargs): ... return tapes, processing_fn @my_transform.custom_qnode_transform def my_custom_qnode_wrapper(self, qnode, targs, tkwargs): tkwargs = {**tkwargs, shots=100} return self.default_qnode_transform(qnode, targs, tkwargs) The custom QNode execution wrapper must have arguments ``self`` (the batch transform object), ``qnode`` (the input QNode to transform and execute), ``targs`` and ``tkwargs`` (the transform arguments and keyword arguments respectively). It should return a QNode that accepts the *same* arguments as the input QNode with the transform applied. The default :meth:`~.default_qnode_transform` method may be called if only pre- or post-processing dependent on QNode arguments is required. """ self._qnode_transform = types.MethodType(fn, self)
[docs] def default_qnode_transform(self, qnode, targs, tkwargs): """ The default method that takes in a QNode and returns another QNode with the transform applied. """ qnode = copy.copy(qnode) if self.expand_transform: qnode.add_transform(TransformContainer(self._expand_transform, targs, tkwargs)) qnode.add_transform( TransformContainer( self._transform, targs, tkwargs, self._classical_cotransform, self._is_informative, self._final_transform, ) ) return qnode
def _qfunc_transform(self, qfunc, targs, tkwargs): """Apply the transform on a quantum function.""" def qfunc_transformed(*args, **kwargs): with qml.queuing.AnnotatedQueue() as q: qfunc_output = qfunc(*args, **kwargs) tape = qml.tape.QuantumScript.from_queue(q) transformed_tapes, processing_fn = self._transform(tape, *targs, **tkwargs) if len(transformed_tapes) != 1: raise TransformError( "Impossible to dispatch your transform on quantum function, because more than " "one tape is returned" ) transformed_tape = transformed_tapes[0] if self.is_informative: return processing_fn(transformed_tapes) for op in transformed_tape.circuit: qml.apply(op) mps = transformed_tape.measurements if not mps: return qfunc_output if isinstance(qfunc_output, qml.measurements.MeasurementProcess): return tuple(mps) if len(mps) > 1 else mps[0] if isinstance(qfunc_output, (tuple, list)): return type(qfunc_output)(mps) interface = qml.math.get_interface(qfunc_output) return qml.math.asarray(mps, like=interface) return qfunc_transformed def _old_device_transform(self, original_device, targs, tkwargs): """Apply the transform on a device""" if self._expand_transform: raise TransformError("Device transform does not support expand transforms.") if self._is_informative: raise TransformError("Device transform does not support informative transforms.") if self._final_transform: raise TransformError("Device transform does not support final transforms.") new_dev = copy.deepcopy(original_device) transform = self._transform @new_dev.custom_expand def new_expand_fn(self, tape, *args, **kwargs): # pylint: disable=unused-variable tapes, _ = transform(tape, *targs, **tkwargs) tape = tapes[0] return self.default_expand_fn(tape, *args, **kwargs) return new_dev def _device_transform(self, original_device, targs, tkwargs): """Apply the transform on a device""" if self._expand_transform: raise TransformError("Device transform does not support expand transforms.") if self._is_informative: raise TransformError("Device transform does not support informative transforms.") if self._final_transform: raise TransformError("Device transform does not support final transforms.") class TransformedDevice(type(original_device)): """A transformed device with updated preprocess method.""" def __init__(self, original_device, transform): for key, value in original_device.__dict__.items(): self.__setattr__(key, value) self.transform = transform self._original_device = original_device def __repr__(self): return f"Transformed Device({original_device.__repr__()} with additional preprocess transform {self.transform})" def preprocess( self, execution_config: qml.devices.ExecutionConfig = qml.devices.DefaultExecutionConfig, ): """This function updates the original device transform program to be applied.""" program, config = self.original_device.preprocess(execution_config) program.push_back(TransformContainer(self.transform, targs, tkwargs)) return program, config @property def original_device(self): """Return the original device.""" return self._original_device return TransformedDevice(original_device, self._transform)
[docs]class TransformContainer: """Class to store a quantum transform with its ``args``, ``kwargs`` and classical co-transforms. Use :func:`~.pennylane.transform`. .. warning:: This class is developer-facing and should not be used directly. Instead, use :func:`qml.transform <pennylane.transform>` if you would like to make a custom transform. .. seealso:: :func:`~.pennylane.transform` """ def __init__( self, transform, args=None, kwargs=None, classical_cotransform=None, is_informative=False, final_transform=False, ): # pylint:disable=redefined-outer-name,too-many-arguments self._transform = transform self._args = args or [] self._kwargs = kwargs or {} self._classical_cotransform = classical_cotransform self._is_informative = is_informative self._final_transform = is_informative or final_transform def __iter__(self): return iter( ( self._transform, self._args, self._kwargs, self._classical_cotransform, self._is_informative, self.final_transform, ) ) @property def transform(self): """The stored quantum transform.""" return self._transform @property def args(self): """The stored quantum transform's ``args``.""" return self._args @property def kwargs(self): """The stored quantum transform's ``kwargs``.""" return self._kwargs @property def classical_cotransform(self): """The stored quantum transform's classical co-transform.""" return self._classical_cotransform @property def is_informative(self): """``True`` if the transform is informative.""" return self._is_informative @property def final_transform(self): """``True`` if the transform needs to be executed""" return self._final_transform