qml.capture.determine_abstracted_axes¶
- determine_abstracted_axes(args)[source]¶
Computed the abstracted axes and extracting the abstract shapes from the arguments.
- Parameters
args (tuple) – the arguments for a higher order primitive
- Returns
the corresponding abstracted axes and dynamic shapes
- Return type
tuple, tuple
Note that “dynamic shapes” only refers to the size of dimensions, but not the number of dimensions. Even with dynamic shapes mode enabled, we cannot change the number of dimensions.
See the
intro_to_dynamic_shapes.md
document for more information on how dynamic shapes work.To make jaxpr from arguments with dynamic shapes, the
abstracted_axes
keyword argument must be set. Then, when calling the jaxpr, variables for the dynamic shapes must be passed.jax.config.update("jax_dynamic_shapes", True) def f(n): x = jax.numpy.ones((n,)) abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x)