catalyst.for_loop

for_loop(lower_bound, upper_bound, step, allow_array_resizing=False)[source]

A qjit() compatible for-loop decorator for PennyLane/Catalyst.

Note

Catalyst can automatically convert Python for loop statements for you. Requires setting autograph=True, see the qjit() function or documentation page for more details.

This for-loop representation is a functional version of the traditional for-loop, similar to jax.cond.fori_loop. That is, any variables that are modified across iterations need to be provided as inputs/outputs to the loop body function:

  • Input arguments contain the value of a variable at the start of an iteration.

  • output arguments contain the value at the end of the iteration. The outputs are then fed back as inputs to the next iteration.

The final iteration values are also returned from the transformed function.

This form of control flow can also be called from the Python interpreter without needing to use qjit().

The semantics of for_loop are given by the following Python pseudo-code:

def for_loop(lower_bound, upper_bound, step, loop_fn, *args):
    for i in range(lower_bound, upper_bound, step):
        args = loop_fn(i, *args)
    return args

Unlike jax.cond.fori_loop, the step can be negative if it is known at tracing time (i.e. constant). If a non-constant negative step is used, the loop will produce no iterations.

Parameters
  • lower_bound (int) – starting value of the iteration index

  • upper_bound (int) – (exclusive) upper bound of the iteration index

  • step (int) – increment applied to the iteration index at the end of each iteration

  • allow_array_resizing (bool) – Whether to allow arrays to change shape/size within the for loop. By default this is False; this will allow out-of-scope dynamical-shaped arrays to be captured by the for loop, and binary operations to be applied to arrays of the same shape. Set this to True to modify dimension sizes within the for loop, however outer-scope dynamically-shaped arrays will no longer be captured, and arrays of the same shape cannot be used in binary operations.

Returns

A wrapper around the loop body function. Note that the loop body function must always have the iteration index as its first argument, which can be used arbitrarily inside the loop body. As the value of the index across iterations is handled automatically by the provided loop bounds, it must not be returned from the function.

Return type

Callable[[int, …], …]

Example

dev = qml.device("lightning.qubit", wires=1)

@qjit
@qml.qnode(dev)
def circuit(n: int, x: float):

    def loop_rx(i, x):
        # perform some work and update (some of) the arguments
        qml.RX(x, wires=0)

        # update the value of x for the next iteration
        return jnp.sin(x)

    # apply the for loop
    final_x = for_loop(0, n, 1)(loop_rx)(x)

    return qml.expval(qml.PauliZ(0)), final_x
>>> circuit(7, 1.6)
(Array(0.97926626, dtype=float64), Array(0.55395718, dtype=float64))

Note that using dynamically-shaped arrays within for loops, while loops, and conditional statements, are also supported:

>>> @qjit
... def f(shape):
...     a = jnp.ones([shape], dtype=float)
...     @for_loop(0, 10, 2)
...     def loop(i, a):
...         return a + i
...     return loop(a)
>>> f(5)
Array([21., 21., 21., 21., 21.], dtype=float64)

By default, allow_array_resizing is False, allowing dynamically-shaped arrays from outside the for loop to be correctly captured, and arrays of the same shape to be used in binary operations:

>>> @qjit(abstracted_axes={1: 'n'})
... def g(x, y):
...     @catalyst.for_loop(0, 10, 1)
...     def loop(_, a):
...         # Attempt to capture `x` from the outer scope,
...         # and apply a binary operation '*' between the two arrays.
...         return a * x
...     return jnp.sum(loop(y))
>>> a = jnp.ones([1,3], dtype=float)
>>> b = jnp.ones([1,3], dtype=float)
>>> g(a, b)
Array(3., dtype=float64)

However, if you wish to have the for loop return differently sized arrays at each iteration, set allow_array_resizing to True:

>>> @qjit()
... def f(N):
...     a = jnp.ones([N], dtype=float)
...     @for_loop(0, 10, 1, allow_array_resizing=True)
...     def loop(i, _):
...         return jnp.ones([i], dtype=float) # return array of new dimensions
...     return loop(a)
>>> f(5)
Array([1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float64)

Note that when allow_array_resizing=True, dynamically-shaped arrays can no longer be captured from outer-scopes by the for loop, and binary operations between arrays of the same size are not supported.

For more details on dynamically-shaped arrays, please see Dynamically-shaped arrays.