Source code for pennylane.transforms.optimization.merge_amplitude_embedding
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Transform for merging AmplitudeEmbedding gates in a quantum circuit."""fromcollections.abcimportSequencefromcopyimportcopyfromfunctoolsimportlru_cache,partialimportpennylaneasqmlfrompennylaneimportAmplitudeEmbeddingfrompennylane.exceptionsimportDeviceError,TransformErrorfrompennylane.mathimportflatten,is_abstract,reshapefrompennylane.queuingimportQueuingManagerfrompennylane.tapeimportQuantumScript,QuantumScriptBatchfrompennylane.transforms.coreimporttransformfrompennylane.typingimportPostprocessingFn# pylint: disable=too-many-statements@lru_cachedef_get_plxpr_merge_amplitude_embedding():try:# pylint: disable=import-outside-toplevelfromjaximportmake_jaxprfromjax.extend.coreimportJaxprfrompennylane.captureimportPlxprInterpreterfrompennylane.capture.base_interpreterimportjaxpr_to_jaxprfrompennylane.capture.primitivesimportcond_prim,measure_primfrompennylane.operationimportOperatorexceptImportError:# pragma: no coverreturnNone,None# pylint: disable=redefined-outer-nameclassMergeAmplitudeEmbeddingInterpreter(PlxprInterpreter):"""Plxpr Interpreter for merging AmplitudeEmbedding gates when program capture is enabled."""def__init__(self):self._env={}self.dynamic_wires_encountered=Falseself.previous_ops=[]# * visited_wires (set): tracks all wires we have encountered so far.# * dynamic_wires_found (bool): True if we have encountered any non-AmplitudeEmbedding# ops that have dynamic wires so far.# * ops_found (bool): True if we have encountered any non-AmplitudeEmbedding ops so far.self.state={"visited_wires":set(),"dynamic_wires_found":False,"ops_found":False}self.input_wires,self.input_vectors,self.input_batch_size=[],[],[]defsetup(self)->None:"""Setup the interpreter for a new evaluation."""self.previous_ops=[]self.input_wires,self.input_vectors,self.input_batch_size=[],[],[]defcleanup(self)->None:"""Clean up the interpreter after evaluation."""self.state={"visited_wires":set(),"dynamic_wires_found":False,"ops_found":False}definterpret_operation(self,op:Operator)->None:"""Interpret a PennyLane operation instance. If the operator is not an ``AmplitudeEmbedding`` operator, it is added to the new operations list; otherwise, the wires and parameters are stored for future usage. Args: op (Operator): a pennylane operator instance Raises: DeviceError: if the AmplitudeEmbedding operator's wires have already been used by other operations Returns: None: returns None This method is only called when the operator's output is a dropped variable, so the output will not affect later equations in the circuit. """ifnotisinstance(op,AmplitudeEmbedding):ifany(is_abstract(w)forwinop.wires):ifself.input_wires:self._merge_and_insert_at_the_start()self.interpret_all_previous_ops()self.state["dynamic_wires_found"]=Trueself.state["ops_found"]=Trueself.previous_ops.append(op)self.state["visited_wires"]=self.state["visited_wires"].union(set(op.wires))returnifself.state["dynamic_wires_found"]:raiseTransformError("Cannot apply qml.AmplitudeEmbedding after operators with dynamic wires as it ""is indeterminable if the wires overlap.")ifself.state["ops_found"]andany(is_abstract(w)forwinop.wires):raiseTransformError("Cannot apply qml.AmplitudeEmbedding with dynamic wires after other operators ""as it is indeterminable if the wires overlap.")iflen(self.state["visited_wires"].intersection(set(op.wires)))>0:raiseTransformError("qml.AmplitudeEmbedding cannot be applied on wires already used by other operations.")self.input_wires.append(op.wires)self.input_vectors.append(op.parameters[0])self.input_batch_size.append(op.batch_size)self.state["visited_wires"]=self.state["visited_wires"].union(set(op.wires))def_merge_and_insert_at_the_start(self)->None:"""Merge the AmplitudeEmbedding gates and insert it at the beginning of the previously seen operations."""final_wires=self.input_wires[0]final_vector=self.input_vectors[0]final_batch_size=self.input_batch_size[0]forw,v,binzip(self.input_wires[1:],self.input_vectors[1:],self.input_batch_size[1:],strict=True,):final_vector=final_vector[...,:,None]*v[...,None,:]final_batch_size=final_batch_sizeorbfinal_wires=final_wires+wiffinal_batch_size:final_vector=reshape(final_vector,(final_batch_size,-1))else:final_vector=flatten(final_vector)withqml.capture.pause():self.previous_ops.insert(0,qml.AmplitudeEmbedding(final_vector,wires=final_wires))# Clear history of amplitude embedding gates since we've mergedself.input_wires,self.input_vectors,self.input_batch_size=[],[],[]definterpret_all_previous_ops(self)->None:"""Interpret all previous operations and clear the setup variables."""foropinself.previous_ops:super().interpret_operation(op)self.previous_ops.clear()# pylint: disable=too-many-branchesdefeval(self,jaxpr:Jaxpr,consts:Sequence,*args)->list:"""Evaluate a jaxpr. Args: jaxpr (jax.extend.core.Jaxpr): the jaxpr to evaluate consts (list[TensorLike]): the constant variables for the jaxpr *args (tuple[TensorLike]): The arguments for the jaxpr. Returns: list[TensorLike]: the results of the execution. """self._env={}self.setup()forarg,invarinzip(args,jaxpr.invars,strict=True):self._env[invar]=argforconst,constvarinzip(consts,jaxpr.constvars,strict=True):self._env[constvar]=constforeqninjaxpr.eqns:custom_handler=self._primitive_registrations.get(eqn.primitive,None)prim_type=getattr(eqn.primitive,"prim_type","")# Currently cannot merge through higher order primitives.# Workaround is to merge and insert the merged gate before entering# a higher order primitive.ifprim_type=="higher_order":iflen(self.input_wires)>0:self._merge_and_insert_at_the_start()self.interpret_all_previous_ops()ifcustom_handler:invals=[self.read(invar)forinvarineqn.invars]outvals=custom_handler(self,*invals,**eqn.params)elifprim_type=="operator":outvals=self.interpret_operation_eqn(eqn)elifprim_type=="measurement":iflen(self.input_wires)>0:self._merge_and_insert_at_the_start()self.interpret_all_previous_ops()outvals=self.interpret_measurement_eqn(eqn)else:invals=[self.read(invar)forinvarineqn.invars]extra_args,params=eqn.primitive.get_bind_params(eqn.params)outvals=eqn.primitive.bind(*extra_args,*invals,**params)ifnoteqn.primitive.multiple_results:outvals=[outvals]foroutvar,outvalinzip(eqn.outvars,outvals,strict=True):self._env[outvar]=outval# The following is needed because any operations inside self.previous_ops have not yet# been applied.iflen(self.input_wires)>0:self._merge_and_insert_at_the_start()self.interpret_all_previous_ops()# Read the final result of the Jaxpr from the environmentoutvals=[]forvarinjaxpr.outvars:outval=self.read(var)ifisinstance(outval,Operator):outvals.append(super().interpret_operation(outval))else:outvals.append(outval)self.cleanup()self._env={}returnoutvals# Overwrite the cond primitive so that visited wires can be correctly# detected across the different branches.@MergeAmplitudeEmbeddingInterpreter.register_primitive(cond_prim)def_cond_primitive(self,*invals,jaxpr_branches,consts_slices,args_slice):args=invals[args_slice]new_jaxprs=[]new_consts=[]new_consts_slices=[]end_const_ind=len(jaxpr_branches)# Store state before we begin to process the branches# (create copies as to not accidently mutate the original state).# We cannot just copy self.state because a shallow copy would not# create a copy of `visited_wires`, which is a set.# We cannot use deepcopy as `visited_wires` may have tracers inside,# which have hashes specific to the instance. Copying these will cause# the dynamic wires in the original and copy to be different.initial_wires=copy(self.state["visited_wires"])curr_wires=copy(self.state["visited_wires"])initial_dynamic_wires_found=self.state["dynamic_wires_found"]curr_dynamic_wires_found=self.state["dynamic_wires_found"]initial_ops_found=self.state["ops_found"]curr_ops_found=self.state["ops_found"]forconst_slice,jaxprinzip(consts_slices,jaxpr_branches,strict=True):consts=invals[const_slice]new_jaxpr=jaxpr_to_jaxpr(copy(self),jaxpr,consts,*args)# Update state so far so collisions with# newly seen states from the branches continue to be# detected after the condcurr_wires|=self.state["visited_wires"]curr_dynamic_wires_found=curr_dynamic_wires_foundorself.state["dynamic_wires_found"]curr_ops_found=curr_ops_foundorself.state["ops_found"]# Reset state for the next branch so we don't get false positive collisions# (copy so if state mutates we preserved true initial state)self.state={"visited_wires":copy(initial_wires),"dynamic_wires_found":initial_dynamic_wires_found,"ops_found":initial_ops_found,}new_jaxprs.append(new_jaxpr.jaxpr)new_consts.extend(new_jaxpr.consts)new_consts_slices.append(slice(end_const_ind,end_const_ind+len(new_jaxpr.consts)))end_const_ind+=len(new_jaxpr.consts)# Reset state to all updates from all branches in the condself.state={"visited_wires":curr_wires,"dynamic_wires_found":curr_dynamic_wires_found,"ops_found":curr_ops_found,}new_args_slice=slice(end_const_ind,None)returncond_prim.bind(*invals[:len(jaxpr_branches)],*new_consts,*args,jaxpr_branches=new_jaxprs,consts_slices=new_consts_slices,args_slice=new_args_slice,)@MergeAmplitudeEmbeddingInterpreter.register_primitive(measure_prim)def_measure_primitive(self,*invals,**params):# Make sure to record that we have visited the wires on this measurement# in order to be able to detect potential wire collisions with future AE gatesself.state["visited_wires"]=self.state["visited_wires"].union(set(invals))self.state["dynamic_wires_found"]=any(is_abstract(w)forwininvals)self.state["ops_found"]=True# pylint: disable=protected-accessiflen(self.input_wires)>0:self._merge_and_insert_at_the_start()self.interpret_all_previous_ops()_,params=measure_prim.get_bind_params(params)returnmeasure_prim.bind(*invals,**params)defmerge_amplitude_embedding_plxpr_to_plxpr(jaxpr,consts,_,__,*args):"""Function for applying the ``merge_amplitude_embedding`` transform on plxpr."""interpreter=MergeAmplitudeEmbeddingInterpreter()defwrapper(*inner_args):returninterpreter.eval(jaxpr,consts,*inner_args)returnmake_jaxpr(wrapper)(*args)returnMergeAmplitudeEmbeddingInterpreter,merge_amplitude_embedding_plxpr_to_plxprMergeAmplitudeEmbeddingInterpreter,merge_amplitude_embedding_plxpr_to_plxpr=(_get_plxpr_merge_amplitude_embedding())
[docs]@partial(transform,plxpr_transform=merge_amplitude_embedding_plxpr_to_plxpr)defmerge_amplitude_embedding(tape:QuantumScript)->tuple[QuantumScriptBatch,PostprocessingFn]:r"""Quantum function transform to combine amplitude embedding templates that act on different qubits. Args: tape (QNode or QuantumTape or Callable): A quantum circuit. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[.QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. **Example** >>> dev = qml.device('default.qubit', wires=4) You can apply the transform directly on :class:`QNode`: .. code-block:: python @qml.transforms.merge_amplitude_embedding @qml.qnode(device=dev) def circuit(): qml.CNOT(wires = [0,1]) qml.AmplitudeEmbedding([0,1], wires = 2) qml.AmplitudeEmbedding([0,1], wires = 3) return qml.state() >>> circuit() array([0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]) .. details:: :title: Usage Details You can also apply it on quantum function. .. code-block:: python def qfunc(): qml.CNOT(wires = [0,1]) qml.AmplitudeEmbedding([0,1], wires = 2) qml.AmplitudeEmbedding([0,1], wires = 3) return qml.state() The circuit before compilation will not work because of using two amplitude embedding. Using the transformation we can join the different amplitude embedding into a single one: >>> optimized_qfunc = qml.transforms.merge_amplitude_embedding(qfunc) >>> optimized_qnode = qml.QNode(optimized_qfunc, dev) >>> print(qml.draw(optimized_qnode)()) 0: ─╭●───┤ State 1: ─╰X───┤ State 2: ─╭|Ψ⟩─┤ State 3: ─╰|Ψ⟩─┤ State """new_operations=[]visited_wires=set()input_wires,input_vectors,input_batch_size=[],[],[]forcurrent_gateintape.operations:wires_set=set(current_gate.wires)# Check if the current gate is an AmplitudeEmbedding.ifnotisinstance(current_gate,AmplitudeEmbedding):new_operations.append(current_gate)visited_wires=visited_wires.union(wires_set)continue# Check the qubits have not been used.iflen(visited_wires.intersection(wires_set))>0:raiseDeviceError(f"Operation {current_gate.name} cannot be used after other Operation applied in the same qubit ")input_wires.append(current_gate.wires)input_vectors.append(current_gate.parameters[0])input_batch_size.append(current_gate.batch_size)visited_wires=visited_wires.union(wires_set)iflen(input_wires)>0:final_wires=input_wires[0]final_vector=input_vectors[0]final_batch_size=input_batch_size[0]# Merge all parameters and qubits into a single one.forw,v,binzip(input_wires[1:],input_vectors[1:],input_batch_size[1:],strict=True):final_vector=final_vector[...,:,None]*v[...,None,:]final_batch_size=final_batch_sizeorbfinal_wires=final_wires+wiffinal_batch_size:final_vector=reshape(final_vector,(final_batch_size,-1))else:final_vector=flatten(final_vector)withQueuingManager.stop_recording():new_operations.insert(0,AmplitudeEmbedding(final_vector,wires=final_wires))new_tape=tape.copy(operations=new_operations)defnull_postprocessing(results):"""A postprocesing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]return[new_tape],null_postprocessing