catalyst.pure_callback¶
- pure_callback(callback_fn, result_type=None)[source]¶
Execute and return the results of a functionally pure Python function from within a qjit-compiled function.
The callback function will be quantum just-in-time compiled alongside the rest of the workflow, however it will be executed at runtime by the Python virtual machine. This is in contrast to functions which get directly qjit-compiled by Catalyst, which will be executed at runtime as machine-native code.
Note
Callbacks do not automatically support differentiation. To use them within functions that are being differentiated, please define their vector-Jacobian product (see below for more details).
- Parameters
callback_fn (callable) –
The pure function to be used as a callback. Any Python-based function is supported, as long as it:
is a pure function (meaning it is deterministic — for the same function arguments, the same result is always returned — and has no side effects, such as modifying a non-local variable),
has a signature that can be inspected (that is, it is not a NumPy ufunc or Python builtin),
the return type and shape is deterministic and known ahead of time.
result_type (type) – The type returned by the function.
See also
Example
pure_callback
can be used as a decorator. In this case, we must specify the result type via a type hint:@catalyst.pure_callback def callback_fn(x) -> float: # here we call non-JAX compatible code, such # as standard NumPy return np.sin(x) @qjit def fn(x): return jnp.cos(callback_fn(x ** 2))
>>> fn(0.654) Array(0.9151995, dtype=float64)
It can also be used functionally:
>>> @qjit >>> def add_one(x): ... return catalyst.pure_callback(lambda x: x + 1, int)(x) >>> add_one(2) Array(3, dtype=int64)
For callback functions that return arrays, a
jax.ShapeDtypeStruct
object can be created to specify the expected return shape and data type:@qjit def fn(x): x = jnp.cos(x) result_shape = jax.ShapeDtypeStruct(x.shape, jnp.complex128) @catalyst.pure_callback def callback_fn(y) -> result_shape: return jax.jit(jnp.fft.fft)(y) x = callback_fn(x) return x
>>> fn(jnp.array([0.1, 0.2])) Array([1.97507074+0.j, 0.01493759+0.j], dtype=complex128)
Differentiating callbacks with custom VJP rules
Pure callbacks must have custom gradients manually registered with the Catalyst compiler in order to support differentiation.
This can be done via the
pure_callback.fwd
andpure_callback.bwd
methods, to specify how the forwards and backwards pass (the vector-Jacobian product) of the callback should be computed:@catalyst.pure_callback def callback_fn(x) -> float: return np.sin(x[0]) * x[1] @callback_fn.fwd def callback_fn_fwd(x): # returns the evaluated function as well as residual # values that may be useful for the backwards pass return callback_fn(x), x @callback_fn.bwd def callback_fn_vjp(res, dy): # Accepts residuals from the forward pass, as well # as (one or more) cotangent vectors dy, and returns # a tuple of VJPs corresponding to each input parameter. def vjp(x, dy) -> (jax.ShapeDtypeStruct((2,), jnp.float64),): return (np.array([np.cos(x[0]) * dy * x[1], np.sin(x[0]) * dy]),) # The VJP function can also be a pure callback return catalyst.pure_callback(vjp)(res, dy)
>>> @qml.qjit ... @catalyst.grad ... def f(x): ... y = jnp.array([jnp.cos(x[0]), x[1]]) ... return jnp.sin(callback_fn(y)) >>> f(jnp.array([0.1, 0.2])) Array([-0.01071923, 0.82698717], dtype=float64)