Source code for catalyst.api_extensions.function_maps
# Copyright 2022-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains public API functions which represent higher-order
transformations on functions, for example the vectorization map which adds
additional dimensions to the inputs and outputs of a function.
"""
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import jax
import jax.numpy as jnp
import numpy as np
from jax._src.tree_util import tree_flatten, tree_leaves, tree_structure, tree_unflatten
from catalyst.api_extensions.control_flow import for_loop
from catalyst.tracing.contexts import EvaluationContext
## API ##
[docs]def vmap(
fn: Callable,
in_axes: Union[int, Sequence[Any]] = 0,
out_axes: Union[int, Sequence[Any]] = 0,
axis_size: Optional[int] = None,
) -> Callable:
"""A :func:`~.qjit` compatible vectorizing map.
Creates a function which maps an input function over argument axes.
Args:
f (Callable): A Python function containing PennyLane quantum operations.
in_axes (Union[int, Sequence[Any]]): Specifies the value(s) over which input
array axes to map.
out_axes (Union[int, Sequence[Any]]): Specifies where the mapped axis should appear
in the output.
axis_size (int): An integer can be optionally provided to indicate the size of the
axis to be mapped. If omitted, the size of the mapped axis will be inferred from
the provided arguments.
Returns:
Callable: Vectorized version of ``fn``.
Raises:
ValueError: Invalid ``in_axes``, ``out_axes``, and ``axis_size`` values.
**Example**
For example, consider the following QNode:
.. code-block:: python
dev = qml.device("lightning.qubit", wires=1)
@qml.qnode(dev)
def circuit(x, y):
qml.RX(jnp.pi * x[0] + y, wires=0)
qml.RY(x[1] ** 2, wires=0)
qml.RX(x[1] * x[2], wires=0)
return qml.expval(qml.PauliZ(0))
>>> circuit(jnp.array([0.1, 0.2, 0.3]), jnp.pi)
Array(-0.93005586, dtype=float64)
We can use ``catalyst.vmap`` to introduce additional batch dimensions
to our input arguments,
without needing to use a Python for loop:
>>> x = jnp.array([[0.1, 0.2, 0.3],
... [0.4, 0.5, 0.6],
... [0.7, 0.8, 0.9]])
>>> y = jnp.array([jnp.pi, jnp.pi / 2, jnp.pi / 4])
>>> qjit(vmap(cost))(x, y)
array([-0.93005586, -0.97165424, -0.6987465 ])
``catalyst.vmap()`` has been implemented to match the same behaviour of
``jax.vmap``, so should be a drop-in replacement in most cases.
Under-the-hood, it is automatically inserting Catalyst-compatible for loops,
which will be compiled and executed outside of Python for increased performance.
Outside of a Catalyst qjit-compiled function, ``vmap`` will simply dispatch to
``jax.vmap``.
.. details::
:title: Selecting batching axes for arguments
The ``in_axes`` parameter provides different modes the allow large- and fine-grained
control over which arguments to apply the batching transformation on. Enabling batching for
a particular argument requires that the selected axis be of the same size as the determined
batch size, which is the same for all arguments.
The following modes are supported:
- ``int``: Specifies the same batch axis for all arguments
- ``Tuple[int]``: Specify a different batch axis for each argument
- ``Tuple[int | None]``: Same as previous, but selectively disable batching for certain
arguments with a ``None`` value
- ``Tuple[int | PyTree[int] | None]``: Same as previous, but specify a different batch
axis for each leaf of an argument (Note that the ``PyTreeDefs``, i.e. the container
structure, must match between the ``in_axes`` element and the corresponding argument.)
- ``Tuple[int | PyTree[int | None] | None]``: Same as previous, but selectively disable
batching for individual PyTree leaves
The ``out_axes`` parameter can be also used to specify the positions of the mapped axis
in the output. ``out_axes`` is subject to the same modes as well.
"""
# Check the validity of in_axes and out_axes
if not all(isinstance(l, int) for l in tree_leaves(in_axes)):
raise ValueError(
"Invalid 'in_axes'; it must be an int or a tuple of PyTrees with integer leaves, "
f"but got {in_axes}"
)
if not all(isinstance(l, int) for l in tree_leaves(out_axes)):
raise ValueError(
"Invalid 'out_axes'; it must be an int or a tuple of PyTree with integer leaves, "
f"but got {out_axes}"
)
def batched_fn(*args, **kwargs):
"""Vectorization wrapper around the hybrid program using catalyst.for_loop"""
# Dispatch to jax.vmap when it is called outside qjit.
if not EvaluationContext.is_tracing():
return jax.vmap(fn, in_axes, out_axes)(*args, **kwargs)
args_flat, args_tree = tree_flatten(args)
in_axes_flat, _ = tree_flatten(in_axes, is_leaf=lambda x: x is None)
# Check the validity of the input arguments w.r.t. in_axes
in_axes_deep_struct = tree_structure(in_axes, is_leaf=lambda x: x is None)
args_deep_struct = tree_structure(args, is_leaf=lambda x: x is None)
if not isinstance(in_axes, int) and in_axes_deep_struct != args_deep_struct:
raise ValueError(
"Invalid 'in_axes'; it must be an int or match the length of positional "
f"arguments, but got {in_axes_deep_struct} axis specifiers "
f"and {args_deep_struct} arguments."
)
if isinstance(in_axes, int):
in_axes_flat = [
in_axes,
] * len(args_flat)
batch_size = _get_batch_size(args_flat, in_axes_flat, axis_size)
batch_loc = _get_batch_loc(in_axes_flat)
# Prepare args_flat to run 'fn' one time and get the output-shape
fn_args_flat = args_flat.copy()
for loc in batch_loc:
ax = in_axes_flat[loc]
fn_args_flat[loc] = jnp.take(args_flat[loc], 0, axis=ax)
fn_args = tree_unflatten(args_tree, fn_args_flat)
# Run 'fn' one time to get output-shape
init_result = fn(*fn_args, **kwargs)
# Check the validity of the output w.r.t. out_axes
out_axes_deep_struct = tree_structure(out_axes, is_leaf=lambda x: x is None)
init_result_deep_struct = tree_structure(init_result, is_leaf=lambda x: x is None)
if not isinstance(out_axes, int) and out_axes_deep_struct != init_result_deep_struct:
raise ValueError(
"Invalid 'out_axes'; it must be an int or match "
"the number of function results, but got "
f"{out_axes_deep_struct} axis specifiers and {init_result_deep_struct} results."
)
init_result_flat, init_result_tree = tree_flatten(init_result)
num_axes_out = len(init_result_flat)
if isinstance(out_axes, int):
out_axes_flat = [
out_axes,
] * num_axes_out
else:
out_axes_flat, _ = tree_flatten(out_axes, is_leaf=lambda x: x is None)
out_loc = _get_batch_loc(out_axes_flat)
# Store batched results of all leaves
# in the flatten format with respect to the 'init_result' shape
batched_result_list = []
for j in range(num_axes_out):
out_shape = (
(batch_size,)
if not init_result_flat[j].shape
else (batch_size, *init_result_flat[j].shape)
)
batched_result_list.append(jnp.zeros(shape=out_shape, dtype=init_result_flat[j].dtype))
batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])
# Apply mapping batched_args[1:] ---> fn(args)
@for_loop(1, batch_size, 1)
def loop_fn(i, batched_result_list):
fn_args_flat = args_flat
for loc in batch_loc:
ax = in_axes_flat[loc]
fn_args_flat[loc] = jnp.take(args_flat[loc], i, axis=ax)
fn_args = tree_unflatten(args_tree, fn_args_flat)
res = fn(*fn_args, **kwargs)
res_flat, _ = tree_flatten(res)
# Update the list of results
for j in range(num_axes_out):
batched_result_list[j] = batched_result_list[j].at[i].set(res_flat[j])
return batched_result_list
batched_result_list = loop_fn(batched_result_list)
# Support out_axes on dim > 0
for loc in out_loc:
if ax := out_axes_flat[loc]:
up_axes = [*range(batched_result_list[loc].ndim)]
up_axes[ax], up_axes[0] = up_axes[0], up_axes[ax]
batched_result_list[loc] = jnp.transpose(batched_result_list[loc], up_axes)
# Unflatten batched_result before return
return tree_unflatten(init_result_tree, batched_result_list)
return batched_fn
## PRIVATE ##
def _get_batch_loc(axes_flat):
"""
Get the list of mapping locations in the flattened list of in-axes or out-axes.
This function takes a flattened list of axes and identifies all elements with a
non-None value. The resulting list contains the indices of these non-None values,
indicating where the mapping should apply.
Args:
axes_flat (List): Flattened list of in-axes or out-axes including `None` elements.
Returns:
List: A list of indices representing the locations where the mapping should be applied.
"""
return [i for i, d in enumerate(axes_flat) if d is not None]
def _get_batch_size(args_flat, axes_flat, axis_size):
"""Get the batch size based on the provided arguments and axes specifiers, or the manually
specified batch size by the user request. The batch size must be the same for all arguments.
Args:
args_flat (List): Flatten list of arguments.
axes_flat (List): Flatten list of `in_axes` or `our_axes` including `None` elements.
axis_size (Optional[int]): Optional default batch size.
Returns:
int: Returns the batch size used as the upper bound of the QJIT-compatible for loop
in the computation of vmap.
Raises:
ValueError: The batch size must be the same for all arguments.
ValueError: The default batch is expected to be None, or less than or equal
to the computed batch size.
"""
batch_sizes = []
for arg, d in zip(args_flat, axes_flat):
shape = np.shape(arg) if arg.shape else (1,)
if d is not None and len(shape) > d:
batch_sizes.append(shape[d])
if any(size != batch_sizes[0] for size in batch_sizes[1:]):
raise ValueError(
"Invalid batch sizes; expected the batch size to be the same for all arguments, "
f"but got batch_sizes={batch_sizes} from args_flat={args_flat}"
)
batch_size = batch_sizes[0] if batch_sizes else 0
if axis_size is not None:
if axis_size <= batch_size:
batch_size = axis_size
else:
raise ValueError(
"Invalid 'axis_size'; the default batch is expected to be None, "
"or less than or equal to the computed batch size, but got "
f"axis_size={axis_size} > batch_size={batch_size}"
)
if not batch_size:
raise ValueError(
f"Invalid batch size; it must be a non-zero integer, but got {batch_size}."
)
return batch_size
_modules/catalyst/api_extensions/function_maps
Download Python script
Download Notebook
View on GitHub