Callbacks and GPUs¶
While Catalyst aims to support all classical processing functionality as provided by JAX, there are sometimes cases where you may need to perform a host callback to execute arbitrary Python-compatible code. This may include use-cases such as:
runtime debugging and logging,
executing classical subroutines on accelerators such as GPUs or TPUs, or
incorporating non-JAX compatible classical subroutines within a larger QJIT workflow.
Catalyst supports all of these via a collection of callback functions.
Overview¶
Catalyst provides several callback functions:
callback()
supports callbacks of functions with no return values. This makes it an easy entry point for debugging, for example via printing or logging at runtime.pure_callback()
supports callbacks of pure functions. That is, functions with no side-effects that accept parameters and return values. However, the return type and shape of the function must be known in advance, and is provided as a type signature.Note that to use
pure_callback()
within functions that are being differentiated, a custom VJP rule must be defined so that the Catalyst compiler knows how to differentiate the callback. This can be done via thepure_callback.fwd
andpure_callback.bwd
methods. See thepure_callback()
documentation for more details.accelerate()
is similar topure_callback()
above, but is designed to work only with functions that arejax.jit
compatible. As a result of this restriction, return types do not have to be provided upfront, and support is provided for executing these callbacks directly on classical accelerators such as GPUs and TPUs.
In addition, print()
, a convenient wrapper around the Python print
function,
is provided for runtime printing support.
Callbacks to arbitrary Python¶
When coming across functionality that is not yet supported by Catalyst, such as functions like
scipy.integrate.simpson
, Python callbacks can be used to call arbitrary Python code within
a qjit-compiled function, as long as the return shape and type is known:
import scipy as sp
@pure_callback
def simpson(x, y) -> float:
return sp.integrate.simpson(y, x=x)
@qjit
def integrate_xsq(a, b):
x = jnp.linspace(a, b, 100)
return simpson(x, x ** 2)
>>> integrate_xsq(-1, 1)
Array(0.66666667, dtype=float64)
>>> integrate_xsq(-1, 2)
Array(3., dtype=float64)
Please see the docstring of pure_callback()
for more details, including how to define
vector-Jacobian product (VJP) rules for autodifferentiation, and for specifying the return-type
of vector-valued functions.
Callbacks to JIT-compatible code¶
If a function is JIT-compatible, then accelerate()
can be used, negating the need to manually
provide return shape and dtype information:
@qjit
def fn(x):
x = jnp.sin(x)
y = catalyst.accelerate(jnp.fft.fft)(x)
return jnp.sum(y)
>>> x = np.array([1.0, 2.0, 1.0, -1.0, 1.5])
>>> fn(x)
Array(4.20735492+0.j, dtype=complex128)
Accelerated functions also fully support autodifferentiation with
grad()
, jacobian()
, and other Catalyst differentiation functions,
without needing to specify VJP rules manually:
@qjit
@grad
def f(x):
expm = catalyst.accelerate(jax.scipy.linalg.expm)
return jnp.sum(expm(jnp.sin(x)) ** 2)
>>> x = jnp.array([[0.1, 0.2], [0.3, 0.4]])
>>> f(x)
Array([[2.80120452, 1.67518663],
[1.61605839, 4.42856163]], dtype=float64)
Accelerator (GPU and TPU) support¶
accelerate()
can also be used to execute classical subroutines on
classical accelerators such as GPUs and TPUs:
@accelerate(dev=jax.devices("gpu")[0])
def classical_fn(x):
return jnp.sin(x) ** 2
@qjit
def hybrid_fn(x):
y = classical_fn(jnp.sqrt(x)) # will be executed on a GPU
return jnp.cos(y)