pure_callback

pure_callback(callback_fn, result_type=None)[source]

Execute and return the results of a functionally pure Python function from within a qjit-compiled function.

The callback function will be quantum just-in-time compiled alongside the rest of the workflow, however it will be executed at runtime by the Python virtual machine. This is in contrast to functions which get directly qjit-compiled by Catalyst, which will be executed at runtime as machine-native code.

Note

Callbacks do not currently support differentiation, and cannot be used inside functions that catalyst.grad() is applied to.

Parameters
  • callback_fn (callable) –

    The pure function to be used as a callback. Any Python-based function is supported, as long as it:

    • is a pure function (meaning it is deterministic — for the same function arguments, the same result is always returned — and has no side effects, such as modifying a non-local variable),

    • has a signature that can be inspected (that is, it is not a NumPy ufunc or Python builtin),

    • the return type and shape is deterministic and known ahead of time.

  • result_type (type) – The type returned by the function.

Example

pure_callback can be used as a decorator. In this case, we must specify the result type via a type hint:

@catalyst.pure_callback
def callback_fn(x) -> float:
    # here we call non-JAX compatible code, such
    # as standard NumPy
    return np.sin(x)

@qjit
def fn(x):
    return jnp.cos(callback_fn(x ** 2))
>>> fn(0.654)
array(0.9151995)

It can also be used functionally:

>>> @qjit
>>> def add_one(x):
...     return catalyst.pure_callback(lambda x: x + 1, int)(x)
>>> add_one(2)
array(3)

For callback functions that return arrays, a jax.ShapeDtypeStruct object can be created to specify the expected return shape and data type:

@qjit
def fn(x):
    x = jnp.cos(x)

    result_shape = jax.ShapeDtypeStruct(x.shape, jnp.complex128)

    @catalyst.pure_callback
    def callback_fn(y) -> result_shape:
        return jax.jit(jnp.fft.fft)(y)

    x = callback_fn(x)
    return x
>>> fn(jnp.array([0.1, 0.2]))
array([1.97507074+0.j, 0.01493759+0.j])