catalyst.debug.replace_ir

replace_ir(fn, stage, new_ir)[source]

Replace the IR at any compilation stage that will be used the next time the function runs.

It is important that the function signature (inputs and outputs) for the next execution matches that of the provided IR, or else the behaviour is undefined.

Available stages include:

  • MILR: mlir, HLOLoweringPass, QuantumCompilationPass, BufferizationPass, and MLIRToLLVMDialect.

  • LLVM: llvm_ir, CoroOpt, O2Opt, Enzyme, and last.

Note that CoroOpt (Coroutine lowering), O2Opt (O2 optimization), and Enzyme (automatic differentiation) passes do not always happen. last denotes the stage right before object file generation.

Parameters
  • fn (QJIT) – a qjit-decorated function

  • stage (str) – Recompilation picks up after this stage.

  • new_ir (str) – The replacement IR to use for recompilation.

Example

>>> from catalyst.debug import get_compilation_stage, replace_ir
>>> @qjit(keep_intermediate=True)
>>> def f(x):
...     return x**2
>>> f(2.0)  # just-in-time compile the function
4.0

Here we modify %2 = arith.mulf %in, %in_0 : f64 to turn the square function into a cubic one:

>>> old_ir = get_compilation_stage(f, "HLOLoweringPass")
>>> new_ir = old_ir.replace(
...   "%2 = arith.mulf %in, %in_0 : f64\n",
...   "%t = arith.mulf %in, %in_0 : f64\n    %2 = arith.mulf %t, %in_0 : f64\n"
... )

The recompilation starts after the given checkpoint stage:

>>> replace_ir(f, "HLOLoweringPass", new_ir)
>>> f(2.0)
8.0