Source code for pennylane.transforms.optimization.single_qubit_fusion

# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transform for fusing sequences of single-qubit gates."""
# pylint: disable=too-many-branches

from functools import lru_cache, partial
from typing import Optional

import pennylane as qml
from pennylane.ops.qubit import Rot
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.typing import PostprocessingFn, TensorLike

from .optimization_utils import find_next_gate, fuse_rot_angles


@lru_cache
def _get_plxpr_single_qubit_fusion():  # pylint: disable=missing-function-docstring,too-many-statements
    try:
        # pylint: disable=import-outside-toplevel
        from jax import make_jaxpr

        from pennylane.capture import PlxprInterpreter
        from pennylane.capture.primitives import measure_prim
        from pennylane.operation import Operator

    except ImportError:  # pragma: no cover
        return None, None

    # pylint: disable=redefined-outer-name

    class SingleQubitFusionInterpreter(PlxprInterpreter):
        """Plxpr Interpreter for applying the ``single_qubit_fusion`` transform to callables or jaxpr
        when program capture is enabled.

        .. note::

            In the process of transforming plxpr, this interpreter may reorder operations that do
            not share any wires. This will not impact the correctness of the circuit.
        """

        def __init__(self, atol: Optional[float] = 1e-8, exclude_gates: Optional[list[str]] = None):
            """Initialize the interpreter."""
            self.atol = atol
            self.exclude_gates = set(exclude_gates) if exclude_gates is not None else set()
            self.previous_ops = {}
            self._env = {}

        def setup(self) -> None:
            """Initialize the instance before interpreting equations."""
            self.previous_ops = {}
            self._env.clear()

        def cleanup(self) -> None:
            """Clean up the instance after interpreting equations."""
            self.previous_ops.clear()
            self._env.clear()

        def _retrieve_prev_ops_same_wire(self, op: Operator):
            """Retrieve and remove all previous operations that act on the same wire(s) as the given operation."""

            # The order might not be deterministic if the wires (keys) are abstract.
            # However, this only impacts operators without any shared wires,
            # which does not affect the correctness of the result.

            # If the wires are concrete, the order of the keys (wires)
            # and thus the values should reflect the order in which they are iterated
            # because in Python 3.7+ dictionaries maintain insertion order.
            previous_ops_on_wires = {
                w: self.previous_ops.pop(w) for w in op.wires if w in self.previous_ops
            }

            return previous_ops_on_wires.values()

        def _handle_non_fusible_op(self, op: Operator) -> list:
            """Handle an operation that cannot be fused into a Rot gate."""

            previous_ops_on_wires = self._retrieve_prev_ops_same_wire(op)

            res = []
            for prev_op in previous_ops_on_wires:
                # pylint: disable=protected-access
                rot = qml.Rot._primitive.impl(
                    *qml.math.stack(prev_op.single_qubit_rot_angles()), wires=prev_op.wires
                )
                res.append(super().interpret_operation(rot))

            res.append(super().interpret_operation(op))

            return res

        def _handle_fusible_op(self, op: Operator, cumulative_angles: TensorLike) -> list:
            """Handle an operation that can be potentially fused into a Rot gate."""

            # Only single-qubit gates are considered for fusion
            op_wire = op.wires[0]

            prev_op = self.previous_ops.get(op_wire)
            if prev_op is None:
                self.previous_ops[op_wire] = op
                return []

            prev_op_angles = qml.math.stack(prev_op.single_qubit_rot_angles())
            cumulative_angles = fuse_rot_angles(prev_op_angles, cumulative_angles)

            if (
                qml.math.is_abstract(cumulative_angles)
                or qml.math.requires_grad(cumulative_angles)
                or not qml.math.allclose(
                    qml.math.stack(
                        [cumulative_angles[0] + cumulative_angles[2], cumulative_angles[1]]
                    ),
                    0.0,
                    atol=self.atol,
                    rtol=0,
                )
            ):
                # pylint: disable=protected-access
                new_rot = qml.Rot._primitive.impl(*cumulative_angles, wires=op.wires)
                self.previous_ops[op_wire] = new_rot
            else:
                del self.previous_ops[op_wire]

            return []

        def interpret_operation(self, op: Operator):
            """Interpret a PennyLane operation instance."""

            # Operators like Identity() have no wires
            if len(op.wires) == 0:
                return super().interpret_operation(op)

            # We interpret directly if the gate is explicitly excluded,
            # after interpreting all previous operations on the same wires.
            if op.name in self.exclude_gates:

                previous_ops_on_wires = self._retrieve_prev_ops_same_wire(op)

                for prev_op in previous_ops_on_wires:
                    super().interpret_operation(prev_op)

                return super().interpret_operation(op)

            try:
                cumulative_angles = qml.math.stack(op.single_qubit_rot_angles())
            except (NotImplementedError, AttributeError):
                return self._handle_non_fusible_op(op)

            return self._handle_fusible_op(op, cumulative_angles)

        def interpret_all_previous_ops(self) -> None:
            """Interpret all previous operations stored in the instance."""

            for op in self.previous_ops.values():
                super().interpret_operation(op)

            self.previous_ops.clear()

        def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
            """Evaluate a jaxpr.

            Args:
                jaxpr (jax.core.Jaxpr): the jaxpr to evaluate
                consts (list[TensorLike]): the constant variables for the jaxpr
                *args (tuple[TensorLike]): The arguments for the jaxpr.

            Returns:
                list[TensorLike]: the results of the execution.
            """

            self.setup()

            for arg, invar in zip(args, jaxpr.invars, strict=True):
                self._env[invar] = arg
            for const, constvar in zip(consts, jaxpr.constvars, strict=True):
                self._env[constvar] = const

            for eqn in jaxpr.eqns:

                prim_type = getattr(eqn.primitive, "prim_type", "")

                custom_handler = self._primitive_registrations.get(eqn.primitive, None)
                if custom_handler:
                    self.interpret_all_previous_ops()
                    invals = [self.read(invar) for invar in eqn.invars]
                    outvals = custom_handler(self, *invals, **eqn.params)
                elif prim_type == "operator":
                    outvals = self.interpret_operation_eqn(eqn)
                elif prim_type == "measurement":
                    self.interpret_all_previous_ops()
                    outvals = self.interpret_measurement_eqn(eqn)
                else:
                    invals = [self.read(invar) for invar in eqn.invars]
                    subfuns, params = eqn.primitive.get_bind_params(eqn.params)
                    outvals = eqn.primitive.bind(*subfuns, *invals, **params)

                if not eqn.primitive.multiple_results:
                    outvals = [outvals]
                for outvar, outval in zip(eqn.outvars, outvals, strict=True):
                    self._env[outvar] = outval

            self.interpret_all_previous_ops()

            outvals = []
            for var in jaxpr.outvars:
                outval = self.read(var)
                if isinstance(outval, Operator):
                    outvals.append(super().interpret_operation(outval))
                else:
                    outvals.append(outval)

            self.cleanup()
            return outvals

    @SingleQubitFusionInterpreter.register_primitive(measure_prim)
    def _(_, *invals, **params):
        subfuns, params = measure_prim.get_bind_params(params)
        return measure_prim.bind(*subfuns, *invals, **params)

    def single_qubit_fusion_plxpr_to_plxpr(jaxpr, consts, targs, tkwargs, *args):
        """Function for applying the ``single_qubit_fusion`` transform on plxpr."""

        interpreter = SingleQubitFusionInterpreter(*targs, **tkwargs)

        def wrapper(*inner_args):
            return interpreter.eval(jaxpr, consts, *inner_args)

        return make_jaxpr(wrapper)(*args)

    return SingleQubitFusionInterpreter, single_qubit_fusion_plxpr_to_plxpr


SingleQubitFusionInterpreter, single_qubit_plxpr_to_plxpr = _get_plxpr_single_qubit_fusion()


[docs]@partial(transform, plxpr_transform=single_qubit_plxpr_to_plxpr) def single_qubit_fusion( tape: QuantumScript, atol: Optional[float] = 1e-8, exclude_gates: Optional[list[str]] = None ) -> tuple[QuantumScriptBatch, PostprocessingFn]: r"""Quantum function transform to fuse together groups of single-qubit operations into a general single-qubit unitary operation (:class:`~.Rot`). Fusion is performed only between gates that implement the property ``single_qubit_rot_angles``. Any sequence of two or more single-qubit gates (on the same qubit) with that property defined will be fused into one ``Rot``. Args: tape (QNode or QuantumTape or Callable): A quantum circuit. atol (float): An absolute tolerance for which to apply a rotation after fusion. After fusion of gates, if the fused angles :math:`\theta` are such that :math:`|\theta|\leq \text{atol}`, no rotation gate will be applied. exclude_gates (None or list[str]): A list of gates that should be excluded from full fusion. If set to ``None``, all single-qubit gates that can be fused will be fused. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], Callable]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. **Example** >>> dev = qml.device('default.qubit', wires=1) You can apply the transform directly on :class:`QNode`: .. code-block:: python @qml.transforms.single_qubit_fusion @qml.qnode(device=dev) def qfunc(r1, r2): qml.Hadamard(wires=0) qml.Rot(*r1, wires=0) qml.Rot(*r2, wires=0) qml.RZ(r1[0], wires=0) qml.RZ(r2[0], wires=0) return qml.expval(qml.X(0)) The single qubit gates are fused before execution. .. note:: The fused angles between two sets of rotation angles are not always defined uniquely because Euler angles are not unique for some rotations. ``single_qubit_fusion`` makes a particular choice in this case. .. note:: The order of the gates resulting from the fusion may be different depending on whether program capture is enabled or not. This only impacts the order of operations that do not share any wires, so the correctness of the circuit is not affected. .. warning:: This function is not differentiable everywhere. It has singularities for specific input rotation angles, where the derivative will be NaN. .. warning:: This function is numerically unstable at its singular points. It is recommended to use it with 64-bit floating point precision. .. details:: :title: Usage Details Consider the following quantum function. .. code-block:: python def qfunc(r1, r2): qml.Hadamard(wires=0) qml.Rot(*r1, wires=0) qml.Rot(*r2, wires=0) qml.RZ(r1[0], wires=0) qml.RZ(r2[0], wires=0) return qml.expval(qml.X(0)) The circuit before optimization: >>> qnode = qml.QNode(qfunc, dev) >>> print(qml.draw(qnode)([0.1, 0.2, 0.3], [0.4, 0.5, 0.6])) 0: ──H──Rot(0.1, 0.2, 0.3)──Rot(0.4, 0.5, 0.6)──RZ(0.1)──RZ(0.4)──┤ ⟨X⟩ Full single-qubit gate fusion allows us to collapse this entire sequence into a single ``qml.Rot`` rotation gate. >>> optimized_qfunc = qml.transforms.single_qubit_fusion(qfunc) >>> optimized_qnode = qml.QNode(optimized_qfunc, dev) >>> print(qml.draw(optimized_qnode)([0.1, 0.2, 0.3], [0.4, 0.5, 0.6])) 0: ──Rot(3.57, 2.09, 2.05)──┤ ⟨X⟩ .. details:: :title: Derivation :href: derivation The matrix for an individual rotation is given by .. math:: R(\phi_j,\theta_j,\omega_j) &= \begin{bmatrix} e^{-i(\phi_j+\omega_j)/2}\cos(\theta_j/2) & -e^{i(\phi_j-\omega_j)/2}\sin(\theta_j/2)\\ e^{-i(\phi_j-\omega_j)/2}\sin(\theta_j/2) & e^{i(\phi_j+\omega_j)/2}\cos(\theta_j/2) \end{bmatrix}\\ &= \begin{bmatrix} e^{-i\alpha_j}c_j & -e^{i\beta_j}s_j \\ e^{-i\beta_j}s_j & e^{i\alpha_j}c_j \end{bmatrix}, where we introduced abbreviations :math:`\alpha_j,\beta_j=\frac{\phi_j\pm\omega_j}{2}`, :math:`c_j=\cos(\theta_j / 2)` and :math:`s_j=\sin(\theta_j / 2)` for notational brevity. The upper left entry of the matrix product :math:`R(\phi_2,\theta_2,\omega_2)R(\phi_1,\theta_1,\omega_1)` reads .. math:: x = e^{-i(\alpha_2+\alpha_1)} c_2 c_1 - e^{i(\beta_2-\beta_1)} s_2 s_1 and should equal :math:`e^{-i\alpha_f}c_f` for the fused rotation angles. This means that we can obtain :math:`\theta_f` from the magnitude of the matrix product entry above, choosing :math:`c_f=\cos(\theta_f / 2)` to be non-negative: .. math:: c_f = |x| &= \left| e^{-i(\alpha_2+\alpha_1)} c_2 c_1 -e^{i(\beta_2-\beta_1)} s_2 s_1 \right| \\ &= \sqrt{c_1^2 c_2^2 + s_1^2 s_2^2 - 2 c_1 c_2 s_1 s_2 \cos(\omega_1 + \phi_2)}. Now we again make a choice and pick :math:`\theta_f` to be non-negative: .. math:: \theta_f = 2\arccos(|x|). We can also extract the angle combination :math:`\alpha_f` from :math:`x` via :math:`\operatorname{arg}(x)`, which can be readily computed with :math:`\arctan`: .. math:: \alpha_f = -\arctan\left( \frac{-c_1c_2\sin(\alpha_1+\alpha_2)-s_1s_2\sin(\beta_2-\beta_1)} {c_1c_2\cos(\alpha_1+\alpha_2)-s_1s_2\cos(\beta_2-\beta_1)} \right). We can use the standard numerical function ``arctan2``, which computes :math:`\arctan(x_1/x_2)` from :math:`x_1` and :math:`x_2` while handling special points suitably, to obtain the argument of the underlying complex number :math:`x_2 + x_1 i`. Finally, to obtain :math:`\beta_f`, we need a second element of the matrix product from above. We compute the lower-left entry to be .. math:: y = e^{-i(\beta_2+\alpha_1)} s_2 c_1 + e^{i(\alpha_2-\beta_1)} c_2 s_1, which should equal :math:`e^{-i \beta_f}s_f`. From this, we can compute .. math:: \beta_f = -\arctan\left( \frac{-c_1s_2\sin(\alpha_1+\beta_2)+s_1c_2\sin(\alpha_2-\beta_1)} {c_1s_2\cos(\alpha_1+\beta_2)+s_1c_2\cos(\alpha_2-\beta_1)} \right). From this, we may extract .. math:: \phi_f = \alpha_f + \beta_f\qquad \omega_f = \alpha_f - \beta_f and are done. **Special cases:** There are a number of special cases for which we can skip the computation above and can combine rotation angles directly. 1. If :math:`\omega_1=\phi_2=0`, we can simply merge the ``RY`` rotation angles :math:`\theta_j` and obtain :math:`(\phi_1, \theta_1+\theta_2, \omega_2)`. 2. If :math:`\theta_j=0`, we can merge the two ``RZ`` rotations of the same ``Rot`` and obtain :math:`(\phi_1+\omega_1+\phi_2, \theta_2, \omega_2)` or :math:`(\phi_1, \theta_1, \omega_1+\phi_2+\omega_2)`. If both ``RY`` angles vanish we get :math:`(\phi_1+\omega_1+\phi_2+\omega_2, 0, 0)`. Note that this optimization is not performed for differentiable input parameters, in order to maintain differentiability. **Mathematical properties:** All functions above are well-defined on the domain we are using them on, if we handle :math:`\arctan` via standard numerical implementations such as ``np.arctan2``. Based on the choices we made in the derivation above, the fused angles will lie in the intervals .. math:: \phi_f, \omega_f \in [-\pi, \pi],\quad \theta_f \in [0, \pi]. Close to the boundaries of these intervals, ``single_qubit_fusion`` exhibits discontinuities, depending on the combination of input angles. These discontinuities also lead to singular (non-differentiable) points as discussed below. **Differentiability:** The function derived above is differentiable almost everywhere. In particular, there are two problematic scenarios at which the derivative is not defined. First, the square root is not differentiable at :math:`0`, making all input angles with :math:`|x|=0` singular. Second, :math:`\arccos` is not differentiable at :math:`1`, making all input angles with :math:`|x|=1` singular. """ # Make a working copy of the list to traverse list_copy = tape.operations.copy() new_operations = [] while len(list_copy) > 0: current_gate = list_copy[0] # If the gate should be excluded, queue it and move on regardless # of fusion potential if exclude_gates is not None: if current_gate.name in exclude_gates: new_operations.append(current_gate) list_copy.pop(0) continue # Look for single_qubit_rot_angles; if not available, queue and move on. # If available, grab the angles and try to fuse. try: cumulative_angles = qml.math.stack(current_gate.single_qubit_rot_angles()) except (NotImplementedError, AttributeError): new_operations.append(current_gate) list_copy.pop(0) continue # Find the next gate that acts on at least one of the same wires next_gate_idx = find_next_gate(current_gate.wires, list_copy[1:]) if next_gate_idx is None: new_operations.append(current_gate) list_copy.pop(0) continue # Before entering the loop, we check to make sure the next gate is not in the # exclusion list. If it is, we should apply the original gate as-is, and not the # Rot version (example in test test_single_qubit_fusion_exclude_gates). if exclude_gates is not None: next_gate = list_copy[next_gate_idx + 1] if next_gate.name in exclude_gates: new_operations.append(current_gate) list_copy.pop(0) continue # Loop as long as a valid next gate exists while next_gate_idx is not None: next_gate = list_copy[next_gate_idx + 1] # Check first if the next gate is in the exclusion list if exclude_gates is not None: if next_gate.name in exclude_gates: break # Try to extract the angles; since the Rot angles are implemented # solely for single-qubit gates, and we used find_next_gate to obtain # the gate in question, only valid single-qubit gates on the same # wire as the current gate will be fused. try: next_gate_angles = qml.math.stack(next_gate.single_qubit_rot_angles()) except (NotImplementedError, AttributeError): break cumulative_angles = fuse_rot_angles(cumulative_angles, next_gate_angles) list_copy.pop(next_gate_idx + 1) next_gate_idx = find_next_gate(current_gate.wires, list_copy[1:]) # If we are tracing/jitting or differentiating, don't perform any conditional checks and # apply the rotation regardless of the angles. # If not tracing or differentiating, check whether total rotation is trivial by checking # if the RY angle and the sum of the RZ angles are close to 0 if ( qml.math.is_abstract(cumulative_angles) or qml.math.requires_grad(cumulative_angles) or not qml.math.allclose( qml.math.stack([cumulative_angles[0] + cumulative_angles[2], cumulative_angles[1]]), 0.0, atol=atol, rtol=0, ) ): with QueuingManager.stop_recording(): new_operations.append(Rot(*cumulative_angles, wires=current_gate.wires)) # Remove the starting gate from the list list_copy.pop(0) new_tape = tape.copy(operations=new_operations) def null_postprocessing(results): """A postprocesing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """ return results[0] return [new_tape], null_postprocessing