catalyst.while_loop¶
- while_loop(cond_fn, allow_array_resizing: bool = False)[source]¶
A
qjit()
compatible while-loop decorator for PennyLane/Catalyst.This decorator provides a functional version of the traditional while loop, similar to
jax.lax.while_loop
. That is, any variables that are modified across iterations need to be provided as inputs and outputs to the loop body function:Input arguments contain the value of a variable at the start of an iteration
Output arguments contain the value at the end of the iteration. The outputs are then fed back as inputs to the next iteration.
The final iteration values are also returned from the transformed function.
This form of control flow can also be called from the Python interpreter without needing to use
qjit()
.The semantics of
while_loop
are given by the following Python pseudo-code:def while_loop(cond_fun, body_fun, *args): while cond_fun(*args): args = body_fn(*args) return args
- Parameters
cond_fn (Callable) – the condition function in the while loop
allow_array_resizing (bool) – Whether to allow arrays to change shape/size within the loop. By default this is
False
; this will allow out-of-scope dynamically-shaped arrays to be captured by the loop, and binary operations to be applied to arrays of the same shape. Set this toTrue
to modify dimension sizes within the loop, however outer-scope dynamically-shaped arrays will no longer be captured, and arrays of the same shape cannot be used in binary operations.
- Returns
A wrapper around the while-loop function.
- Return type
Callable
- Raises
TypeError – Invalid return type of the condition expression.
Example
dev = qml.device("lightning.qubit", wires=1) @qjit @qml.qnode(dev) def circuit(x: float): @while_loop(lambda x: x < 2.0) def loop_rx(x): # perform some work and update (some of) the arguments qml.RX(x, wires=0) return x ** 2 # apply the while loop final_x = loop_rx(x) return qml.expval(qml.PauliZ(0)), final_x
>>> circuit(1.6) (Array(-0.02919952, dtype=float64), Array(2.56, dtype=float64))
By default,
allow_array_resizing
isFalse
, allowing dynamically-shaped arrays from outside the for loop to be correctly captured, and arrays of the same shape to be used in binary operations:>>> @qjit(abstracted_axes={0: 'n'}) ... def g(x, y): ... @catalyst.while_loop(lambda i: jnp.sum(i) > 2., allow_array_resizing=False) ... def loop(a): ... # Attempt to capture `x` from the outer scope, ... # and apply a binary operation '*' between the two arrays. ... return a * x ... return loop(y) >>> x = jnp.array([0.1, 0.2, 0.3]) >>> y = jnp.array([5.2, 10.3, 2.4]) >>> g(x, y) Array([0.052, 0.412, 0.216], dtype=float64)
However, if you wish to have the for loop return differently sized arrays at each iteration, set
allow_array_resizing
toTrue
:>>> @qjit ... def f(N): ... a0 = jnp.ones([N]) ... b0 = jnp.ones([N]) ... @while_loop(lambda _a, _b, i: i < 3, allow_array_resizing=True) ... def loop(a, _, i): ... i += 1 ... b = jnp.ones([i + 1]) ... return (a, b, i) # return array of new dimensions ... return loop(a0, b0, 0) >>> f(2) (Array([1., 1.], dtype=float64), Array([1., 1., 1., 1.], dtype=float64), Array(3, dtype=int64))
Note that when
allow_array_resizing=True
, dynamically-shaped arrays can no longer be captured from outer-scopes by the for loop, and binary operations between arrays of the same size are not supported.For more details on dynamically-shaped arrays, please see Dynamically-shaped arrays.