Source code for pennylane.transforms.defer_measurements
# Copyright 2018-2022 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Code for the tape transform implementing the deferred measurement principle."""fromfunctoolsimportlru_cache,partialfromnumbersimportNumberfromtypingimportCallable,Optional,Sequence,Unionfromwarningsimportwarnimportpennylaneasqmlfrompennylane.measurementsimport(CountsMP,MeasurementValue,MidMeasureMP,ProbabilityMP,SampleMP,get_mcm_predicates,)frompennylane.ops.op_mathimportctrlfrompennylane.queuingimportQueuingManagerfrompennylane.tapeimportQuantumScript,QuantumScriptBatchfrompennylane.transformsimportTransformError,transformfrompennylane.typingimportPostprocessingFnfrompennylane.wiresimportWires# pylint: disable=too-many-branches, protected-access, too-many-statementsdef_check_tape_validity(tape:QuantumScript):"""Helper function to check that the tape is valid."""cv_types=(qml.operation.CVOperation,qml.operation.CVObservable)ops_cv=any(isinstance(op,cv_types)andop.name!="Identity"foropintape.operations)obs_cv=any(isinstance(getattr(op,"obs",None),cv_types)andnotisinstance(getattr(op,"obs",None),qml.Identity)foropintape.measurements)ifops_cvorobs_cv:raiseValueError("Continuous variable operations and observables are not supported.")formpintape.measurements:ifisinstance(mp,(CountsMP,ProbabilityMP,SampleMP))andnot(mp.obsormp._wiresormp.mvisnotNone):raiseValueError(f"Cannot use {mp.__class__.__name__} as a measurement without specifying wires ""when using qml.defer_measurements. Deferred measurements can occur ""automatically when using mid-circuit measurements on a device that does not ""support them.")ifmp.__class__.__name__=="StateMP":raiseValueError("Cannot use StateMP as a measurement when using qml.defer_measurements. ""Deferred measurements can occur automatically when using mid-circuit ""measurements on a device that does not support them.")samples_present=any(isinstance(mp,SampleMP)formpintape.measurements)postselect_present=any(op.postselectisnotNoneforopintape.operationsifisinstance(op,MidMeasureMP))ifpostselect_presentandsamples_presentandtape.batch_sizeisnotNone:raiseValueError("Returning qml.sample is not supported when postselecting mid-circuit ""measurements with broadcasting")def_collect_mid_measure_info(tape:QuantumScript):"""Helper function to collect information related to mid-circuit measurements in the tape."""# Find wires that are reused after measurementmeasured_wires=[]reused_measurement_wires=set()any_repeated_measurements=Falseis_postselecting=Falseforopintape:ifisinstance(op,MidMeasureMP):ifop.postselectisnotNone:is_postselecting=Trueifop.reset:reused_measurement_wires.add(op.wires[0])ifop.wires[0]inmeasured_wires:any_repeated_measurements=Truemeasured_wires.append(op.wires[0])else:reused_measurement_wires=reused_measurement_wires.union(set(measured_wires).intersection(op.wires.toset()))returnmeasured_wires,reused_measurement_wires,any_repeated_measurements,is_postselectingdefnull_postprocessing(results):"""A postprocessing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]@lru_cachedef_get_plxpr_defer_measurements():try:# pylint: disable=import-outside-toplevelimportjaxfrompennylane.captureimportPlxprInterpreterfrompennylane.capture.primitivesimportcond_prim,ctrl_transform_prim,measure_primexceptImportError:# pragma: no coverreturnNone,None# pylint: disable=redefined-outer-nameclassDeferMeasurementsInterpreter(PlxprInterpreter):"""Interpreter for applying the defer_measurements transform to plxpr."""# pylint: disable=unnecessary-lambda-assignment,attribute-defined-outside-init,no-self-usedef__init__(self,num_wires):super().__init__()self._num_wires=num_wires# We use a dict here instead of a normal int variable because we want the state to mutate# when we interpret higher-order primitives# We store all used wires rather than just the max because if the wires are tracers, then# we can't do comparisons to find the max wireself.state={"cur_target":num_wires-1,"used_wires":set()}defcleanup(self)->None:"""Perform any final steps after iterating through all equations. This method resets the internal ``state``, specifically the ``cur_idx`` entry. ``cur_idx`` tracks the index within the auxiliary wires, determining which will be used for the target wire of the next mid-circuit measurement's replacement :class:`~pennylane.CNOT`. """self.state={"cur_target":self._num_wires-1,"used_wires":set()}def_update_used_wires(self,wires:qml.wires.Wires,cur_target:int):"""Update the state with the number of wires that have been used and validate that there is no overlap between the used circuit wires and the mid-circuit measurement target wires. Args: wires (pennylane.wires.Wires): wires to add to the set of used wires cur_target (int): target wire to be used for a mid-circuit measurement Raises: TransformError: if there is overlap between the used circuit wires and mid-circuit measurement target wires """self.state["used_wires"]|=wires.toset()ifself.state["used_wires"].intersection(range(cur_target,self._num_wires)):raiseTransformError("Too many mid-circuit measurements for the specified number of wires ""with 'defer_measurements'.")definterpret_dynamic_operation(self,data,struct,inds):"""Interpret an operation that uses mid-circuit measurement outcomes as parameters. * This will not work if mid-circuit measurement values are used to specify operator wires. * This will not work if more than one parameter uses mid-circuit measurement values. Args: data (TensorLike): Flattened data of the operator struct (PyTreeDef): Pytree structure of the operator inds (Sequence[int]): Indices of mid-circuit measurement values in ``data`` Raises: TransformError: if there is overlap between the used circuit wires and mid-circuit measurement target wires """iflen(inds)>1:raiseTransformError("Cannot create operations with multiple parameters based on ""mid-circuit measurements with 'defer_measurements'.")idx=inds[0]mv=data[idx]forbranch,valueinmv.items():data[idx]=valueop=jax.tree_util.tree_unflatten(struct,data)qml.ctrl(op,mv.wires,control_values=branch)definterpret_operation(self,op:"qml.operation.Operator"):"""Interpret a PennyLane operation instance. Args: op (Operator): a pennylane operator instance Returns: Any This method is only called when the operator's output is a dropped variable, so the output will not affect later equations in the circuit. See also: :meth:`~.interpret_operation_eqn`. """# Range for comparison is [cur_target + 1, num_wires) because cur_target# is the _next_ wire to be used for an MCM. We want to check if the used# wires overlap with the already applied MCMs.self._update_used_wires(op.wires,self.state["cur_target"]+1)# We treat operators with operators based on mid-circuit measurement values# separately, and otherwise default to the standard behaviourdata,struct=jax.tree_util.tree_flatten(op)mcm_data_inds=[]fori,dinenumerate(data):ifisinstance(d,MeasurementValue):mcm_data_inds.append(i)ifmcm_data_inds:returnself.interpret_dynamic_operation(data,struct,mcm_data_inds)returnjax.tree_util.tree_unflatten(struct,data)definterpret_measurement(self,measurement:"qml.measurement.MeasurementProcess"):"""Interpret a measurement process instance. Args: measurement (MeasurementProcess): a measurement instance. See also :meth:`~.interpret_measurement_eqn`. """ifmeasurement.mvisnotNone:kwargs={"wires":measurement.wires,"eigvals":measurement.eigvals()}ifisinstance(measurement,CountsMP):kwargs["all_outcomes"]=measurement.all_outcomesmeasurement=type(measurement)(**kwargs)else:# Range for comparison is [cur_target + 1, num_wires) because cur_target# is the _next_ wire to be used for an MCM. We want to check if the used# wires overlap with the already applied MCMs.self._update_used_wires(measurement.wires,self.state["cur_target"]+1)returnsuper().interpret_measurement(measurement)defresolve_mcm_values(self,primitive:"jax.core.Primitive",subfuns:Sequence[Callable],invals:Sequence[Union[MeasurementValue,Number]],params:dict,)->MeasurementValue:"""Create a ``MeasurementValue`` that captures all classical processing of the input ``eqn`` in its ``processing_fn``. Args: primitive (jax.core.Primitive): Jax primitive subfuns (Sequence[Callable]): Callable positional arguments to the primitive. These are created by pre-processing jaxpr equation parameters. invals (Sequence[Union[MeasurementValue, Number]]): Inputs to the primitive params (dict): Keyword arguments to the primitive Returns: MeasurementValue: ``MeasurementValue`` containing classical processing information for applying the input equation to mid-circuit measurement outcomes. """# pylint: disable=protected-access# MeasurementValue._apply is for applying a new operation to the current# MeasurementValue# It is used when either:# 1. A unary operation is performed on one MeasurementValue, e.g., ~m0# 2. A binary operation is performed on a MeasurementValue and a scalar,# e.g., m0 + 1## MeasurementValue._transform_bin_op is for applying a binary operation# to two MeasurementValuesassertlen(invals)<=2processing_fn=partial(primitive.bind,*subfuns,**params)# One MeasurementValueiflen(invals)==1:m0=invals[0]returnm0._apply(processing_fn)# Two MeasurementValuesifall(isinstance(inval,MeasurementValue)forinvalininvals):m0,m1=invalsreturnm0._transform_bin_op(processing_fn,m1)# One MeasurementValue, one number[m0,other]=invalsifisinstance(invals[0],MeasurementValue)elseinvals[::-1]returnm0._apply(lambdax:processing_fn(x,other))defeval(self,jaxpr:"jax.core.Jaxpr",consts:list,*args)->list:"""Evaluate a jaxpr. Args: jaxpr (jax.core.Jaxpr): the jaxpr to evaluate consts (list[TensorLike]): the constant variables for the jaxpr *args (tuple[TensorLike]): The arguments for the jaxpr. Returns: list[TensorLike]: the results of the execution. """self._env={}self.setup()forarg,invarinzip(args,jaxpr.invars,strict=True):self._env[invar]=argforconst,constvarinzip(consts,jaxpr.constvars,strict=True):self._env[constvar]=constforeqninjaxpr.eqns:primitive=eqn.primitivecustom_handler=self._primitive_registrations.get(primitive,None)ifcustom_handler:invals=[self.read(invar)forinvarineqn.invars]outvals=custom_handler(self,*invals,**eqn.params)elifgetattr(primitive,"prim_type","")=="operator":outvals=self.interpret_operation_eqn(eqn)elifgetattr(primitive,"prim_type","")=="measurement":outvals=self.interpret_measurement_eqn(eqn)else:invals=[self.read(invar)forinvarineqn.invars]subfuns,params=primitive.get_bind_params(eqn.params)ifany(isinstance(inval,MeasurementValue)forinvalininvals):outvals=self.resolve_mcm_values(primitive,subfuns,invals,params)else:outvals=primitive.bind(*subfuns,*invals,**params)ifnotprimitive.multiple_results:outvals=[outvals]foroutvar,outvalinzip(eqn.outvars,outvals,strict=True):self._env[outvar]=outval# Read the final result of the Jaxpr from the environmentoutvals=[]forvarinjaxpr.outvars:outval=self.read(var)ifisinstance(outval,qml.operation.Operator):outvals.append(self.interpret_operation(outval))else:outvals.append(outval)self.cleanup()self._env={}returnoutvals@DeferMeasurementsInterpreter.register_primitive(measure_prim)def_(self,wires,reset=False,postselect=None):cur_target=self.state["cur_target"]# Range for comparison is [cur_target, num_wires) because cur_target# is the _current_ wire to be used for an MCM.self._update_used_wires(Wires(wires),cur_target)# Using type.__call__ instead of normally constructing the class prevents# the primitive corresponding to the class to get binded. We do not want the# MidMeasureMP's primitive to get recorded.meas=type.__call__(MidMeasureMP,Wires(cur_target),reset=reset,postselect=postselect,id=str(cur_target))cnot_wires=(wires,cur_target)ifpostselectisnotNone:qml.Projector(jax.numpy.array([postselect]),wires=wires)qml.CNOT(wires=cnot_wires)ifreset:ifpostselectisNone:qml.CNOT(wires=cnot_wires[::-1])elifpostselect==1:qml.PauliX(wires=wires)self.state["cur_target"]-=1returnMeasurementValue([meas],lambdax:x)@DeferMeasurementsInterpreter.register_primitive(cond_prim)def_(self,*invals,jaxpr_branches,consts_slices,args_slice):# pylint: disable=unused-argumentn_branches=len(jaxpr_branches)conditions=invals[:n_branches]ifnotany(isinstance(c,MeasurementValue)forcinconditions):returnPlxprInterpreter._primitive_registrations[cond_prim](self,*invals,jaxpr_branches=jaxpr_branches,consts_slices=consts_slices,args_slice=args_slice,)conditions=get_mcm_predicates(conditions[:-1])args=invals[args_slice]fori,(condition,jaxpr)inenumerate(zip(conditions,jaxpr_branches,strict=True)):ifjaxprisNone:# If a false branch isn't provided, the jaxpr corresponding to the condition# for the false branch will be None. That is the only scenario where we would# reach here.continueifisinstance(condition,MeasurementValue):control_wires=Wires([m.wires[0]formincondition.measurements])forbranch,valueincondition.items():# When reduce_postselected is True, some branches can be ()cur_consts=invals[consts_slices[i]]qml.cond(value,ctrl_transform_prim.bind)(*cur_consts,*args,*control_wires,jaxpr=jaxpr,n_control=len(control_wires),control_values=branch,work_wires=None,n_consts=len(cur_consts),)return[]defdefer_measurements_plxpr_to_plxpr(jaxpr,consts,targs,tkwargs,*args):"""Function for applying the ``defer_measurements`` transform on plxpr."""ifnottkwargs.get("num_wires",None):raiseValueError("'num_wires' argument for qml.defer_measurements must be provided ""when qml.capture.enabled() is True.")iftkwargs.pop("reduce_postselected",False):warn("Cannot set 'reduce_postselected=True' with qml.capture.enabled() ""when using qml.defer_measurements. Argument will be ignored.",UserWarning,)iftkwargs.pop("allow_postselect",False):warn("Cannot set 'allow_postselect=True' with qml.capture.enabled() ""when using qml.defer_measurements. Argument will be ignored.",UserWarning,)interpreter=DeferMeasurementsInterpreter(*targs,**tkwargs)defwrapper(*inner_args):returninterpreter.eval(jaxpr,consts,*inner_args)returnjax.make_jaxpr(wrapper)(*args)returnDeferMeasurementsInterpreter,defer_measurements_plxpr_to_plxprDeferMeasurementsInterpreter,defer_measurements_plxpr_to_plxpr=_get_plxpr_defer_measurements()# pylint: disable=unused-argument
[docs]@partial(transform,plxpr_transform=defer_measurements_plxpr_to_plxpr)defdefer_measurements(tape:QuantumScript,reduce_postselected:bool=True,allow_postselect:bool=True,num_wires:Optional[int]=None,)->tuple[QuantumScriptBatch,PostprocessingFn]:"""Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. This transform uses the `deferred measurement principle <https://en.wikipedia.org/wiki/Deferred_Measurement_Principle>`_ and applies to qubit-based quantum functions. Support for mid-circuit measurements is device-dependent. If a device doesn't support mid-circuit measurements natively, then the QNode will apply this transform. .. note:: The transform uses the :func:`~.ctrl` transform to implement operations controlled on mid-circuit measurement outcomes. The set of operations that can be controlled as such depends on the set of operations supported by the chosen device. .. note:: Devices that inherit from :class:`~pennylane.devices.QubitDevice` **must** be initialized with an additional wire for each mid-circuit measurement after which the measured wire is reused or reset for ``defer_measurements`` to transform the quantum tape correctly. .. note:: This transform does not change the list of terminal measurements returned by the quantum function. .. note:: When applying the transform on a quantum function that contains the :class:`~.Snapshot` instruction, state information corresponding to simulating the transformed circuit will be obtained. No post-measurement states are considered. .. warning:: :func:`~.pennylane.state` is not supported with the ``defer_measurements`` transform. Additionally, :func:`~.pennylane.probs`, :func:`~.pennylane.sample` and :func:`~.pennylane.counts` can only be used with ``defer_measurements`` if wires or an observable are explicitly specified. .. warning:: ``defer_measurements`` does not support using custom wire labels if any measured wires are reused or reset. Args: tape (QNode or QuantumTape or Callable): a quantum circuit. reduce_postselected (bool): Whether to use postselection information to reduce the number of operations and control wires in the output tape. Active by default. This is currently ignored if program capture is enabled. allow_postselect (bool): Whether postselection is allowed. In order to perform postselection with ``defer_measurements``, the device must support the :class:`~.Projector` operation. Defaults to ``True``. This is currently ignored if program capture is enabled. num_wires (int): Optional argument to specify the total number of circuit wires. This is only used if program capture is enabled. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. Raises: ValueError: If custom wire labels are used with qubit reuse or reset ValueError: If any measurements with no wires or observable are present ValueError: If continuous variable operations or measurements are present ValueError: If using the transform with any device other than :class:`default.qubit <~pennylane.devices.DefaultQubit>` and postselection is used **Example** Suppose we have a quantum function with mid-circuit measurements and conditional operations: .. code-block:: python3 def qfunc(par): qml.RY(0.123, wires=0) qml.Hadamard(wires=1) m_0 = qml.measure(1) qml.cond(m_0, qml.RY)(par, wires=0) return qml.expval(qml.Z(0)) The ``defer_measurements`` transform allows executing such quantum functions without having to perform mid-circuit measurements: >>> dev = qml.device('default.qubit', wires=2) >>> transformed_qfunc = qml.defer_measurements(qfunc) >>> qnode = qml.QNode(transformed_qfunc, dev) >>> par = np.array(np.pi/2, requires_grad=True) >>> qnode(par) tensor(0.43487747, requires_grad=True) We can also differentiate parameters passed to conditional operations: >>> qml.grad(qnode)(par) tensor(-0.49622252, requires_grad=True) Reusing and resetting measured wires will work as expected with the ``defer_measurements`` transform: .. code-block:: python3 dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(x, y): qml.RY(x, wires=0) qml.CNOT(wires=[0, 1]) m_0 = qml.measure(1, reset=True) qml.cond(m_0, qml.RY)(y, wires=0) qml.RX(np.pi/4, wires=1) return qml.probs(wires=[0, 1]) Executing this QNode: >>> pars = np.array([0.643, 0.246], requires_grad=True) >>> func(*pars) tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True) .. details:: :title: Usage Details By default, ``defer_measurements`` makes use of postselection information of mid-circuit measurements in the circuit in order to reduce the number of controlled operations and control wires. We can explicitly switch this feature off and compare the created circuits with and without this optimization. Consider the following circuit: .. code-block:: python3 @qml.qnode(qml.device("default.qubit")) def node(x): qml.RX(x, 0) qml.RX(x, 1) qml.RX(x, 2) mcm0 = qml.measure(0, postselect=0, reset=False) mcm1 = qml.measure(1, postselect=None, reset=True) mcm2 = qml.measure(2, postselect=1, reset=False) qml.cond(mcm0+mcm1+mcm2==1, qml.RX)(0.5, 3) return qml.expval(qml.Z(0) @ qml.Z(3)) Without the optimization, we find three gates controlled on the three measured qubits. They correspond to the combinations of controls that satisfy the condition ``mcm0+mcm1+mcm2==1``. >>> print(qml.draw(qml.defer_measurements(node, reduce_postselected=False))(0.6)) 0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────────────────────────────────┤ ╭<Z@Z> 1: ──RX(0.60)─────────│──╭●─╭X───────────────────────────────────────┤ │ 2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|─╭○────────╭○────────╭●────────┤ │ 3: ───────────────────│──│──│──────────├RX(0.50)─├RX(0.50)─├RX(0.50)─┤ ╰<Z@Z> 4: ───────────────────╰X─│──│──────────├○────────├●────────├○────────┤ 5: ──────────────────────╰X─╰●─────────╰●────────╰○────────╰○────────┤ If we do not explicitly deactivate the optimization, we obtain a much simpler circuit: >>> print(qml.draw(qml.defer_measurements(node))(0.6)) 0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────┤ ╭<Z@Z> 1: ──RX(0.60)─────────│──╭●─╭X───────────┤ │ 2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|───┤ │ 3: ───────────────────│──│──│──╭RX(0.50)─┤ ╰<Z@Z> 4: ───────────────────╰X─│──│──│─────────┤ 5: ──────────────────────╰X─╰●─╰○────────┤ There is only one controlled gate with only one control wire. .. details:: :title: Deferred measurements with program capture ``qml.defer_measurements`` can be applied to callables when program capture is enabled. To do so, the ``num_wires`` argument must be provided, which should be an integer corresponding to the total number of available wires. For ``m`` mid-circuit measurements, ``range(num_wires - m, num_wires)`` will be the range of wires used to map mid-circuit measurements to ``CNOT`` gates. .. warning:: While the transform includes validation to avoid overlap between wires of the original circuit and mid-circuit measurement target wires, if any wires of the original ciruit are traced, i.e. dependent on dynamic arguments to the transformed workflow, the validation may not catch overlaps. Consider the following example: .. code-block:: python from functools import partial import jax qml.capture.enable() @qml.capture.expand_plxpr_transforms @partial(qml.defer_measurements, num_wires=1) def f(n): qml.measure(n) >>> jax.make_jaxpr(f)(0) { lambda ; a:i64[]. let _:AbstractOperator() = CNOT[n_wires=2] a 0 in () } The circuit gets transformed without issue because the concrete value of the measured wire is unknown. However, execution with n = 0 would raise an error, as the CNOT wires would be (0, 0). Thus, users must by cautious when transforming a circuit. **For ``n`` total wires and ``c`` circuit wires, the number of mid-circuit measurements allowed is ``n - c``.** Using ``defer_measurements`` with program capture enabled introduces new features and restrictions: **New features** * Arbitrary classical processing of mid-circuit measurement values is now possible. With program capture disabled, only limited classical processing, as detailed in the documentation for :func:`~pennylane.measure`. With program capture enabled, any unary or binary ``jax.numpy`` functions that can be applied to scalars can be used with mid-circuit measurements. * Using mid-circuit measurements as gate parameters is now possible. This feature currently has the following restrictions: * Mid-circuit measurement values cannot be used for multiple parameters of the same gate. * Mid-circuit measurement values cannot be used as wires. .. code-block:: python from functools import partial import jax import jax.numpy as jnp qml.capture.enable() @qml.capture.expand_plxpr_transforms @partial(qml.defer_measurements, num_wires=10) def f(): m0 = qml.measure(0) phi = jnp.sin(jnp.pi * m0) qml.RX(phi, 0) return qml.expval(qml.PauliZ(0)) >>> jax.make_jaxpr(f)() { lambda ; . let _:AbstractOperator() = CNOT[n_wires=2] 0 9 a:f64[] = mul 0.0 3.141592653589793 b:f64[] = sin a c:AbstractOperator() = RX[n_wires=1] b 0 _:AbstractOperator() = Controlled[ control_values=(False,) work_wires=Wires([]) ] c 9 d:f64[] = mul 1.0 3.141592653589793 e:f64[] = sin d f:AbstractOperator() = RX[n_wires=1] e 0 _:AbstractOperator() = Controlled[ control_values=(True,) work_wires=Wires([]) ] f 9 g:AbstractOperator() = PauliZ[n_wires=1] 0 h:AbstractMeasurement(n_wires=None) = expval_obs g in (h,) } The above dummy example showcases how the transform is applied when the aforementioned features are used. **What doesn't work** * mid-circuit measurement values cannot be used in the condition for a :func:`~pennylane.while_loop`. * :func:`~pennylane.measure` cannot be used inside the body of loop primitives (:func:`~pennylane.while_loop`, :func:`~pennylane.for_loop`). * If a branch of :func:`~pennylane.cond` uses mid-circuit measurements as its predicate, then all other branches must also use mid-circuit measurement values as predicates. * For an ``n``-parameter gate, mid-circuit measurement values can only be used for 1 of the ``n`` parameters. * :func:`~pennylane.measure` can only be used in the bodies of branches of :func:`~pennylane.cond` if none of the branches use mid-circuit measurements as predicates * :func:`~pennylane.measure` cannot be used inside the body of functions being transformed with :func:`~pennylane.adjoint` or :func:`~pennylane.ctrl`. """ifnotany(isinstance(o,MidMeasureMP)forointape.operations):return(tape,),null_postprocessing_check_tape_validity(tape)new_operations=[]# Find wires that are reused after measurement(measured_wires,reused_measurement_wires,any_repeated_measurements,is_postselecting,)=_collect_mid_measure_info(tape)ifis_postselectingandnotallow_postselect:raiseValueError("Postselection is not allowed on the device with deferred measurements. The device ""must support the Projector gate to apply postselection.")iflen(reused_measurement_wires)>0andnotall(isinstance(w,int)forwintape.wires):raiseValueError("qml.defer_measurements does not support custom wire labels with qubit reuse/reset.")# Apply controlled operations to store measurement outcomes and replace# classically controlled operationscontrol_wires={}cur_wire=(max(tape.wires)+1ifreused_measurement_wiresorany_repeated_measurementselseNone)foropintape.operations:ifisinstance(op,MidMeasureMP):_=measured_wires.pop(0)ifop.postselectisnotNone:withQueuingManager.stop_recording():new_operations.append(qml.Projector([op.postselect],wires=op.wires[0]))# Store measurement outcome in new wire if wire gets reusedifop.wires[0]inreused_measurement_wiresorop.wires[0]inmeasured_wires:control_wires[op.id]=cur_wirewithQueuingManager.stop_recording():new_operations.append(qml.CNOT([op.wires[0],cur_wire]))ifop.reset:withQueuingManager.stop_recording():# No need to manually reset if postselecting on |0>ifop.postselectisNone:new_operations.append(qml.CNOT([cur_wire,op.wires[0]]))elifop.postselect==1:# We know that the measured wire will be in the |1> state if# postselected |1>. So we can just apply a PauliX instead of# a CNOT to resetnew_operations.append(qml.X(op.wires[0]))cur_wire+=1else:control_wires[op.id]=op.wires[0]elifop.__class__.__name__=="Conditional":withQueuingManager.stop_recording():new_operations.extend(_add_control_gate(op,control_wires,reduce_postselected))else:new_operations.append(op)new_measurements=[]formpintape.measurements:ifmp.mvisnotNone:# Update measurement value wires. We can't use `qml.map_wires` because the same# wire can map to different control wires when multiple mid-circuit measurements# are made on the same wire. This mapping is determined by the id of the# MidMeasureMPs. Thus, we need to manually map wires for each MidMeasureMP.ifisinstance(mp.mv,MeasurementValue):new_ms=[qml.map_wires(m,{m.wires[0]:control_wires[m.id]})forminmp.mv.measurements]new_m=MeasurementValue(new_ms,mp.mv.processing_fn)else:new_m=[]forvalinmp.mv:new_ms=[qml.map_wires(m,{m.wires[0]:control_wires[m.id]})forminval.measurements]new_m.append(MeasurementValue(new_ms,val.processing_fn))withQueuingManager.stop_recording():new_mp=(type(mp)(obs=new_m)ifnotisinstance(mp,CountsMP)elseCountsMP(obs=new_m,all_outcomes=mp.all_outcomes))else:new_mp=mpnew_measurements.append(new_mp)new_tape=tape.copy(operations=new_operations,measurements=new_measurements)ifis_postselectingandnew_tape.batch_sizeisnotNone:# Split tapes if broadcasting with postselectionreturnqml.transforms.broadcast_expand(new_tape)return[new_tape],null_postprocessing
def_add_control_gate(op,control_wires,reduce_postselected):"""Helper function to add control gates"""ifreduce_postselected:control=[control_wires[m.id]forminop.meas_val.measurementsifm.postselectisNone]items=op.meas_val.postselected_items()else:control=[control_wires[m.id]forminop.meas_val.measurements]items=op.meas_val.items()new_ops=[]forbranch,valueinitems:ifvalue:# Empty sampling branches can occur when using _postselected_itemsnew_op=(op.baseifbranch==()elsectrl(op.base,control=Wires(control),control_values=branch))new_ops.append(new_op)returnnew_ops