Source code for pennylane.capture.flatfn
# Copyright 2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines a utility for capturing higher order primitives that return pytrees.
"""
from functools import update_wrapper
has_jax = True
try:
import jax
except ImportError:
has_jax = False
# pylint: disable=too-few-public-methods
[docs]class FlatFn:
"""Wrap a function so that it caches the pytree shape of the output into the ``out_tree``
property, so that the results can be repacked later. It also returns flattened results
instead of the original result object.
If an ``in_tree`` is provided, the function accepts flattened inputs instead of the
original inputs with tree structure given by ``in_tree``.
**Example**
>>> import jax
>>> from pennylane.capture.flatfn import FlatFn
>>> def f(x):
... return {"y": 2+x["x"]}
>>> flat_f = FlatFn(f)
>>> arg = {"x": 0.5}
>>> res = flat_f(arg)
>>> res
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}
If we want to use a fully flattened function that also takes flat inputs instead of
the original inputs with tree structure, we can provide the ``PyTreeDef`` for this input
structure:
>>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,))
>>> flat_f = FlatFn(f, in_tree)
>>> res = flat_f(*flat_args)
>>> res
[2.5]
>>> jax.tree_util.tree_unflatten(flat_f.out_tree, res)
{'y': 2.5}
Note that the ``in_tree`` has to be created by flattening a tuple of all input
arguments, even if there is only a single argument.
"""
def __init__(self, f, in_tree=None):
self.f = f
self.in_tree = in_tree
self.out_tree = None
update_wrapper(self, f)
def __call__(self, *args):
if self.in_tree is not None:
args = jax.tree_util.tree_unflatten(self.in_tree, args)
out = self.f(*args)
out_flat, out_tree = jax.tree_util.tree_flatten(out)
self.out_tree = out_tree
return out_flat
_modules/pennylane/capture/flatfn
Download Python script
Download Notebook
View on GitHub