catalyst.pure_callback

pure_callback(callback_fn, result_type=None)[source]

Execute and return the results of a functionally pure Python function from within a qjit-compiled function.

The callback function will be quantum just-in-time compiled alongside the rest of the workflow, however it will be executed at runtime by the Python virtual machine. This is in contrast to functions which get directly qjit-compiled by Catalyst, which will be executed at runtime as machine-native code.

Note

Callbacks do not automatically support differentiation. To use them within functions that are being differentiated, please define their vector-Jacobian product (see below for more details).

Parameters
  • callback_fn (callable) –

    The pure function to be used as a callback. Any Python-based function is supported, as long as it:

    • is a pure function (meaning it is deterministic — for the same function arguments, the same result is always returned — and has no side effects, such as modifying a non-local variable),

    • has a signature that can be inspected (that is, it is not a NumPy ufunc or Python builtin),

    • the return type and shape is deterministic and known ahead of time.

  • result_type (type) – The type returned by the function.

Example

pure_callback can be used as a decorator. In this case, we must specify the result type via a type hint:

@catalyst.pure_callback
def callback_fn(x) -> float:
    # here we call non-JAX compatible code, such
    # as standard NumPy
    return np.sin(x)

@qjit
def fn(x):
    return jnp.cos(callback_fn(x ** 2))
>>> fn(0.654)
Array(0.9151995, dtype=float64)

It can also be used functionally:

>>> @qjit
>>> def add_one(x):
...     return catalyst.pure_callback(lambda x: x + 1, int)(x)
>>> add_one(2)
Array(3, dtype=int64)

For callback functions that return arrays, a jax.ShapeDtypeStruct object can be created to specify the expected return shape and data type:

@qjit
def fn(x):
    x = jnp.cos(x)

    result_shape = jax.ShapeDtypeStruct(x.shape, jnp.complex128)

    @catalyst.pure_callback
    def callback_fn(y) -> result_shape:
        return jax.jit(jnp.fft.fft)(y)

    x = callback_fn(x)
    return x
>>> fn(jnp.array([0.1, 0.2]))
Array([1.97507074+0.j, 0.01493759+0.j], dtype=complex128)

Pure callbacks must have custom gradients manually registered with the Catalyst compiler in order to support differentiation.

This can be done via the pure_callback.fwd and pure_callback.bwd methods, to specify how the forwards and backwards pass (the vector-Jacobian product) of the callback should be computed:

@catalyst.pure_callback
def callback_fn(x) -> float:
    return np.sin(x[0]) * x[1]

@callback_fn.fwd
def callback_fn_fwd(x):
    # returns the evaluated function as well as residual
    # values that may be useful for the backwards pass
    return callback_fn(x), x

@callback_fn.bwd
def callback_fn_vjp(res, dy):
    # Accepts residuals from the forward pass, as well
    # as (one or more) cotangent vectors dy, and returns
    # a tuple of VJPs corresponding to each input parameter.

    def vjp(x, dy) -> (jax.ShapeDtypeStruct((2,), jnp.float64),):
        return (np.array([np.cos(x[0]) * dy * x[1], np.sin(x[0]) * dy]),)

    # The VJP function can also be a pure callback
    return catalyst.pure_callback(vjp)(res, dy)
>>> @qml.qjit
... @catalyst.grad
... def f(x):
...     y = jnp.array([jnp.cos(x[0]), x[1]])
...     return jnp.sin(callback_fn(y))
>>> f(jnp.array([0.1, 0.2]))
Array([-0.01071923,  0.82698717], dtype=float64)