# Copyright 2024 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Defines a function for converting plxpr to a tape."""fromcopyimportcopyimportpennylaneasqmlfrompennylane.capture.base_interpreterimportFlattenedInterpreterfrompennylane.capture.primitivesimport(adjoint_transform_prim,cond_prim,ctrl_transform_prim,grad_prim,jacobian_prim,measure_prim,qnode_prim,)frompennylane.measurementsimportMeasurementValue,get_mcm_predicatesfrom.qscriptimportQuantumScriptclassCollectOpsandMeas(FlattenedInterpreter):"""Collect the dropped operations and measurements in a plxpr. Used by ``convert_to_tape``. .. code-block:: python @qml.for_loop(3) def loop(i): qml.X(i) def f(x): loop() qml.adjoint(qml.S)(0) m0 = qml.measure(0) qml.RX(2*x, 0) return qml.probs(wires=0), qml.expval(qml.Z(1)) >>> from pennylane.tape.plxpr_conversion import CollectOpsandMeas >>> from jax import make_jaxpr >>> qml.capture.enable() >>> plxpr = make_jaxpr(f)(0.5) >>> collector = CollectOpsandMeas() >>> collector.eval(plxpr.jaxpr, plxpr.consts, 1.2) [probs(wires=[0]), expval(Z(1))] >>> collector.state {'ops': [X(0), X(1), X(2), Adjoint(S(0)), measure(wires=[0]), RX(Array(2.4, dtype=float32, weak_type=True), wires=[0])], 'measurements': [probs(wires=[0]), expval(Z(1))]} After execution, the collected operations and measurements are available in the ``state`` property. Note that if the same instance is used again, the new operations will be appended to the same state. >>> collector = CollectOpsandMeas() >>> collector(qml.T)(0) >>> collector.state['ops'] [T(0)] >>> collector(qml.S)(0) >>> collector.state['ops'] [T(0), S(0)] """def__init__(self,state=None):self.state=statesuper().__init__()defsetup(self):ifself.stateisNone:self.state={"ops":[],"measurements":[]}definterpret_operation(self,op:"pennylane.operation.Operator"):self.state["ops"].append(op)definterpret_measurement(self,measurement):self.state["measurements"].append(measurement)returnmeasurement@CollectOpsandMeas.register_primitive(adjoint_transform_prim)def_(self,*invals,jaxpr,lazy,n_consts):"""Handle an adjoint transform primitive by collecting the operations in the jaxpr, and then applying their adjoint in reverse order."""consts=invals[:n_consts]args=invals[n_consts:]child=CollectOpsandMeas()child.eval(jaxpr,consts,*args)assertchild.stateforopinreversed(child.state["ops"]):self.state["ops"].append(qml.adjoint(op,lazy=lazy))return[]@CollectOpsandMeas.register_primitive(ctrl_transform_prim)def_(self,*invals,n_control,jaxpr,n_consts,**params):"""Handle a control transform primitive by collecting the operations in the jaxpr, and then applying their controlled versions. """consts=invals[:n_consts]args=invals[n_consts:-n_control]control=invals[-n_control:]child=CollectOpsandMeas()child.eval(jaxpr,consts,*args)assertchild.stateforopinchild.state["ops"]:self.state["ops"].append(qml.ctrl(op,control=control,**params))return[]@CollectOpsandMeas.register_primitive(cond_prim)def_(self,*all_args,jaxpr_branches,consts_slices,args_slice):n_branches=len(jaxpr_branches)conditions=all_args[:n_branches]args=all_args[args_slice]# Find predicates that use mid-circuit measurements. We don't check the last# condition as that is always `True`.mcm_conditions=tuple(predforpredinconditions[:-1]ifisinstance(pred,MeasurementValue))ifmcm_conditions:iflen(mcm_conditions)!=len(conditions)-1:raiseValueError("Cannot use qml.cond with a combination of mid-circuit measurements ""and other classical conditions as predicates.")conditions=get_mcm_predicates(mcm_conditions)forpred,jaxpr,const_sliceinzip(conditions,jaxpr_branches,consts_slices):consts=all_args[const_slice]ifjaxprisNone:continueifisinstance(pred,qml.measurements.MeasurementValue):ifjaxpr.outvars:outvals=[v.avalforvinjaxpr.outvars]raiseValueError(("Conditional branches of mid circuit measurements are not allowed to"f" return anything with plxpr_to_tape and CollectOpsandMeas. Branch returns {outvals}"))child=CollectOpsandMeas()child.eval(jaxpr,consts,*args)assertchild.stateself.state["ops"].extend(qml.ops.Conditional(pred,op)foropinchild.state["ops"])elifpred:returncopy(self).eval(jaxpr,consts,*args)return()@CollectOpsandMeas.register_primitive(measure_prim)def_(self,wires,reset,postselect):m0=qml.measure(wires,reset=reset,postselect=postselect)self.state["ops"].extend(m0.measurements)returnm0# pylint: disable=unused-argument@CollectOpsandMeas.register_primitive(grad_prim)def_(self,*invals,jaxpr,n_consts,**params):raiseNotImplementedError("CollectOpsandMeas cannot handle the grad primitive")# pylint: disable=unused-argument@CollectOpsandMeas.register_primitive(jacobian_prim)def_(self,*invals,jaxpr,n_consts,**params):raiseNotImplementedError("CollectOpsandMeas cannot handle the jacobian primitive")@CollectOpsandMeas.register_primitive(qnode_prim)def_(self,*invals,shots,qnode,device,execution_config,qfunc_jaxpr,n_consts):# pylint: disable=too-many-arguments,unused-argumentconsts=invals[:n_consts]args=invals[n_consts:]child=CollectOpsandMeas()out=child.eval(qfunc_jaxpr,consts,*args)assertchild.stateself.state["ops"].extend(child.state["ops"])self.state["measurements"].extend(child.state["measurements"])returnout
[docs]defplxpr_to_tape(plxpr:"jax.core.Jaxpr",consts,*args,shots=None)->QuantumScript:"""Convert a plxpr into a tape. Args: plxpr (jax.core.Jaxpr): a pennylane variant jaxpr consts (list): the consts for the jaxpr *args : the arguments to execute the plxpr with Keyword Args: shots (None, int, Sequence[int], Shots): the shots for the tape. Returns: QuantumScript: a single quantum script containing the quantum operations and measurements .. code-block:: python @qml.for_loop(3) def loop(i): qml.X(i) def f(x): loop() qml.adjoint(qml.S)(0) m0 = qml.measure(0) qml.RX(2*x, 0) return qml.probs(wires=0), qml.expval(qml.Z(1)) qml.capture.enable() plxpr = jax.make_jaxpr(f)(0.5) tape = qml.tape.plxpr_to_tape(plxpr.jaxpr, plxpr.consts, 1.2) print(qml.drawer.tape_text(tape, decimals=2)) .. code-block:: 0: ──X──S†──┤↗├──RX(2.40)─┤ Probs 1: ──X────────────────────┤ <Z> 2: ──X────────────────────┤ """collector=CollectOpsandMeas()collector.eval(plxpr,consts,*args)assertcollector.statereturnQuantumScript(collector.state["ops"],collector.state["measurements"],shots=shots)