qml.capture.determine_abstracted_axes¶
- determine_abstracted_axes(args)[source]¶
Compute the abstracted axes and extract the abstract shapes from the arguments.
- Parameters
args (tuple) – the arguments for a higher order primitive
- Returns
the corresponding abstracted axes and dynamic shapes
- Return type
tuple, tuple
Note that “dynamic shapes” only refers to the size of dimensions, but not the number of dimensions. Even with dynamic shapes mode enabled, we cannot change the number of dimensions.
See the
intro_to_dynamic_shapes.md
document for more information on how dynamic shapes work.To make jaxpr from arguments with dynamic shapes, the
abstracted_axes
keyword argument must be set. Then, when calling the jaxpr, variables for the dynamic shapes must be passed.jax.config.update("jax_dynamic_shapes", True) def f(n): x = jax.numpy.ones((n,)) abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x)
For cases where the shape of an argument matches a previous argument like:
>>> def f(i, x): ... return x >>> def workflow(i): ... args = (i, jax.numpy.ones((i, ))) ... abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args) ... print("abstracted_axes: ", abstracted_axes) ... print("abstract_shapes: ", abstract_shapes) ... print("jaxpr: ", jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args)) >>> _ = jax.make_jaxpr(workflow)(2) abstracted_axes: ({}, {0: '0_arg'}) abstract_shapes: [] jaxpr: { lambda ; a:i32[] b:f32[a]. let in (b,) }
We allow Jax to identify that the shape of
b
matches our first argument,a
. This is demonstrated by the fact that we do not have any additionalabstract_shapes
, as it is already present in the call signature. The abstracted axis is also"0_arg"
instead of0
. The"_arg"
at the end indicates that the corresponding abstract axis was already in the argument loop.