catalyst.debug.get_compilation_stage¶
- get_compilation_stage(fn, stage)[source]¶
Returns the intermediate representation of one of the recorded compilation stages for a JIT-compiled function.
The stages are indexed by their Catalyst compilation pipeline name, which are either provided by the user as a compilation option, or predefined in
catalyst.compiler
.All the available stages are:
MILR:
mlir
,HLOLoweringPass
,QuantumCompilationPass
,BufferizationPass
, andMLIRToLLVMDialect
.LLVM:
llvm_ir
,CoroOpt
,O2Opt
,Enzyme
, andlast
.
Note that
CoroOpt
(Coroutine lowering),O2Opt
(O2 optimization), andEnzyme
(automatic differentiation) passes do not always happen.last
denotes the stage right before object file generation.Note
In order to use this function,
keep_intermediate=True
must be set in theqjit()
decorator of the input function.- Parameters
fn (QJIT) – a qjit-decorated function
stage (str) – string corresponding with the name of the stage to be printed
- Returns
output ir from the target compiler stage
- Return type
str
See also
Example
@qjit(keep_intermediate=True) def func(x: float): return x
>>> print(debug.get_compilation_stage(func, "HLOLoweringPass")) module @func { func.func public @jit_func(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} { return %arg0 : tensor<f64> } func.func @setup() { quantum.init return } func.func @teardown() { quantum.finalize return } }