catalyst.vmap

vmap(fn=None, *, in_axes=0, out_axes=0, axis_size=None)[source]

A qjit() compatible vectorizing map. Creates a function which maps an input function over argument axes.

Parameters
  • fn (Callable) – A Python function containing PennyLane quantum operations.

  • in_axes (Union[int, Sequence[Any]]) – Specifies the value(s) over which input array axes to map.

  • out_axes (Union[int, Sequence[Any]]) – Specifies where the mapped axis should appear in the output.

  • axis_size (int) – An integer can be optionally provided to indicate the size of the axis to be mapped. If omitted, the size of the mapped axis will be inferred from the provided arguments.

Returns

Vectorized version of fn.

Return type

Callable

Raises

ValueError – Invalid in_axes, out_axes, and axis_size values.

Note

Using vmap will prevent AOT compilation from working, since annotated type information will no longer be valid when batching arguments/results.

Example

For example, consider the following QNode:

dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(x, y):
  qml.RX(jnp.pi * x[0] + y, wires=0)
  qml.RY(x[1] ** 2, wires=0)
  qml.RX(x[1] * x[2], wires=0)
  return qml.expval(qml.PauliZ(0))
>>> circuit(jnp.array([0.1, 0.2, 0.3]), jnp.pi)
Array(-0.93005586, dtype=float64)

We can use catalyst.vmap to introduce additional batch dimensions to our input arguments, without needing to use a Python for loop:

>>> x = jnp.array([[0.1, 0.2, 0.3],
...                [0.4, 0.5, 0.6],
...                [0.7, 0.8, 0.9]])
>>> y = jnp.array([jnp.pi, jnp.pi / 2, jnp.pi / 4])
>>> qjit(vmap(circuit))(x, y)
Array([-0.93005586, -0.97165424, -0.6987465 ], dtype=float64)

catalyst.vmap() has been implemented to match the same behaviour of jax.vmap, so should be a drop-in replacement in most cases. Under-the-hood, it is automatically inserting Catalyst-compatible for loops, which will be compiled and executed outside of Python for increased performance.

Outside of a Catalyst qjit-compiled function, vmap will simply dispatch to jax.vmap.

The in_axes parameter provides different modes the allow large- and fine-grained control over which arguments to apply the batching transformation on. Enabling batching for a particular argument requires that the selected axis be of the same size as the determined batch size, which is the same for all arguments.

The following modes are supported:

  • int: Specifies the same batch axis for all arguments

  • Tuple[int]: Specify a different batch axis for each argument

  • Tuple[int | None]: Same as previous, but selectively disable batching for certain arguments with a None value

  • Tuple[int | PyTree[int] | None]: Same as previous, but specify a different batch axis for each leaf of an argument (Note that the PyTreeDefs, i.e. the container structure, must match between the in_axes element and the corresponding argument.)

  • Tuple[int | PyTree[int | None] | None]: Same as previous, but selectively disable batching for individual PyTree leaves

The out_axes parameter can be also used to specify the positions of the mapped axis in the output. out_axes is subject to the same modes as well.