# Copyright 2022 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Contains the condition transform."""importfunctoolsfromfunctoolsimportwrapsfromtypingimportCallable,Optional,Sequence,Type,UnionimportpennylaneasqmlfrompennylaneimportQueuingManagerfrompennylane.capture.flatfnimportFlatFnfrompennylane.compilerimportcompilerfrompennylane.measurementsimportMeasurementValue,MidMeasureMP,get_mcm_predicatesfrompennylane.operationimportAnyWires,Operation,Operatorfrompennylane.ops.op_math.symbolicopimportSymbolicOpdef_add_abstract_shapes(f):"""Add the shapes of all the returned variables before the returned variables. Dynamic shape support currently has a lot of dragons. This function is subject to change at any moment. Use duplicate code till reliable abstractions are found. >>> @qml.capture.FlatFn ... def f(x): ... return x + 1 >>> jax.make_jaxpr(f, abstracted_axes={0:"a"})(jnp.zeros(4)) { lambda ; a:i32[] b:f32[a]. let c:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a d:f32[a] = add b c in (d,) } >>> jax.make_jaxpr(_add_abstract_shapes(f), abstracted_axes={0:"a"})(jnp.zeros(4)) { lambda ; a:i32[] b:f32[a]. let c:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a d:f32[a] = add b c in (a, d) } Now both the dimension of the array and the array are getting returned, rather than just the array. Note that we assume that ``f`` returns a sequence of tensorlikes, like ``FlatFn`` would. """defnew_f(*args,**kwargs):out=f(*args,**kwargs)shapes=[]forxinout:shapes.extend(sforsingetattr(x,"shape",())ifqml.math.is_abstract(s))return*shapes,*outreturnnew_fclassConditionalTransformError(ValueError):"""Error for using qml.cond incorrectly"""
[docs]classConditional(SymbolicOp,Operation):"""A Conditional Operation. Unless you are a Pennylane plugin developer, **you should NOT directly use this class**, instead, use the :func:`qml.cond <.cond>` function. The ``Conditional`` class is a container class that defines an operation that should be applied relative to a single measurement value. Support for executing ``Conditional`` operations is device-dependent. If a device doesn't support mid-circuit measurements natively, then the QNode will apply the :func:`defer_measurements` transform. Args: expr (MeasurementValue): the measurement outcome value to consider then_op (Operation): the PennyLane operation to apply conditionally id (str): custom label given to an operator instance, can be useful for some applications where the instance has to be identified """num_wires=AnyWiresdef__init__(self,expr,then_op:Type[Operation],id=None):self.hyperparameters["meas_val"]=exprself._name=f"Conditional({then_op.name})"super().__init__(then_op,id=id)ifself.grad_recipeisNone:self.grad_recipe=[None]*self.num_params
@propertydefmeas_val(self):"""the measurement outcome value to consider from `expr` argument"""returnself.hyperparameters["meas_val"]@propertydefnum_params(self):returnself.base.num_params@propertydefndim_params(self):returnself.base.ndim_params
classCondCallable:# pylint:disable=too-few-public-methods"""Base class to represent a conditional function with boolean predicates. Args: condition (bool): a conditional expression true_fn (callable): The function to apply if ``condition`` is ``True`` false_fn (callable): The function to apply if ``condition`` is ``False`` elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. Passing ``false_fn`` and ``elifs`` on initialization is optional; these functions can be registered post-initialization via decorators: .. code-block:: python def f(x): @qml.cond(x > 0) def conditional(y): return y ** 2 @conditional.else_if(x < -2) def conditional(y): return y @conditional.otherwise def conditional_false_fn(y): return -y return conditional(x + 1) >>> [f(0.5), f(-3), f(-0.5)] [2.25, -2, -0.5] """def__init__(self,condition,true_fn,false_fn=None,elifs=()):self.preds=[condition]self.branch_fns=[true_fn]self.otherwise_fn=false_fn# when working with `qml.capture.enabled()`,# it's easier to store the original `elifs` argumentself.orig_elifs=elifsiffalse_fnisNoneandnotqml.capture.enabled():self.otherwise_fn=lambda*args,**kwargs:Noneifelifsandnotqml.capture.enabled():elif_preds,elif_fns=list(zip(*elifs))self.preds.extend(elif_preds)self.branch_fns.extend(elif_fns)defelse_if(self,pred):"""Decorator that allows else-if functions to be registered with a corresponding boolean predicate. Args: pred (bool): The predicate that will determine if this branch is executed. Returns: callable: decorator that is applied to the else-if function """defdecorator(branch_fn):self.preds.append(pred)self.branch_fns.append(branch_fn)self.orig_elifs+=((pred,branch_fn),)returnselfreturndecoratordefotherwise(self,otherwise_fn):"""Decorator that registers the function to be run if all conditional predicates (including optional) evaluates to ``False``. Args: otherwise_fn (callable): the function to apply if all ``self.preds`` evaluate to ``False`` """self.otherwise_fn=otherwise_fnreturnself@propertydeffalse_fn(self):"""callable: the function to apply if all ``self.preds`` evaluate to ``False``. Alias for ``otherwise_fn``. """returnself.otherwise_fn@propertydeftrue_fn(self):"""callable: the function to apply if all ``self.condition`` evaluate to ``True``"""returnself.branch_fns[0]@propertydefcondition(self):"""bool: the condition that determines if ``self.true_fn`` is applied"""returnself.preds[0]@propertydefelifs(self):"""(List(Tuple(bool, callable))): a list of (bool, elif_fn) clauses"""returnlist(zip(self.preds[1:],self.branch_fns[1:]))def__call_capture_disabled(self,*args,**kwargs):# python fallbackforpred,branch_fninzip(self.preds,self.branch_fns):ifpred:returnbranch_fn(*args,**kwargs)returnself.false_fn(*args,**kwargs)# pylint: disable=not-callabledef__call_capture_enabled(self,*args,**kwargs):importjax# pylint: disable=import-outside-toplevelcond_prim=_get_cond_qfunc_prim()elifs=([self.orig_elifs]iflen(self.orig_elifs)>0andnotisinstance(self.orig_elifs[0],tuple)elselist(self.orig_elifs))flat_true_fn=FlatFn(self.true_fn)branches=[(self.preds[0],flat_true_fn),*elifs,(True,self.otherwise_fn)]end_const_ind=len(branches)# consts go after the len(branches) conditions, first const at len(branches)conditions=[]jaxpr_branches=[]consts=[]consts_slices=[]abstracted_axes,abstract_shapes=qml.capture.determine_abstracted_axes(args)forpred,fninbranches:conditions.append(pred)iffnisNone:jaxpr_branches.append(None)consts_slices.append(slice(0,0))else:f=FlatFn(functools.partial(fn,**kwargs))ifjax.config.jax_dynamic_shapes:f=_add_abstract_shapes(f)jaxpr=jax.make_jaxpr(f,abstracted_axes=abstracted_axes)(*args)jaxpr_branches.append(jaxpr.jaxpr)consts_slices.append(slice(end_const_ind,end_const_ind+len(jaxpr.consts)))consts+=jaxpr.constsend_const_ind+=len(jaxpr.consts)_validate_jaxpr_returns(jaxpr_branches)flat_args,_=jax.tree_util.tree_flatten(args)results=cond_prim.bind(*conditions,*consts,*abstract_shapes,*flat_args,jaxpr_branches=jaxpr_branches,consts_slices=consts_slices,args_slice=slice(end_const_ind,None),)assertflat_true_fn.out_treeisnotNone,"out_tree of flat_true_fn should exist"results=results[-flat_true_fn.out_tree.num_leaves:]returnjax.tree_util.tree_unflatten(flat_true_fn.out_tree,results)def__call__(self,*args,**kwargs):ifqml.capture.enabled():returnself.__call_capture_enabled(*args,**kwargs)returnself.__call_capture_disabled(*args,**kwargs)
[docs]defcond(condition:Union[MeasurementValue,bool],true_fn:Optional[Callable]=None,false_fn:Optional[Callable]=None,elifs:Sequence=(),):"""Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. This method is restricted to simply branching on mid-circuit measurement results when it is not used with the :func:`~.qjit` decorator. When used with the :func:`~.qjit` decorator, this function allows for general if-elif-else constructs. All ``true_fn``, ``false_fn`` and ``elifs`` branches will be captured by Catalyst, the just-in-time (JIT) compiler, with the executed branch determined at runtime. For more details, please see :func:`catalyst.cond`. .. note:: With the Python interpreter, support for :func:`~.cond` is device-dependent. If a device doesn't support mid-circuit measurements natively, then the QNode will apply the :func:`defer_measurements` transform. .. note:: When used with :func:`~.qjit`, this function only supports the Catalyst compiler. See :func:`catalyst.cond` for more details. Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`, as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>`. .. note:: When used with :func:`.pennylane.capture.enabled`, this function allows for general if-elif-else constructs. As with the JIT mode, all branches are captured, with the executed branch determined at runtime. Each branch can receive arguments, but the arguments must be JAX-compatible. If a branch returns one or more variables, every other branch must return the same abstract values. Args: condition (Union[.MeasurementValue, bool]): a conditional expression that may involve a mid-circuit measurement value (see :func:`.pennylane.measure`). true_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``True`` false_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``False`` elifs (Sequence(Tuple(bool, callable))): A sequence of (bool, elif_fn) clauses. Can only be used when decorated by :func:`~.qjit` or if the condition is not a mid-circuit measurement. Returns: function: A new function that applies the conditional equivalent of ``true_fn``. The returned function takes the same input arguments as ``true_fn``. **Example** .. code-block:: python3 dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def qnode(x, y): qml.Hadamard(0) m_0 = qml.measure(0) qml.cond(m_0, qml.RY)(x, wires=1) qml.Hadamard(2) qml.RY(-np.pi/2, wires=[2]) m_1 = qml.measure(2) qml.cond(m_1 == 0, qml.RX)(y, wires=1) return qml.expval(qml.Z(1)) .. code-block :: pycon >>> first_par = np.array(0.3) >>> sec_par = np.array(1.23) >>> qnode(first_par, sec_par) tensor(0.32677361, requires_grad=True) .. note:: If the first argument of ``cond`` is a measurement value (e.g., ``m_0`` in ``qml.cond(m_0, qml.RY)``), then ``m_0 == 1`` is considered internally. .. warning:: Expressions with boolean logic flow using operators like ``and``, ``or`` and ``not`` are not supported as the ``condition`` argument. While such statements may not result in errors, they may result in incorrect behaviour. In just-in-time (JIT) mode using the :func:`~.qjit` decorator, .. code-block:: python3 dev = qml.device("lightning.qubit", wires=1) @qml.qjit @qml.qnode(dev) def circuit(x: float): def ansatz_true(): qml.RX(x, wires=0) qml.Hadamard(wires=0) def ansatz_false(): qml.RY(x, wires=0) qml.cond(x > 1.4, ansatz_true, ansatz_false)() return qml.expval(qml.Z(0)) >>> circuit(1.4) Array(0.16996714, dtype=float64) >>> circuit(1.6) Array(0., dtype=float64) Additional 'else-if' clauses can also be included via the ``elif`` argument: .. code-block:: python3 @qml.qjit @qml.qnode(dev) def circuit(x): def true_fn(): qml.RX(x, wires=0) def elif_fn(): qml.RY(x, wires=0) def false_fn(): qml.RX(x ** 2, wires=0) qml.cond(x > 2.7, true_fn, false_fn, ((x > 1.4, elif_fn),))() return qml.expval(qml.Z(0)) >>> circuit(1.2) Array(0.13042371, dtype=float64) .. note:: If the above syntax is used with a ``QNode`` that is not decorated with :func:`~pennylane.qjit` and none of the predicates contain mid-circuit measurements, ``qml.cond`` will fall back to using native Python ``if``-``elif``-``else`` blocks. .. details:: :title: Usage Details **Conditional quantum functions** The ``cond`` transform allows conditioning quantum functions too: .. code-block:: python3 dev = qml.device("default.qubit") def qfunc(par, wires): qml.Hadamard(wires[0]) qml.RY(par, wires[0]) @qml.qnode(dev) def qnode(x): qml.Hadamard(0) m_0 = qml.measure(0) qml.cond(m_0, qfunc)(x, wires=[1]) return qml.expval(qml.Z(1)) .. code-block :: pycon >>> par = np.array(0.3) >>> qnode(par) tensor(0.3522399, requires_grad=True) **Postprocessing multiple measurements into a condition** The Boolean condition for ``cond`` may consist of arithmetic expressions of one or multiple mid-circuit measurements: .. code-block:: python3 def cond_fn(mcms): first_term = np.prod(mcms) second_term = (2 ** np.arange(len(mcms))) @ mcms return (1 - first_term) * (second_term > 3) @qml.qnode(dev) def qnode(x): ... mcms = [qml.measure(w) for w in range(4)] qml.cond(cond_fn(mcms), qml.RX)(x, wires=4) ... return qml.expval(qml.Z(1)) **Passing two quantum functions** In the qubit model, single-qubit measurements may result in one of two outcomes. Such measurement outcomes may then be used to create conditional expressions. According to the truth value of the conditional expression passed to ``cond``, the transform can apply a quantum function in both the ``True`` and ``False`` case: .. code-block:: python3 dev = qml.device("default.qubit", wires=2) def qfunc1(x, wires): qml.Hadamard(wires[0]) qml.RY(x, wires[0]) def qfunc2(x, wires): qml.Hadamard(wires[0]) qml.RZ(x, wires[0]) @qml.qnode(dev) def qnode1(x): qml.Hadamard(0) m_0 = qml.measure(0) qml.cond(m_0, qfunc1, qfunc2)(x, wires=[1]) return qml.expval(qml.Z(1)) .. code-block :: pycon >>> par = np.array(0.3) >>> qnode1(par) tensor(-0.1477601, requires_grad=True) The previous QNode is equivalent to using ``cond`` twice, inverting the conditional expression in the second case using the ``~`` unary operator: .. code-block:: python3 @qml.qnode(dev) def qnode2(x): qml.Hadamard(0) m_0 = qml.measure(0) qml.cond(m_0, qfunc1)(x, wires=[1]) qml.cond(~m_0, qfunc2)(x, wires=[1]) return qml.expval(qml.Z(1)) .. code-block :: pycon >>> qnode2(par) tensor(-0.1477601, requires_grad=True) **Quantum functions with different signatures** It may be that the two quantum functions passed to ``qml.cond`` have different signatures. In such a case, ``lambda`` functions taking no arguments can be used with Python closure: .. code-block:: python3 dev = qml.device("default.qubit", wires=2) def qfunc1(x, wire): qml.Hadamard(wire) qml.RY(x, wire) def qfunc2(x, y, z, wire): qml.Hadamard(wire) qml.Rot(x, y, z, wire) @qml.qnode(dev) def qnode(a, x, y, z): qml.Hadamard(0) m_0 = qml.measure(0) qml.cond(m_0, lambda: qfunc1(a, wire=1), lambda: qfunc2(x, y, z, wire=1))() return qml.expval(qml.Z(1)) .. code-block :: pycon >>> par = np.array(0.3) >>> x = np.array(1.2) >>> y = np.array(1.1) >>> z = np.array(0.3) >>> qnode(par, x, y, z) tensor(-0.30922805, requires_grad=True) """ifactive_jit:=compiler.active_compiler():available_eps=compiler.AvailableCompilers.names_entrypointsops_loader=available_eps[active_jit]["ops"].load()iftrue_fnisNone:returnops_loader.cond(condition)cond_func=ops_loader.cond(condition)(true_fn)# Optional 'elif' branchesforcond_val,elif_fninelifs:cond_func.else_if(cond_val)(elif_fn)# Optional 'else' branchiffalse_fn:cond_func.otherwise(false_fn)returncond_funcifnotisinstance(condition,MeasurementValue):# The condition is not a mid-circuit measurement. This will also work# when the condition is a mid-circuit measurement but qml.capture.enabled()iftrue_fnisNone:returnlambdafn:CondCallable(condition,fn)returnCondCallable(condition,true_fn,false_fn,elifs)iftrue_fnisNone:raiseTypeError("cond missing 1 required positional argument: 'true_fn'.\n""Note that if the conditional includes a mid-circuit measurement, ""qml.cond cannot be used as a decorator.\n""Instead, please use the form qml.cond(condition, true_fn, false_fn).")ifelifs:raiseConditionalTransformError("'elif' branches are not supported when not using @qjit and with qml.capture.disabled()\n""if the conditional includes mid-circuit measurements.")ifcallable(true_fn):# We assume that the callable is an operation or a quantum functionwith_meas_err=("Only quantum functions that contain no measurements can be applied conditionally.")@wraps(true_fn)defwrapper(*args,**kwargs):# We assume that the callable is a quantum functionrecorded_ops=[aforainargsifisinstance(a,Operator)]+[kforkinkwargs.values()ifisinstance(k,Operator)]# This will dequeue all operators passed in as arguments to the qfunc that is# being conditioned. These are queued incorrectly due to be fully constructed# before the wrapper function is called.ifrecorded_opsandQueuingManager.recording():foropinrecorded_ops:QueuingManager.remove(op)# 1. Apply true_fn conditionallyqscript=qml.tape.make_qscript(true_fn)(*args,**kwargs)ifqscript.measurements:raiseConditionalTransformError(with_meas_err)foropinqscript.operations:ifisinstance(op,MidMeasureMP):raiseConditionalTransformError(with_meas_err)Conditional(condition,op)iffalse_fnisnotNone:# 2. Apply false_fn conditionallyelse_qscript=qml.tape.make_qscript(false_fn)(*args,**kwargs)ifelse_qscript.measurements:raiseConditionalTransformError(with_meas_err)inverted_condition=~conditionforopinelse_qscript.operations:ifisinstance(op,MidMeasureMP):raiseConditionalTransformError(with_meas_err)Conditional(inverted_condition,op)else:raiseConditionalTransformError("Only operations and quantum functions with no measurements can be applied conditionally.")returnwrapper
def_register_custom_staging_rule(cond_prim):# see https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L3538# and https://github.com/jax-ml/jax/blob/9e62994bce7c7fcbb2f6a50c9ef89526cd2c2be6/jax/_src/lax/lax.py#L208# for reference to how jax is handling staging rules for dynamic shapes in v0.4.28# see also capture/intro_to_dynamic_shapes.mdimportjax# pylint: disable=import-outside-toplevelfromjax._src.interpretersimportpartial_evalaspe# pylint: disable=import-outside-topleveldef_tracer_and_outvar(jaxpr_trace:pe.DynamicJaxprTrace,outvar:jax.core.Var,env:dict[jax.core.Var,jax.core.Var],)->tuple[pe.DynamicJaxprTracer,jax.core.Var]:""" Create a new tracer and returned var from the true branch outvar returned vars are cached in env for use in future shapes """ifnothasattr(outvar.aval,"shape"):out_tracer=pe.DynamicJaxprTracer(jaxpr_trace,outvar.aval)returnout_tracer,jaxpr_trace.makevar(out_tracer)new_shape=[sifisinstance(s,int)elseenv[s]forsinoutvar.aval.shape]new_aval=jax.core.DShapedArray(tuple(new_shape),outvar.aval.dtype)out_tracer=pe.DynamicJaxprTracer(jaxpr_trace,new_aval)new_var=jaxpr_trace.makevar(out_tracer)ifnotisinstance(outvar,jax.core.Literal):env[outvar]=new_varreturnout_tracer,new_vardefcustom_staging_rule(jaxpr_trace:pe.DynamicJaxprTrace,*tracers:pe.DynamicJaxprTracer,**params)->Union[Sequence[pe.DynamicJaxprTracer],pe.DynamicJaxprTracer]:""" Add new jaxpr equation to the jaxpr_trace and return new tracers. """ifnotjax.config.jax_dynamic_shapes:# fallback to normal behaviorreturnjaxpr_trace.default_process_primitive(cond_prim,tracers,params)true_outvars=params["jaxpr_branches"][0].outvarsenv:dict[jax.core.Var,jax.core.Var]={}# branch var to new equation variftrue_outvars:out_tracers,returned_vars=tuple(zip(*(_tracer_and_outvar(jaxpr_trace,var,env)forvarintrue_outvars),strict=True,))else:out_tracers,returned_vars=(),()invars=[jaxpr_trace.getvar(x)forxintracers]eqn=pe.new_jaxpr_eqn(invars,returned_vars,cond_prim,params,jax.core.no_effects,)jaxpr_trace.frame.add_eqn(eqn)returnout_tracerspe.custom_staging_rules[cond_prim]=custom_staging_ruledef_aval_mismatch_error(branch_type,branch_index,i,outval,expected_outval):raiseValueError(f"Mismatch in output abstract values in {branch_type} branch "f"#{branch_index} at position {i}: "f"{outval} vs {expected_outval}.")def_validate_abstract_values(outvals:list,expected_outvals:list,branch_type:str,branch_index:int)->None:"""Ensure the collected abstract values match the expected ones."""importjax# pylint: disable=import-outside-topleveliflen(outvals)!=len(expected_outvals):msg=(f"Mismatch in number of output variables in {branch_type} branch"f" #{branch_index}: {len(outvals)} vs {len(expected_outvals)} "f" for {outvals} and {expected_outvals}")ifjax.config.jax_dynamic_shapes:msg+="\n This may be due to different sized shapes when dynamic shapes are enabled."raiseValueError(msg)fori,(outval,expected_outval)inenumerate(zip(outvals,expected_outvals)):ifjax.config.jax_dynamic_shapes:# we need to be a bit more manual with the comparison.iftype(outval)!=type(expected_outval):# pylint: disable=unidiomatic-typecheck_aval_mismatch_error(branch_type,branch_index,i,outval,expected_outval)ifgetattr(outval,"dtype",None)!=getattr(expected_outval,"dtype",None):_aval_mismatch_error(branch_type,branch_index,i,outval,expected_outval)shape1=getattr(outval,"shape",())shape2=getattr(expected_outval,"shape",())fors1,s2inzip(shape1,shape2,strict=True):ifisinstance(s1,jax.core.Var)!=isinstance(s2,jax.core.Var):_aval_mismatch_error(branch_type,branch_index,i,outval,expected_outval)elifisinstance(s1,int)ands1!=s2:_aval_mismatch_error(branch_type,branch_index,i,outval,expected_outval)elifoutval!=expected_outval:_aval_mismatch_error(branch_type,branch_index,i,outval,expected_outval)def_validate_jaxpr_returns(jaxpr_branches):out_avals_true=[out.avalforoutinjaxpr_branches[0].outvars]foridx,jaxpr_branchinenumerate(jaxpr_branches):ifidx==0:continueifjaxpr_branchisNone:ifout_avals_true:raiseValueError("The false branch must be provided if the true branch returns any variables")# this is tested, but coverage does not pick it upcontinue# pragma: no coverout_avals_branch=[out.avalforoutinjaxpr_branch.outvars]branch_type="elif"ifidx<len(jaxpr_branches)-1else"false"_validate_abstract_values(out_avals_branch,out_avals_true,branch_type,idx-1)@functools.lru_cachedef_get_cond_qfunc_prim():"""Get the cond primitive for quantum functions."""# pylint: disable=import-outside-toplevelfrompennylane.capture.custom_primitivesimportNonInterpPrimitivecond_prim=NonInterpPrimitive("cond")cond_prim.multiple_results=Truecond_prim.prim_type="higher_order"_register_custom_staging_rule(cond_prim)@cond_prim.def_abstract_evaldef_(*_,jaxpr_branches,**__):return[out.avalforoutinjaxpr_branches[0].outvars]@cond_prim.def_impldef_(*all_args,jaxpr_branches,consts_slices,args_slice):n_branches=len(jaxpr_branches)conditions=all_args[:n_branches]args=all_args[args_slice]# Find predicates that use mid-circuit measurements. We don't check the last# condition as that is always `True`.mcm_conditions=tuple(predforpredinconditions[:-1]ifisinstance(pred,MeasurementValue))iflen(mcm_conditions)!=0:iflen(mcm_conditions)!=len(conditions)-1:raiseConditionalTransformError("Cannot use qml.cond with a combination of mid-circuit measurements ""and other classical conditions as predicates.")conditions=get_mcm_predicates(mcm_conditions)forpred,jaxpr,const_sliceinzip(conditions,jaxpr_branches,consts_slices):consts=all_args[const_slice]ifjaxprisNone:continueifisinstance(pred,qml.measurements.MeasurementValue):withqml.queuing.AnnotatedQueue()asq:out=qml.capture.eval_jaxpr(jaxpr,consts,*args)iflen(out)!=0:raiseConditionalTransformError("Only quantum functions without return values can be applied ""conditionally with mid-circuit measurement predicates.")forwrapped_opinq:Conditional(pred,wrapped_op.obj)elifpred:returnqml.capture.eval_jaxpr(jaxpr,consts,*args)return()returncond_prim