Source code for pennylane.math.interface_utils

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions related to interfaces"""

import warnings
from enum import Enum
from typing import Literal, Union

import autoray as ar


class Interface(Enum):
    """Canonical set of interfaces supported."""

    AUTOGRAD = "autograd"
    NUMPY = "numpy"
    TORCH = "torch"
    JAX = "jax"
    JAX_JIT = "jax-jit"
    TF = "tf"
    TF_AUTOGRAPH = "tf-autograph"
    AUTO = "auto"

    def get_like(self):
        """Maps canonical set of interfaces to those known by autoray."""
        mapping = {
            Interface.AUTOGRAD: "autograd",
            Interface.NUMPY: "numpy",
            Interface.TORCH: "torch",
            Interface.JAX: "jax",
            Interface.JAX_JIT: "jax",
            Interface.TF: "tensorflow",
            Interface.TF_AUTOGRAPH: "tensorflow",
            Interface.AUTO: None,
        }
        return mapping[self]


InterfaceLike = Union[str, Interface, None]


INTERFACE_MAP = {
    None: Interface.NUMPY,
    "auto": Interface.AUTO,
    "autograd": Interface.AUTOGRAD,
    "numpy": Interface.NUMPY,
    "scipy": Interface.NUMPY,
    "jax": Interface.JAX,
    "jax-jit": Interface.JAX_JIT,
    "jax-python": Interface.JAX,
    "JAX": Interface.JAX,
    "torch": Interface.TORCH,
    "pytorch": Interface.TORCH,
    "tf": Interface.TF,
    "tensorflow": Interface.TF,
    "tensorflow-autograph": Interface.TF_AUTOGRAPH,
    "tf-autograph": Interface.TF_AUTOGRAPH,
}
"""dict[str, str]: maps an allowed interface specification to its canonical name."""

SupportedInterfaceUserInput = Literal[tuple(INTERFACE_MAP.keys())]
"""list[str]: allowed interface names that the user can input"""

SUPPORTED_INTERFACE_NAMES = list(Interface)
"""list[Interface]: allowed interface names"""


[docs]def get_interface(*values): """Determines the correct framework to dispatch to given a tensor-like object or a sequence of tensor-like objects. Args: *values (tensor_like): variable length argument list with single tensor-like objects Returns: str: the name of the interface To determine the framework to dispatch to, the following rules are applied: * Tensors that are incompatible (such as Torch, TensorFlow and Jax tensors) cannot both be present. * Autograd tensors *may* be present alongside Torch, TensorFlow and Jax tensors, but Torch, TensorFlow and Jax take precedence; the autograd arrays will be treated as non-differentiable NumPy arrays. A warning will be raised suggesting that vanilla NumPy be used instead. * Vanilla NumPy arrays and SciPy sparse matrices can be used alongside other tensor objects; they will always be treated as non-differentiable constants. .. warning:: ``get_interface`` defaults to ``"numpy"`` whenever Python built-in objects are passed. I.e. a list or tuple of ``torch`` tensors will be identified as ``"numpy"``: >>> get_interface([torch.tensor([1]), torch.tensor([1])]) "numpy" The correct usage in that case is to unpack the arguments ``get_interface(*[torch.tensor([1]), torch.tensor([1])])``. """ if len(values) == 1: return _get_interface_of_single_tensor(values[0]) interfaces = {_get_interface_of_single_tensor(v) for v in values} if len(interfaces - {"numpy", "scipy", "autograd"}) > 1: # contains multiple non-autograd interfaces raise ValueError("Tensors contain mixed types; cannot determine dispatch library") non_numpy_scipy_interfaces = set(interfaces) - {"numpy", "scipy"} if len(non_numpy_scipy_interfaces) > 1: # contains autograd and another interface warnings.warn( f"Contains tensors of types {non_numpy_scipy_interfaces}; dispatch will prioritize " "TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.", UserWarning, ) if "tensorflow" in interfaces: return "tensorflow" if "torch" in interfaces: return "torch" if "jax" in interfaces: return "jax" if "autograd" in interfaces: return "autograd" return "numpy"
def _get_interface_of_single_tensor(tensor): """Returns the name of the package that any array/tensor manipulations will dispatch to. The returned strings correspond to those used for PennyLane :doc:`interfaces </introduction/interfaces>`. Args: tensor (tensor_like): tensor input Returns: str: name of the interface **Example** >>> x = torch.tensor([1., 2.]) >>> get_interface(x) 'torch' >>> from pennylane import numpy as np >>> x = np.array([4, 5], requires_grad=True) >>> get_interface(x) 'autograd' """ namespace = tensor.__class__.__module__.split(".")[0] if namespace in ("pennylane", "autograd"): return "autograd" res = ar.infer_backend(tensor) if res == "builtins": return "numpy" return res
[docs]def get_deep_interface(value): """ Given a deep data structure with interface-specific scalars at the bottom, return their interface name. Args: value (list, tuple): A deep list-of-lists, tuple-of-tuples, or combination with interface-specific data hidden within it Returns: str: The name of the interface deep within the value **Example** >>> x = [[jax.numpy.array(1), jax.numpy.array(2)], [jax.numpy.array(3), jax.numpy.array(4)]] >>> get_deep_interface(x) 'jax' This can be especially useful when converting to the appropriate interface: >>> qml.math.asarray(x, like=qml.math.get_deep_interface(x)) Array([[1, 2], [3, 4]], dtype=int64) """ itr = value while isinstance(itr, (list, tuple)): if len(itr) == 0: return "numpy" itr = itr[0] return _get_interface_of_single_tensor(itr)
[docs]def get_canonical_interface_name(user_input: InterfaceLike) -> Interface: """Helper function to get the canonical interface name. Args: interface (str, Interface): reference interface Raises: ValueError: key does not exist in the interface map Returns: Interface: canonical interface """ if user_input in SUPPORTED_INTERFACE_NAMES: return user_input try: return INTERFACE_MAP[user_input] except KeyError as exc: raise ValueError( f"Unknown interface {user_input}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}." ) from exc