catalyst.vmap¶
- vmap(fn=None, *, in_axes=0, out_axes=0, axis_size=None)[source]¶
A
qjit()
compatible vectorizing map. Creates a function which maps an input function over argument axes.- Parameters
fn (Callable) – A Python function containing PennyLane quantum operations.
in_axes (Union[int, Sequence[Any]]) – Specifies the value(s) over which input array axes to map.
out_axes (Union[int, Sequence[Any]]) – Specifies where the mapped axis should appear in the output.
axis_size (int) – An integer can be optionally provided to indicate the size of the axis to be mapped. If omitted, the size of the mapped axis will be inferred from the provided arguments.
- Returns
Vectorized version of
fn
.- Return type
Callable
- Raises
ValueError – Invalid
in_axes
,out_axes
, andaxis_size
values.
Note
Using vmap will prevent AOT compilation from working, since annotated type information will no longer be valid when batching arguments/results.
Example
For example, consider the following QNode:
dev = qml.device("lightning.qubit", wires=1) @qml.qnode(dev) def circuit(x, y): qml.RX(jnp.pi * x[0] + y, wires=0) qml.RY(x[1] ** 2, wires=0) qml.RX(x[1] * x[2], wires=0) return qml.expval(qml.PauliZ(0))
>>> circuit(jnp.array([0.1, 0.2, 0.3]), jnp.pi) Array(-0.93005586, dtype=float64)
We can use
catalyst.vmap
to introduce additional batch dimensions to our input arguments, without needing to use a Python for loop:>>> x = jnp.array([[0.1, 0.2, 0.3], ... [0.4, 0.5, 0.6], ... [0.7, 0.8, 0.9]]) >>> y = jnp.array([jnp.pi, jnp.pi / 2, jnp.pi / 4]) >>> qjit(vmap(circuit))(x, y) Array([-0.93005586, -0.97165424, -0.6987465 ], dtype=float64)
catalyst.vmap()
has been implemented to match the same behaviour ofjax.vmap
, so should be a drop-in replacement in most cases. Under-the-hood, it is automatically inserting Catalyst-compatible for loops, which will be compiled and executed outside of Python for increased performance.Outside of a Catalyst qjit-compiled function,
vmap
will simply dispatch tojax.vmap
.Selecting batching axes for arguments
The
in_axes
parameter provides different modes the allow large- and fine-grained control over which arguments to apply the batching transformation on. Enabling batching for a particular argument requires that the selected axis be of the same size as the determined batch size, which is the same for all arguments.The following modes are supported:
int
: Specifies the same batch axis for all argumentsTuple[int]
: Specify a different batch axis for each argumentTuple[int | None]
: Same as previous, but selectively disable batching for certain arguments with aNone
valueTuple[int | PyTree[int] | None]
: Same as previous, but specify a different batch axis for each leaf of an argument (Note that thePyTreeDefs
, i.e. the container structure, must match between thein_axes
element and the corresponding argument.)Tuple[int | PyTree[int | None] | None]
: Same as previous, but selectively disable batching for individual PyTree leaves
The
out_axes
parameter can be also used to specify the positions of the mapped axis in the output.out_axes
is subject to the same modes as well.