pure_callback¶
- pure_callback(callback_fn, result_type=None)[source]¶
Execute and return the results of a functionally pure Python function from within a qjit-compiled function.
The callback function will be quantum just-in-time compiled alongside the rest of the workflow, however it will be executed at runtime by the Python virtual machine. This is in contrast to functions which get directly qjit-compiled by Catalyst, which will be executed at runtime as machine-native code.
Note
Callbacks do not currently support differentiation, and cannot be used inside functions that
catalyst.grad()
is applied to.- Parameters
callback_fn (callable) –
The pure function to be used as a callback. Any Python-based function is supported, as long as it:
is a pure function (meaning it is deterministic — for the same function arguments, the same result is always returned — and has no side effects, such as modifying a non-local variable),
has a signature that can be inspected (that is, it is not a NumPy ufunc or Python builtin),
the return type and shape is deterministic and known ahead of time.
result_type (type) – The type returned by the function.
See also
Example
pure_callback
can be used as a decorator. In this case, we must specify the result type via a type hint:@catalyst.pure_callback def callback_fn(x) -> float: # here we call non-JAX compatible code, such # as standard NumPy return np.sin(x) @qjit def fn(x): return jnp.cos(callback_fn(x ** 2))
>>> fn(0.654) array(0.9151995)
It can also be used functionally:
>>> @qjit >>> def add_one(x): ... return catalyst.pure_callback(lambda x: x + 1, int)(x) >>> add_one(2) array(3)
For callback functions that return arrays, a
jax.ShapeDtypeStruct
object can be created to specify the expected return shape and data type:@qjit def fn(x): x = jnp.cos(x) result_shape = jax.ShapeDtypeStruct(x.shape, jnp.complex128) @catalyst.pure_callback def callback_fn(y) -> result_shape: return jax.jit(jnp.fft.fft)(y) x = callback_fn(x) return x
>>> fn(jnp.array([0.1, 0.2])) array([1.97507074+0.j, 0.01493759+0.j])