catalyst.debug.replace_ir¶
- replace_ir(fn, stage, new_ir)[source]¶
Replace the IR at any compilation stage that will be used the next time the function runs.
It is important that the function signature (inputs and outputs) for the next execution matches that of the provided IR, or else the behaviour is undefined.
Available stages include:
MILR:
mlir
,HLOLoweringPass
,QuantumCompilationPass
,BufferizationPass
, andMLIRToLLVMDialect
.LLVM:
llvm_ir
,CoroOpt
,O2Opt
,Enzyme
, andlast
.
Note that
CoroOpt
(Coroutine lowering),O2Opt
(O2 optimization), andEnzyme
(automatic differentiation) passes do not always happen.last
denotes the stage right before object file generation.- Parameters
fn (QJIT) – a qjit-decorated function
stage (str) – Recompilation picks up after this stage.
new_ir (str) – The replacement IR to use for recompilation.
See also
Example
>>> from catalyst.debug import get_compilation_stage, replace_ir >>> @qjit(keep_intermediate=True) >>> def f(x): ... return x**2 >>> f(2.0) # just-in-time compile the function 4.0
Here we modify
%2 = arith.mulf %in, %in_0 : f64
to turn the square function into a cubic one:>>> old_ir = get_compilation_stage(f, "HLOLoweringPass") >>> new_ir = old_ir.replace( ... "%2 = arith.mulf %in, %in_0 : f64\n", ... "%t = arith.mulf %in, %in_0 : f64\n %2 = arith.mulf %t, %in_0 : f64\n" ... )
The recompilation starts after the given checkpoint stage:
>>> replace_ir(f, "HLOLoweringPass", new_ir) >>> f(2.0) 8.0