qml.capture.register_custom_staging_rule

register_custom_staging_rule(primitive, get_outvars_from_params)[source]

Register a custom staging rule for a primitive, where the output should match the variables retrieved by get_outvars_from_params.

Parameters
  • primitive (jax.core.Primitive) – a jax primitive we want to register a custom staging rule for

  • get_outvars_from_params (Callable[[dict], list[jax.core.Var]]) – A function that takes in the equation’s params and returns jax.core.Var we need to mimic for the primitives return.

For example, the cond_prim will request its custom staging rule like:

register_custom_staging_rule(cond_prim, lambda params: params['jaxpr_branches'][0].outvars)

The return of any cond_prim will match the output variables of the first jaxpr branch.

Contents

Using PennyLane

Release news

Development

API

Internals