Source code for pennylane.transforms.optimization.commute_controlled
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Transforms for pushing commuting gates through targets/control qubits."""fromcollectionsimportdequefromfunctoolsimportlru_cache,partialfromitertoolsimportislicefromtypingimportOptional,Sequencefrompennylane.operationimportOperatorfrompennylane.tapeimportQuantumScript,QuantumScriptBatchfrompennylane.transformsimporttransformfrompennylane.typingimportPostprocessingFnfrompennylane.wiresimportWiresfrom.optimization_utilsimportfind_next_gate@lru_cachedef_get_plxpr_commute_controlled():# pylint: disable=missing-function-docstring,too-many-statementstry:# pylint: disable=import-outside-toplevelfromjaximportmake_jaxprfrompennylane.captureimportPlxprInterpreterfrompennylane.capture.primitivesimportmeasure_primexceptImportError:# pragma: no coverreturnNone,None# pylint: disable=redefined-outer-nameclassCommuteControlledInterpreter(PlxprInterpreter):"""Plxpr Interpreter for applying the ``commute_controlled`` transform to callables or jaxpr when program capture is enabled. .. note:: If the direction is set to ``"right"``, this class interprets the operations by scanning them backward after the jaxpr has been traversed (pushing the gates to the right of controlled gates). This is because we can only traverse the jaxpr in a forward direction with the current implementation when program capture is enabled. This is less efficient than setting the direction to ``"left"``, which allows the interpreter to push the gates to the left of controlled gates as it interprets the jaxpr. Despite this, the default direction is ``"right"`` to maintain compatibility with the current default value of the transform. """def__init__(self,direction="right"):"""Initialize the interpreter."""ifdirectionnotin("left","right"):raiseValueError(f"Direction for commute_controlled must be 'left' or 'right'. Got {direction}")self.direction=directionself.op_deque=deque()self._env={}self.current_index=0defcleanup(self)->None:"""Clean up the instance after interpreting equations."""self.op_deque.clear()def_interpret_operation_left(self,op:Operator)->list:"""Interpret a PennyLane operation and push it through controlled operations as far left as possible."""# This function follows the same logic used in the `_commute_controlled_left` function.ifnot_can_be_pushed_through(op):self.current_index+=1self.op_deque.append(op)return[]prev_gate_idx=_find_previous_gate_on_wires(op.wires,self.op_deque)new_index=self.current_indexwhileprev_gate_idxisnotNone:prev_gate=self.op_deque[new_index-(prev_gate_idx+1)]ifnot_can_push_through(prev_gate)ornot_can_commute(op,prev_gate):breaknew_index-=prev_gate_idx+1prev_gate_idx=_find_previous_gate_on_wires(op.wires,tuple(islice(self.op_deque,new_index)))self.op_deque.insert(new_index,op)self.current_index+=1return[]def_interpret_all_operations_right(self)->None:"""Push all single-qubit gates as far right as possible through controlled operations."""# This function follows the same logic used in the `_commute_controlled_right` function.self.current_index=len(self.op_deque)-1whileself.current_index>=0:current_gate=self.op_deque[self.current_index]ifnot_can_be_pushed_through(current_gate):self.current_index-=1continuenext_gate_idx=find_next_gate(current_gate.wires,tuple(islice(self.op_deque,self.current_index+1,len(self.op_deque))),)new_index=self.current_indexwhilenext_gate_idxisnotNone:next_gate=self.op_deque[new_index+next_gate_idx+1]ifnot_can_push_through(next_gate)ornot_can_commute(current_gate,next_gate):breaknew_index+=next_gate_idx+1next_gate_idx=find_next_gate(current_gate.wires,tuple(islice(self.op_deque,new_index+1,len(self.op_deque))),)self.op_deque.insert(new_index+1,current_gate)delself.op_deque[self.current_index]self.current_index-=1definterpret_operation(self,op:Operator):"""Interpret a PennyLane operation instance."""ifself.direction=="left":returnself._interpret_operation_left(op)# If the direction is right, we append the operator# to the list while we scan through the operators forwards.self.op_deque.append(op)return[]definterpret_all_previous_ops(self)->None:"""Interpret all previous operations stored in the instance."""ifself.direction=="left":foropinself.op_deque:super().interpret_operation(op)self.op_deque.clear()return# If the direction is right, push the gates in each sub-list# created at this stage by traversing it backwards.self._interpret_all_operations_right()foropinself.op_deque:super().interpret_operation(op)self.op_deque.clear()defeval(self,jaxpr:"jax.core.Jaxpr",consts:list,*args)->list:"""Evaluate a jaxpr. Args: jaxpr (jax.core.Jaxpr): the jaxpr to evaluate consts (list[TensorLike]): the constant variables for the jaxpr *args (tuple[TensorLike]): The arguments for the jaxpr. Returns: list[TensorLike]: the results of the execution. """self.setup()self._env={}forarg,invarinzip(args,jaxpr.invars,strict=True):self._env[invar]=argforconst,constvarinzip(consts,jaxpr.constvars,strict=True):self._env[constvar]=constforeqninjaxpr.eqns:prim_type=getattr(eqn.primitive,"prim_type","")custom_handler=self._primitive_registrations.get(eqn.primitive,None)ifcustom_handler:self.interpret_all_previous_ops()invals=[self.read(invar)forinvarineqn.invars]outvals=custom_handler(self,*invals,**eqn.params)elifprim_type=="operator":outvals=self.interpret_operation_eqn(eqn)elifprim_type=="measurement":self.interpret_all_previous_ops()outvals=self.interpret_measurement_eqn(eqn)else:invals=[self.read(invar)forinvarineqn.invars]subfuns,params=eqn.primitive.get_bind_params(eqn.params)outvals=eqn.primitive.bind(*subfuns,*invals,**params)ifnoteqn.primitive.multiple_results:outvals=[outvals]foroutvar,outvalinzip(eqn.outvars,outvals,strict=True):self._env[outvar]=outvalself.interpret_all_previous_ops()outvals=[]forvarinjaxpr.outvars:outval=self.read(var)ifisinstance(outval,Operator):outvals.append(super().interpret_operation(outval))else:outvals.append(outval)self._env={}self.cleanup()returnoutvals@CommuteControlledInterpreter.register_primitive(measure_prim)def_(_,*invals,**params):_,params=measure_prim.get_bind_params(params)returnmeasure_prim.bind(*invals,**params)defcommute_controlled_plxpr_to_plxpr(jaxpr,consts,targs,tkwargs,*args):# pylint: disable=unused-argumentinterpreter=CommuteControlledInterpreter(direction=tkwargs.get("direction","right"))defwrapper(*inner_args):returninterpreter.eval(jaxpr,consts,*inner_args)returnmake_jaxpr(wrapper)(*args)returnCommuteControlledInterpreter,commute_controlled_plxpr_to_plxprCommuteControlledInterpreter,commute_controlled_plxpr_to_plxpr=_get_plxpr_commute_controlled()def_find_previous_gate_on_wires(wires:Wires,prevs_ops:Sequence)->Optional[int]:"""Finds the previous gate index that shares wires."""returnfind_next_gate(wires,reversed(prevs_ops))def_shares_control_wires(op:Operator,ctrl_gate:Operator)->bool:"""Check if the operation shares wires with the control wires of the provided controlled gate."""returnlen(Wires.shared_wires([Wires(op.wires),ctrl_gate.control_wires]))>0def_can_commute(op1:Operator,op2:Operator)->bool:"""Helper that determines if op1 can commute with a single-qubit gate op2 based on their basis and control wires."""# Case 1: overlap is on the control wires. Only Z-type gates go throughif_shares_control_wires(op1,op2):returnop1.basis=="Z"# Case 2: since we know the gates overlap somewhere, and it's a# single-qubit gate, if it wasn't on a control it's the target.returnop1.basis==op2.basisdef_can_push_through(op:Operator)->bool:"""Check if the provided gate can be pushed through."""# Only go ahead if information is available.# If the gate does not have control_wires defined, it is not# controlled so we won't push through.returnhasattr(op,"basis")andlen(op.control_wires)>0def_can_be_pushed_through(op:Operator)->bool:"""Check if the provided gate is a single-qubit gate that can be pushed through."""# We are looking only at the gates that can be pushed through# controls/targets; these are single-qubit gates with the basis# property specified.returnhasattr(op,"basis")andlen(op.wires)==1def_commute_controlled_right(op_list):"""Push commuting single qubit gates to the right of controlled gates. Args: op_list (list[Operation]): The initial list of operations. Returns: list[Operation]: The modified list of operations with all single-qubit gates as far right as possible. """# We will go through the list backwards; whenever we find a single-qubit# gate, we will extract it and push it through 2-qubit gates as far as# possible to the right.current_location=len(op_list)-1whilecurrent_location>=0:current_gate=op_list[current_location]ifnot_can_be_pushed_through(current_gate):current_location-=1continue# Find the next gate that contains an overlapping wirenext_gate_idx=find_next_gate(current_gate.wires,op_list[current_location+1:])new_location=current_location# Loop as long as a valid next gate existswhilenext_gate_idxisnotNone:next_gate=op_list[new_location+next_gate_idx+1]ifnot_can_push_through(next_gate)ornot_can_commute(current_gate,next_gate):breaknew_location+=next_gate_idx+1next_gate_idx=find_next_gate(current_gate.wires,op_list[new_location+1:])# After we have gone as far as possible, move the gate to new locationop_list.insert(new_location+1,current_gate)op_list.pop(current_location)current_location-=1returnop_listdef_commute_controlled_left(op_list):"""Push commuting single qubit gates to the left of controlled gates. Args: op_list (list[Operation]): The initial list of operations. Returns: list[Operation]: The modified list of operations with all single-qubit gates as far left as possible. """# We will go through the list forwards; whenever we find a single-qubit# gate, we will extract it and push it through 2-qubit gates as far as# possible back to the left.current_location=0whilecurrent_location<len(op_list):current_gate=op_list[current_location]ifnot_can_be_pushed_through(current_gate):current_location+=1continue# Pass a backwards copy of the listprev_gate_idx=find_next_gate(current_gate.wires,op_list[:current_location][::-1])new_location=current_locationwhileprev_gate_idxisnotNone:prev_gate=op_list[new_location-prev_gate_idx-1]ifnot_can_push_through(prev_gate)ornot_can_commute(current_gate,prev_gate):breaknew_location-=prev_gate_idx+1prev_gate_idx=find_next_gate(current_gate.wires,op_list[:new_location][::-1])op_list.pop(current_location)op_list.insert(new_location,current_gate)current_location+=1returnop_list
[docs]@partial(transform,plxpr_transform=commute_controlled_plxpr_to_plxpr)defcommute_controlled(tape:QuantumScript,direction="right")->tuple[QuantumScriptBatch,PostprocessingFn]:"""Quantum transform to move commuting gates past control and target qubits of controlled operations. Args: tape (QNode or QuantumTape or Callable): A quantum circuit. direction (str): The direction in which to move single-qubit gates. Options are "right" (default), or "left". Single-qubit gates will be pushed through controlled operations as far as possible in the specified direction. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. **Example** >>> dev = qml.device('default.qubit', wires=3) You can apply the transform directly on :class:`QNode`: .. code-block:: python @partial(commute_controlled, direction="right") @qml.qnode(device=dev) def circuit(theta): qml.CZ(wires=[0, 2]) qml.X(2) qml.S(wires=0) qml.CNOT(wires=[0, 1]) qml.Y(1) qml.CRY(theta, wires=[0, 1]) qml.PhaseShift(theta/2, wires=0) qml.Toffoli(wires=[0, 1, 2]) qml.T(wires=0) qml.RZ(theta/2, wires=1) return qml.expval(qml.Z(0)) >>> circuit(0.5) 0.9999999999999999 .. details:: :title: Usage Details You can also apply it on quantum function. .. code-block:: python def qfunc(theta): qml.CZ(wires=[0, 2]) qml.X(2) qml.S(wires=0) qml.CNOT(wires=[0, 1]) qml.Y(1) qml.CRY(theta, wires=[0, 1]) qml.PhaseShift(theta/2, wires=0) qml.Toffoli(wires=[0, 1, 2]) qml.T(wires=0) qml.RZ(theta/2, wires=1) return qml.expval(qml.Z(0)) >>> qnode = qml.QNode(qfunc, dev) >>> print(qml.draw(qnode)(0.5)) 0: ─╭●──S─╭●────╭●─────────Rϕ(0.25)─╭●──T────────┤ <Z> 1: ─│─────╰X──Y─╰RY(0.50)───────────├●──RZ(0.25)─┤ 2: ─╰Z──X───────────────────────────╰X───────────┤ Diagonal gates on either side of control qubits do not affect the outcome of controlled gates; thus we can push all the single-qubit gates on the first qubit together on the right (and fuse them if desired). Similarly, X gates commute with the target of ``CNOT`` and ``Toffoli`` (and ``PauliY`` with ``CRY``). We can use the transform to push single-qubit gates as far as possible through the controlled operations: >>> optimized_qfunc = commute_controlled(qfunc, direction="right") >>> optimized_qnode = qml.QNode(optimized_qfunc, dev) >>> print(qml.draw(optimized_qnode)(0.5)) 0: ─╭●─╭●─╭●───────────╭●──S─────────Rϕ(0.25)──T─┤ <Z> 1: ─│──╰X─╰RY(0.50)──Y─├●──RZ(0.25)──────────────┤ 2: ─╰Z─────────────────╰X──X─────────────────────┤ """ifdirectionnotin("left","right"):raiseValueError("Direction for commute_controlled must be 'left' or 'right'")ifdirection=="right":op_list=_commute_controlled_right(tape.operations)else:op_list=_commute_controlled_left(tape.operations)new_tape=tape.copy(operations=op_list)defnull_postprocessing(results):"""A postprocesing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]return[new_tape],null_postprocessing