Source code for pennylane.capture.capture_measurements
# Copyright 2024 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This submodule defines the abstract classes and primitives for capturing measurements."""fromcollections.abcimportCallablefromfunctoolsimportlru_cachefromtypingimportOptional,Typeimportpennylaneasqmlhas_jax=Truetry:importjaxexceptImportError:has_jax=False@lru_cachedef_get_abstract_measurement():ifnothas_jax:# pragma: no coverraiseImportError("Jax is required for plxpr.")# pragma: no coverclassAbstractMeasurement(jax.core.AbstractValue):"""An abstract measurement. Args: abstract_eval (Callable): See :meth:`~.MeasurementProcess._abstract_eval`. A function of ``n_wires``, ``has_eigvals``, ``num_device_wires`` and ``shots`` that returns a shape and numeric type. n_wires=None (Optional[int]): the number of wires has_eigvals=False (bool): Whether or not the measurement contains eigenvalues in a wires+eigvals diagonal representation. """def__init__(self,abstract_eval:Callable,n_wires:Optional[int]=None,has_eigvals:bool=False):self._abstract_eval=abstract_evalself._n_wires=n_wiresself.has_eigvals:bool=has_eigvalsdefabstract_eval(self,num_device_wires:int,shots:int)->tuple[tuple,type]:"""Calculate the shape and dtype for an evaluation with specified number of device wires and shots. """returnself._abstract_eval(n_wires=self._n_wires,has_eigvals=self.has_eigvals,num_device_wires=num_device_wires,shots=shots,)@propertydefn_wires(self)->Optional[int]:"""The number of wires for a wire based measurement. Options are: * ``None``: The measurement is observable based or single mcm based * ``0``: The measurement is broadcasted across all available devices wires * ``int>0``: A wire or mcm based measurement with specified wires or mid circuit measurements. """returnself._n_wiresdef__repr__(self):ifself.has_eigvals:returnf"AbstractMeasurement(n_wires={self.n_wires}, has_eigvals=True)"returnf"AbstractMeasurement(n_wires={self.n_wires})"# pylint: disable=missing-function-docstringdefat_least_vspace(self):# TODO: investigate the proper definition of this methodraiseNotImplementedError# pylint: disable=missing-function-docstringdefjoin(self,other):# TODO: investigate the proper definition of this methodraiseNotImplementedError# pylint: disable=missing-function-docstringdefupdate(self,**kwargs):# TODO: investigate the proper definition of this methodraiseNotImplementedErrordef__eq__(self,other):returnisinstance(other,AbstractMeasurement)def__hash__(self):returnhash("AbstractMeasurement")jax.core.raise_to_shaped_mappings[AbstractMeasurement]=lambdaaval,_:avalreturnAbstractMeasurement
[docs]defcreate_measurement_obs_primitive(measurement_type:Type["qml.measurements.MeasurementProcess"],name:str)->Optional["jax.core.Primitive"]:"""Create a primitive corresponding to the input type where the abstract inputs are an operator. Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to set the ``MeasurementProcesss._obs_primitive`` property. Args: measurement_type (type): a subclass of :class:`~.MeasurementProcess` name (str): the preferred string name for the class. For example, ``"expval"``. ``"_obs"`` is appended to this name for the name of the primitive. Returns: Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. """ifnothas_jax:returnNonefrom.custom_primitivesimportNonInterpPrimitive# pylint: disable=import-outside-toplevelprimitive=NonInterpPrimitive(name+"_obs")primitive.prim_type="measurement"@primitive.def_impldef_(obs,**kwargs):returntype.__call__(measurement_type,obs=obs,**kwargs)abstract_type=_get_abstract_measurement()@primitive.def_abstract_evaldef_(*_,**__):abstract_eval=measurement_type._abstract_eval# pylint: disable=protected-accessreturnabstract_type(abstract_eval,n_wires=None)returnprimitive
[docs]defcreate_measurement_mcm_primitive(measurement_type:Type["qml.measurements.MeasurementProcess"],name:str)->Optional["jax.core.Primitive"]:"""Create a primitive corresponding to the input type where the abstract inputs are classical mid circuit measurement results. Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to set the ``MeasurementProcesss._mcm_primitive`` property. Args: measurement_type (type): a subclass of :class:`~.MeasurementProcess` name (str): the preferred string name for the class. For example, ``"expval"``. ``"_mcm"`` is appended to this name for the name of the primitive. Returns: Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. """ifnothas_jax:returnNonefrom.custom_primitivesimportNonInterpPrimitive# pylint: disable=import-outside-toplevelprimitive=NonInterpPrimitive(name+"_mcm")primitive.prim_type="measurement"@primitive.def_impldef_(*mcms,single_mcm=True,**kwargs):returntype.__call__(measurement_type,obs=mcms[0]ifsingle_mcmelsemcms,**kwargs)abstract_type=_get_abstract_measurement()@primitive.def_abstract_evaldef_(*mcms,**__):abstract_eval=measurement_type._abstract_eval# pylint: disable=protected-accessreturnabstract_type(abstract_eval,n_wires=len(mcms))returnprimitive
[docs]defcreate_measurement_wires_primitive(measurement_type:type,name:str)->Optional["jax.core.Primitive"]:"""Create a primitive corresponding to the input type where the abstract inputs are the wires. Called by default when defining any class inheriting from :class:`~.MeasurementProcess`, and is used to set the ``MeasurementProcesss._wires_primitive`` property. Args: measurement_type (type): a subclass of :class:`~.MeasurementProcess` name (str): the preferred string name for the class. For example, ``"expval"``. ``"_wires"`` is appended to this name for the name of the primitive. Returns: Optional[jax.core.Primitive]: A new jax primitive. ``None`` is returned if jax is not available. """ifnothas_jax:returnNonefrom.custom_primitivesimportNonInterpPrimitive# pylint: disable=import-outside-toplevelprimitive=NonInterpPrimitive(name+"_wires")primitive.prim_type="measurement"@primitive.def_impldef_(*args,has_eigvals=False,**kwargs):ifhas_eigvals:wires=qml.wires.Wires(args[:-1])kwargs["eigvals"]=args[-1]else:wires=qml.wires.Wires(args)returntype.__call__(measurement_type,wires=wires,**kwargs)abstract_type=_get_abstract_measurement()@primitive.def_abstract_evaldef_(*args,has_eigvals=False,**_):abstract_eval=measurement_type._abstract_eval# pylint: disable=protected-accessn_wires=len(args)-1ifhas_eigvalselselen(args)returnabstract_type(abstract_eval,n_wires=n_wires,has_eigvals=has_eigvals)returnprimitive