catalyst.debug.get_compilation_stage¶
- get_compilation_stage(fn, stage)[source]¶
Returns the intermediate representation of one of the recorded compilation stages for a JIT-compiled function.
The stages are indexed by their Catalyst compilation pipeline name, which are either provided by the user as a compilation option, or predefined in
catalyst.compiler.All the available stages are:
MLIR:
mlir,HLOLoweringPass,QuantumCompilationPass,BufferizationPass, andMLIRToLLVMDialect.LLVM:
llvm_ir,CoroOpt,O2Opt,Enzyme, andlast.
Note that
CoroOpt(Coroutine lowering),O2Opt(O2 optimization), andEnzyme(automatic differentiation) passes do not always happen.lastdenotes the stage right before object file generation.Note
In order to use this function,
keep_intermediate=Truemust be set in theqjit()decorator of the input function.- Parameters:
fn (QJIT) – a qjit-decorated function
stage (str) – string corresponding with the name of the stage to be printed
- Returns:
output ir from the target compiler stage
- Return type:
str
See also
Example
@qjit(keep_intermediate=True) def func(x: float): return x
>>> print(debug.get_compilation_stage(func, "HLOLoweringPass")) module @func { func.func public @jit_func(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} { return %arg0 : tensor<f64> } func.func @setup() { quantum.init return } func.func @teardown() { quantum.finalize return } }