Source code for pennylane.capture.capture_operators
# Copyright 2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule defines the abstract classes and primitives for capturing operators.
"""
from functools import lru_cache
from typing import Optional, Type
import pennylane as qml
from .capture_diff import create_non_jvp_primitive
has_jax = True
try:
import jax
except ImportError:
has_jax = False
@lru_cache # construct the first time lazily
def _get_abstract_operator() -> type:
"""Create an AbstractOperator once in a way protected from lack of a jax install."""
if not has_jax: # pragma: no cover
raise ImportError("Jax is required for plxpr.") # pragma: no cover
class AbstractOperator(jax.core.AbstractValue):
"""An operator captured into plxpr."""
# pylint: disable=missing-function-docstring
def at_least_vspace(self):
# TODO: investigate the proper definition of this method
raise NotImplementedError
# pylint: disable=missing-function-docstring
def join(self, other):
# TODO: investigate the proper definition of this method
raise NotImplementedError
# pylint: disable=missing-function-docstring
def update(self, **kwargs):
# TODO: investigate the proper definition of this method
raise NotImplementedError
def __eq__(self, other):
return isinstance(other, AbstractOperator)
def __hash__(self):
return hash("AbstractOperator")
@staticmethod
def _matmul(*args):
return qml.prod(*args)
@staticmethod
def _mul(a, b):
return qml.s_prod(b, a)
@staticmethod
def _rmul(a, b):
return qml.s_prod(b, a)
@staticmethod
def _add(a, b):
return qml.sum(a, b)
@staticmethod
def _pow(a, b):
return qml.pow(a, b)
jax.core.raise_to_shaped_mappings[AbstractOperator] = lambda aval, _: aval
return AbstractOperator
[docs]def create_operator_primitive(
operator_type: Type["qml.operation.Operator"],
) -> Optional["jax.core.Primitive"]:
"""Create a primitive corresponding to an operator type.
Called when defining any :class:`~.Operator` subclass, and is used to set the
``Operator._primitive`` class property.
Args:
operator_type (type): a subclass of qml.operation.Operator
Returns:
Optional[jax.core.Primitive]: A new jax primitive with the same name as the operator subclass.
``None`` is returned if jax is not available.
"""
if not has_jax:
return None
primitive = create_non_jvp_primitive()(operator_type.__name__)
@primitive.def_impl
def _(*args, **kwargs):
if "n_wires" not in kwargs:
return type.__call__(operator_type, *args, **kwargs)
n_wires = kwargs.pop("n_wires")
split = None if n_wires == 0 else -n_wires
# need to convert array values into integers
# for plxpr, all wires must be integers
wires = tuple(int(w) for w in args[split:])
args = args[:split]
return type.__call__(operator_type, *args, wires=wires, **kwargs)
abstract_type = _get_abstract_operator()
@primitive.def_abstract_eval
def _(*_, **__):
return abstract_type()
return primitive
_modules/pennylane/capture/capture_operators
Download Python script
Download Notebook
View on GitHub