Source code for pennylane.capture.capture_operators
# Copyright 2024 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This submodule defines the abstract classes and primitives for capturing operators."""fromfunctoolsimportlru_cachefromtypingimportOptional,Typeimportpennylaneasqmlhas_jax=Truetry:importjaxexceptImportError:has_jax=False@lru_cache# construct the first time lazilydef_get_abstract_operator()->type:"""Create an AbstractOperator once in a way protected from lack of a jax install."""ifnothas_jax:# pragma: no coverraiseImportError("Jax is required for plxpr.")# pragma: no coverclassAbstractOperator(jax.core.AbstractValue):"""An operator captured into plxpr."""# pylint: disable=missing-function-docstringdefat_least_vspace(self):# TODO: investigate the proper definition of this methodraiseNotImplementedError# pylint: disable=missing-function-docstringdefjoin(self,other):# TODO: investigate the proper definition of this methodraiseNotImplementedError# pylint: disable=missing-function-docstringdefupdate(self,**kwargs):# TODO: investigate the proper definition of this methodraiseNotImplementedErrordef__eq__(self,other):returnisinstance(other,AbstractOperator)def__hash__(self):returnhash("AbstractOperator")@staticmethoddef_matmul(*args):returnqml.prod(*args)@staticmethoddef_mul(a,b):returnqml.s_prod(b,a)@staticmethoddef_rmul(a,b):returnqml.s_prod(b,a)@staticmethoddef_add(a,b):returnqml.sum(a,b)@staticmethoddef_pow(a,b):returnqml.pow(a,b)jax.core.raise_to_shaped_mappings[AbstractOperator]=lambdaaval,_:avalreturnAbstractOperator
[docs]defcreate_operator_primitive(operator_type:Type["qml.operation.Operator"],)->Optional["jax.core.Primitive"]:"""Create a primitive corresponding to an operator type. Called when defining any :class:`~.Operator` subclass, and is used to set the ``Operator._primitive`` class property. Args: operator_type (type): a subclass of qml.operation.Operator Returns: Optional[jax.core.Primitive]: A new jax primitive with the same name as the operator subclass. ``None`` is returned if jax is not available. """ifnothas_jax:returnNonefrom.custom_primitivesimportNonInterpPrimitive# pylint: disable=import-outside-toplevelprimitive=NonInterpPrimitive(operator_type.__name__)primitive.prim_type="operator"@primitive.def_impldef_(*args,**kwargs):if"n_wires"notinkwargs:returntype.__call__(operator_type,*args,**kwargs)n_wires=kwargs.pop("n_wires")split=Noneifn_wires==0else-n_wires# need to convert array values into integers# for plxpr, all wires must be integers# could be abstract when using tracing evaluation in interpreterwires=tuple(wifqml.math.is_abstract(w)elseint(w)forwinargs[split:])args=args[:split]returntype.__call__(operator_type,*args,wires=wires,**kwargs)abstract_type=_get_abstract_operator()@primitive.def_abstract_evaldef_(*_,**__):returnabstract_type()returnprimitive