catalyst.accelerate¶
- accelerate(func=None, *, dev=None)[source]¶
Execute a
jax.jit
accelerated function on classical accelerators such as GPUs from within a qjit-compiled function.- Parameters
func (Callable or PjitFunction) – The function to be classically accelerated from within the qjit-compiled workflow. This function can be already just-in-time compiled with JAX via the
jax.jit
decorator and a specified device. If not, it will be implicitly JIT-compiled, and so must be JIT compatible.dev (jax.Device) – the classical accelerator device the JIT-compiled function will run on. Available devices can be retrieved via
jax.devices()
. If not provided, the default value ofjax.devices()[0]
as determined by JAX will be used.
See also
Example
@accelerate(dev=jax.devices("gpu")[0]) def classical_fn(x): return jnp.sin(x) ** 2 @qjit def hybrid_fn(x): y = classical_fn(jnp.sqrt(x)) # will be executed on a GPU return jnp.cos(y)
In addition, you can accelerate function that have already been
jax.jit
decorated:@jax.jit def classical_fn(x): x = jax.device_put(x, jax.local_devices("gpu")[0]) return jnp.sin(x) ** 2 @qjit def hybrid_fn(x): y = accelerate(classical_fn)(x) # will be executed on a GPU return jnp.cos(y)
Accelerated functions also fully support autodifferentiation with
grad()
,jacobian()
, and other Catalyst differentiation functions:@qjit @grad def f(x): expm = accelerate(jax.scipy.linalg.expm) return jnp.sum(expm(jnp.sin(x)) ** 2)
>>> x = jnp.array([[0.1, 0.2], [0.3, 0.4]]) >>> f(x) Array([[2.80120452, 1.67518663], [1.61605839, 4.42856163]], dtype=float64)