Source code for catalyst.debug.callback
# Copyright 2023-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module enables runtime Python callbacks with no return values."""
from catalyst.api_extensions.callbacks import base_callback
[docs]
def callback(callback_fn):
"""Execute a Python
function with no return value and potential side effects from within a qjit-compiled function.
This makes it an easy entry point for debugging, for example via printing or logging at
runtime.
The callback function will be quantum just-in-time compiled alongside the rest of the
workflow, however it will be executed at runtime by the Python virtual machine.
This is in contrast to functions which get directly qjit-compiled by Catalyst, which will
be executed at runtime by machine-native code.
Args:
callback_fn (callable): The function to be used as a callback.
Any Python-based function is supported, as long as it does not
return anything (or returns None).
.. seealso:: :func:`.debug.print`, :func:`.pure_callback`.
**Example**
``debug.callback`` can be used as a decorator:
.. code-block:: python
@catalyst.debug.callback
def callback_fn(y):
print("Value of y =", y)
@qjit
def fn(x):
y = jnp.sin(x)
callback_fn(y)
return y ** 2
>>> fn(0.54)
Value of y = 0.5141359916531132
Array(0.26433582, dtype=float64)
>>> fn(1.52)
Value of y = 0.998710143975583
Array(0.99742195, dtype=float64)
It can also be used functionally:
.. code-block:: python
import logging
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
@qjit
@grad
def fn(x):
y = jnp.sin(x)
catalyst.debug.callback(lambda _: log.info("Value of y = %s", _))(y)
return y ** 2
>>> fn(0.543)
INFO:__main__:Value of y = 0.5167068002272901
Array(0.88476988, dtype=float64)
Note that during differentiation, the callback function will only be executed
during the forward pass.
"""
@base_callback
def closure(*args, **kwargs) -> None:
retval = callback_fn(*args, **kwargs)
if retval is not None:
raise ValueError("debug.callback is expected to return None")
return closure
_modules/catalyst/debug/callback
Download Python script
Download Notebook
View on GitHub