qml.capture.register_custom_staging_rule¶
- register_custom_staging_rule(primitive, get_outvars_from_params)[source]¶
Register a custom staging rule for a primitive, where the output should match the variables retrieved by
get_outvars_from_params
.- Parameters
primitive (jax.core.Primitive) – a jax primitive we want to register a custom staging rule for
get_outvars_from_params (Callable[[dict], list[jax.core.Var]]) – A function that takes in the equation’s
params
and returnsjax.core.Var
we need to mimic for the primitives return.
For example, the
cond_prim
will request its custom staging rule like:register_custom_staging_rule(cond_prim, lambda params: params['jaxpr_branches'][0].outvars)
The return of any
cond_prim
will match the output variables of the first jaxpr branch.