qml.while_loop¶
- while_loop(cond_fn, allow_array_resizing='auto')[source]¶
A
qjit()
compatible for-loop for PennyLane programs. When used withoutqjit()
or program capture, this function will fall back to a standard Python for loop.This decorator provides a functional version of the traditional while loop, similar to jax.lax.while_loop. That is, any variables that are modified across iterations need to be provided as inputs and outputs to the loop body function:
Input arguments contain the value of a variable at the start of an iteration
Output arguments contain the value at the end of the iteration. The outputs are then fed back as inputs to the next iteration.
The final iteration values are also returned from the transformed function.
The semantics of
while_loop
are given by the following Python pseudocode:def while_loop(cond_fn, body_fn, *args): while cond_fn(*args): args = body_fn(*args) return args
- Parameters
cond_fn (Callable) – the condition function in the while loop
allow_array_resizing (Literal["auto", True, False]) – How to handle arrays with dynamic shapes that change between iterations. Defaults to “auto”.
- Returns
A wrapper around the while-loop function.
- Return type
Callable
- Raises
CompileError – if the compiler is not installed
See also
Example
dev = qml.device("lightning.qubit", wires=1) @qml.qnode(dev) def circuit(x: float): @qml.while_loop(lambda x: x < 2.0) def loop_rx(x): # perform some work and update (some of) the arguments qml.RX(x, wires=0) return x ** 2 # apply the while loop loop_rx(x) return qml.expval(qml.Z(0))
>>> circuit(1.6) -0.02919952
while_loop
is alsoqjit()
compatible; when used with theqjit()
decorator, the while loop will not be unrolled, and instead will be captured as-is during compilation and executed during runtime:>>> qml.qjit(circuit)(1.6) Array(-0.02919952, dtype=float64)
Usage Details
Note
The following examples may yield different outputs depending on how the workflow function is executed. For instance, the function can be run directly as:
>>> arg = 2 >>> workflow(arg)
Alternatively, the function can be traced with
jax.make_jaxpr
to produce a JAXPR representation, which captures the abstract computational graph for the given input and generates the abstract shapes. The resulting JAXPR can then be evaluated usingqml.capture.eval_jaxpr
:>>> jaxpr = jax.make_jaxpr(workflow)(arg) >>> qml.capture.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg)
A dynamically shaped array is an array whose shape depends on an abstract value. This is an experimental jax mode that can be turned on with:
>>> import jax >>> import jax.numpy as jnp >>> jax.config.update("jax_dynamic_shapes", True) >>> qml.capture.enable()
allow_array_resizing="auto"
will try and choose between the following two possible modes. If the needed mode isallow_array_resizing=False
, then this will require re-capturing the loop, potentially taking more time.When working with dynamic shapes in a
while_loop
, we have two possible options.allow_array_resizing=True
treats every dynamic dimension as independent.@qml.while_loop(lambda a, b: jnp.sum(a) < 10, allow_array_resizing=True) def f(x, y): return jnp.hstack([x, y]), 2*y def workflow(i0): x0, y0 = jnp.ones(i0), jnp.ones(i0) return f(x0, y0)
Even though
x
andy
are initialized with the same shape, the shapes no longer match after one iteration. In this circumstance,x
andy
can no longer be combined with operations likex * y
, as they do not have matching shapes.With
allow_array_resizing=False
, anything that starts with the same dynamic dimension must keep the same shape pattern throughout the loop.@qml.while_loop(lambda a, b: jnp.sum(a) < 10, allow_array_resizing=False) def f(x, y): return x * y, 2*y def workflow(i0): x0 = jnp.ones(i0) y0 = jnp.ones(i0) return f(x0, y0)
Note that with
allow_array_resizing=False
, all arrays can still be resized together, as long as the pattern still matches. For example, here bothx
andy
start with the same shape, and keep the same shape as each other for each iteration.@qml.while_loop(lambda a, b: jnp.sum(a) < 10, allow_array_resizing=False) def f(x, y): x = jnp.hstack([x, y]) return x, 2*x def workflow(i0): x0 = jnp.ones(i0) y0 = jnp.ones(i0) return f(x0, y0)
Note that new dynamic dimensions cannot yet be created inside a loop. Only things that already have a dynamic dimension can have that dynamic dimension change. For example, this is not a viable
while_loop
, asx
is initialized with an array with a concrete size.def w(): @qml.while_loop(lambda i, x: i < 5) def f(i, x): return i + 1, jnp.append(x, i) return f(0, jnp.array([]))