Source code for catalyst.debug.compiler_functions
# Copyright 2023-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains debug functions to interact with the compiler and compiled functions.
"""
import logging
import os
import platform
import re
import shutil
import subprocess
import catalyst
from catalyst.compiler import LinkerDriver
from catalyst.logging import debug_logger
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import filter_static_args, promote_arguments
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]@debug_logger
def get_compilation_stage(fn, stage):
"""Returns the intermediate representation of one of the recorded compilation
stages for a JIT-compiled function.
The stages are indexed by their Catalyst compilation pipeline name, which are either provided
by the user as a compilation option, or predefined in ``catalyst.compiler``.
All the available stages are:
- MILR: ``mlir``, ``HLOLoweringPass``, ``QuantumCompilationPass``, ``BufferizationPass``,
and ``MLIRToLLVMDialect``.
- LLVM: ``llvm_ir``, ``CoroOpt``, ``O2Opt``, ``Enzyme``, and ``last``.
Note that ``CoroOpt`` (Coroutine lowering), ``O2Opt`` (O2 optimization), and ``Enzyme``
(automatic differentiation) passes do not always happen. ``last`` denotes the stage
right before object file generation.
.. note::
In order to use this function, ``keep_intermediate=True`` must be
set in the :func:`~.qjit` decorator of the input function.
Args:
fn (QJIT): a qjit-decorated function
stage (str): string corresponding with the name of the stage to be printed
Returns:
str: output ir from the target compiler stage
.. seealso:: :doc:`/dev/debugging`, :func:`~.replace_ir`.
**Example**
.. code-block:: python
@qjit(keep_intermediate=True)
def func(x: float):
return x
>>> print(debug.get_compilation_stage(func, "HLOLoweringPass"))
module @func {
func.func public @jit_func(%arg0: tensor<f64>)
-> tensor<f64> attributes {llvm.emit_c_interface} {
return %arg0 : tensor<f64>
}
func.func @setup() {
quantum.init
return
}
func.func @teardown() {
quantum.finalize
return
}
}
"""
EvaluationContext.check_is_not_tracing("C interface cannot be generated from tracing context.")
if not isinstance(fn, catalyst.QJIT):
raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.")
return fn.compiler.get_output_of(stage, fn.workspace)
[docs]@debug_logger
def get_compilation_stages_groups(options):
"""Returns a list of tuples. The tuples correspond to the name
of the compilation stage and the list of passes within that stage.
"""
return options.get_stages()
[docs]@debug_logger
def get_cmain(fn, *args):
"""Return a C program that calls a jitted function with the provided arguments.
Args:
fn (QJIT): a qjit-decorated function
*args: argument values to use in the C program when invoking ``fn``
Returns:
str: A C program that can be compiled and linked with the current shared object.
"""
EvaluationContext.check_is_not_tracing("C interface cannot be generated from tracing context.")
if not isinstance(fn, catalyst.QJIT):
raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.")
requires_promotion = fn.jit_compile(args)
if requires_promotion:
dynamic_args = filter_static_args(args, fn.compile_options.static_argnums)
args = promote_arguments(fn.c_sig, dynamic_args)
return fn.compiled_function.get_cmain(*args)
[docs]@debug_logger
def replace_ir(fn, stage, new_ir):
r"""Replace the IR at any compilation stage that will be used the next time the function runs.
It is important that the function signature (inputs and outputs) for the next execution matches
that of the provided IR, or else the behaviour is undefined.
Available stages include:
- MILR: ``mlir``, ``HLOLoweringPass``, ``QuantumCompilationPass``, ``BufferizationPass``,
and ``MLIRToLLVMDialect``.
- LLVM: ``llvm_ir``, ``CoroOpt``, ``O2Opt``, ``Enzyme``, and ``last``.
Note that ``CoroOpt`` (Coroutine lowering), ``O2Opt`` (O2 optimization), and ``Enzyme``
(automatic differentiation) passes do not always happen. ``last`` denotes the stage
right before object file generation.
Args:
fn (QJIT): a qjit-decorated function
stage (str): Recompilation picks up after this stage.
new_ir (str): The replacement IR to use for recompilation.
.. seealso:: :doc:`/dev/debugging`, :func:`~.get_compilation_stage`.
**Example**
>>> from catalyst.debug import get_compilation_stage, replace_ir
>>> @qjit(keep_intermediate=True)
>>> def f(x):
... return x**2
>>> f(2.0) # just-in-time compile the function
4.0
Here we modify ``%2 = arith.mulf %in, %in_0 : f64`` to turn the square function into a cubic one:
>>> old_ir = get_compilation_stage(f, "HLOLoweringPass")
>>> new_ir = old_ir.replace(
... "%2 = arith.mulf %in, %in_0 : f64\n",
... "%t = arith.mulf %in, %in_0 : f64\n %2 = arith.mulf %t, %in_0 : f64\n"
... )
The recompilation starts after the given checkpoint stage:
>>> replace_ir(f, "HLOLoweringPass", new_ir)
>>> f(2.0)
8.0
"""
fn.overwrite_ir = new_ir
fn.compiler.options.checkpoint_stage = stage
fn.fn_cache.clear()
[docs]@debug_logger
def compile_executable(fn, *args):
"""Generate an executable binary for the native host architecture from a
:func:`~.qjit` decorated function with provided arguments.
Args:
fn (QJIT): a qjit-decorated function
*args: argument values to use in the C program when invoking ``fn``
Returns:
str: the path of output binary
**Example**
For example, considering the following function where we are
using :func:`~.print_memref` to print (at runtime) information about
variable ``y``:
.. code-block:: python
@qjit
def f(x):
y = x * x
debug.print_memref(y)
return y
>>> f(5)
MemRef: base@ = 0x64fc9dd5ffc0 rank = 0 offset = 0 sizes = [] strides = [] data =
25
Array(25, dtype=int64)
We can now use ``compile_executable`` to compile this function to a binary.
The executable will be saved in the directory for intermediate results if
``keep_intermediate=True``. Otherwise, the executable will appear in the Catalyst project
root.
>>> from catalyst.debug import compile_executable
>>> binary = compile_executable(f, 5)
>>> print(binary)
/path/to/executable
Executing this function from a shell environment:
.. code-block:: shell
$ /path/to/executable
MemRef: base@ = 0x64fc9dd5ffc0 rank = 0 offset = 0 sizes = [] strides = [] data =
25
"""
# if fn is not compiled, compile it first.
if not fn.compiled_function:
fn(*args)
f_name = str(fn.__name__)
workspace = str(fn.workspace) if fn.compile_options.keep_intermediate else os.getcwd()
main_c_file = workspace + "/main.c"
output_file = workspace + "/" + f_name + ".out"
shared_object_file = workspace + "/" + f_name + ".so"
# copy shared object to current directory
original_shared_object_file = str(fn.workspace) + "/" + f_name + ".so"
if not fn.compile_options.keep_intermediate:
shutil.copy(original_shared_object_file, shared_object_file)
options = fn.compiler.options
with open(main_c_file, "w", encoding="utf-8") as file:
file.write(get_cmain(fn, *args))
# Set search path mainly for gfortran and quadmath, which are located in the same
# directory as openblas from scipy.
if platform.system() == "Linux":
object_directory = "$ORIGIN"
else: # pragma: nocover
object_directory = "@loader_path"
# configure flags
link_so_flags = [
"-Wl,-rpath," + workspace,
shared_object_file,
f"-Wl,-rpath,{object_directory}",
]
LinkerDriver.run(main_c_file, outfile=output_file, flags=link_so_flags, options=options)
# Patch DLC prefix related to openblas
if platform.system() == "Darwin": # pragma: nocover
otool_path = shutil.which("otool")
install_name_tool_path = shutil.which("install_name_tool")
otool_result = subprocess.run(
[otool_path, "-l", shared_object_file], capture_output=True, text=True, check=True
)
dlc_pattern = r"/DLC[^)]+\.dylib"
dlc_matches = re.findall(dlc_pattern, otool_result.stdout)
for entry in dlc_matches:
dylib_pattern = r"/([^/]+\.dylib)$"
dylib_file_name = re.findall(dylib_pattern, entry)[-1]
new_entry = f"@rpath/{dylib_file_name}"
subprocess.run(
[install_name_tool_path, "-change", entry, new_entry, shared_object_file],
capture_output=True,
text=True,
check=True,
)
# Update the path of shared library if copy happens.
if not fn.compile_options.keep_intermediate:
subprocess.run(
[
install_name_tool_path,
"-change",
original_shared_object_file,
shared_object_file,
output_file,
],
capture_output=True,
text=True,
check=True,
)
return output_file
_modules/catalyst/debug/compiler_functions
Download Python script
Download Notebook
View on GitHub