Guide for AutoGraph for plxpr capture¶
When capturing PennyLane programs as a plxpr instance using AutoGraph, you can represent quantum programs with structure. That is, you can use classical control flow (such as conditionals and loops) with quantum operations and measurements, and this structure is captured and preserved in the plxpr representation.
PennyLane provides various high-level functions, such as cond()
,
for_loop()
, and while_loop()
, that work with native PennyLane
quantum operations. However, it can sometimes take a bit of work to rewrite
existing Python code using these specific control flow functions. AutoGraph is an experimental
feature of PennyLane capture that allows Pennylane capture to work
with native Python control flow, such as if
statements and for
loops.
Here, we’ll aim to provide an overview of AutoGraph, as well as various restrictions and constraints you may discover.
Note
When converting code in these examples, we will use the make_plxpr()
function,
which uses AutoGraph by default.
When creating the initial plxpr representation, we must call the constructor function produced
by make_plxpr()
with some initial values, which should have the same type and
shape as the values we intend to use when evaluating:
from pennylane.capture import make_plxpr
def f(x):
if x > 5:
x = x ** 2
return x
>>> plxpr = make_plxpr(f)(0.0) # x will be a float
Once the plxpr representation is created, we can evaluate it using
>>> from jax.core import eval_jaxpr
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 5.3) # evaluate f(5.3)
[Array(28.09, dtype=float64, weak_type=True)]
Using AutoGraph¶
The AutoGraph feature in PennyLane is supported by the diastatic-malt
package, a standalone
fork of the AutoGraph module in TensorFlow (official documentation).
The make_plxpr()
function uses AutoGraph by default. Consider a function using
Python control flow:
dev = qml.device("default.qubit", wires=4)
@qml.qnode(dev)
def cost(weights, data):
for w in dev.wires:
qml.X(w)
for x in weights:
for j, p in enumerate(x):
if p > 0:
qml.RX(p, wires=j)
elif p < 0:
qml.RY(p, wires=j)
for j in range(4):
qml.CNOT(wires=[j, jnp.mod((j + 1), 4)])
return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))
While this function cannot be captured directly because there is control flow that depends on the values of the function’s inputs (the inputs are treated as JAX tracers at capture time, which don’t have concrete values) it can be captured by converting to native PennyLane syntax
via AutoGraph. This is the default behaviour of make_plxpr()
.
>>> weights = jnp.linspace(-1, 1, 20).reshape([5, 4])
>>> data = jnp.ones([4])
>>> plxpr = make_plxpr(cost)(weights, data)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, weights, data)
[Array(-0.45165857, dtype=float64)]
This would be equivalent to writing the following program, without using
AutoGraph, but instead using cond()
and for_loop()
:
@qml.qnode(dev)
def cost(weights, data):
@qml.for_loop(0, 4, 1)
def initialize_loop(w):
qml.X(w)
@qml.for_loop(0, jnp.shape(weights)[0], 1)
def layer_loop(i):
x = weights[i]
@qml.for_loop(0, 4, 1)
def wire_loop(j):
@qml.cond(x[j] > 0)
def trainable_gate():
qml.RX(x[j], wires=j)
@trainable_gate.else_if(x[j] < 0)
def trainable_gate():
qml.RY(x[j], wires=j)
trainable_gate()
@qml.for_loop(0, 4, 1)
def cnot_loop(j):
qml.CNOT(wires=[j, jnp.mod((j + 1), 4)])
wire_loop()
cnot_loop()
initialize_loop()
layer_loop()
return qml.expval(qml.PauliZ(0) + qml.PauliZ(3))
Once converted to native PennyLane control flow manually, AutoGraph is no longer needed:
>>> plxpr = make_plxpr(cost, autograph=False)(weights, data)
>>> jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, weights, data)
[Array(-0.45165857, dtype=float64)]
Currently, AutoGraph supports converting the following Python statements:
if
statements (includingelif
andelse
)for
loopswhile
loops
break
and continue
statements are currently not supported. The logical operators
and
, or
and not
are currently unsupported.
Nested functions¶
AutoGraph will continue to work even when the function itself calls nested functions. All functions called within the top-level function will also have Python control flow captured and converted by AutoGraph.
In addition, built-in functions from jax
, pennylane
, and catalyst
are automatically excluded from the AutoGraph conversion.
def f(x):
if x > 5:
y = x ** 2
else:
y = x ** 3
return y
def g(x, n):
for i in range(n):
x = x + f(x)
return x
>>> plxpr = make_plxpr(g)(0.0, 1) # initialize with arguments of correct type and shape
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.4, 6)
[Array(22.14135448, dtype=float64)]
If statements¶
While most if
statements you may write in Python will be automatically
converted, there are some important constraints and restrictions to be aware of.
Return statements¶
Return statements are generally supported inside of if
/elif
/else
statements,
however, the returned values require a matching shape and structure across branches.
For example, consider the following pattern, where two different array dimensions are returned from each branch:
def f(x):
if x > 5:
return jnp.array([1, 2])
return jnp.array([0])
This will generate the following error:
>>> make_plxpr(f)(0)
ValueError: Mismatch in output abstract values in false branch #0 at position 1:
ShapedArray(int64[1]) vs ShapedArray(int64[2])
This is relevant for any example that uses different structure across branches. The structure of a function output is defined by things like the number of results, the containers used like lists or dictionaries, or more generally any (compile-time) PyTree metadata.
Different branches must assign the same type¶
Different branches of an if
statement must always assign variables with the same type across branches,
if those variables are used in the outer scope (external variables). The type must be the same in the sense
that the structure of the variable should not change across branches, and the dtypes must match.
Consider this function, which differs in the type of the elements in y
in different logic branches:
>>> def f(x):
... if x > 1:
... y = jnp.array([1.0, 2.0, 3.0])
... else:
... y = jnp.array([4, 5, 6])
... return jnp.sum(y)
>>> make_plxpr(f)(0.5)
ValueError: Mismatch in output abstract values in false branch #0 at position 0: ShapedArray(int64[3]) vs ShapedArray(float64[3])
Instead, all possible outcomes for y
at the end of the if/else block need to have the same shape, type, etc:
>>> def f(x):
... if x > 1:
... y = jnp.array([1.0, 2.0, 3.0])
... else:
... y = jnp.array([4.0, 5.0, 6.0])
... return jnp.sum(y)
>>> plxpr = make_plxpr(f)(0.5)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.5)
[Array(15., dtype=float64)]
More generally, this also applies to common container classes such as
dict
, list
, and tuple
. If one branch assigns an external variable,
then all other branches must also assign the external variable with the same
type, nested structure, number of elements, element types, and array shapes.
Changing a variable type¶
We can change the type of an existing variable y
, as long as we make sure to change it in all branches.
This means will need to include an else
statement to also change the type:
>>> def f(x):
... y = -1.0
... if x > 5:
... y = 4
... return y
>>> plxpr = make_plxpr(f)(0.5)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 7.0)
ValueError: Mismatch in output abstract values in false branch #0 at position 0: ShapedArray(float64[], weak_type=True) vs ShapedArray(int64[], weak_type=True)
Even if we want to keep the value in the else
condition, we need to update it to the new data type:
>>> def f(x):
... y = -1.0
... if x > 5:
... y = 4
... else:
... y = -1
... return y
>>> plxpr = make_plxpr(f)(0.5)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 7.0)
Array(-1, dtype=int64)
Compatible type assignments¶
Within an if
statement, variable assignments must include JAX compatible
types (Booleans, Python numeric types, JAX arrays, and PennyLane quantum
operators). Non-compatible types (such as strings) used
after the if
statement will result in an error:
>>> def f(x):
... if x > 5:
... y = "a"
... else:
... y = "b"
... return y
>>> plxpr = make_plxpr(f)(0.5)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 7.0)
TypeError: Value 'a' with type <class 'str'> is not a valid JAX type
For loops¶
Most for
loop constructs will be properly captured and compiled by AutoGraph.
dev = qml.device("default.qubit", wires=1)
@qml.qnode(dev)
def f():
for x in jnp.array([0, 1, 2]):
qml.RY(x * jnp.pi / 4, wires=0)
return qml.expval(qml.PauliZ(0))
>>> plxpr = make_plxpr(f)()
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts)
[Array(-0.70710678, dtype=float64)]
This includes automatic unpacking and enumeration through JAX arrays:
>>> def f(weights):
... z = 0.
... for i, (x, y) in enumerate(weights):
... z = i * x + i ** 2 * y
... return z
>>> weights = jnp.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]).T
>>> plxpr = make_plxpr(f)(weights)
>>> eval_jaxpr(plxpr.jaxpr, [], weights)
Array(8.4, dtype=float64)
The Python range
function is also supported by AutoGraph, even when
its input is a dynamic variable (i.e., its numeric value is only known at
runtime):
>>> def f(n):
... x = -jnp.log(n)
... for k in range(1, n + 1):
... x = x + 1 / k
... return x
>>> plxpr = make_plxpr(f)(0)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 1000)
[Array(0.57771558, dtype=float64, weak_type=True)]
Indexing within a loop¶
Indexing arrays within a for
loop will generally work, but care must be taken.
For example, using a for
loop with static bounds to index a JAX array is straightforward:
>>> dev = qml.device("default.qubit", wires=3)
... @qml.qnode(dev)
... def f(x):
... for i in range(3):
... qml.RX(x[i], wires=i)
... return qml.expval(qml.PauliZ(0))
>>> weights = jnp.array([0.1, 0.2, 0.3])
>>> plxpr = make_plxpr(f)(weights)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, weights)
[Array(0.99500417, dtype=float64)]
However, indexing within a for
loop with AutoGraph will require that the object indexed is
a JAX array or dynamic runtime variable.
If the array you are indexing within the for
loop is not a JAX array
or dynamic variable, an error will be raised:
>>> @qml.qnode(dev)
... def f():
... x = [0.1, 0.2, 0.3]
... for i in range(3):
... qml.RX(x[i], wires=i)
... return qml.expval(qml.PauliZ(0))
>>> plxpr = make_plxpr(f)()
AutoGraphError: Tracing of an AutoGraph converted for loop failed with an exception:
TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]
The error occurred while tracing the function functional_for [...]
To allow AutoGraph conversion to work in this case, simply convert the list to a JAX array:
>>> @qml.qnode(dev)
... def f():
... x = jnp.array([0.1, 0.2, 0.3])
... for i in range(3):
... qml.RX(x[i], wires=i)
... return qml.expval(qml.PauliZ(0))
>>> plxpr = make_plxpr(f)()
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts)
[Array(0.99500417, dtype=float64)]
If the object you are indexing cannot be converted to a JAX array, it is not possible for AutoGraph to capture this for
loop.
If you are updating elements of the array, this must be done using the JAX .at
and .set
syntax.
>>> def f():
... my_list = jnp.empty(2, dtype=int)
... for i in range(2):
... my_list = my_list.at[i].set(i) # not my_list[i] = i
... return my_list
>>> plxpr = make_plxpr(f)()
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts)
Array([0, 1], dtype=int64)
Dynamic indexing¶
Indexing into arrays where the for
loop has dynamic bounds (that is, where
the size of the loop is set by a dynamic runtime variable) will also work, as long
as the object indexed is a JAX array:
>>> @qml.qnode(dev)
... def f(n):
... x = jnp.array([0.0, 1 / 4 * jnp.pi, 2 / 4 * jnp.pi])
... for i in range(n):
... qml.RY(x[i], wires=0)
... return qml.expval(qml.PauliZ(0))
>>> plxpr = make_plxpr(f)(0)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 2)
Array(0.70710678, dtype=float64)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 3)
Array(-0.70710678, dtype=float64)
However AutoGraph conversion will fail if the object being indexed by the
loop with dynamic bounds is not a JAX array, because you cannot index
standard Python objects with dynamic variables. Ensure that all objects that
are indexed within dynamic for
loops are JAX arrays.
Break and continue¶
Within a for
loop, control flow statements break
and continue
are not currently supported.
Updating and assigning variables¶
for
loops that update variables can also be converted with AutoGraph:
>>> def f(x):
... for y in [0, 4, 5]:
... x = x + y
... return x
>>> plxpr = make_plxpr(f)(0)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 3)
[Array(12, dtype=int64)]
However, like with conditionals, a similar restriction applies: variables which are updated across iterations of the loop must have a JAX compilable type (Booleans, Python numeric types, and JAX arrays).
You can also utilize temporary variables within a for
loop:
>>> def f(x):
... for y in [0, 4, 5]:
... c = 2
... x = x + y * c
... return x
>>> plxpr = make_plxpr(f)(0)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 4)
[Array(22, dtype=int64)]
Temporary variables used inside a loop — and that are not passed to a function within the loop — do not have any type restrictions.
While loops¶
Most while
loop constructs will be properly captured and compiled by
AutoGraph:
>>> def f(param):
... n = 0.
... while param < 0.5:
... param *= 1.2
... n += 1
... return n
>>> plxpr = make_plxpr(f)(0.0)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.1)
[Array(9., dtype=float64, weak_type=True)]
Indexing within a loop¶
Indexing arrays within a while
loop will generally work, but care must be taken.
For example, using a while
loop variable to index a JAX array is straightforward:
>>> dev = qml.device("default.qubit", wires=3)
... @qml.qnode(dev)
... def f(x):
... i = 0
... while i < 3:
... qml.RX(x[i], wires=i)
... i += 1
... return qml.expval(qml.PauliZ(0))
>>> weights = jnp.array([0.1, 0.2, 0.3])
>>> plxpr = make_plxpr(f)(weights)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, weights)
[Array(0.99500417, dtype=float64)]
However, indexing within a while
loop with AutoGraph will require that the object indexed is
a JAX array:
>>> @qml.qnode(dev)
... def f():
... x = [0.1, 0.2, 0.3]
... i = 0
... while i < 3:
... qml.RX(x[i], wires=i)
... i += 1
... return qml.expval(qml.PauliZ(0))
>>> plxpr = make_plxpr(f)()
TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
The error occurred while tracing the function functional_while at [...]
To allow AutoGraph conversion to work in this case, simply convert the list to a JAX array:
>>> @qml.qnode(dev)
... def f():
... x = jnp.array([0.1, 0.2, 0.3])
... i = 0
... while i < 3:
... qml.RX(x[i], wires=i)
... i += 1
... return qml.expval(qml.PauliZ(0))
>>> plxpr = make_plxpr(f)()
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts)
[Array(0.99500417, dtype=float64)]
If the object you are indexing cannot be converted to a JAX array, it is not possible for AutoGraph to capture this while
loop.
If you are updating elements of the array, this must be done using the JAX .at
and .set
syntax.
>>> def f():
... my_list = jnp.empty(2, dtype=int)
... i = 0
... while i < 2:
... my_list = my_list.at[i].set(i) # not my_list[i] = i
... i += 1
... return my_list
>>> plxpr = make_plxpr(f)()
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts)
Array([0, 1], dtype=int64)
Break and continue¶
Within a while
loop, control flow statements break
and continue
are not currently supported.
Updating and assigning variables¶
As with for
loops, while
loops that update variables can also be converted with AutoGraph:
>>> def f(x):
... while x < 5:
... x = x + 2
... return x
>>> plxpr = make_plxpr(f)(0.0)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 4.4)
[Array(6.4, dtype=float64, weak_type=True)]
However, like with conditionals, a similar restriction applies: variables which are updated across iterations of the loop must have a JAX compilable type (Booleans, Python numeric types, and JAX arrays).
You can also utilize temporary variables within a while
loop:
>>> def f(x):
... while x < 5:
... c = "hi"
... x = x + 2 * len(c)
... return x
>>> plxpr = make_plxpr(f)(0.0)
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 4.4)
[Array(8.4, dtype=float64, weak_type=True)]
Temporary variables used inside a loop—and that are not passed to a function within the loop—do not have any type restrictions.
A caveat regarding updating variables in a while
loop is that it is not possible to
update variables inside the loop test statement. For example, while the following
works in standard Python:
>>> def fn(limit):
... i = 0
... y = 0
... while (i := y) < limit:
... y += 1
... return i
>>> fn(10)
10
any updates to the variables inside the while
test function (in this case (i := y)
)
will be ignored by AutoGraph:
>>> plxpr = make_plxpr(fn)(0)
>>> jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 10)
[0]
Logical statements¶
AutoGraph in PennyLane currently does not provide support for capturing logical statements that involve dynamic variables — that is,
statements involving and
, not
, and or
that return booleans.
Debugging¶
One useful tool in debugging issues is to examine the plxpr representation of the compiled function, in order to verify that AutoGraph is correctly capturing the control flow. For example, consider:
def f(x, n):
for i in range(n):
if x > 5:
y = x ** 2
else:
y = x ** 3
x = x + y
return x
We can verify that the control flow is being correctly captured and converted by examining the plxpr representation of the compiled program:
>>> make_plxpr(f)(0.0, 0)
{ lambda ; a:f64[] b:i64[]. let
c:f64[] = for_loop[
args_slice=slice(0, None, None)
consts_slice=slice(0, 0, None)
jaxpr_body_fn={ lambda ; d:i64[] e:f64[]. let
f:bool[] = gt e 5.0
g:f64[] = cond[
args_slice=slice(4, None, None)
consts_slices=[slice(2, 3, None), slice(3, 4, None)]
jaxpr_branches=[{ lambda a:f64[]; . let b:f64[] = integer_pow[y=2] a in (b,) }, { lambda a:f64[]; . let b:f64[] = integer_pow[y=3] a in (b,) }]
] f True e e
h:f64[] = add e g
in (h,) }
] 0 b 1 a
in (c,) }
Here, we can see the cond
operation inside the for
loop, and
the two branches of the if
statement represented by the jaxpr_branches
list.
In addition, the function autograph_source()
is provided,
and allows you to view the converted Python code generated by AutoGraph:
>>> def f(n):
... x = - jnp.log(n)
... for k in range(1, n + 1):
... x = x + 1 / k
... return x
>>> plxpr = make_plxpr(f)(0)
>>> print(qml.capture.autograph.autograph_source(f))
def ag__f(n):
with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=ag__.Feature.BUILTIN_FUNCTIONS, internal_convert_user_code=True)) as fscope:
do_return = False
retval_ = ag__.UndefinedReturnValue()
x = -ag__.converted_call(ag__.ld(jnp).log, (ag__.ld(n),), None, fscope)
def get_state():
return (x,)
def set_state(vars_):
nonlocal x
x, = vars_
def loop_body(itr):
nonlocal x
k = itr
x = ag__.ld(x) + 1 / ag__.ld(k)
k = ag__.Undefined('k')
ag__.for_stmt(ag__.converted_call(ag__.ld(range), (1, ag__.ld(n) + 1), None, fscope), None, loop_body, get_state, set_state, ('x',), {'iterate_names': 'k'})
try:
do_return = True
retval_ = ag__.ld(x)
except:
do_return = False
raise
return fscope.ret(retval_, do_return)
Warning
Nested functions are only lazily converted by AutoGraph. If the input includes nested functions, these won’t be converted until the first time the function is traced. This is important to be aware of if examining the output of running autograph for debugging purposes. In an example like:
def f(x):
if x > 5:
y = x ** 2
else:
y = x ** 3
return y
def g(x, n):
for i in range(n):
x = x + f(x)
return x
ag_fn = make_plxpr(g)
we can access autograph_source(g)
, but we will get an error for autograph_source(f)
:
>>> autograph_source(f)
AutoGraphError: The given function was not converted by AutoGraph. If you expect the given function to be converted, please submit a bug report.
This is because it has only been lazily converted. To examine the inner function’s Autograph conversion, we must trace the output function from make_plxpr with values at least once:
>>> plxpr = ag_fn(0, 0)
>>> autograph_source(f)
def ag__f(x):
with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=False, optional_features=ag__.Feature.BUILTIN_FUNCTIONS, internal_convert_user_code=True)) as fscope:
...
Native Python control flow without AutoGraph¶
It’s important to note that native Python control flow—in cases where the control flow parameters are static—will continue to work with PennyLane without AutoGraph. However, if AutoGraph is not enabled, such control flow will be evaluated at compile time, and not preserved in the compiled program.
Let’s consider an example where a for
loop is evaluated at compile time:
>>> def f(x):
... for i in range(2):
... print(i, x)
... x = x / 2
... return x ** 2
>>> plxpr = make_plxpr(f, autograph=False)(0.0)
0 Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
1 Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
>>> plxpr
{ lambda ; a:f64[]. let
b:f64[] = div a 2.0
c:f64[] = div b 2.0
d:f64[] = integer_pow[y=2] c
in (d,) }
Here, the loop is evaluated at compile time, rather than runtime. Notice the multiple tracers that have been printed out during program capture—one for each loop—as well as the unrolling of the loop in the resulting plxpr.
With AutoGraph, we instead get a single print of the tracers, and compile with a for
loop that can be
evaluated at runtime:
>>> plxpr = make_plxpr(f, autograph=True)(0.0)
Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)> Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>
>>> plxpr
{ lambda ; a:f64[]. let
b:f64[] = for_loop[
args_slice=slice(0, None, None)
consts_slice=slice(0, 0, None)
jaxpr_body_fn={ lambda ; c:i64[] d:f64[]. let
e:f64[] = div d 2.0
in (e,) }
] 0 2 1 a
f:f64[] = integer_pow[y=2] b
in (f,) }
In-place JAX array updates¶
To update array values when using JAX, the JAX syntax for array assignment
(which uses the array at
and set
methods) must be used:
def f(x):
first_dim = x.shape[0]
result = jnp.empty((first_dim,), dtype=x.dtype)
for i in range(first_dim):
result = result.at[i].set(x[i] * 2)
return result
>>> plxpr = make_plxpr(f)(jnp.zeros(3))
>>> eval_jaxprF(plxpr.jaxpr, plxpr.consts, jnp.array([0.1, 0.2, 0.3]))
[Array([0.2, 0.4, 0.6], dtype=float64)]
Similarly, to update array values with an operation when using JAX, the JAX syntax for array
update (which uses the array at
and the add
, multiply
, etc. methods) must be used:
>>> def f(x):
... first_dim = x.shape[0]
... result = jnp.copy(x)
...
... for i in range(first_dim):
... result = result.at[i].multiply(2)
...
... return result