

Decorator that converts the given function into graph form.

AutoGraph can be used in PennyLane’s capture workflow to convert Pythonic control flow to PennyLane native control flow. This requires the diastatic-malt package, a standalone fork of the AutoGraph module in TensorFlow (official documentation ).


fn (Callable) – The callable to be converted. This could be a function, a QNode, or another callable object. For a QNode, the QNode.func will be converted. For another callable object, a function calling the object will be converted.


For a function, the converted function is returned directly. For a QNode, a copy of the QNode will be returned with QNode.func replaced with the converted version of func. For any other callable obj, the returned function will be a converted version of lambda *args, **kwargs: obj(*args, **kwargs)

There are some limitations and sharp bits regarding AutoGraph; to better understand supported behaviour and limitations, see Guide for AutoGraph for plxpr capture.


Nested functions are only lazily converted by AutoGraph. If the input includes nested functions, these won’t be converted until the first time the function is traced.


Consider the following function including Pythonic control flow, which can’t be captured directly:

>>> def f(x, n):
...     for i in range(n):
...          x += 1
...     return x
>>> jax.make_jaxpr(f)(2, 4)
TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
The error occurred while tracing the function f at /var/folders/61/wr1fxnf95tg9k56bz1_7g29r0000gq/T/ipykernel_23187/3992882129.py:1 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument n.

Passing it thorough AutoGraph converts the structure of the function to native PennyLane control flow with cond(), for_loop(), and while_loop(), making it possible to capture:

>>> ag_fn = run_autograph(f)
>>> jax.make_jaxpr(ag_fn)(2, 4)
{ lambda ; a:i64[] b:i64[]. let
    c:i64[] = for_loop[
      args_slice=slice(0, None, None)
      consts_slice=slice(0, 0, None)
      jaxpr_body_fn={ lambda ; d:i64[] e:i64[]. let f:i64[] = add e 1 in (f,) }
    ] 0 b 1 a
  in (c,) }


