Source code for pennylane.transforms.core.transform_program
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the ``TransformProgram`` class.
"""
from collections import namedtuple
from collections.abc import Sequence
from functools import partial
from typing import Optional, overload
import pennylane as qml
from pennylane.tape import QuantumScriptBatch
from pennylane.typing import BatchPostprocessingFn, PostprocessingFn, ResultBatch
from .transform_dispatcher import TransformContainer, TransformDispatcher, TransformError
CotransformCache = namedtuple("CotransformCache", ("qnode", "args", "kwargs"))
def _get_interface(qnode, args, kwargs) -> str:
if qnode.interface == "auto":
interface = qml.math.get_interface(*args, *list(kwargs.values()))
try:
interface = qml.math.get_canonical_interface_name(interface).value
except ValueError:
interface = "numpy"
else:
interface = qnode.interface
return interface
def _numpy_jac(*_, **__) -> qml.typing.TensorLike:
raise qml.QuantumFunctionError("No trainable parameters.")
def _autograd_jac(classical_function, argnums, *args, **kwargs) -> qml.typing.TensorLike:
if not qml.math.get_trainable_indices(args) and argnums is None:
raise qml.QuantumFunctionError("No trainable parameters.")
return qml.jacobian(classical_function, argnum=argnums)(*args, **kwargs)
# pylint: disable=import-outside-toplevel, unused-argument
def _tf_jac(classical_function, argnums, *args, **kwargs) -> qml.typing.TensorLike:
if not qml.math.get_trainable_indices(args):
raise qml.QuantumFunctionError("No trainable parameters.")
import tensorflow as tf
with tf.GradientTape() as tape:
gate_params = classical_function(*args, **kwargs)
return tape.jacobian(gate_params, args)
# pylint: disable=import-outside-toplevel, unused-argument
def _torch_jac(classical_function, argnums, *args, **kwargs) -> qml.typing.TensorLike:
if not qml.math.get_trainable_indices(args):
raise qml.QuantumFunctionError("No trainable parameters.")
from torch.autograd.functional import jacobian
return jacobian(partial(classical_function, **kwargs), args)
# pylint: disable=import-outside-toplevel
def _jax_jac(classical_function, argnums, *args, **kwargs) -> qml.typing.TensorLike:
import jax
if argnums is None:
if not isinstance(args[0], jax.numpy.ndarray):
raise qml.QuantumFunctionError("No trainable parameters.")
argnums = 0
return jax.jacobian(classical_function, argnums=argnums)(*args, **kwargs)
_jac_map = {
None: _numpy_jac,
"numpy": _numpy_jac,
"autograd": _autograd_jac,
"tf": _tf_jac,
"torch": _torch_jac,
"jax": _jax_jac,
"jax-jit": _jax_jac,
}
# pylint: disable=unused-argument
def _classical_preprocessing(qnode, program, *args, argnums=None, **kwargs):
"""Returns the trainable gate parameters for a given QNode input."""
tape = qml.workflow.construct_tape(qnode, level=0)(*args, **kwargs)
tapes, _ = program((tape,))
res = tuple(qml.math.stack(tape.get_parameters(trainable_only=True)) for tape in tapes)
# autograd and tf cant handle pytrees, so need to squeeze batches
if len(tapes) == 1:
return res[0]
return res
def _jax_argnums_to_tape_trainable(qnode, argnums, program, args, kwargs):
"""This function gets the tape parameters from the QNode construction given some argnums (only for Jax).
The tape parameters are transformed to JVPTracer if they are from argnums. This function imitates the behaviour
of Jax in order to mark trainable parameters.
Args:
qnode(qml.QNode): the quantum node.
argnums(int, list[int]): the parameters that we want to set as trainable (on the QNode level).
program(qml.transforms.core.TransformProgram): the transform program to be applied on the tape.
Return:
list[float, jax.JVPTracer]: List of parameters where the trainable one are `JVPTracer`.
"""
import jax # pylint: disable=import-outside-toplevel
with jax.core.new_main(jax.interpreters.ad.JVPTrace) as main:
trace = jax.interpreters.ad.JVPTrace(main, 0)
args_jvp = [
(
jax.interpreters.ad.JVPTracer(trace, arg, jax.numpy.zeros(arg.shape))
if i in argnums
else arg
)
for i, arg in enumerate(args)
]
tape = qml.workflow.construct_tape(qnode, level=0)(*args_jvp, **kwargs)
tapes, _ = program((tape,))
del trace
return tuple(tape.get_parameters(trainable_only=False) for tape in tapes)
def _batch_postprocessing(
results: ResultBatch, individual_fns: list[PostprocessingFn], slices: list[slice]
) -> ResultBatch:
"""Broadcast individual post processing functions onto their respective tapes.
Args:
results (ResultBatch): The numeric outcome from executing a batch of :class:`~.QuantumTape`
Keyword Args:
individual_fns (List[Callable]): postprocessing functions converting a batch of results into a single result
corresponding to only a single :class:`~.QuantumTape`.
slices (List[slice]): the indices for the results that correspond to each individual post processing function.
>>> results = (1.0, 2.0, 3.0, 4.0)
>>> def postprocessing1(results):
... return results[0] + results[1]
>>> def postprocessing2(results):
... return results[0]+0.5
>>> def postprocessing3(results):
... return results[0]*2
>>> slices = [slice(0,2), slice(2,3), slice(3,4)]
>>> individual_fns = [postprocessing1, postprocessing2, postprocessing3]
>>> _batch_postprocessing(results, individual_fns, slices)
(3.0, 3.5, 8.0)
"""
return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))
def _apply_postprocessing_stack(
results: ResultBatch,
postprocessing_stack: list[BatchPostprocessingFn],
) -> ResultBatch:
"""Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
Args:
results (ResultBatch): The numeric outcome from executing a batch of :class:`~.QuantumTape`
Keyword Args:
postprocessing_stack (List(BatchPostProcessingFn)): a LIFO stack of post processing functions.
Returns:
ResultBatch: the post processed results.
>>> results = (1.0, 2.0, 3.0, 4.0)
>>> def postprocessing1(results):
... return (results[0] + results[1], results[2] + results[3])
>>> def postprocessing2(results):
.... return (results[0] + 1, results[1] + 2)
>>> _apply_postprocessing_stack(results, [postprocessing1])
(3.0, 7.0)
>>> _apply_postprocessing_stack(results, [postprocessing2, postprocessing1])
(4.0, 9.0)
"""
for postprocessing in reversed(postprocessing_stack):
results = postprocessing(results)
return results
[docs]def null_postprocessing(results: ResultBatch) -> ResultBatch:
"""An empty postprocessing function that simply returns its input.
Args:
results (ResultBatch): Results from executing a batch of :class:`~.QuantumTape`.
Returns:
ResultBatch: the input to the function.
"""
return results
[docs]class TransformProgram:
"""Class that contains a transform program and the methods to interact with it.
The order of execution is the order in the list containing the containers.
Args:
initial_program (Optional[Sequence[TransformContainer]]): A sequence of transforms with
which to initialize the program.
cotransform_cache (Optional[CotransformCache]): A named tuple containing the ``qnode``,
``args``, and ``kwargs`` required to compute classical cotransforms.
The main case where one would have to interact directly with a transform program is when developing a
:class:`Device <pennylane.devices.Device>`. In this case, the pre-processing method of a device
returns a transform program. You should directly refer to the device API documentation for more details.
.. warning::
This class is developer-facing and should not be used directly. Instead, use
:func:`qml.transform <pennylane.transform>` if you would like to make a custom
transform.
.. seealso:: :func:`~.pennylane.transform`
**Implemented Dunder methods**
Programs have several implemented dunder methods for easy manipulation.
>>> from pennylane.transforms.core.transform_program import TransformProgram
>>> from copy import copy
>>> program = TransformProgram()
>>> program.add_transform(qml.compile)
>>> program.add_transform(qml.transforms.cancel_inverses)
>>> [t for t in program] # Iteration
[<compile([], {})>, <cancel_inverses([], {})>]
>>> program[0]
<compile([], {})>
>>> program[::-1]
TransformProgram(cancel_inverses, compile)
>>> len(program)
2
>>> True if program else False
True
>>> True if TransformProgram() else False
False
>>> program2 = copy(program)
>>> program2 == program
True
>>> qml.compile in program
True
>>> qml.transforms.split_non_commuting in program
False
>>> program + program
TransformProgram(compile, cancel_inverses, compile, cancel_inverses)
"""
def __init__(
self,
initial_program: Optional[Sequence[TransformContainer]] = None,
cotransform_cache: Optional[CotransformCache] = None,
):
self._transform_program = list(initial_program) if initial_program else []
self.cotransform_cache = cotransform_cache
def __iter__(self):
"""list[TransformContainer]: Return an iterator to the underlying transform program."""
return self._transform_program.__iter__()
def __len__(self) -> int:
"""int: Return the number transforms in the program."""
return len(self._transform_program)
@overload
def __getitem__(self, idx: int) -> "TransformContainer": ...
@overload
def __getitem__(self, idx: slice) -> "TransformProgram": ...
def __getitem__(self, idx):
"""(TransformContainer, List[TransformContainer]): Return the indexed transform container from underlying
transform program"""
if isinstance(idx, slice):
return TransformProgram(self._transform_program[idx])
return self._transform_program[idx]
def __bool__(self) -> bool:
return bool(self._transform_program)
def __add__(self, other: "TransformProgram") -> "TransformProgram":
if self.has_final_transform and other.has_final_transform:
raise TransformError("The transform program already has a terminal transform.")
transforms = self._transform_program + other._transform_program
if self.has_final_transform:
transforms.append(transforms.pop(len(self) - 1))
cotransform_cache = None
if self.cotransform_cache:
if other.cotransform_cache:
raise ValueError("Cannot add two transform programs with cotransform caches.")
cotransform_cache = self.cotransform_cache
elif other.cotransform_cache:
cotransform_cache = other.cotransform_cache
return TransformProgram(transforms, cotransform_cache=cotransform_cache)
def __repr__(self):
"""The string representation of the transform program class."""
contents = ", ".join(f"{transform_c.transform.__name__}" for transform_c in self)
return f"TransformProgram({contents})"
def __eq__(self, other) -> bool:
if not isinstance(other, TransformProgram):
return False
return self._transform_program == other._transform_program
def __contains__(self, obj) -> bool:
if isinstance(obj, TransformContainer):
return obj in self._transform_program
if isinstance(obj, TransformDispatcher):
return any(obj.transform == t.transform for t in self)
return False
[docs] def push_back(self, transform_container: TransformContainer):
"""Add a transform (container) to the end of the program.
Args:
transform_container(TransformContainer): A transform represented by its container.
"""
if not isinstance(transform_container, TransformContainer):
raise TransformError("Only transform container can be added to the transform program.")
# Program can only contain one informative transform and at the end of the program
if self.has_final_transform:
if transform_container.final_transform:
raise TransformError("The transform program already has a terminal transform.")
self._transform_program.insert(-1, transform_container)
return
self._transform_program.append(transform_container)
[docs] def insert_front(self, transform_container: TransformContainer):
"""Insert the transform container at the beginning of the program.
Args:
transform_container(TransformContainer): A transform represented by its container.
"""
if (transform_container.final_transform) and not self.is_empty():
raise TransformError(
"Informative transforms can only be added at the end of the program."
)
self._transform_program.insert(0, transform_container)
[docs] def add_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
"""Add a transform (dispatcher) to the end of the program.
Note that this should be a function decorated with/called by
``qml.transforms.transform``, and not a ``TransformContainer``.
Args:
transform (TransformDispatcher): The transform to add to the transform program.
*targs: Any additional arguments that are passed to the transform.
Keyword Args:
**tkwargs: Any additional keyword arguments that are passed to the transform.
"""
if not isinstance(transform, TransformDispatcher):
raise TransformError("Only transform dispatcher can be added to the transform program.")
if transform.expand_transform:
self.push_back(TransformContainer(transform.expand_transform, targs, tkwargs))
self.push_back(
TransformContainer(
transform.transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=transform.classical_cotransform,
plxpr_transform=transform.plxpr_transform,
is_informative=transform.is_informative,
final_transform=transform.final_transform,
)
)
[docs] def insert_front_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
"""Add a transform (dispatcher) to the beginning of the program.
Args:
transform(TransformDispatcher): The transform to add to the front of the transform program.
*targs: Any additional arguments that are passed to the transform.
Keyword Args:
**tkwargs: Any additional keyword arguments that are passed to the transform.
"""
if transform.final_transform and not self.is_empty():
raise TransformError(
"Informative transforms can only be added at the end of the program."
)
self.insert_front(
TransformContainer(
transform.transform,
args=targs,
kwargs=tkwargs,
classical_cotransform=transform.classical_cotransform,
plxpr_transform=transform.plxpr_transform,
is_informative=transform.is_informative,
final_transform=transform.final_transform,
)
)
if transform.expand_transform:
self.insert_front(TransformContainer(transform.expand_transform, targs, tkwargs))
[docs] def pop_front(self):
"""Pop the transform container at the beginning of the program.
Returns:
TransformContainer: The transform container at the beginning of the program.
"""
return self._transform_program.pop(0)
[docs] def get_last(self):
"""Get the last transform container.
Returns:
TransformContainer: The last transform in the program.
Raises:
TransformError: It raises an error if the program is empty.
"""
if self:
return self._transform_program[-1]
raise TransformError(
"The transform program is empty and you cannot get the last transform container."
)
[docs] def is_empty(self):
"""Check if the transform program is empty or not.
Returns:
bool: Boolean, True if empty, False otherwise.
"""
return len(self) == 0
@property
def is_informative(self) -> bool:
"""``True`` if the transform program is informative.
Returns:
bool: Boolean
"""
return self[-1].is_informative if self else False
@property
def has_final_transform(self) -> bool:
"""``True`` if the transform program has a terminal transform."""
return self[-1].final_transform if self else False # pylint: disable=no-member
[docs] def has_classical_cotransform(self) -> bool:
"""Check if the transform program has some classical cotransforms.
Returns:
bool: Boolean
"""
return any(t.classical_cotransform is not None for t in self)
[docs] def set_classical_component(self, qnode, args, kwargs):
"""Set the classical jacobians and argnums if the transform is hybrid with a classical cotransform."""
# pylint: disable=no-member
if self.has_classical_cotransform() and self[-1].kwargs.get("hybrid", True):
self.cotransform_cache = CotransformCache(qnode, args, kwargs)
[docs] def prune_dynamic_transform(self, type_to_keep=1):
"""Ensures that only one or none ``dynamic_one_shot`` is applied.
Args:
type_to_keep (int): The type of the dynamic transform to keep. 0: keep none,
1: dynamic_one_shot or mid_circuit_measurements, 2: only mid_circuit_measurements.
Returns:
bool: ``True`` if a dynamic transform was found, ``False`` otherwise.
"""
i = len(self._transform_program) - 1
found = False
while i >= 0:
t = self._transform_program[i]
if "mid_circuit_measurements" in str(t) and type_to_keep > 0:
type_to_keep = 0 # keep this and do not keep the rest
found = True
elif "dynamic_one_shot" in str(t) and type_to_keep == 1:
type_to_keep = 0 # keep this and do not keep the rest
found = True
elif "dynamic_one_shot" in str(t) or "mid_circuit_measurements" in str(t):
self._transform_program.pop(i)
i -= 1
return found
# pylint: disable=too-many-arguments, too-many-positional-arguments
def _get_classical_jacobian(self, index: int):
if self.cotransform_cache is None or not self[index].classical_cotransform:
return None
argnums = self[-1].kwargs.get("argnums", None) # pylint: disable=no-member
qnode, args, kwargs = self.cotransform_cache
interface = _get_interface(qnode, args, kwargs)
if interface == "jax" and "argnum" in self[index].kwargs:
raise qml.QuantumFunctionError(
"argnum does not work with the Jax interface. You should use argnums instead."
)
f = partial(_classical_preprocessing, qnode, self[:index])
classical_jacobian = _jac_map[interface](f, argnums, *args, **kwargs)
# autograd and tf cant handle pytrees, so need to unsqueeze the squeezing
# done in _classical_preprocessing
tape = qml.workflow.construct_tape(qnode, level=0)(*args, **kwargs)
tapes, _ = self[:index]((tape,)) # pylint: disable=not-callable
multi_tapes = len(tapes) > 1
if not multi_tapes:
classical_jacobian = [classical_jacobian]
return classical_jacobian
def _get_argnums(self, index):
"""It can be used inside the QNode to set all argnums (tape level) using argnums from the argnums at the QNode
level.
"""
if self.cotransform_cache is None:
return None
qnode, args, kwargs = self.cotransform_cache
interface = _get_interface(qnode, args, kwargs)
transform = self[index]
argnums = self[-1].kwargs.get("argnums", None) # pylint: disable=no-member
argnums = [0] if interface in ["jax", "jax-jit"] and argnums is None else argnums
# pylint: disable=protected-access
if (transform._use_argnum or transform.classical_cotransform) and argnums:
params = _jax_argnums_to_tape_trainable(qnode, argnums, self[:index], args, kwargs)
return [qml.math.get_trainable_indices(param) for param in params]
return None
def __call__(
self, tapes: QuantumScriptBatch
) -> tuple[QuantumScriptBatch, BatchPostprocessingFn]:
if not self:
return tapes, null_postprocessing
processing_fns_stack = []
for i, transform_container in enumerate(self):
transform, targs, tkwargs, cotransform, _, _, _ = transform_container
tkwargs = {
key: value for key, value in tkwargs.items() if key not in {"argnums", "hybrid"}
}
execution_tapes = []
fns = []
slices = []
classical_fns = []
slices_classical = []
start = 0
start_classical = 0
classical_jacobians = self._get_classical_jacobian(i)
argnums = self._get_argnums(i)
for j, tape in enumerate(tapes):
if argnums is not None:
tape.trainable_params = argnums[j]
new_tapes, fn = transform(tape, *targs, **tkwargs)
execution_tapes.extend(new_tapes)
fns.append(fn)
end = start + len(new_tapes)
slices.append(slice(start, end))
start = end
if cotransform and classical_jacobians:
classical_fns.append(
partial(cotransform, cjac=classical_jacobians[j], tape=tape)
)
slices_classical.append(slice(start_classical, start_classical + 1))
start_classical += 1
if cotransform and classical_jacobians:
batch_postprocessing_classical = partial(
_batch_postprocessing, individual_fns=classical_fns, slices=slices_classical
)
batch_postprocessing_classical.__doc__ = _batch_postprocessing.__doc__
processing_fns_stack.append(batch_postprocessing_classical)
batch_postprocessing = partial(_batch_postprocessing, individual_fns=fns, slices=slices)
batch_postprocessing.__doc__ = _batch_postprocessing.__doc__
processing_fns_stack.append(batch_postprocessing)
# set input tapes for next iteration.
tapes = execution_tapes
postprocessing_fn = partial(
_apply_postprocessing_stack,
postprocessing_stack=processing_fns_stack,
)
postprocessing_fn.__doc__ = _apply_postprocessing_stack.__doc__
# Reset classical jacobians
return tuple(tapes), postprocessing_fn
_modules/pennylane/transforms/core/transform_program
Download Python script
Download Notebook
View on GitHub