catalyst.debug.callback

callback(callback_fn)[source]

Execute a Python function with no return value and potential side effects from within a qjit-compiled function.

This makes it an easy entry point for debugging, for example via printing or logging at runtime.

The callback function will be quantum just-in-time compiled alongside the rest of the workflow, however it will be executed at runtime by the Python virtual machine. This is in contrast to functions which get directly qjit-compiled by Catalyst, which will be executed at runtime by machine-native code.

Parameters

callback_fn (callable) – The function to be used as a callback. Any Python-based function is supported, as long as it does not return anything (or returns None).

Example

debug.callback can be used as a decorator:

@catalyst.debug.callback
def callback_fn(y):
    print("Value of y =", y)

@qjit
def fn(x):
    y = jnp.sin(x)
    callback_fn(y)
    return y ** 2
>>> fn(0.54)
Value of y = 0.5141359916531132
Array(0.26433582, dtype=float64)
>>> fn(1.52)
Value of y = 0.998710143975583
Array(0.99742195, dtype=float64)

It can also be used functionally:

import logging
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

@qjit
@grad
def fn(x):
    y = jnp.sin(x)
    catalyst.debug.callback(lambda _: log.info("Value of y = %s", _))(y)
    return y ** 2
>>> fn(0.543)
INFO:__main__:Value of y = 0.5167068002272901
Array(0.88476988, dtype=float64)

Note that during differentiation, the callback function will only be executed during the forward pass.