Source code for pennylane.transforms.core.transform_program
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the ``TransformProgram`` class.
"""
from collections.abc import Sequence
from functools import partial
from typing import Optional, Union, overload
from pennylane.exceptions import TransformError
from pennylane.tape import QuantumScriptBatch
from pennylane.typing import BatchPostprocessingFn, PostprocessingFn, ResultBatch
from .cotransform_cache import CotransformCache
from .transform_dispatcher import TransformContainer, TransformDispatcher
def _batch_postprocessing(
results: ResultBatch,
individual_fns: list[PostprocessingFn],
slices: Union[list[slice], list[int]],
) -> ResultBatch:
"""Broadcast individual post processing functions onto their respective tapes.
Args:
results (ResultBatch): The numeric outcome from executing a batch of :class:`~.QuantumTape`
Keyword Args:
individual_fns (List[Callable]): postprocessing functions converting a batch of results into a single result
corresponding to only a single :class:`~.QuantumTape`.
slices (List[slice]): the indices for the results that correspond to each individual post processing function.
>>> results = (1.0, 2.0, 3.0, 4.0)
>>> def postprocessing1(results):
... return results[0] + results[1]
>>> def postprocessing2(results):
... return results[0]+0.5
>>> def postprocessing3(results):
... return results[0]*2
>>> slices = [slice(0,2), slice(2,3), slice(3,4)]
>>> individual_fns = [postprocessing1, postprocessing2, postprocessing3]
>>> _batch_postprocessing(results, individual_fns, slices)
(3.0, 3.5, 8.0)
"""
return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))
def _apply_postprocessing_stack(
results: ResultBatch,
postprocessing_stack: list[BatchPostprocessingFn],
) -> ResultBatch:
"""Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
Args:
results (ResultBatch): The numeric outcome from executing a batch of :class:`~.QuantumTape`
Keyword Args:
postprocessing_stack (List(BatchPostProcessingFn)): a LIFO stack of post processing functions.
Returns:
ResultBatch: the post processed results.
>>> results = (1.0, 2.0, 3.0, 4.0)
>>> def postprocessing1(results):
... return (results[0] + results[1], results[2] + results[3])
>>> def postprocessing2(results):
.... return (results[0] + 1, results[1] + 2)
>>> _apply_postprocessing_stack(results, [postprocessing1])
(3.0, 7.0)
>>> _apply_postprocessing_stack(results, [postprocessing2, postprocessing1])
(4.0, 9.0)
"""
for postprocessing in reversed(postprocessing_stack):
results = postprocessing(results)
return results
[docs]
def null_postprocessing(results: ResultBatch) -> ResultBatch:
"""An empty postprocessing function that simply returns its input.
Args:
results (ResultBatch): Results from executing a batch of :class:`~.QuantumTape`.
Returns:
ResultBatch: the input to the function.
"""
return results
[docs]
class TransformProgram:
"""Class that contains a transform program and the methods to interact with it.
The order of execution is the order in the list containing the containers.
Args:
initial_program (Optional[Sequence[TransformContainer]]): A sequence of transforms with
which to initialize the program.
cotransform_cache (Optional[CotransformCache]): A named tuple containing the ``qnode``,
``args``, and ``kwargs`` required to compute classical cotransforms.
The main case where one would have to interact directly with a transform program is when developing a
:class:`Device <pennylane.devices.Device>`. In this case, the pre-processing method of a device
returns a transform program. You should directly refer to the device API documentation for more details.
.. warning::
This class is developer-facing and should not be used directly. Instead, use
:func:`qml.transform <pennylane.transform>` if you would like to make a custom
transform.
.. seealso:: :func:`~.pennylane.transform`
**Implemented Dunder methods**
Programs have several implemented dunder methods for easy manipulation.
>>> from pennylane.transforms.core.transform_program import TransformProgram
>>> from copy import copy
>>> program = TransformProgram()
>>> program.add_transform(qml.compile)
>>> program.add_transform(qml.transforms.cancel_inverses)
>>> [t for t in program] # Iteration
[<compile([], {})>, <cancel_inverses([], {})>]
>>> program[0]
<compile([], {})>
>>> program[::-1]
TransformProgram(cancel_inverses, compile)
>>> len(program)
2
>>> True if program else False
True
>>> True if TransformProgram() else False
False
>>> program2 = copy(program)
>>> program2 == program
True
>>> qml.compile in program
True
>>> qml.transforms.split_non_commuting in program
False
>>> program + program
TransformProgram(compile, cancel_inverses, compile, cancel_inverses)
"""
def __init__(
self,
initial_program: Optional[Sequence[TransformContainer]] = None,
cotransform_cache: Optional[CotransformCache] = None,
):
self._transform_program = list(initial_program) if initial_program else []
self.cotransform_cache = cotransform_cache
def __iter__(self):
"""list[TransformContainer]: Return an iterator to the underlying transform program."""
return self._transform_program.__iter__()
def __len__(self) -> int:
"""int: Return the number transforms in the program."""
return len(self._transform_program)
@overload
def __getitem__(self, idx: int) -> "TransformContainer": ...
@overload
def __getitem__(self, idx: slice) -> "TransformProgram": ...
def __getitem__(self, idx):
"""(TransformContainer, List[TransformContainer]): Return the indexed transform container from underlying
transform program"""
if isinstance(idx, slice):
return TransformProgram(self._transform_program[idx])
return self._transform_program[idx]
def __bool__(self) -> bool:
return bool(self._transform_program)
def __add__(self, other: "TransformProgram") -> "TransformProgram":
if self.has_final_transform and other.has_final_transform:
raise TransformError("The transform program already has a terminal transform.")
transforms = self._transform_program + other._transform_program
if self.has_final_transform:
transforms.append(transforms.pop(len(self) - 1))
cotransform_cache = None
if self.cotransform_cache:
if other.cotransform_cache:
raise ValueError("Cannot add two transform programs with cotransform caches.")
cotransform_cache = self.cotransform_cache
elif other.cotransform_cache:
cotransform_cache = other.cotransform_cache
return TransformProgram(transforms, cotransform_cache=cotransform_cache)
def __repr__(self):
"""The string representation of the transform program class."""
contents = ", ".join(f"{transform_c.transform.__name__}" for transform_c in self)
return f"TransformProgram({contents})"
def __eq__(self, other) -> bool:
if not isinstance(other, TransformProgram):
return False
return self._transform_program == other._transform_program
def __contains__(self, obj) -> bool:
if isinstance(obj, TransformContainer):
return obj in self._transform_program
if isinstance(obj, TransformDispatcher):
return any(obj.transform == t.transform for t in self)
return False
[docs]
def push_back(self, transform_container: TransformContainer):
"""Add a transform (container) to the end of the program.
Args:
transform_container(TransformContainer): A transform represented by its container.
"""
if not isinstance(transform_container, TransformContainer):
raise TransformError("Only transform container can be added to the transform program.")
# Program can only contain one informative transform and at the end of the program
if self.has_final_transform:
if transform_container.final_transform:
raise TransformError("The transform program already has a terminal transform.")
self._transform_program.insert(-1, transform_container)
return
self._transform_program.append(transform_container)
[docs]
def insert_front(self, transform_container: TransformContainer):
"""Insert the transform container at the beginning of the program.
Args:
transform_container(TransformContainer): A transform represented by its container.
"""
if (transform_container.final_transform) and not self.is_empty():
raise TransformError(
"Informative transforms can only be added at the end of the program."
)
self._transform_program.insert(0, transform_container)
[docs]
def add_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
"""Add a transform (dispatcher) to the end of the program.
Note that this should be a function decorated with/called by
``qml.transforms.transform``, and not a ``TransformContainer``.
Args:
transform (TransformDispatcher): The transform to add to the transform program.
*targs: Any additional arguments that are passed to the transform.
Keyword Args:
**tkwargs: Any additional keyword arguments that are passed to the transform.
"""
if not isinstance(transform, TransformDispatcher):
raise TransformError("Only transform dispatcher can be added to the transform program.")
if transform.expand_transform:
self.push_back(TransformContainer(transform.expand_transform, targs, tkwargs))
self.push_back(
TransformContainer(
transform.transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=transform.classical_cotransform,
plxpr_transform=transform.plxpr_transform,
is_informative=transform.is_informative,
final_transform=transform.final_transform,
)
)
[docs]
def insert_front_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
"""Add a transform (dispatcher) to the beginning of the program.
Args:
transform(TransformDispatcher): The transform to add to the front of the transform program.
*targs: Any additional arguments that are passed to the transform.
Keyword Args:
**tkwargs: Any additional keyword arguments that are passed to the transform.
"""
if transform.final_transform and not self.is_empty():
raise TransformError(
"Informative transforms can only be added at the end of the program."
)
self.insert_front(
TransformContainer(
transform.transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=transform.classical_cotransform,
plxpr_transform=transform.plxpr_transform,
is_informative=transform.is_informative,
final_transform=transform.final_transform,
)
)
if transform.expand_transform:
self.insert_front(TransformContainer(transform.expand_transform, targs, tkwargs))
[docs]
def pop_front(self):
"""Pop the transform container at the beginning of the program.
Returns:
TransformContainer: The transform container at the beginning of the program.
"""
return self._transform_program.pop(0)
[docs]
def get_last(self):
"""Get the last transform container.
Returns:
TransformContainer: The last transform in the program.
Raises:
TransformError: It raises an error if the program is empty.
"""
if self:
return self._transform_program[-1]
raise TransformError(
"The transform program is empty and you cannot get the last transform container."
)
[docs]
def is_empty(self):
"""Check if the transform program is empty or not.
Returns:
bool: Boolean, True if empty, False otherwise.
"""
return len(self) == 0
@property
def is_informative(self) -> bool:
"""``True`` if the transform program is informative.
Returns:
bool: Boolean
"""
return self[-1].is_informative if self else False
@property
def has_final_transform(self) -> bool:
"""``True`` if the transform program has a terminal transform."""
return self[-1].final_transform if self else False # pylint: disable=no-member
[docs]
def has_classical_cotransform(self) -> bool:
"""Check if the transform program has some classical cotransforms.
Returns:
bool: Boolean
"""
return any(t.classical_cotransform is not None for t in self)
[docs]
def set_classical_component(self, qnode, args, kwargs):
"""Set the classical jacobians and argnums if the transform is hybrid with a classical cotransform."""
# pylint: disable=no-member
if self.has_classical_cotransform() and self[-1].kwargs.get("hybrid", True):
self.cotransform_cache = CotransformCache(qnode, args, kwargs)
[docs]
def prune_dynamic_transform(self, type_to_keep=1):
"""Ensures that only one or none ``dynamic_one_shot`` is applied.
Args:
type_to_keep (int): The type of the dynamic transform to keep. 0: keep none,
1: dynamic_one_shot or mid_circuit_measurements, 2: only mid_circuit_measurements.
Returns:
bool: ``True`` if a dynamic transform was found, ``False`` otherwise.
"""
i = len(self._transform_program) - 1
found = False
while i >= 0:
t = self._transform_program[i]
if "mid_circuit_measurements" in str(t) and type_to_keep > 0:
type_to_keep = 0 # keep this and do not keep the rest
found = True
elif "dynamic_one_shot" in str(t) and type_to_keep == 1:
type_to_keep = 0 # keep this and do not keep the rest
found = True
elif "dynamic_one_shot" in str(t) or "mid_circuit_measurements" in str(t):
self._transform_program.pop(i)
i -= 1
return found
def __call_tapes(
self, tapes: QuantumScriptBatch
) -> tuple[QuantumScriptBatch, BatchPostprocessingFn]:
if not self:
return tapes, null_postprocessing
processing_fns_stack = []
for transform_container in self:
transform, targs, tkwargs, cotransform, _, _, _ = transform_container
tkwargs = {
key: value for key, value in tkwargs.items() if key not in {"argnums", "hybrid"}
}
execution_tapes, fns, slices, classical_fns = [], [], [], []
start = 0
argnums = (
self.cotransform_cache.get_argnums(transform_container)
if self.cotransform_cache
else None
)
classical_jacobians = []
for tape_idx, tape in enumerate(tapes):
if argnums is not None:
# pylint: disable=unsubscriptable-object
tape.trainable_params = argnums[tape_idx]
new_tapes, fn = transform(tape, *targs, **tkwargs)
execution_tapes.extend(new_tapes)
fns.append(fn)
end = start + len(new_tapes)
slices.append(slice(start, end))
start = end
jac = (
self.cotransform_cache.get_classical_jacobian(transform_container, tape_idx)
if self.cotransform_cache
else None
)
classical_jacobians.append(jac)
if cotransform and classical_jacobians[-1] is not None:
classical_fns.append(
partial(cotransform, cjac=classical_jacobians[-1], tape=tape)
)
if cotransform and classical_fns:
slices_classical = list(range(len(tapes)))
batch_postprocessing_classical = partial(
_batch_postprocessing, individual_fns=classical_fns, slices=slices_classical
)
batch_postprocessing_classical.__doc__ = _batch_postprocessing.__doc__
processing_fns_stack.append(batch_postprocessing_classical)
batch_postprocessing = partial(_batch_postprocessing, individual_fns=fns, slices=slices)
batch_postprocessing.__doc__ = _batch_postprocessing.__doc__
processing_fns_stack.append(batch_postprocessing)
# set input tapes for next iteration.
tapes = execution_tapes
postprocessing_fn = partial(
_apply_postprocessing_stack,
postprocessing_stack=processing_fns_stack,
)
postprocessing_fn.__doc__ = _apply_postprocessing_stack.__doc__
# Reset classical jacobians
return tuple(tapes), postprocessing_fn
def __call_jaxpr(
self, jaxpr: "jax.extend.core.Jaxpr", consts: Sequence, *args
) -> "jax.extend.core.ClosedJaxpr":
# pylint: disable=import-outside-toplevel
import jax
cur_jaxpr = jax.extend.core.ClosedJaxpr(jaxpr, consts)
for container in self:
_, targs, tkwargs, _, plxpr_transform, _, _ = container
cur_jaxpr = plxpr_transform(cur_jaxpr.jaxpr, cur_jaxpr.consts, targs, tkwargs, *args)
return cur_jaxpr
@overload
def __call__(
self, jaxpr: "jax.extend.core.Jaxpr", consts: Sequence, *args
) -> "jax.extend.core.ClosedJaxpr": ...
@overload
def __call__(
self, tapes: QuantumScriptBatch
) -> tuple[QuantumScriptBatch, BatchPostprocessingFn]: ...
def __call__(self, *args, **kwargs):
if type(args[0]).__name__ == "Jaxpr":
return self.__call_jaxpr(*args, **kwargs)
return self.__call_tapes(*args, **kwargs)
_modules/pennylane/transforms/core/transform_program
Download Python script
Download Notebook
View on GitHub