qml.capture.determine_abstracted_axes

determine_abstracted_axes(args)[source]

Computed the abstracted axes and extracting the abstract shapes from the arguments.

Parameters

args (tuple) – the arguments for a higher order primitive

Returns

the corresponding abstracted axes and dynamic shapes

Return type

tuple, tuple

Note that “dynamic shapes” only refers to the size of dimensions, but not the number of dimensions. Even with dynamic shapes mode enabled, we cannot change the number of dimensions.

See the intro_to_dynamic_shapes.md document for more information on how dynamic shapes work.

To make jaxpr from arguments with dynamic shapes, the abstracted_axes keyword argument must be set. Then, when calling the jaxpr, variables for the dynamic shapes must be passed.

jax.config.update("jax_dynamic_shapes", True)

def f(n):
    x = jax.numpy.ones((n,))
    abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,))
    jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x)
    return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x)

Contents

Using PennyLane

Release news

Development

API

Internals