catalyst.debug.print_memref

print_memref(x)[source]

A qjit() compatible print function for printing numeric values at runtime with memref information.

Enables printing of numeric values at runtime and the value’s metadata.

Tensors in the Catalyst runtime are represented as memref descriptor structs. For more information about memref descriptors see the MLIR documentation. This function will print the base memory address of the data buffer, as well as the rank of the array, the size of each dimension, and the strides between elements.

Parameters

x (jax.Array, Any) – A single jax array whose numeric values are printed at runtime.

See also

print()

Example

@qjit
def func(x: float):
    debug.print_memref(x)
>>> func(jnp.array(0.43))
Unranked Memref base@ = 0x5629ff2b6680 rank = 0 offset = 0 sizes = [] strides = [] data =
[0.43]

Outside a qjit() compiled function the operation falls back to the Python print statement.