qml.capture.FlatFn¶
- class FlatFn(f, in_tree=None)[source]¶
Bases:
object
Wrap a function so that it caches the pytree shape of the output into the
out_tree
property, so that the results can be repacked later. It also returns flattened results instead of the original result object.If an
in_tree
is provided, the function accepts flattened inputs instead of the original inputs with tree structure given byin_tree
.Example
>>> import jax >>> from pennylane.capture.flatfn import FlatFn >>> def f(x): ... return {"y": 2+x["x"]} >>> flat_f = FlatFn(f) >>> arg = {"x": 0.5} >>> res = flat_f(arg) >>> res [2.5] >>> jax.tree_util.tree_unflatten(flat_f.out_tree, res) {'y': 2.5}
If we want to use a fully flattened function that also takes flat inputs instead of the original inputs with tree structure, we can provide the
PyTreeDef
for this input structure:>>> flat_args, in_tree = jax.tree_util.tree_flatten((arg,)) >>> flat_f = FlatFn(f, in_tree) >>> res = flat_f(*flat_args) >>> res [2.5] >>> jax.tree_util.tree_unflatten(flat_f.out_tree, res) {'y': 2.5}
Note that the
in_tree
has to be created by flattening a tuple of all input arguments, even if there is only a single argument.
code/api/pennylane.capture.FlatFn
Download Python script
Download Notebook
View on GitHub