Source code for pennylane.transforms.optimization.merge_amplitude_embedding
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transform for merging AmplitudeEmbedding gates in a quantum circuit."""
from copy import copy
from functools import lru_cache, partial
from typing import Sequence
import pennylane as qml
from pennylane import AmplitudeEmbedding
from pennylane.exceptions import DeviceError, TransformError
from pennylane.math import flatten, is_abstract, reshape
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms.core import transform
from pennylane.typing import PostprocessingFn
# pylint: disable=too-many-statements
@lru_cache
def _get_plxpr_merge_amplitude_embedding(): # pylint: disable=missing-docstring
try:
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr
from jax.extend.core import Jaxpr
from pennylane.capture import PlxprInterpreter
from pennylane.capture.base_interpreter import jaxpr_to_jaxpr
from pennylane.capture.primitives import cond_prim, measure_prim
from pennylane.operation import Operator
except ImportError: # pragma: no cover
return None, None
# pylint: disable=redefined-outer-name
class MergeAmplitudeEmbeddingInterpreter(PlxprInterpreter):
"""Plxpr Interpreter for merging AmplitudeEmbedding gates when program capture is enabled."""
def __init__(self):
self._env = {}
self.dynamic_wires_encountered = False
self.previous_ops = []
# * visited_wires (set): tracks all wires we have encountered so far.
# * dynamic_wires_found (bool): True if we have encountered any non-AmplitudeEmbedding
# ops that have dynamic wires so far.
# * ops_found (bool): True if we have encountered any non-AmplitudeEmbedding ops so far.
self.state = {"visited_wires": set(), "dynamic_wires_found": False, "ops_found": False}
self.input_wires, self.input_vectors, self.input_batch_size = [], [], []
def setup(self) -> None:
"""Setup the interpreter for a new evaluation."""
self.previous_ops = []
self.input_wires, self.input_vectors, self.input_batch_size = [], [], []
def cleanup(self) -> None:
"""Clean up the interpreter after evaluation."""
self.state = {"visited_wires": set(), "dynamic_wires_found": False, "ops_found": False}
def interpret_operation(self, op: Operator) -> None:
"""Interpret a PennyLane operation instance.
If the operator is not an ``AmplitudeEmbedding`` operator, it is added to the new operations list;
otherwise, the wires and parameters are stored for future usage.
Args:
op (Operator): a pennylane operator instance
Raises:
DeviceError: if the AmplitudeEmbedding operator's wires have already been used by other operations
Returns:
None: returns None
This method is only called when the operator's output is a dropped variable,
so the output will not affect later equations in the circuit.
"""
if not isinstance(op, AmplitudeEmbedding):
if any(is_abstract(w) for w in op.wires):
if self.input_wires:
self._merge_and_insert_at_the_start()
self.interpret_all_previous_ops()
self.state["dynamic_wires_found"] = True
self.state["ops_found"] = True
self.previous_ops.append(op)
self.state["visited_wires"] = self.state["visited_wires"].union(set(op.wires))
return
if self.state["dynamic_wires_found"]:
raise TransformError(
"Cannot apply qml.AmplitudeEmbedding after operators with dynamic wires as it "
"is indeterminable if the wires overlap."
)
if self.state["ops_found"] and any(is_abstract(w) for w in op.wires):
raise TransformError(
"Cannot apply qml.AmplitudeEmbedding with dynamic wires after other operators "
"as it is indeterminable if the wires overlap."
)
if len(self.state["visited_wires"].intersection(set(op.wires))) > 0:
raise TransformError(
"qml.AmplitudeEmbedding cannot be applied on wires already used by other operations."
)
self.input_wires.append(op.wires)
self.input_vectors.append(op.parameters[0])
self.input_batch_size.append(op.batch_size)
self.state["visited_wires"] = self.state["visited_wires"].union(set(op.wires))
def _merge_and_insert_at_the_start(self) -> None:
"""Merge the AmplitudeEmbedding gates and insert it at the beginning of the previously seen operations."""
final_wires = self.input_wires[0]
final_vector = self.input_vectors[0]
final_batch_size = self.input_batch_size[0]
for w, v, b in zip(
self.input_wires[1:],
self.input_vectors[1:],
self.input_batch_size[1:],
strict=True,
):
final_vector = final_vector[..., :, None] * v[..., None, :]
final_batch_size = final_batch_size or b
final_wires = final_wires + w
if final_batch_size:
final_vector = reshape(final_vector, (final_batch_size, -1))
else:
final_vector = flatten(final_vector)
with qml.capture.pause():
self.previous_ops.insert(0, qml.AmplitudeEmbedding(final_vector, wires=final_wires))
# Clear history of amplitude embedding gates since we've merged
self.input_wires, self.input_vectors, self.input_batch_size = [], [], []
def interpret_all_previous_ops(self) -> None:
"""Interpret all previous operations and clear the setup variables."""
for op in self.previous_ops:
super().interpret_operation(op)
self.previous_ops.clear()
# pylint: disable=too-many-branches
def eval(self, jaxpr: Jaxpr, consts: Sequence, *args) -> list:
"""Evaluate a jaxpr.
Args:
jaxpr (jax.extend.core.Jaxpr): the jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args (tuple[TensorLike]): The arguments for the jaxpr.
Returns:
list[TensorLike]: the results of the execution.
"""
self._env = {}
self.setup()
for arg, invar in zip(args, jaxpr.invars, strict=True):
self._env[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env[constvar] = const
for eqn in jaxpr.eqns:
custom_handler = self._primitive_registrations.get(eqn.primitive, None)
prim_type = getattr(eqn.primitive, "prim_type", "")
# Currently cannot merge through higher order primitives.
# Workaround is to merge and insert the merged gate before entering
# a higher order primitive.
if prim_type == "higher_order":
if len(self.input_wires) > 0:
self._merge_and_insert_at_the_start()
self.interpret_all_previous_ops()
if custom_handler:
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
elif prim_type == "operator":
outvals = self.interpret_operation_eqn(eqn)
elif prim_type == "measurement":
if len(self.input_wires) > 0:
self._merge_and_insert_at_the_start()
self.interpret_all_previous_ops()
outvals = self.interpret_measurement_eqn(eqn)
else:
invals = [self.read(invar) for invar in eqn.invars]
extra_args, params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*extra_args, *invals, **params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
self._env[outvar] = outval
# The following is needed because any operations inside self.previous_ops have not yet
# been applied.
if len(self.input_wires) > 0:
self._merge_and_insert_at_the_start()
self.interpret_all_previous_ops()
# Read the final result of the Jaxpr from the environment
outvals = []
for var in jaxpr.outvars:
outval = self.read(var)
if isinstance(outval, Operator):
outvals.append(super().interpret_operation(outval))
else:
outvals.append(outval)
self.cleanup()
self._env = {}
return outvals
# Overwrite the cond primitive so that visited wires can be correctly
# detected across the different branches.
@MergeAmplitudeEmbeddingInterpreter.register_primitive(cond_prim)
def _(self, *invals, jaxpr_branches, consts_slices, args_slice):
args = invals[args_slice]
new_jaxprs = []
new_consts = []
new_consts_slices = []
end_const_ind = len(jaxpr_branches)
# Store state before we begin to process the branches
# (create copies as to not accidently mutate the original state).
# We cannot just copy self.state because a shallow copy would not
# create a copy of `visited_wires`, which is a set.
# We cannot use deepcopy as `visited_wires` may have tracers inside,
# which have hashes specific to the instance. Copying these will cause
# the dynamic wires in the original and copy to be different.
initial_wires = copy(self.state["visited_wires"])
curr_wires = copy(self.state["visited_wires"])
initial_dynamic_wires_found = self.state["dynamic_wires_found"]
curr_dynamic_wires_found = self.state["dynamic_wires_found"]
initial_ops_found = self.state["ops_found"]
curr_ops_found = self.state["ops_found"]
for const_slice, jaxpr in zip(consts_slices, jaxpr_branches):
consts = invals[const_slice]
if jaxpr is None:
new_jaxprs.append(None)
new_consts_slices.append(slice(0, 0))
else:
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
# Update state so far so collisions with
# newly seen states from the branches continue to be
# detected after the cond
curr_wires |= self.state["visited_wires"]
curr_dynamic_wires_found = self.state["dynamic_wires_found"]
curr_ops_found = self.state["ops_found"]
# Reset state for the next branch so we don't get false positive collisions
# (copy so if state mutates we preserved true initial state)
self.state = {
"visited_wires": copy(initial_wires),
"dynamic_wires_found": initial_dynamic_wires_found,
"ops_found": initial_ops_found,
}
new_jaxprs.append(new_jaxpr.jaxpr)
new_consts.extend(new_jaxpr.consts)
new_consts_slices.append(
slice(end_const_ind, end_const_ind + len(new_jaxpr.consts))
)
end_const_ind += len(new_jaxpr.consts)
# Reset state to all updates from all branches in the cond
self.state = {
"visited_wires": curr_wires,
"dynamic_wires_found": curr_dynamic_wires_found,
"ops_found": curr_ops_found,
}
new_args_slice = slice(end_const_ind, None)
return cond_prim.bind(
*invals[: len(jaxpr_branches)],
*new_consts,
*args,
jaxpr_branches=new_jaxprs,
consts_slices=new_consts_slices,
args_slice=new_args_slice,
)
@MergeAmplitudeEmbeddingInterpreter.register_primitive(measure_prim)
def _(self, *invals, **params):
# Make sure to record that we have visited the wires on this measurement
# in order to be able to detect potential wire collisions with future AE gates
self.state["visited_wires"] = self.state["visited_wires"].union(set(invals))
self.state["dynamic_wires_found"] = any(is_abstract(w) for w in invals)
self.state["ops_found"] = True
# pylint: disable=protected-access
if len(self.input_wires) > 0:
self._merge_and_insert_at_the_start()
self.interpret_all_previous_ops()
_, params = measure_prim.get_bind_params(params)
return measure_prim.bind(*invals, **params)
def merge_amplitude_embedding_plxpr_to_plxpr(jaxpr, consts, _, __, *args):
"""Function for applying the ``merge_amplitude_embedding`` transform on plxpr."""
interpreter = MergeAmplitudeEmbeddingInterpreter()
def wrapper(*inner_args):
return interpreter.eval(jaxpr, consts, *inner_args)
return make_jaxpr(wrapper)(*args)
return MergeAmplitudeEmbeddingInterpreter, merge_amplitude_embedding_plxpr_to_plxpr
MergeAmplitudeEmbeddingInterpreter, merge_amplitude_embedding_plxpr_to_plxpr = (
_get_plxpr_merge_amplitude_embedding()
)
[docs]
@partial(transform, plxpr_transform=merge_amplitude_embedding_plxpr_to_plxpr)
def merge_amplitude_embedding(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]:
r"""Quantum function transform to combine amplitude embedding templates that act on different qubits.
Args:
tape (QNode or QuantumTape or Callable): A quantum circuit.
Returns:
qnode (QNode) or quantum function (Callable) or tuple[List[.QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`.
**Example**
>>> dev = qml.device('default.qubit', wires=4)
You can apply the transform directly on :class:`QNode`:
.. code-block:: python
@qml.transforms.merge_amplitude_embedding
@qml.qnode(device=dev)
def circuit():
qml.CNOT(wires = [0,1])
qml.AmplitudeEmbedding([0,1], wires = 2)
qml.AmplitudeEmbedding([0,1], wires = 3)
return qml.state()
>>> circuit()
[1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
.. details::
:title: Usage Details
You can also apply it on quantum function.
.. code-block:: python
def qfunc():
qml.CNOT(wires = [0,1])
qml.AmplitudeEmbedding([0,1], wires = 2)
qml.AmplitudeEmbedding([0,1], wires = 3)
return qml.state()
The circuit before compilation will not work because of using two amplitude embedding.
Using the transformation we can join the different amplitude embedding into a single one:
>>> optimized_qfunc = qml.transforms.merge_amplitude_embedding(qfunc)
>>> optimized_qnode = qml.QNode(optimized_qfunc, dev)
>>> print(qml.draw(optimized_qnode)())
0: ─╭●──────────────────────┤ State
1: ─╰X──────────────────────┤ State
2: ─╭AmplitudeEmbedding(M0)─┤ State
3: ─╰AmplitudeEmbedding(M0)─┤ State
M0 =
[0.+0.j 0.+0.j 0.+0.j 1.+0.j]
"""
new_operations = []
visited_wires = set()
input_wires, input_vectors, input_batch_size = [], [], []
for current_gate in tape.operations:
wires_set = set(current_gate.wires)
# Check if the current gate is an AmplitudeEmbedding.
if not isinstance(current_gate, AmplitudeEmbedding):
new_operations.append(current_gate)
visited_wires = visited_wires.union(wires_set)
continue
# Check the qubits have not been used.
if len(visited_wires.intersection(wires_set)) > 0:
raise DeviceError(
f"Operation {current_gate.name} cannot be used after other Operation applied in the same qubit "
)
input_wires.append(current_gate.wires)
input_vectors.append(current_gate.parameters[0])
input_batch_size.append(current_gate.batch_size)
visited_wires = visited_wires.union(wires_set)
if len(input_wires) > 0:
final_wires = input_wires[0]
final_vector = input_vectors[0]
final_batch_size = input_batch_size[0]
# Merge all parameters and qubits into a single one.
for w, v, b in zip(input_wires[1:], input_vectors[1:], input_batch_size[1:]):
final_vector = final_vector[..., :, None] * v[..., None, :]
final_batch_size = final_batch_size or b
final_wires = final_wires + w
if final_batch_size:
final_vector = reshape(final_vector, (final_batch_size, -1))
else:
final_vector = flatten(final_vector)
with QueuingManager.stop_recording():
new_operations.insert(0, AmplitudeEmbedding(final_vector, wires=final_wires))
new_tape = tape.copy(operations=new_operations)
def null_postprocessing(results):
"""A postprocesing function returned by a transform that only converts the batch of results
into a result for a single ``QuantumTape``.
"""
return results[0]
return [new_tape], null_postprocessing
_modules/pennylane/transforms/optimization/merge_amplitude_embedding
Download Python script
Download Notebook
View on GitHub