qml.capture.run_autograph¶
- run_autograph(fn)[source]¶
Decorator that converts the given function into graph form.
AutoGraph can be used in PennyLane’s capture workflow to convert Pythonic control flow to PennyLane native control flow. This requires the
diastatic-malt
package, a standalone fork of the AutoGraph module in TensorFlow (official documentation ).- Parameters
fn (Callable) – The callable to be converted. This could be a function, a QNode, or another callable object. For a QNode, the
QNode.func
will be converted. For another callable object, a function calling the object will be converted.- Returns
For a function, the converted function is returned directly. For a QNode, a copy of the QNode will be returned with
QNode.func
replaced with the converted version offunc
. For any other callableobj
, the returned function will be a converted version oflambda *args, **kwargs: obj(*args, **kwargs)
- Return type
Callable
Note
There are some limitations and sharp bits regarding AutoGraph; to better understand supported behaviour and limitations, see Guide for AutoGraph for plxpr capture.
Warning
Nested functions are only lazily converted by AutoGraph. If the input includes nested functions, these won’t be converted until the first time the function is traced.
Example
Consider the following function including Pythonic control flow, which can’t be captured directly:
>>> def f(x, n): ... for i in range(n): ... x += 1 ... return x >>> jax.make_jaxpr(f)(2, 4) TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]. The error occurred while tracing the function f at /var/folders/61/wr1fxnf95tg9k56bz1_7g29r0000gq/T/ipykernel_23187/3992882129.py:1 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument n.
Passing it thorough AutoGraph converts the structure of the function to native PennyLane control flow with
cond()
,for_loop()
, andwhile_loop()
, making it possible to capture:>>> ag_fn = run_autograph(f) >>> jax.make_jaxpr(ag_fn)(2, 4) { lambda ; a:i64[] b:i64[]. let c:i64[] = for_loop[ args_slice=slice(0, None, None) consts_slice=slice(0, 0, None) jaxpr_body_fn={ lambda ; d:i64[] e:i64[]. let f:i64[] = add e 1 in (f,) } ] 0 b 1 a in (c,) }