Source code for pennylane.pytrees.pytrees

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
An internal module for working with pytrees.
"""
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Optional

import autograd

import pennylane.queuing

has_jax = True
try:
    import jax.tree_util as jax_tree_util
except ImportError:
    has_jax = False

Leaves = Any
Metadata = Any

FlattenFn = Callable[[Any], tuple[Leaves, Metadata]]
UnflattenFn = Callable[[Leaves, Metadata], Any]


def flatten_list(obj: list):
    """Flatten a list."""
    return obj, None


def flatten_tuple(obj: tuple):
    """Flatten a tuple."""
    return obj, None


def flatten_dict(obj: dict):
    """Flatten a dictionary."""
    return obj.values(), tuple(obj.keys())


flatten_registrations: dict[type, FlattenFn] = {
    list: flatten_list,
    tuple: flatten_tuple,
    dict: flatten_dict,
}


def unflatten_list(data, _) -> list:
    """Unflatten a list."""
    return data if isinstance(data, list) else list(data)


def unflatten_tuple(data, _) -> tuple:
    """Unflatten a tuple."""
    return tuple(data)


def unflatten_dict(data, metadata) -> dict:
    """Unflatten a dictionary."""
    return dict(zip(metadata, data))


unflatten_registrations: dict[type, UnflattenFn] = {
    list: unflatten_list,
    tuple: unflatten_tuple,
    dict: unflatten_dict,
}

type_to_typename: dict[type, str] = {
    list: "builtins.list",
    dict: "builtins.dict",
    tuple: "builtins.tuple",
}

typename_to_type: dict[str, type] = {name: type_ for type_, name in type_to_typename.items()}


def _register_pytree_with_pennylane(
    pytree_type: type, typename: str, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn
):
    type_to_typename[pytree_type] = typename
    typename_to_type[typename] = pytree_type

    flatten_registrations[pytree_type] = flatten_fn
    unflatten_registrations[pytree_type] = unflatten_fn


def _register_pytree_with_jax(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn):
    def jax_unflatten(aux, parameters):
        return unflatten_fn(parameters, aux)

    jax_tree_util.register_pytree_node(pytree_type, flatten_fn, jax_unflatten)


[docs]def register_pytree( pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn, *, namespace: str = "qml" ): """Register a type with all available pytree backends. Current backends are jax and pennylane. Args: pytree_type (type): the type to register, such as ``qml.RX`` flatten_fn (Callable): a function that splits an object into trainable leaves and hashable metadata. unflatten_fn (Callable): a function that reconstructs an object from its leaves and metadata. namespace (str): A prefix for the name under which this type will be registered. Returns: None Side Effects: ``pytree`` type becomes registered with available backends. .. seealso:: :func:`~.flatten`, :func:`~.unflatten`. """ typename = f"{namespace}.{pytree_type.__qualname__}" _register_pytree_with_pennylane(pytree_type, typename, flatten_fn, unflatten_fn) if has_jax: _register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn)
[docs]def is_pytree(type_: type[Any]) -> bool: """Returns True if ``type_`` is a registered Pytree.""" return type_ in type_to_typename
def get_typename(pytree_type: type[Any]) -> str: """Return the typename under which ``pytree_type`` was registered. Raises: TypeError: If ``pytree_type`` is not a Pytree. >>> get_typename(list) 'builtins.list' >>> import pennylane >>> get_typename(pennylane.PauliX) 'qml.PauliX' """ try: return type_to_typename[pytree_type] except KeyError as exc: raise TypeError(f"{repr(pytree_type)} is not a Pytree type") from exc def get_typename_type(typename: str) -> type[Any]: """Return the Pytree type with given ``typename``. Raises: ValueError: If ``typename`` is not the name of a pytree type. >>> import pennylane >>> get_typename_type("builtins.list") <class 'list'> >>> get_typename_type("qml.PauliX") <class 'pennylane.ops.qubit.non_parametric_ops.PauliX'> """ try: return typename_to_type[typename] except KeyError as exc: raise ValueError(f"{repr(typename)} is not the name of a Pytree type.") from exc
[docs]@dataclass(repr=False) class PyTreeStructure: """A pytree data structure, holding the type, metadata, and child pytree structures. >>> op = qml.adjoint(qml.RX(0.1, 0)) >>> data, structure = qml.pytrees.flatten(op) >>> structure PyTree(AdjointOperation, (), [PyTree(RX, (Wires([0]), ()), [Leaf])]) A leaf is defined as just a ``PyTreeStructure`` with ``type_=None``. """ type_: Optional[type[Any]] = None """The type corresponding to the node. If ``None``, then the structure is a leaf.""" metadata: Metadata = () """Any metadata needed to reproduce the original object.""" children: list["PyTreeStructure"] = field(default_factory=list) """The children of the pytree node. Can be either other structures or terminal leaves.""" @property def is_leaf(self) -> bool: """Whether or not the structure is a leaf.""" return self.type_ is None def __repr__(self): if self.is_leaf: return "PyTreeStructure()" return f"PyTreeStructure({self.type_.__name__}, {self.metadata}, {self.children})" def __str__(self): if self.is_leaf: return "Leaf" children_string = ", ".join(str(c) for c in self.children) return f"PyTree({self.type_.__name__}, {self.metadata}, [{children_string}])"
leaf = PyTreeStructure(None, (), [])
[docs]def flatten( obj: Any, is_leaf: Optional[Callable[[Any], bool]] = None ) -> tuple[list[Any], PyTreeStructure]: """Flattens a pytree into leaves and a structure. Args: obj (Any): any object. is_leaf (Callable[[Any], bool] | None = None): an optionally specified function that will be called at each flattening step. It should return a boolean, with ``True`` stopping the traversal and the whole subtree being treated as a leaf, and ``False`` indicating the flattening should traverse the current object. Returns: List[Any], Union[Structure, Leaf]: a list of leaves and a structure representing the object See also :func:`~.unflatten`. **Example** >>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0)) >>> data, structure = flatten(op) >>> data [1.2, 2.3, 3.4] >>> structure <PyTree(AdjointOperation, (), (<PyTree(Rot, (Wires([0]), ()), (Leaf, Leaf, Leaf))>,))> """ flatten_fn = flatten_registrations.get(type(obj), None) # set the flag is_leaf_node if is_leaf argument is provided and returns true is_leaf_node = is_leaf(obj) if is_leaf is not None else False if flatten_fn is None or is_leaf_node: return [obj], leaf leaves, metadata = flatten_fn(obj) flattened_leaves = [] child_structures = [] for l in leaves: child_leaves, child_structure = flatten(l, is_leaf) flattened_leaves += child_leaves child_structures.append(child_structure) structure = PyTreeStructure(type(obj), metadata, child_structures) return flattened_leaves, structure
[docs]def unflatten(data: list[Any], structure: PyTreeStructure) -> Any: """Bind data to a structure to reconstruct a pytree object. Args: data (Iterable): iterable of numbers and numeric arrays structure (Structure, Leaf): The pytree structure object Returns: A repacked pytree. .. seealso:: :func:`~.flatten` **Example** >>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0)) >>> data, structure = flatten(op) >>> unflatten([-2, -3, -4], structure) Adjoint(Rot(-2, -3, -4, wires=[0])) """ with pennylane.queuing.QueuingManager.stop_recording(): return _unflatten(iter(data), structure)
def _unflatten(new_data, structure): if structure.is_leaf: return next(new_data) children = tuple(_unflatten(new_data, s) for s in structure.children) return unflatten_registrations[structure.type_](children, structure.metadata) # pylint: disable=no-member register_pytree( autograd.builtins.list, lambda obj: (list(obj), ()), lambda data, _: autograd.builtins.list(data), ) register_pytree( autograd.builtins.tuple, lambda obj: (list(obj), ()), lambda data, _: autograd.builtins.tuple(data), ) register_pytree( autograd.builtins.SequenceBox, lambda obj: (list(obj), ()), lambda data, _: autograd.builtins.SequenceBox(data), )