Source code for pennylane.pytrees.pytrees
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
An internal module for working with pytrees.
"""
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Optional
import pennylane.queuing
has_jax = True
try:
import jax.tree_util as jax_tree_util
except ImportError:
has_jax = False
Leaves = Any
Metadata = Any
FlattenFn = Callable[[Any], tuple[Leaves, Metadata]]
UnflattenFn = Callable[[Leaves, Metadata], Any]
def flatten_list(obj: list):
"""Flatten a list."""
return obj, None
def flatten_tuple(obj: tuple):
"""Flatten a tuple."""
return obj, None
def flatten_dict(obj: dict):
"""Flatten a dictionary."""
return obj.values(), tuple(obj.keys())
flatten_registrations: dict[type, FlattenFn] = {
list: flatten_list,
tuple: flatten_tuple,
dict: flatten_dict,
}
def unflatten_list(data, _) -> list:
"""Unflatten a list."""
return data if isinstance(data, list) else list(data)
def unflatten_tuple(data, _) -> tuple:
"""Unflatten a tuple."""
return tuple(data)
def unflatten_dict(data, metadata) -> dict:
"""Unflatten a dictinoary."""
return dict(zip(metadata, data))
unflatten_registrations: dict[type, UnflattenFn] = {
list: unflatten_list,
tuple: unflatten_tuple,
dict: unflatten_dict,
}
type_to_typename: dict[type, str] = {
list: "builtins.list",
dict: "builtins.dict",
tuple: "builtins.tuple",
}
typename_to_type: dict[str, type] = {name: type_ for type_, name in type_to_typename.items()}
def _register_pytree_with_pennylane(
pytree_type: type, typename: str, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn
):
type_to_typename[pytree_type] = typename
typename_to_type[typename] = pytree_type
flatten_registrations[pytree_type] = flatten_fn
unflatten_registrations[pytree_type] = unflatten_fn
def _register_pytree_with_jax(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn):
def jax_unflatten(aux, parameters):
return unflatten_fn(parameters, aux)
jax_tree_util.register_pytree_node(pytree_type, flatten_fn, jax_unflatten)
[docs]def register_pytree(
pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn, *, namespace: str = "qml"
):
"""Register a type with all available pytree backends.
Current backends are jax and pennylane.
Args:
pytree_type (type): the type to register, such as ``qml.RX``
flatten_fn (Callable): a function that splits an object into trainable leaves and hashable metadata.
unflatten_fn (Callable): a function that reconstructs an object from its leaves and metadata.
namespace (str): A prefix for the name under which this type will be registered.
Returns:
None
Side Effects:
``pytree`` type becomes registered with available backends.
.. seealso:: :func:`~.flatten`, :func:`~.unflatten`.
"""
typename = f"{namespace}.{pytree_type.__qualname__}"
_register_pytree_with_pennylane(pytree_type, typename, flatten_fn, unflatten_fn)
if has_jax:
_register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn)
[docs]def is_pytree(type_: type[Any]) -> bool:
"""Returns True if ``type_`` is a registered Pytree."""
return type_ in type_to_typename
def get_typename(pytree_type: type[Any]) -> str:
"""Return the typename under which ``pytree_type``
was registered.
Raises:
TypeError: If ``pytree_type`` is not a Pytree.
>>> get_typename(list)
'builtins.list'
>>> import pennylane
>>> get_typename(pennylane.PauliX)
'qml.PauliX'
"""
try:
return type_to_typename[pytree_type]
except KeyError as exc:
raise TypeError(f"{repr(pytree_type)} is not a Pytree type") from exc
def get_typename_type(typename: str) -> type[Any]:
"""Return the Pytree type with given ``typename``.
Raises:
ValueError: If ``typename`` is not the name of a
pytree type.
>>> import pennylane
>>> get_typename_type("builtins.list")
<class 'list'>
>>> get_typename_type("qml.PauliX")
<class 'pennylane.ops.qubit.non_parametric_ops.PauliX'>
"""
try:
return typename_to_type[typename]
except KeyError as exc:
raise ValueError(f"{repr(typename)} is not the name of a Pytree type.") from exc
[docs]@dataclass(repr=False)
class PyTreeStructure:
"""A pytree data structure, holding the type, metadata, and child pytree structures.
>>> op = qml.adjoint(qml.RX(0.1, 0))
>>> data, structure = qml.pytrees.flatten(op)
>>> structure
PyTree(AdjointOperation, (), [PyTree(RX, (Wires([0]), ()), [Leaf])])
A leaf is defined as just a ``PyTreeStructure`` with ``type_=None``.
"""
type_: Optional[type[Any]] = None
"""The type corresponding to the node. If ``None``, then the structure is a leaf."""
metadata: Metadata = ()
"""Any metadata needed to reproduce the original object."""
children: list["PyTreeStructure"] = field(default_factory=list)
"""The children of the pytree node. Can be either other structures or terminal leaves."""
@property
def is_leaf(self) -> bool:
"""Whether or not the structure is a leaf."""
return self.type_ is None
def __repr__(self):
if self.is_leaf:
return "PyTreeStructure()"
return f"PyTreeStructure({self.type_.__name__}, {self.metadata}, {self.children})"
def __str__(self):
if self.is_leaf:
return "Leaf"
children_string = ", ".join(str(c) for c in self.children)
return f"PyTree({self.type_.__name__}, {self.metadata}, [{children_string}])"
leaf = PyTreeStructure(None, (), [])
[docs]def flatten(
obj: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> tuple[list[Any], PyTreeStructure]:
"""Flattens a pytree into leaves and a structure.
Args:
obj (Any): any object.
is_leaf (Callable[[Any], bool] | None = None): an optionally specified
function that will be called at each flattening step. It should return
a boolean, with ``True`` stopping the traversal and the whole subtree being
treated as a leaf, and ``False`` indicating the flattening should traverse
the current object.
Returns:
List[Any], Union[Structure, Leaf]: a list of leaves and a structure representing the object
See also :func:`~.unflatten`.
**Example**
>>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0))
>>> data, structure = flatten(op)
>>> data
[1.2, 2.3, 3.4]
>>> structure
<PyTree(AdjointOperation, (), (<PyTree(Rot, (Wires([0]), ()), (Leaf, Leaf, Leaf))>,))>
"""
flatten_fn = flatten_registrations.get(type(obj), None)
# set the flag is_leaf_node if is_leaf argument is provided and returns true
is_leaf_node = is_leaf(obj) if is_leaf is not None else False
if flatten_fn is None or is_leaf_node:
return [obj], leaf
leaves, metadata = flatten_fn(obj)
flattened_leaves = []
child_structures = []
for l in leaves:
child_leaves, child_structure = flatten(l, is_leaf)
flattened_leaves += child_leaves
child_structures.append(child_structure)
structure = PyTreeStructure(type(obj), metadata, child_structures)
return flattened_leaves, structure
[docs]def unflatten(data: list[Any], structure: PyTreeStructure) -> Any:
"""Bind data to a structure to reconstruct a pytree object.
Args:
data (Iterable): iterable of numbers and numeric arrays
structure (Structure, Leaf): The pytree structure object
Returns:
A repacked pytree.
.. seealso:: :func:`~.flatten`
**Example**
>>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0))
>>> data, structure = flatten(op)
>>> unflatten([-2, -3, -4], structure)
Adjoint(Rot(-2, -3, -4, wires=[0]))
"""
with pennylane.queuing.QueuingManager.stop_recording():
return _unflatten(iter(data), structure)
def _unflatten(new_data, structure):
if structure.is_leaf:
return next(new_data)
children = tuple(_unflatten(new_data, s) for s in structure.children)
return unflatten_registrations[structure.type_](children, structure.metadata)
_modules/pennylane/pytrees/pytrees
Download Python script
Download Notebook
View on GitHub