Source code for pennylane.measurements.mid_measure
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This module contains the qml.measure measurement."""importuuidfromcollections.abcimportHashablefromfunctoolsimportlru_cachefromtypingimportGeneric,Optional,TypeVar,Unionimportpennylaneasqmlfrompennylane.wiresimportWiresfrom.measurementsimportMeasurementProcess,MidMeasure
[docs]defmeasure(wires:Union[Hashable,Wires],reset:bool=False,postselect:Optional[int]=None):r"""Perform a mid-circuit measurement in the computational basis on the supplied qubit. Computational basis measurements are performed using the 0, 1 convention rather than the ±1 convention. Measurement outcomes can be used to conditionally apply operations, and measurement statistics can be gathered and returned by a quantum function. If a device doesn't support mid-circuit measurements natively, then the QNode will apply the :func:`defer_measurements` transform. **Example:** .. code-block:: python3 dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(x, y): qml.RY(x, wires=0) qml.CNOT(wires=[0, 1]) m_0 = qml.measure(1) qml.cond(m_0, qml.RY)(y, wires=0) return qml.probs(wires=[0]) Executing this QNode: >>> pars = np.array([0.643, 0.246], requires_grad=True) >>> func(*pars) tensor([0.90165331, 0.09834669], requires_grad=True) Wires can be reused after measurement. Moreover, measured wires can be reset to the :math:`|0 \rangle` state by setting ``reset=True``. .. code-block:: python3 dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(): qml.X(1) m_0 = qml.measure(1, reset=True) return qml.probs(wires=[1]) Executing this QNode: >>> func() tensor([1., 0.], requires_grad=True) Mid-circuit measurements can be manipulated using the following arithmetic operators: ``+``, ``-``, ``*``, ``/``, ``~`` (not), ``&`` (and), ``|`` (or), ``==``, ``<=``, ``>=``, ``<``, ``>`` with other mid-circuit measurements or scalars. .. Note :: Python ``not``, ``and``, ``or``, do not work since these do not have dunder methods. Instead use ``~``, ``&``, ``|``. Mid-circuit measurement results can be processed with the usual measurement functions such as :func:`~.expval`. For QNodes with finite shots, :func:`~.sample` applied to a mid-circuit measurement result will return a binary sequence of samples. See :ref:`here <mid_circuit_measurements_statistics>` for more details. .. Note :: Computational basis measurements are performed using the 0, 1 convention rather than the ±1 convention. So, for example, ``expval(qml.measure(0))`` and ``expval(qml.Z(0))`` will give different answers. .. code-block:: python3 dev = qml.device("default.qubit") @qml.qnode(dev) def circuit(x, y): qml.RX(x, wires=0) qml.RY(y, wires=1) m0 = qml.measure(1) return ( qml.sample(m0), qml.expval(m0), qml.var(m0), qml.probs(op=m0), qml.counts(op=m0), ) >>> circuit(1.0, 2.0, shots=1000) (array([0, 1, 1, ..., 1, 1, 1])), 0.702, 0.20919600000000002, array([0.298, 0.702]), {0: 298, 1: 702}) Args: wires (Wires): The wire to measure. reset (Optional[bool]): Whether to reset the wire to the :math:`|0 \rangle` state after measurement. postselect (Optional[int]): Which basis state to postselect after a mid-circuit measurement. None by default. If postselection is requested, only the post-measurement state that is used for postselection will be considered in the remaining circuit. Returns: MidMeasureMP: measurement process instance Raises: QuantumFunctionError: if multiple wires were specified .. details:: :title: Postselection Postselection discards outcomes that do not meet the criteria provided by the ``postselect`` argument. For example, specifying ``postselect=1`` on wire 0 would be equivalent to projecting the state vector onto the :math:`|1\rangle` state on wire 0: .. code-block:: python3 dev = qml.device("default.qubit") @qml.qnode(dev) def func(x): qml.RX(x, wires=0) m0 = qml.measure(0, postselect=1) qml.cond(m0, qml.X)(wires=1) return qml.sample(wires=1) By postselecting on ``1``, we only consider the ``1`` measurement outcome on wire 0. So, the probability of measuring ``1`` on wire 1 after postselection should also be 1. Executing this QNode with 10 shots: >>> func(np.pi / 2, shots=10) array([1, 1, 1, 1, 1, 1, 1]) Note that only 7 samples are returned. This is because samples that do not meet the postselection criteria are thrown away. If postselection is requested on a state with zero probability of being measured, the result may contain ``NaN`` or ``Inf`` values: .. code-block:: python3 dev = qml.device("default.qubit") @qml.qnode(dev) def func(x): qml.RX(x, wires=0) m0 = qml.measure(0, postselect=1) qml.cond(m0, qml.X)(wires=1) return qml.probs(wires=1) >>> func(0.0) tensor([nan, nan], requires_grad=True) In the case of ``qml.sample``, an empty array will be returned: .. code-block:: python3 dev = qml.device("default.qubit") @qml.qnode(dev) def func(x): qml.RX(x, wires=0) m0 = qml.measure(0, postselect=1) qml.cond(m0, qml.X)(wires=1) return qml.sample(wires=[0, 1]) >>> func(0.0, shots=[10, 10]) (array([], shape=(0, 2), dtype=int64), array([], shape=(0, 2), dtype=int64)) .. note:: Currently, postselection support is only available on ``default.qubit``. Using postselection on other devices will raise an error. .. warning:: All measurements are supported when using postselection. However, postselection on a zero probability state can cause some measurements to break: * With finite shots, one must be careful when measuring ``qml.probs`` or ``qml.counts``, as these measurements will raise errors if there are no valid samples after postselection. This will occur with postselection states that have zero or close to zero probability. * With analytic execution, ``qml.mutual_info`` will raise errors when using any interfaces except ``jax``, and ``qml.vn_entropy`` will raise an error with the ``tensorflow`` interface when the postselection state has zero probability. * When using JIT, ``QNode``'s may have unexpected behaviour when postselection on a zero probability state is performed. Due to floating point precision, the zero probability may not be detected, thus letting execution continue as normal without ``NaN`` or ``Inf`` values or empty samples, leading to unexpected or incorrect results. """ifqml.capture.enabled():primitive=_create_mid_measure_primitive()returnprimitive.bind(wires,reset=reset,postselect=postselect)return_measure_impl(wires,reset=reset,postselect=postselect)
def_measure_impl(wires:Union[Hashable,Wires],reset:Optional[bool]=False,postselect:Optional[int]=None):"""Concrete implementation of qml.measure"""wires=Wires(wires)iflen(wires)>1:raiseqml.QuantumFunctionError("Only a single qubit can be measured in the middle of the circuit")# Create a UUID and a map between MP and MV to support serializationmeasurement_id=str(uuid.uuid4())mp=MidMeasureMP(wires=wires,reset=reset,postselect=postselect,id=measurement_id)returnMeasurementValue([mp],processing_fn=lambdav:v)@lru_cachedef_create_mid_measure_primitive():"""Create a primitive corresponding to an mid-circuit measurement type. Called when using :func:`~pennylane.measure`. Returns: jax.core.Primitive: A new jax primitive corresponding to a mid-circuit measurement. """# pylint: disable=import-outside-toplevelimportjaxfrompennylane.capture.custom_primitivesimportNonInterpPrimitivemid_measure_p=NonInterpPrimitive("measure")@mid_measure_p.def_impldef_(wires,reset=False,postselect=None):return_measure_impl(wires,reset=reset,postselect=postselect)@mid_measure_p.def_abstract_evaldef_(*_,**__):dtype=jax.numpy.int64ifjax.config.jax_enable_x64elsejax.numpy.int32returnjax.core.ShapedArray((),dtype)returnmid_measure_pT=TypeVar("T")
[docs]classMidMeasureMP(MeasurementProcess):"""Mid-circuit measurement. This class additionally stores information about unknown measurement outcomes in the qubit model. Measurements on a single qubit in the computational basis are assumed. Please refer to :func:`pennylane.measure` for detailed documentation. Args: wires (.Wires): The wires the measurement process applies to. This can only be specified if an observable was not provided. reset (bool): Whether to reset the wire after measurement. postselect (Optional[int]): Which basis state to postselect after a mid-circuit measurement. None by default. If postselection is requested, only the post-measurement state that is used for postselection will be considered in the remaining circuit. id (str): Custom label given to a measurement instance. """_shortname=MidMeasure#! Note: deprecated. Change the value to "measure" in v0.42def_flatten(self):metadata=(("wires",self.raw_wires),("reset",self.reset),("id",self.id))return(None,None),metadatadef__init__(self,wires:Optional[Wires]=None,reset:Optional[bool]=False,postselect:Optional[int]=None,id:Optional[str]=None,):self.batch_size=Nonesuper().__init__(wires=Wires(wires),id=id)self.reset=resetself.postselect=postselect# pylint: disable=arguments-renamed, arguments-differ@classmethoddef_primitive_bind_call(cls,wires=None,reset=False,postselect=None,id=None):wires=()ifwiresisNoneelsewiresreturncls._wires_primitive.bind(*wires,reset=reset,postselect=postselect,id=id)@classmethoddef_abstract_eval(cls,n_wires:Optional[int]=None,has_eigvals=False,shots:Optional[int]=None,num_device_wires:int=0,)->tuple:return(),int
[docs]deflabel(self,decimals=None,base_label=None,cache=None):# pylint: disable=unused-argumentr"""How the mid-circuit measurement is represented in diagrams and drawings. Args: decimals=None (Int): If ``None``, no parameters are included. Else, how to round the parameters. base_label=None (Iterable[str]): overwrite the non-parameter component of the label. Must be same length as ``obs`` attribute. cache=None (dict): dictionary that carries information between label calls in the same drawing Returns: str: label to use in drawings """_label="┤↗"ifself.postselectisnotNone:_label+="₁"ifself.postselect==1else"₀"_label+="├"ifnotself.resetelse"│ │0⟩"return_label
@propertydefsamples_computational_basis(self):returnFalse@propertydef_queue_category(self):return"_ops"@propertydefhash(self):"""int: Returns an integer hash uniquely representing the measurement process"""fingerprint=(self.__class__.__name__,tuple(self.wires.tolist()),self.id,)returnhash(fingerprint)@propertydefdata(self):"""The data of the measurement. Needed to match the Operator API."""return[]@propertydefname(self):"""The name of the measurement. Needed to match the Operator API."""returnself.__class__.__name__@propertydefnum_params(self):"""The number of parameters. Needed to match the Operator API."""return0
[docs]classMeasurementValue(Generic[T]):"""A class representing unknown measurement outcomes in the qubit model. Measurements on a single qubit in the computational basis are assumed. Args: measurements (list[.MidMeasureMP]): The measurement(s) that this object depends on. processing_fn (callable): A lazily transformation applied to the measurement values. """name="MeasurementValue"def__init__(self,measurements,processing_fn):self.measurements=measurementsself.processing_fn=processing_fn
[docs]defitems(self):"""A generator representing all the possible outcomes of the MeasurementValue."""num_meas=len(self.measurements)foriinrange(2**num_meas):branch=tuple(int(b)forbinf"{i:0{num_meas}b}")yieldbranch,self.processing_fn(*branch)
[docs]defpostselected_items(self):"""A generator representing all the possible outcomes of the MeasurementValue, taking postselection into account."""# pylint: disable=stop-iteration-returnps={i:pfori,minenumerate(self.measurements)if(p:=m.postselect)isnotNone}num_non_ps=len(self.measurements)-len(ps)ifnum_non_ps==0:yield(),self.processing_fn(*ps.values())returnforiinrange(2**num_non_ps):# Create the branch ignoring postselected measurementsnon_ps_branch=tuple(int(b)forbinf"{i:0{num_non_ps}b}")# We want a consumable iterable and the static tuple abovenon_ps_branch_iter=iter(non_ps_branch)# Extend the branch to include postselected measurementsfull_branch=tuple(ps[j]ifjinpselsenext(non_ps_branch_iter)forjinrange(len(self.measurements)))# Return the reduced non-postselected branch and the procesing function# evaluated on the full branchyieldnon_ps_branch,self.processing_fn(*full_branch)
@propertydefwires(self):"""Returns a list of wires corresponding to the mid-circuit measurements."""returnWires.all_wires([m.wiresforminself.measurements])@propertydefbranches(self):"""A dictionary representing all possible outcomes of the MeasurementValue."""ret_dict={}num_meas=len(self.measurements)foriinrange(2**num_meas):branch=tuple(int(b)forbinf"{i:0{num_meas}b}")ret_dict[branch]=self.processing_fn(*branch)returnret_dict
[docs]defmap_wires(self,wire_map):"""Returns a copy of the current ``MeasurementValue`` with the wires of each measurement changed according to the given wire map. Args: wire_map (dict): dictionary containing the old wires as keys and the new wires as values Returns: MeasurementValue: new ``MeasurementValue`` instance with measurement wires mapped """mapped_measurements=[m.map_wires(wire_map)forminself.measurements]returnMeasurementValue(mapped_measurements,self.processing_fn)
def_transform_bin_op(self,base_bin,other):"""Helper function for defining dunder binary operations."""ifisinstance(other,MeasurementValue):# pylint: disable=protected-accessreturnself._merge(other)._apply(lambdat:base_bin(t[0],t[1]))# if `other` is not a MeasurementValue then apply it to each branchreturnself._apply(lambdav:base_bin(v,other))def__invert__(self):"""Return a copy of the measurement value with an inverted control value."""returnself._apply(qml.math.logical_not)def__bool__(self)->bool:raiseValueError("The truth value of a MeasurementValue is undefined. To condition on a MeasurementValue, please use qml.cond instead.")def__eq__(self,other):returnself._transform_bin_op(lambdaa,b:a==b,other)def__ne__(self,other):returnself._transform_bin_op(lambdaa,b:a!=b,other)def__add__(self,other):returnself._transform_bin_op(lambdaa,b:a+b,other)def__radd__(self,other):returnself._apply(lambdav:other+v)def__sub__(self,other):returnself._transform_bin_op(lambdaa,b:a-b,other)def__rsub__(self,other):returnself._apply(lambdav:other-v)def__mul__(self,other):returnself._transform_bin_op(lambdaa,b:a*b,other)def__rmul__(self,other):returnself._apply(lambdav:other*qml.math.cast_like(v,other))def__truediv__(self,other):returnself._transform_bin_op(lambdaa,b:a/b,other)def__rtruediv__(self,other):returnself._apply(lambdav:other/v)def__lt__(self,other):returnself._transform_bin_op(lambdaa,b:a<b,other)def__le__(self,other):returnself._transform_bin_op(lambdaa,b:a<=b,other)def__gt__(self,other):returnself._transform_bin_op(lambdaa,b:a>b,other)def__ge__(self,other):returnself._transform_bin_op(lambdaa,b:a>=b,other)def__and__(self,other):returnself._transform_bin_op(qml.math.logical_and,other)def__or__(self,other):returnself._transform_bin_op(qml.math.logical_or,other)def__mod__(self,other):returnself._transform_bin_op(qml.math.mod,other)def__xor__(self,other):returnself._transform_bin_op(qml.math.logical_xor,other)def_apply(self,fn):"""Apply a post computation to this measurement"""returnMeasurementValue(self.measurements,lambda*x:fn(self.processing_fn(*x)))
[docs]defconcretize(self,measurements:dict):"""Returns a concrete value from a dictionary of hashes with concrete values."""values=tuple(measurements[meas]formeasinself.measurements)returnself.processing_fn(*values)
def_merge(self,other:"MeasurementValue"):"""Merge two measurement values"""# create a new merged list with no duplicates and in lexical orderingmerged_measurements=list(set(self.measurements).union(set(other.measurements)))merged_measurements.sort(key=lambdam:m.id)# create a new function that selects the correct indices for each sub functiondefmerged_fn(*x):sub_args_1=(x[i]foriin[merged_measurements.index(m)forminself.measurements])sub_args_2=(x[i]foriin[merged_measurements.index(m)forminother.measurements])out_1=self.processing_fn(*sub_args_1)out_2=other.processing_fn(*sub_args_2)returnout_1,out_2returnMeasurementValue(merged_measurements,merged_fn)def__getitem__(self,i):branch=tuple(int(b)forbinf"{i:0{len(self.measurements)}b}")returnself.processing_fn(*branch)def__str__(self):lines=[]num_meas=len(self.measurements)foriinrange(2**num_meas):branch=tuple(int(b)forbinf"{i:0{num_meas}b}")id_branch_mapping=[f"{self.measurements[j].id}={branch[j]}"forjinrange(len(branch))]lines.append("if "+",".join(id_branch_mapping)+" => "+str(self.processing_fn(*branch)))return"\n".join(lines)def__repr__(self):returnf"MeasurementValue(wires={self.wires.tolist()})"
[docs]defget_mcm_predicates(conditions:tuple[MeasurementValue])->list[MeasurementValue]:r"""Function to make mid-circuit measurement predicates mutually exclusive. The ``conditions`` are predicates to the ``if`` and ``elif`` branches of ``qml.cond``. This function updates all the ``MeasurementValue``\ s in ``conditions`` such that reconciling the correct branch is never ambiguous. Args: conditions (Sequence[MeasurementValue]): Sequence containing predicates for ``if`` and all ``elif`` branches of a function decorated with :func:`~pennylane.cond`. Returns: Sequence[MeasurementValue]: Updated sequence of mutually exclusive predicates. """new_conds=[conditions[0]]false_cond=~conditions[0]forcinconditions[1:]:new_conds.append(false_cond&c)false_cond=false_cond&~cnew_conds.append(false_cond)returnnew_conds
[docs]deffind_post_processed_mcms(circuit):"""Return the subset of mid-circuit measurements which are required for post-processing. This includes any mid-circuit measurement that is post-selected or the object of a terminal measurement. """post_processed_mcms=set(opforopincircuit.operationsifisinstance(op,MidMeasureMP)andop.postselectisnotNone)formincircuit.measurements:ifisinstance(m.mv,list):formvinm.mv:post_processed_mcms=post_processed_mcms|set(mv.measurements)elifm.mvisnotNone:post_processed_mcms=post_processed_mcms|set(m.mv.measurements)returnpost_processed_mcms