Source code for catalyst.debug.compiler_functions

# Copyright 2023-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains debug functions to interact with the compiler and compiled functions.
"""

import os

from jax.interpreters import mlir

import catalyst
from catalyst.compiled_functions import CompiledFunction
from catalyst.compiler import Compiler
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import filter_static_args, promote_arguments
from catalyst.utils.filesystem import WorkspaceManager





[docs]def get_cmain(fn, *args): """Return a C program that calls a jitted function with the provided arguments. Args: fn (QJIT): a qjit-decorated function *args: argument values to use in the C program when invoking ``fn`` Returns: str: A C program that can be compiled and linked with the current shared object. """ EvaluationContext.check_is_not_tracing("C interface cannot be generated from tracing context.") if not isinstance(fn, catalyst.QJIT): raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.") requires_promotion = fn.jit_compile(args) if requires_promotion: dynamic_args = filter_static_args(args, fn.compile_options.static_argnums) args = promote_arguments(fn.c_sig, dynamic_args) return fn.compiled_function.get_cmain(*args)
# pylint: disable=line-too-long
[docs]def compile_from_mlir(ir, compiler=None, compile_options=None): """Compile a Catalyst function to binary code from the provided MLIR. Args: ir (str): the MLIR to compile in string form compile_options: options to use during compilation Returns: CompiledFunction: A callable that manages the compiled shared library and its invocation. **Example** The main entry point of the program is required to start with ``catalyst.entry_point``, and the program is required to contain ``setup`` and ``teardown`` functions. .. code-block:: python ir = r\""" module @workflow { func.func public @catalyst.entry_point(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} { return %arg0 : tensor<f64> } func.func @setup() { quantum.init return } func.func @teardown() { quantum.finalize return } } \""" compiled_function = debug.compile_from_mlir(ir) >>> compiled_function(0.1) [0.1] """ EvaluationContext.check_is_not_tracing("Cannot compile from IR in tracing context.") if compiler is None: compiler = Compiler(compile_options) module_name = "debug_module" workspace_dir = os.getcwd() if compiler.options.keep_intermediate else None workspace = WorkspaceManager.get_or_create_workspace("debug_workspace", workspace_dir) shared_object, _llvm_ir, func_data = compiler.run_from_ir(ir, module_name, workspace) # Parse inferred function data, like name and return types. qfunc_name = func_data[0] with mlir.ir.Context(): result_types = [mlir.ir.RankedTensorType.parse(rt) for rt in func_data[1].split(",")] return CompiledFunction(shared_object, qfunc_name, result_types, compiler.options)