Source code for pennylane.transforms.optimization.cancel_inverses
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Transform for cancelling adjacent inverse gates in quantum circuits."""# pylint: disable=too-many-branchesfromfunctoolsimportlru_cache,partialfrompennylane.ops.op_mathimportAdjointfrompennylane.ops.qubit.attributesimport(self_inverses,symmetric_over_all_wires,symmetric_over_control_wires,)frompennylane.tapeimportQuantumScript,QuantumScriptBatchfrompennylane.transformsimporttransformfrompennylane.typingimportPostprocessingFnfrompennylane.wiresimportWiresfrom.optimization_utilsimportfind_next_gatedef_ops_equal(op1,op2):"""Checks if two operators are equal up to class, data, hyperparameters, and wires"""return(op1.__class__isop2.__class__and(op1.data==op2.data)and(op1.hyperparameters==op2.hyperparameters)and(op1.wires==op2.wires))def_are_inverses(op1,op2):"""Checks if two operators are inverses of each other Args: op1 (~.Operator) op2 (~.Operator) Returns: Bool """# op1 is self-inverse and the next gate is also op1ifop1inself_inversesandop1.name==op2.name:returnTrue# op1 is an `Adjoint` class and its base is equal to op2ifisinstance(op1,Adjoint)and_ops_equal(op1.base,op2):returnTrue# op2 is an `Adjoint` class and its base is equal to op1ifisinstance(op2,Adjoint)and_ops_equal(op2.base,op1):returnTruereturnFalse@lru_cachedef_get_plxpr_cancel_inverses():# pylint: disable=missing-function-docstring,too-many-statementstry:# pylint: disable=import-outside-toplevelfromjaximportmake_jaxprfrompennylane.captureimportPlxprInterpreterfrompennylane.capture.primitivesimportmeasure_primfrompennylane.operationimportOperatorexceptImportError:# pragma: no coverreturnNone,None# pylint: disable=redefined-outer-nameclassCancelInversesInterpreter(PlxprInterpreter):"""Plxpr Interpreter for applying the ``cancel_inverses`` transform to callables or jaxpr when program capture is enabled. .. note:: In the process of transforming plxpr, this interpreter may reorder operations that do not share any wires. This will not impact the correctness of the circuit. """def__init__(self):super().__init__()self.previous_ops={}defsetup(self)->None:"""Initialize the instance before interpreting equations."""self.previous_ops={}definterpret_operation(self,op:Operator):"""Interpret a PennyLane operation instance. This method cancels operations that are the adjoint of the previous operation on the same wires, and otherwise, applies it. Args: op (Operator): a pennylane operator instance Returns: Any This method is only called when the operator's output is a dropped variable, so the output will not affect later equations in the circuit. See also: :meth:`~.interpret_operation_eqn`. """# pylint: disable=too-many-branchesiflen(op.wires)==0:returnsuper().interpret_operation(op)prev_op=self.previous_ops.get(op.wires[0],None)ifprev_opisNone:forwinop.wires:self.previous_ops[w]=opreturn[]cancel=Falseif_are_inverses(op,prev_op):# Same wires, cancelifop.wires==prev_op.wires:cancel=True# Full overlap over wireseliflen(Wires.shared_wires([op.wires,prev_op.wires]))==len(op.wires):# symmetric op + full wire overlap; cancelifopinsymmetric_over_all_wires:cancel=True# symmetric over control wires, full overlap over control wires; cancelelifopinsymmetric_over_control_wiresand(len(Wires.shared_wires([op.wires[:-1],prev_op.wires[:-1]]))==len(op.wires)-1):cancel=True# No or partial overlap over wires; can't cancelifcancel:forwinop.wires:self.previous_ops.pop(w)return[]# Putting the operations in a set to avoid applying the same op multiple times# Using a set causes order to no longer be guaranteed, so the new order of the# operations might differ from the original order. However, this only impacts# operators without any shared wires, so correctness will not be impacted.previous_ops_on_wires=set(self.previous_ops.get(w)forwinop.wires)foroinprevious_ops_on_wires:ifoisnotNone:forwino.wires:self.previous_ops.pop(w)forwinop.wires:self.previous_ops[w]=opres=[]foroinprevious_ops_on_wires:res.append(super().interpret_operation(o))returnresdefinterpret_all_previous_ops(self)->None:"""Interpret all operators in ``previous_ops``. This is done when any previously uninterpreted operators, saved for cancellation, no longer need to be stored."""ops_remaining=set(self.previous_ops.values())foropinops_remaining:super().interpret_operation(op)all_wires=tuple(self.previous_ops.keys())forwinall_wires:self.previous_ops.pop(w)defeval(self,jaxpr:"jax.core.Jaxpr",consts:list,*args)->list:"""Evaluate a jaxpr. Args: jaxpr (jax.core.Jaxpr): the jaxpr to evaluate consts (list[TensorLike]): the constant variables for the jaxpr *args (tuple[TensorLike]): The arguments for the jaxpr. Returns: list[TensorLike]: the results of the execution. """# pylint: disable=too-many-branches,attribute-defined-outside-initself._env={}self.setup()forarg,invarinzip(args,jaxpr.invars,strict=True):self._env[invar]=argforconst,constvarinzip(consts,jaxpr.constvars,strict=True):self._env[constvar]=constforeqninjaxpr.eqns:custom_handler=self._primitive_registrations.get(eqn.primitive,None)ifcustom_handler:# Interpret any stored ops so that they are applied before the custom# primitive is handledself.interpret_all_previous_ops()invals=[self.read(invar)forinvarineqn.invars]outvals=custom_handler(self,*invals,**eqn.params)elifgetattr(eqn.primitive,"prim_type","")=="operator":outvals=self.interpret_operation_eqn(eqn)elifgetattr(eqn.primitive,"prim_type","")=="measurement":self.interpret_all_previous_ops()outvals=self.interpret_measurement_eqn(eqn)else:invals=[self.read(invar)forinvarineqn.invars]subfuns,params=eqn.primitive.get_bind_params(eqn.params)outvals=eqn.primitive.bind(*subfuns,*invals,**params)ifnoteqn.primitive.multiple_results:outvals=[outvals]foroutvar,outvalinzip(eqn.outvars,outvals,strict=True):self._env[outvar]=outval# The following is needed because any operations inside self.previous_ops have not yet# been applied. At this point, we **know** that any operations that should be cancelled# have been cancelled, and operations left inside self.previous_ops should be appliedself.interpret_all_previous_ops()# Read the final result of the Jaxpr from the environmentoutvals=[]forvarinjaxpr.outvars:outval=self.read(var)ifisinstance(outval,Operator):outvals.append(super().interpret_operation(outval))else:outvals.append(outval)self.cleanup()self._env={}returnoutvals@CancelInversesInterpreter.register_primitive(measure_prim)def_(_,*invals,**params):subfuns,params=measure_prim.get_bind_params(params)returnmeasure_prim.bind(*subfuns,*invals,**params)defcancel_inverses_plxpr_to_plxpr(jaxpr,consts,targs,tkwargs,*args):"""Function for applying the ``cancel_inverses`` transform on plxpr."""interpreter=CancelInversesInterpreter(*targs,**tkwargs)defwrapper(*inner_args):returninterpreter.eval(jaxpr,consts,*inner_args)returnmake_jaxpr(wrapper)(*args)returnCancelInversesInterpreter,cancel_inverses_plxpr_to_plxprCancelInversesInterpreter,cancel_inverses_plxpr_to_plxpr=_get_plxpr_cancel_inverses()
[docs]@partial(transform,plxpr_transform=cancel_inverses_plxpr_to_plxpr)defcancel_inverses(tape:QuantumScript)->tuple[QuantumScriptBatch,PostprocessingFn]:"""Quantum function transform to remove any operations that are applied next to their (self-)inverses or adjoint. Args: tape (QNode or QuantumTape or Callable): A quantum circuit. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. **Example** You can apply the cancel inverses transform directly on :class:`~.QNode`. >>> dev = qml.device('default.qubit', wires=3) .. code-block:: python @cancel_inverses @qml.qnode(device=dev) def circuit(x, y, z): qml.Hadamard(wires=0) qml.Hadamard(wires=1) qml.Hadamard(wires=0) qml.RX(x, wires=2) qml.RY(y, wires=1) qml.X(1) qml.RZ(z, wires=0) qml.RX(y, wires=2) qml.CNOT(wires=[0, 2]) qml.X(1) return qml.expval(qml.Z(0)) >>> circuit(0.1, 0.2, 0.3) 0.999999999999999 .. details:: :title: Usage Details You can also apply it on quantum functions: .. code-block:: python def qfunc(x, y, z): qml.Hadamard(wires=0) qml.Hadamard(wires=1) qml.Hadamard(wires=0) qml.RX(x, wires=2) qml.RY(y, wires=1) qml.X(1) qml.RZ(z, wires=0) qml.RX(y, wires=2) qml.CNOT(wires=[0, 2]) qml.X(1) return qml.expval(qml.Z(0)) The circuit before optimization: >>> qnode = qml.QNode(qfunc, dev) >>> print(qml.draw(qnode)(1, 2, 3)) 0: ──H─────────H─────────RZ(3.00)─╭●────┤ <Z> 1: ──H─────────RY(2.00)──X────────│───X─┤ 2: ──RX(1.00)──RX(2.00)───────────╰X────┤ We can see that there are two adjacent Hadamards on the first qubit that should cancel each other out. Similarly, there are two Pauli-X gates on the second qubit that should cancel. We can obtain a simplified circuit by running the ``cancel_inverses`` transform: >>> optimized_qfunc = cancel_inverses(qfunc) >>> optimized_qnode = qml.QNode(optimized_qfunc, dev) >>> print(qml.draw(optimized_qnode)(1, 2, 3)) 0: ──RZ(3.00)───────────╭●─┤ <Z> 1: ──H─────────RY(2.00)─│──┤ 2: ──RX(1.00)──RX(2.00)─╰X─┤ """# Make a working copy of the list to traverselist_copy=tape.operations.copy()operations=[]whilelen(list_copy)>0:current_gate=list_copy[0]list_copy.pop(0)# Find the next gate that acts on at least one of the same wiresnext_gate_idx=find_next_gate(current_gate.wires,list_copy)# If no such gate is found queue the operation and move onifnext_gate_idxisNone:operations.append(current_gate)continue# Otherwise, get the next gatenext_gate=list_copy[next_gate_idx]# If either of the two flags is true, we can potentially cancel the gatesif_are_inverses(current_gate,next_gate):# If the wires are the same, then we can safely remove bothifcurrent_gate.wires==next_gate.wires:list_copy.pop(next_gate_idx)continue# If wires are not equal, there are two things that can happen.# 1. There is not full overlap in the wires; we cannot canceliflen(Wires.shared_wires([current_gate.wires,next_gate.wires]))!=len(current_gate.wires):operations.append(current_gate)continue# 2. There is full overlap, but the wires are in a different order.# If the wires are in a different order, gates that are "symmetric"# over all wires (e.g., CZ), can be cancelled.ifcurrent_gateinsymmetric_over_all_wires:list_copy.pop(next_gate_idx)continue# For other gates, as long as the control wires are the same, we can still# cancel (e.g., the Toffoli gate).ifcurrent_gateinsymmetric_over_control_wires:# TODO[David Wierichs]: This assumes single-qubit targets of controlled gatesif(len(Wires.shared_wires([current_gate.wires[:-1],next_gate.wires[:-1]]))==len(current_gate.wires)-1):list_copy.pop(next_gate_idx)continue# Apply gate any cases where# - there is no wire symmetry# - the control wire symmetry does not apply because the control wires are not the same# - neither of the flags are_self_inverses and are_inverses are trueoperations.append(current_gate)continuenew_tape=tape.copy(operations=operations)defnull_postprocessing(results):"""A postprocesing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]return[new_tape],null_postprocessing