Source code for pennylane.transforms.core.transform_dispatcher
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the transform dispatcher and the transform container.
"""
import functools
import os
import types
import warnings
from collections.abc import Sequence
from copy import copy
import pennylane as qml
from pennylane.typing import ResultBatch
class TransformError(Exception):
"""Raised when there is an error with the transform logic."""
[docs]def register_primitive_for_expansion(primitive, plxpr_transform):
"""Register a transform such that it can be expanded when applied to a function with
program capture enabled."""
# pylint: disable=import-outside-toplevel
try:
import jax
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
except ImportError:
return
@ExpandTransformsInterpreter.register_primitive(primitive)
def _(
self, *invals, inner_jaxpr, args_slice, consts_slice, targs_slice, tkwargs
): # pylint: disable=too-many-arguments,missing-docstring
if plxpr_transform is None:
raise NotImplementedError
inner_args = invals[args_slice]
inner_consts = invals[consts_slice]
targs = invals[targs_slice]
def wrapper(*args):
return copy(self).eval(inner_jaxpr, inner_consts, *args)
unravelled_jaxpr = jax.make_jaxpr(wrapper)(*inner_args)
final_jaxpr = plxpr_transform(
unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *inner_args
)
return copy(self).eval(final_jaxpr.jaxpr, final_jaxpr.consts, *inner_args)
[docs]class TransformDispatcher: # pylint: disable=too-many-instance-attributes
r"""Converts a transform that has the signature ``(tape -> Sequence(tape), fn)`` to a transform dispatcher
that can act on :class:`pennylane.tape.QuantumTape`, quantum function, :class:`pennylane.QNode`,
:class:`pennylane.devices.Device`.
.. warning::
This class is developer-facing and should not be used directly. Instead, use
:func:`qml.transform <pennylane.transform>` if you would like to make a custom
transform.
.. seealso:: :func:`~.pennylane.transform`
"""
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
if os.environ.get("SPHINX_BUILD") == "1":
# If called during a Sphinx documentation build,
# simply return the original function rather than
# instantiating the object. This allows the signature to
# be correctly displayed in the documentation.
warnings.warn(
"Transforms have been disabled, as a Sphinx "
"build has been detected via SPHINX_BUILD='1'. If this is not the "
"case, please set the environment variable SPHINX_BUILD='0'.",
UserWarning,
)
args[0].custom_qnode_transform = lambda x: x
return args[0]
return super().__new__(cls)
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
transform,
expand_transform=None,
classical_cotransform=None,
is_informative=False,
final_transform=False,
use_argnum_in_expand=False,
plxpr_transform=None,
): # pylint:disable=redefined-outer-name
self._transform = transform
self._expand_transform = expand_transform
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative
# is_informative supersedes final_transform
self._final_transform = is_informative or final_transform
self._qnode_transform = self.default_qnode_transform
self._use_argnum_in_expand = use_argnum_in_expand
functools.update_wrapper(self, transform)
self._plxpr_transform = plxpr_transform
self._primitive = _create_transform_primitive(self._transform.__name__)
register_primitive_for_expansion(self._primitive, self._plxpr_transform)
def __call__(
self, *targs, **tkwargs
): # pylint: disable=too-many-return-statements,too-many-branches
obj = None
if targs:
# assume the first argument passed to the transform
# is the object we wish to transform
obj, *targs = targs
if isinstance(obj, qml.tape.QuantumScript):
if self._expand_transform:
expanded_tapes, expand_processing = self._expand_transform(obj, *targs, **tkwargs)
transformed_tapes = []
processing_and_slices = []
start = 0
for tape in expanded_tapes:
intermediate_tapes, post_processing_fn = self._transform(
tape, *targs, **tkwargs
)
transformed_tapes.extend(intermediate_tapes)
end = start + len(intermediate_tapes)
processing_and_slices.append(tuple([post_processing_fn, slice(start, end)]))
start = end
def processing_fn(results):
processed_results = [fn(results[slice]) for fn, slice in processing_and_slices]
return expand_processing(processed_results)
else:
transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs)
if self.is_informative:
return processing_fn(transformed_tapes)
return transformed_tapes, processing_fn
if isinstance(obj, qml.QNode):
if qml.capture.enabled():
return self._capture_callable_transform(obj, targs, tkwargs)
return self._qnode_transform(obj, targs, tkwargs)
if isinstance(obj, qml.devices.Device):
return self._device_transform(obj, targs, tkwargs)
if obj.__class__.__name__ == "QJIT":
raise TransformError(
"Functions that are wrapped / decorated with qjit cannot subsequently be"
f" transformed with a PennyLane transform (attempted {self})."
f" For the desired affect, ensure that qjit is applied after {self}."
)
if callable(obj):
if qml.capture.enabled():
return self._capture_callable_transform(obj, targs, tkwargs)
return self._qfunc_transform(obj, targs, tkwargs)
if isinstance(obj, Sequence) and all(isinstance(q, qml.tape.QuantumScript) for q in obj):
return self._batch_transform(obj, targs, tkwargs)
# Input is not a QNode nor a quantum tape nor a device.
# Assume Python decorator syntax:
#
# result = some_transform(*transform_args)(qnode)(*qnode_args)
raise TransformError(
"Decorating a QNode with @transform_fn(**transform_kwargs) has been "
"removed. Please decorate with @functools.partial(transform_fn, **transform_kwargs) "
"instead, or call the transform directly using qnode = transform_fn(qnode, "
"**transform_kwargs). Visit the deprecations page for more details: "
"https://docs.pennylane.ai/en/stable/development/deprecations.html#completed-deprecation-cycles",
)
def __repr__(self):
return f"<transform: {self._transform.__name__}>"
@property
def transform(self):
"""The quantum transform."""
return self._transform
@property
def expand_transform(self):
"""The expand transform."""
return self._expand_transform
@property
def classical_cotransform(self):
"""The classical co-transform."""
return self._classical_cotransform
@property
def plxpr_transform(self):
"""Function for transforming plxpr."""
return self._plxpr_transform
@property
def is_informative(self):
"""``True`` if the transform is informative."""
return self._is_informative
@property
def final_transform(self):
"""``True`` if the transformed tapes must be executed."""
return self._final_transform
[docs] def custom_qnode_transform(self, fn):
"""Register a custom QNode execution wrapper function
for the batch transform.
**Example**
.. code-block:: python
@transform
def my_transform(tape, *targs, **tkwargs):
...
return tapes, processing_fn
@my_transform.custom_qnode_transform
def my_custom_qnode_wrapper(self, qnode, targs, tkwargs):
tkwargs = {**tkwargs, shots=100}
return self.default_qnode_transform(qnode, targs, tkwargs)
The custom QNode execution wrapper must have arguments
``self`` (the batch transform object), ``qnode`` (the input QNode
to transform and execute), ``targs`` and ``tkwargs`` (the transform
arguments and keyword arguments respectively).
It should return a QNode that accepts the *same* arguments as the
input QNode with the transform applied.
The default :meth:`~.default_qnode_transform` method may be called
if only pre- or post-processing dependent on QNode arguments is required.
"""
self._qnode_transform = types.MethodType(fn, self)
[docs] def default_qnode_transform(self, qnode, targs, tkwargs):
"""
The default method that takes in a QNode and returns another QNode
with the transform applied.
"""
qnode = copy(qnode)
if self.expand_transform:
qnode.add_transform(
TransformContainer(
self._expand_transform,
args=targs,
kwargs=tkwargs,
use_argnum=self._use_argnum_in_expand,
)
)
qnode.add_transform(
TransformContainer(
self._transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=self._classical_cotransform,
plxpr_transform=self._plxpr_transform,
is_informative=self._is_informative,
final_transform=self._final_transform,
)
)
return qnode
def _capture_callable_transform(self, qfunc, targs, tkwargs):
"""Apply the transform on a quantum function when program capture is enabled"""
@functools.wraps(qfunc)
def qfunc_transformed(*args, **kwargs):
import jax # pylint: disable=import-outside-toplevel
flat_qfunc = qml.capture.flatfn.FlatFn(qfunc)
jaxpr = jax.make_jaxpr(functools.partial(flat_qfunc, **kwargs))(*args)
n_args = len(args)
n_consts = len(jaxpr.consts)
args_slice = slice(0, n_args)
consts_slice = slice(n_args, n_args + n_consts)
targs_slice = slice(n_args + n_consts, None)
results = self._primitive.bind(
*args,
*jaxpr.consts,
*targs,
inner_jaxpr=jaxpr.jaxpr,
args_slice=args_slice,
consts_slice=consts_slice,
targs_slice=targs_slice,
tkwargs=tkwargs,
)
assert flat_qfunc.out_tree is not None
return jax.tree_util.tree_unflatten(flat_qfunc.out_tree, results)
return qfunc_transformed
def _qfunc_transform(self, qfunc, targs, tkwargs):
"""Apply the transform on a quantum function."""
@functools.wraps(qfunc)
def qfunc_transformed(*args, **kwargs):
with qml.queuing.AnnotatedQueue() as q:
qfunc_output = qfunc(*args, **kwargs)
tape = qml.tape.QuantumScript.from_queue(q)
with qml.QueuingManager.stop_recording():
transformed_tapes, processing_fn = self._transform(tape, *targs, **tkwargs)
if len(transformed_tapes) != 1:
raise TransformError(
"Impossible to dispatch your transform on quantum function, because more than "
"one tape is returned"
)
transformed_tape = transformed_tapes[0]
if self.is_informative:
return processing_fn(transformed_tapes)
for op in transformed_tape.circuit:
qml.apply(op)
mps = transformed_tape.measurements
if not mps:
return qfunc_output
if isinstance(qfunc_output, qml.measurements.MeasurementProcess):
return tuple(mps) if len(mps) > 1 else mps[0]
if isinstance(qfunc_output, (tuple, list)):
return type(qfunc_output)(mps)
interface = qml.math.get_interface(qfunc_output)
return qml.math.asarray(mps, like=interface)
return qfunc_transformed
def _device_transform(self, original_device, targs, tkwargs):
"""Apply the transform on a device"""
if self._expand_transform:
raise TransformError("Device transform does not support expand transforms.")
if self._is_informative:
raise TransformError("Device transform does not support informative transforms.")
if self._final_transform:
raise TransformError("Device transform does not support final transforms.")
class TransformedDevice(type(original_device)):
"""A transformed device with updated preprocess method."""
def __init__(self, original_device, transform):
for key, value in original_device.__dict__.items():
self.__setattr__(key, value)
self.transform = transform
self._original_device = original_device
def __repr__(self):
return f"Transformed Device({original_device.__repr__()} with additional preprocess transform {self.transform})"
def preprocess(
self,
execution_config: qml.devices.ExecutionConfig = qml.devices.DefaultExecutionConfig,
):
"""This function updates the original device transform program to be applied."""
program, config = self.original_device.preprocess(execution_config)
program.push_back(TransformContainer(self.transform, args=targs, kwargs=tkwargs))
return program, config
@property
def original_device(self):
"""Return the original device."""
return self._original_device
return TransformedDevice(original_device, self._transform)
def _batch_transform(self, original_batch, targs, tkwargs):
"""Apply the transform on a batch of tapes."""
execution_tapes = []
batch_fns = []
tape_counts = []
for t in original_batch:
# Preprocess the tapes by applying transforms
# to each tape, and storing corresponding tapes
# for execution, processing functions, and list of tape lengths.
new_tapes, fn = self(t, *targs, **tkwargs)
execution_tapes.extend(new_tapes)
batch_fns.append(fn)
tape_counts.append(len(new_tapes))
def processing_fn(res: ResultBatch) -> ResultBatch:
"""Applies a batch of post-processing functions to results.
Args:
res (ResultBatch): the results of executing a batch of circuits.
Returns:
ResultBatch: results that have undergone classical post processing.
Closure variables:
tape_counts: the number of tapes outputted from each application of the transform.
batch_fns: the post processing functions to apply to each sub-batch.
"""
count = 0
final_results = []
for f, s in zip(batch_fns, tape_counts):
# apply any batch transform post-processing
new_res = f(res[count : count + s])
final_results.append(new_res)
count += s
return tuple(final_results)
return tuple(execution_tapes), processing_fn
[docs]class TransformContainer: # pylint: disable=too-many-instance-attributes, too-many-positional-arguments
"""Class to store a quantum transform with its ``args``, ``kwargs`` and classical co-transforms. Use
:func:`~.pennylane.transform`.
.. warning::
This class is developer-facing and should not be used directly. Instead, use
:func:`qml.transform <pennylane.transform>` if you would like to make a custom
transform.
.. seealso:: :func:`~.pennylane.transform`
"""
def __init__(
self,
transform,
args=None,
kwargs=None,
classical_cotransform=None,
plxpr_transform=None,
is_informative=False,
final_transform=False,
use_argnum=False,
): # pylint:disable=redefined-outer-name,too-many-arguments,too-many-positional-arguments
self._transform = transform
self._args = args or []
self._kwargs = kwargs or {}
self._classical_cotransform = classical_cotransform
self._plxpr_transform = plxpr_transform
self._is_informative = is_informative
self._final_transform = is_informative or final_transform
self._use_argnum = use_argnum
def __repr__(self):
return f"<{self._transform.__name__}({self._args}, {self._kwargs})>"
def __iter__(self):
return iter(
(
self._transform,
self._args,
self._kwargs,
self._classical_cotransform,
self._plxpr_transform,
self._is_informative,
self.final_transform,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, TransformContainer):
return False
return (
self.args == other.args
and self.transform == other.transform
and self.kwargs == other.kwargs
and self.classical_cotransform == other.classical_cotransform
and self.is_informative == other.is_informative
and self.final_transform == other.final_transform
)
@property
def transform(self):
"""The stored quantum transform."""
return self._transform
@property
def args(self):
"""The stored quantum transform's ``args``."""
return self._args
@property
def kwargs(self):
"""The stored quantum transform's ``kwargs``."""
return self._kwargs
@property
def classical_cotransform(self):
"""The stored quantum transform's classical co-transform."""
return self._classical_cotransform
@property
def plxpr_transform(self):
"""The stored quantum transform's PLxPR transform."""
return self._plxpr_transform
@property
def is_informative(self):
"""``True`` if the transform is informative."""
return self._is_informative
@property
def final_transform(self):
"""``True`` if the transform needs to be executed"""
return self._final_transform
def _create_transform_primitive(name):
try:
# pylint: disable=import-outside-toplevel
import jax
except ImportError:
return None
transform_prim = jax.core.Primitive(name + "_transform")
transform_prim.multiple_results = True
@transform_prim.def_impl
def _(
*all_args, inner_jaxpr, args_slice, consts_slice, targs_slice, tkwargs
): # pylint: disable=unused-argument
raise NotImplementedError
@transform_prim.def_abstract_eval
def _(*_, inner_jaxpr, **__):
return [out.aval for out in inner_jaxpr.outvars]
return transform_prim
_modules/pennylane/transforms/core/transform_dispatcher
Download Python script
Download Notebook
View on GitHub