print

print(x, memref=False)[source]

A qjit() compatible print function for printing values at runtime.

Enables printing of numeric values at runtime. Can also print objects or strings as constants.

Parameters
  • x (jax.Array, Any) – A single jax array whose numeric values are printed at runtime, or any object whose string representation will be treated as a constant and printed at runtime.

  • memref (Optional[bool]) – When set to True, additional information about how the array is stored in memory is printed, via the so-called “memref” descriptor. This includes the base memory address of the data buffer, as well as the rank of the array, the size of each dimension, and the strides between elements.

Example

@qjit
def func(x: float):
    debug.print(x, memref=True)
    debug.print("exit")
>>> func(jnp.array(0.43))
Unranked Memref base@ = 0x5629ff2b6680 rank = 0 offset = 0 sizes = [] strides = [] data =
[0.43]
exit

Outside a qjit() compiled function the operation falls back to the Python print statement.

Note

Python f-strings will not work as expected since they will be treated as Python objects. This means that array values embeded in them will have their compile-time representation printed, instead of actual data.