catalyst.debug.get_compilation_stage

get_compilation_stage(fn, stage)[source]

Returns the intermediate representation of one of the recorded compilation stages for a JIT-compiled function.

The stages are indexed by their Catalyst compilation pipeline name, which are either provided by the user as a compilation option, or predefined in catalyst.compiler.

All the available stages are:

  • MILR: mlir, HLOLoweringPass, QuantumCompilationPass, BufferizationPass, and MLIRToLLVMDialect.

  • LLVM: llvm_ir, CoroOpt, O2Opt, Enzyme, and last.

Note that CoroOpt (Coroutine lowering), O2Opt (O2 optimization), and Enzyme (automatic differentiation) passes do not always happen. last denotes the stage right before object file generation.

Note

In order to use this function, keep_intermediate=True must be set in the qjit() decorator of the input function.

Parameters
  • fn (QJIT) – a qjit-decorated function

  • stage (str) – string corresponding with the name of the stage to be printed

Returns

output ir from the target compiler stage

Return type

str

Example

@qjit(keep_intermediate=True)
def func(x: float):
    return x
>>> print(debug.get_compilation_stage(func, "HLOLoweringPass"))
module @func {
  func.func public @jit_func(%arg0: tensor<f64>)
  -> tensor<f64> attributes {llvm.emit_c_interface} {
    return %arg0 : tensor<f64>
  }
  func.func @setup() {
    quantum.init
    return
  }
  func.func @teardown() {
    quantum.finalize
    return
  }
}