qml.capture.determine_abstracted_axes

determine_abstracted_axes(args)[source]

Compute the abstracted axes and extract the abstract shapes from the arguments.

Parameters

args (tuple) – the arguments for a higher order primitive

Returns

the corresponding abstracted axes and dynamic shapes

Return type

tuple, tuple

Note that “dynamic shapes” only refers to the size of dimensions, but not the number of dimensions. Even with dynamic shapes mode enabled, we cannot change the number of dimensions.

See the intro_to_dynamic_shapes.md document for more information on how dynamic shapes work.

To make jaxpr from arguments with dynamic shapes, the abstracted_axes keyword argument must be set. Then, when calling the jaxpr, variables for the dynamic shapes must be passed.

jax.config.update("jax_dynamic_shapes", True)

def f(n):
    x = jax.numpy.ones((n,))
    abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,))
    jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x)
    return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x)

For cases where the shape of an argument matches a previous argument like:

>>> def f(i, x):
...    return x
>>> def workflow(i):
...     args = (i, jax.numpy.ones((i, )))
...     abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args)
...     print("abstracted_axes: ", abstracted_axes)
...     print("abstract_shapes: ", abstract_shapes)
...     print("jaxpr: ", jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args))
>>> _ = jax.make_jaxpr(workflow)(2)
abstracted_axes:  ({}, {0: '0_arg'})
abstract_shapes:  []
jaxpr:  { lambda ; a:i32[] b:f32[a]. let  in (b,) }

We allow Jax to identify that the shape of b matches our first argument, a. This is demonstrated by the fact that we do not have any additional abstract_shapes, as it is already present in the call signature. The abstracted axis is also "0_arg" instead of 0. The "_arg" at the end indicates that the corresponding abstract axis was already in the argument loop.

Contents

Using PennyLane

Release news

Development

API

Internals