Source code for pennylane.transforms.optimization.single_qubit_fusion
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Transform for fusing sequences of single-qubit gates."""# pylint: disable=too-many-branchesfromfunctoolsimportlru_cache,partialfromtypingimportOptionalimportpennylaneasqmlfrompennylane.ops.qubitimportRotfrompennylane.queuingimportQueuingManagerfrompennylane.tapeimportQuantumScript,QuantumScriptBatchfrompennylane.transformsimporttransformfrompennylane.typingimportPostprocessingFn,TensorLikefrom.optimization_utilsimportfind_next_gate,fuse_rot_angles@lru_cachedef_get_plxpr_single_qubit_fusion():# pylint: disable=missing-function-docstring,too-many-statementstry:# pylint: disable=import-outside-toplevelfromjaximportmake_jaxprfrompennylane.captureimportPlxprInterpreterfrompennylane.capture.primitivesimportmeasure_primfrompennylane.operationimportOperatorexceptImportError:# pragma: no coverreturnNone,None# pylint: disable=redefined-outer-nameclassSingleQubitFusionInterpreter(PlxprInterpreter):"""Plxpr Interpreter for applying the ``single_qubit_fusion`` transform to callables or jaxpr when program capture is enabled. .. note:: In the process of transforming plxpr, this interpreter may reorder operations that do not share any wires. This will not impact the correctness of the circuit. """def__init__(self,atol:Optional[float]=1e-8,exclude_gates:Optional[list[str]]=None):"""Initialize the interpreter."""self.atol=atolself.exclude_gates=set(exclude_gates)ifexclude_gatesisnotNoneelseset()self.previous_ops={}self._env={}defsetup(self)->None:"""Initialize the instance before interpreting equations."""self.previous_ops={}self._env.clear()defcleanup(self)->None:"""Clean up the instance after interpreting equations."""self.previous_ops.clear()self._env.clear()def_retrieve_prev_ops_same_wire(self,op:Operator):"""Retrieve and remove all previous operations that act on the same wire(s) as the given operation."""# The order might not be deterministic if the wires (keys) are abstract.# However, this only impacts operators without any shared wires,# which does not affect the correctness of the result.# If the wires are concrete, the order of the keys (wires)# and thus the values should reflect the order in which they are iterated# because in Python 3.7+ dictionaries maintain insertion order.previous_ops_on_wires={w:self.previous_ops.pop(w)forwinop.wiresifwinself.previous_ops}returnprevious_ops_on_wires.values()def_handle_non_fusible_op(self,op:Operator)->list:"""Handle an operation that cannot be fused into a Rot gate."""previous_ops_on_wires=self._retrieve_prev_ops_same_wire(op)res=[]forprev_opinprevious_ops_on_wires:# pylint: disable=protected-accessrot=qml.Rot._primitive.impl(*qml.math.stack(prev_op.single_qubit_rot_angles()),wires=prev_op.wires)res.append(super().interpret_operation(rot))res.append(super().interpret_operation(op))returnresdef_handle_fusible_op(self,op:Operator,cumulative_angles:TensorLike)->list:"""Handle an operation that can be potentially fused into a Rot gate."""# Only single-qubit gates are considered for fusionop_wire=op.wires[0]prev_op=self.previous_ops.get(op_wire)ifprev_opisNone:self.previous_ops[op_wire]=opreturn[]prev_op_angles=qml.math.stack(prev_op.single_qubit_rot_angles())cumulative_angles=fuse_rot_angles(prev_op_angles,cumulative_angles)if(qml.math.is_abstract(cumulative_angles)orqml.math.requires_grad(cumulative_angles)ornotqml.math.allclose(qml.math.stack([cumulative_angles[0]+cumulative_angles[2],cumulative_angles[1]]),0.0,atol=self.atol,rtol=0,)):# pylint: disable=protected-accessnew_rot=qml.Rot._primitive.impl(*cumulative_angles,wires=op.wires)self.previous_ops[op_wire]=new_rotelse:delself.previous_ops[op_wire]return[]definterpret_operation(self,op:Operator):"""Interpret a PennyLane operation instance."""# Operators like Identity() have no wiresiflen(op.wires)==0:returnsuper().interpret_operation(op)# We interpret directly if the gate is explicitly excluded,# after interpreting all previous operations on the same wires.ifop.nameinself.exclude_gates:previous_ops_on_wires=self._retrieve_prev_ops_same_wire(op)forprev_opinprevious_ops_on_wires:super().interpret_operation(prev_op)returnsuper().interpret_operation(op)try:cumulative_angles=qml.math.stack(op.single_qubit_rot_angles())except(NotImplementedError,AttributeError):returnself._handle_non_fusible_op(op)returnself._handle_fusible_op(op,cumulative_angles)definterpret_all_previous_ops(self)->None:"""Interpret all previous operations stored in the instance."""foropinself.previous_ops.values():super().interpret_operation(op)self.previous_ops.clear()defeval(self,jaxpr:"jax.core.Jaxpr",consts:list,*args)->list:"""Evaluate a jaxpr. Args: jaxpr (jax.core.Jaxpr): the jaxpr to evaluate consts (list[TensorLike]): the constant variables for the jaxpr *args (tuple[TensorLike]): The arguments for the jaxpr. Returns: list[TensorLike]: the results of the execution. """self.setup()forarg,invarinzip(args,jaxpr.invars,strict=True):self._env[invar]=argforconst,constvarinzip(consts,jaxpr.constvars,strict=True):self._env[constvar]=constforeqninjaxpr.eqns:prim_type=getattr(eqn.primitive,"prim_type","")custom_handler=self._primitive_registrations.get(eqn.primitive,None)ifcustom_handler:self.interpret_all_previous_ops()invals=[self.read(invar)forinvarineqn.invars]outvals=custom_handler(self,*invals,**eqn.params)elifprim_type=="operator":outvals=self.interpret_operation_eqn(eqn)elifprim_type=="measurement":self.interpret_all_previous_ops()outvals=self.interpret_measurement_eqn(eqn)else:invals=[self.read(invar)forinvarineqn.invars]subfuns,params=eqn.primitive.get_bind_params(eqn.params)outvals=eqn.primitive.bind(*subfuns,*invals,**params)ifnoteqn.primitive.multiple_results:outvals=[outvals]foroutvar,outvalinzip(eqn.outvars,outvals,strict=True):self._env[outvar]=outvalself.interpret_all_previous_ops()outvals=[]forvarinjaxpr.outvars:outval=self.read(var)ifisinstance(outval,Operator):outvals.append(super().interpret_operation(outval))else:outvals.append(outval)self.cleanup()returnoutvals@SingleQubitFusionInterpreter.register_primitive(measure_prim)def_(_,*invals,**params):subfuns,params=measure_prim.get_bind_params(params)returnmeasure_prim.bind(*subfuns,*invals,**params)defsingle_qubit_fusion_plxpr_to_plxpr(jaxpr,consts,targs,tkwargs,*args):"""Function for applying the ``single_qubit_fusion`` transform on plxpr."""interpreter=SingleQubitFusionInterpreter(*targs,**tkwargs)defwrapper(*inner_args):returninterpreter.eval(jaxpr,consts,*inner_args)returnmake_jaxpr(wrapper)(*args)returnSingleQubitFusionInterpreter,single_qubit_fusion_plxpr_to_plxprSingleQubitFusionInterpreter,single_qubit_plxpr_to_plxpr=_get_plxpr_single_qubit_fusion()
[docs]@partial(transform,plxpr_transform=single_qubit_plxpr_to_plxpr)defsingle_qubit_fusion(tape:QuantumScript,atol:Optional[float]=1e-8,exclude_gates:Optional[list[str]]=None)->tuple[QuantumScriptBatch,PostprocessingFn]:r"""Quantum function transform to fuse together groups of single-qubit operations into a general single-qubit unitary operation (:class:`~.Rot`). Fusion is performed only between gates that implement the property ``single_qubit_rot_angles``. Any sequence of two or more single-qubit gates (on the same qubit) with that property defined will be fused into one ``Rot``. Args: tape (QNode or QuantumTape or Callable): A quantum circuit. atol (float): An absolute tolerance for which to apply a rotation after fusion. After fusion of gates, if the fused angles :math:`\theta` are such that :math:`|\theta|\leq \text{atol}`, no rotation gate will be applied. exclude_gates (None or list[str]): A list of gates that should be excluded from full fusion. If set to ``None``, all single-qubit gates that can be fused will be fused. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], Callable]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. **Example** >>> dev = qml.device('default.qubit', wires=1) You can apply the transform directly on :class:`QNode`: .. code-block:: python @qml.transforms.single_qubit_fusion @qml.qnode(device=dev) def qfunc(r1, r2): qml.Hadamard(wires=0) qml.Rot(*r1, wires=0) qml.Rot(*r2, wires=0) qml.RZ(r1[0], wires=0) qml.RZ(r2[0], wires=0) return qml.expval(qml.X(0)) The single qubit gates are fused before execution. .. note:: The fused angles between two sets of rotation angles are not always defined uniquely because Euler angles are not unique for some rotations. ``single_qubit_fusion`` makes a particular choice in this case. .. note:: The order of the gates resulting from the fusion may be different depending on whether program capture is enabled or not. This only impacts the order of operations that do not share any wires, so the correctness of the circuit is not affected. .. warning:: This function is not differentiable everywhere. It has singularities for specific input rotation angles, where the derivative will be NaN. .. warning:: This function is numerically unstable at its singular points. It is recommended to use it with 64-bit floating point precision. .. details:: :title: Usage Details Consider the following quantum function. .. code-block:: python def qfunc(r1, r2): qml.Hadamard(wires=0) qml.Rot(*r1, wires=0) qml.Rot(*r2, wires=0) qml.RZ(r1[0], wires=0) qml.RZ(r2[0], wires=0) return qml.expval(qml.X(0)) The circuit before optimization: >>> qnode = qml.QNode(qfunc, dev) >>> print(qml.draw(qnode)([0.1, 0.2, 0.3], [0.4, 0.5, 0.6])) 0: ──H──Rot(0.1, 0.2, 0.3)──Rot(0.4, 0.5, 0.6)──RZ(0.1)──RZ(0.4)──┤ ⟨X⟩ Full single-qubit gate fusion allows us to collapse this entire sequence into a single ``qml.Rot`` rotation gate. >>> optimized_qfunc = qml.transforms.single_qubit_fusion(qfunc) >>> optimized_qnode = qml.QNode(optimized_qfunc, dev) >>> print(qml.draw(optimized_qnode)([0.1, 0.2, 0.3], [0.4, 0.5, 0.6])) 0: ──Rot(3.57, 2.09, 2.05)──┤ ⟨X⟩ .. details:: :title: Derivation :href: derivation The matrix for an individual rotation is given by .. math:: R(\phi_j,\theta_j,\omega_j) &= \begin{bmatrix} e^{-i(\phi_j+\omega_j)/2}\cos(\theta_j/2) & -e^{i(\phi_j-\omega_j)/2}\sin(\theta_j/2)\\ e^{-i(\phi_j-\omega_j)/2}\sin(\theta_j/2) & e^{i(\phi_j+\omega_j)/2}\cos(\theta_j/2) \end{bmatrix}\\ &= \begin{bmatrix} e^{-i\alpha_j}c_j & -e^{i\beta_j}s_j \\ e^{-i\beta_j}s_j & e^{i\alpha_j}c_j \end{bmatrix}, where we introduced abbreviations :math:`\alpha_j,\beta_j=\frac{\phi_j\pm\omega_j}{2}`, :math:`c_j=\cos(\theta_j / 2)` and :math:`s_j=\sin(\theta_j / 2)` for notational brevity. The upper left entry of the matrix product :math:`R(\phi_2,\theta_2,\omega_2)R(\phi_1,\theta_1,\omega_1)` reads .. math:: x = e^{-i(\alpha_2+\alpha_1)} c_2 c_1 - e^{i(\beta_2-\beta_1)} s_2 s_1 and should equal :math:`e^{-i\alpha_f}c_f` for the fused rotation angles. This means that we can obtain :math:`\theta_f` from the magnitude of the matrix product entry above, choosing :math:`c_f=\cos(\theta_f / 2)` to be non-negative: .. math:: c_f = |x| &= \left| e^{-i(\alpha_2+\alpha_1)} c_2 c_1 -e^{i(\beta_2-\beta_1)} s_2 s_1 \right| \\ &= \sqrt{c_1^2 c_2^2 + s_1^2 s_2^2 - 2 c_1 c_2 s_1 s_2 \cos(\omega_1 + \phi_2)}. Now we again make a choice and pick :math:`\theta_f` to be non-negative: .. math:: \theta_f = 2\arccos(|x|). We can also extract the angle combination :math:`\alpha_f` from :math:`x` via :math:`\operatorname{arg}(x)`, which can be readily computed with :math:`\arctan`: .. math:: \alpha_f = -\arctan\left( \frac{-c_1c_2\sin(\alpha_1+\alpha_2)-s_1s_2\sin(\beta_2-\beta_1)} {c_1c_2\cos(\alpha_1+\alpha_2)-s_1s_2\cos(\beta_2-\beta_1)} \right). We can use the standard numerical function ``arctan2``, which computes :math:`\arctan(x_1/x_2)` from :math:`x_1` and :math:`x_2` while handling special points suitably, to obtain the argument of the underlying complex number :math:`x_2 + x_1 i`. Finally, to obtain :math:`\beta_f`, we need a second element of the matrix product from above. We compute the lower-left entry to be .. math:: y = e^{-i(\beta_2+\alpha_1)} s_2 c_1 + e^{i(\alpha_2-\beta_1)} c_2 s_1, which should equal :math:`e^{-i \beta_f}s_f`. From this, we can compute .. math:: \beta_f = -\arctan\left( \frac{-c_1s_2\sin(\alpha_1+\beta_2)+s_1c_2\sin(\alpha_2-\beta_1)} {c_1s_2\cos(\alpha_1+\beta_2)+s_1c_2\cos(\alpha_2-\beta_1)} \right). From this, we may extract .. math:: \phi_f = \alpha_f + \beta_f\qquad \omega_f = \alpha_f - \beta_f and are done. **Special cases:** There are a number of special cases for which we can skip the computation above and can combine rotation angles directly. 1. If :math:`\omega_1=\phi_2=0`, we can simply merge the ``RY`` rotation angles :math:`\theta_j` and obtain :math:`(\phi_1, \theta_1+\theta_2, \omega_2)`. 2. If :math:`\theta_j=0`, we can merge the two ``RZ`` rotations of the same ``Rot`` and obtain :math:`(\phi_1+\omega_1+\phi_2, \theta_2, \omega_2)` or :math:`(\phi_1, \theta_1, \omega_1+\phi_2+\omega_2)`. If both ``RY`` angles vanish we get :math:`(\phi_1+\omega_1+\phi_2+\omega_2, 0, 0)`. Note that this optimization is not performed for differentiable input parameters, in order to maintain differentiability. **Mathematical properties:** All functions above are well-defined on the domain we are using them on, if we handle :math:`\arctan` via standard numerical implementations such as ``np.arctan2``. Based on the choices we made in the derivation above, the fused angles will lie in the intervals .. math:: \phi_f, \omega_f \in [-\pi, \pi],\quad \theta_f \in [0, \pi]. Close to the boundaries of these intervals, ``single_qubit_fusion`` exhibits discontinuities, depending on the combination of input angles. These discontinuities also lead to singular (non-differentiable) points as discussed below. **Differentiability:** The function derived above is differentiable almost everywhere. In particular, there are two problematic scenarios at which the derivative is not defined. First, the square root is not differentiable at :math:`0`, making all input angles with :math:`|x|=0` singular. Second, :math:`\arccos` is not differentiable at :math:`1`, making all input angles with :math:`|x|=1` singular. """# Make a working copy of the list to traverselist_copy=tape.operations.copy()new_operations=[]whilelen(list_copy)>0:current_gate=list_copy[0]# If the gate should be excluded, queue it and move on regardless# of fusion potentialifexclude_gatesisnotNone:ifcurrent_gate.nameinexclude_gates:new_operations.append(current_gate)list_copy.pop(0)continue# Look for single_qubit_rot_angles; if not available, queue and move on.# If available, grab the angles and try to fuse.try:cumulative_angles=qml.math.stack(current_gate.single_qubit_rot_angles())except(NotImplementedError,AttributeError):new_operations.append(current_gate)list_copy.pop(0)continue# Find the next gate that acts on at least one of the same wiresnext_gate_idx=find_next_gate(current_gate.wires,list_copy[1:])ifnext_gate_idxisNone:new_operations.append(current_gate)list_copy.pop(0)continue# Before entering the loop, we check to make sure the next gate is not in the# exclusion list. If it is, we should apply the original gate as-is, and not the# Rot version (example in test test_single_qubit_fusion_exclude_gates).ifexclude_gatesisnotNone:next_gate=list_copy[next_gate_idx+1]ifnext_gate.nameinexclude_gates:new_operations.append(current_gate)list_copy.pop(0)continue# Loop as long as a valid next gate existswhilenext_gate_idxisnotNone:next_gate=list_copy[next_gate_idx+1]# Check first if the next gate is in the exclusion listifexclude_gatesisnotNone:ifnext_gate.nameinexclude_gates:break# Try to extract the angles; since the Rot angles are implemented# solely for single-qubit gates, and we used find_next_gate to obtain# the gate in question, only valid single-qubit gates on the same# wire as the current gate will be fused.try:next_gate_angles=qml.math.stack(next_gate.single_qubit_rot_angles())except(NotImplementedError,AttributeError):breakcumulative_angles=fuse_rot_angles(cumulative_angles,next_gate_angles)list_copy.pop(next_gate_idx+1)next_gate_idx=find_next_gate(current_gate.wires,list_copy[1:])# If we are tracing/jitting or differentiating, don't perform any conditional checks and# apply the rotation regardless of the angles.# If not tracing or differentiating, check whether total rotation is trivial by checking# if the RY angle and the sum of the RZ angles are close to 0if(qml.math.is_abstract(cumulative_angles)orqml.math.requires_grad(cumulative_angles)ornotqml.math.allclose(qml.math.stack([cumulative_angles[0]+cumulative_angles[2],cumulative_angles[1]]),0.0,atol=atol,rtol=0,)):withQueuingManager.stop_recording():new_operations.append(Rot(*cumulative_angles,wires=current_gate.wires))# Remove the starting gate from the listlist_copy.pop(0)new_tape=tape.copy(operations=new_operations)defnull_postprocessing(results):"""A postprocesing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]return[new_tape],null_postprocessing