catalyst.accelerate

accelerate(func=None, *, dev=None)[source]

Execute a jax.jit accelerated function on classical accelerators such as GPUs from within a qjit-compiled function.

Parameters
  • func (Callable or PjitFunction) – The function to be classically accelerated from within the qjit-compiled workflow. This function can be already just-in-time compiled with JAX via the jax.jit decorator and a specified device. If not, it will be implicitly JIT-compiled, and so must be JIT compatible.

  • dev (jax.Device) – the classical accelerator device the JIT-compiled function will run on. Available devices can be retrieved via jax.devices(). If not provided, the default value of jax.devices()[0] as determined by JAX will be used.

Example

@accelerate(dev=jax.devices("gpu")[0])
def classical_fn(x):
    return jnp.sin(x) ** 2

@qjit
def hybrid_fn(x):
    y = classical_fn(jnp.sqrt(x)) # will be executed on a GPU
    return jnp.cos(y)

In addition, you can accelerate function that have already been jax.jit decorated:

@jax.jit
def classical_fn(x):
    x = jax.device_put(x, jax.local_devices("gpu")[0])
    return jnp.sin(x) ** 2

@qjit
def hybrid_fn(x):
    y = accelerate(classical_fn)(x) # will be executed on a GPU
    return jnp.cos(y)

Accelerated functions also fully support autodifferentiation with grad(), jacobian(), and other Catalyst differentiation functions:

@qjit
@grad
def f(x):
    expm = accelerate(jax.scipy.linalg.expm)
    return jnp.sum(expm(jnp.sin(x)) ** 2)
>>> x = jnp.array([[0.1, 0.2], [0.3, 0.4]])
>>> f(x)
Array([[2.80120452, 1.67518663],
       [1.61605839, 4.42856163]], dtype=float64)