Source code for pennylane.pulse.convenience_functions

# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains convenience functions for pulse programming."""
from collections.abc import Callable
from typing import Optional, Union

import numpy as np

has_jax = True
try:
    import jax.numpy as jnp
except ImportError:
    has_jax = False


# pylint: disable=unused-argument
[docs]def constant(scalar, time): r"""Returns the given ``scalar``, for use in defining a :class:`~.ParametrizedHamiltonian` with a trainable coefficient. Args: scalar (float): the scalar to be returned time (float): Time. This argument is not used, but is required to match the call signature of :class:`~.ParametrizedHamiltonian`. Returns: float: The input ``scalar``. This function is mainly used to build a :class:`~.ParametrizedHamiltonian` that can be differentiated with respect to its time-independent term. It is an alias for ``lambda scalar, t: scalar``. **Example** The ``constant`` function can be used to create a parametrized Hamiltonian >>> H = qml.pulse.constant * qml.X(0) When calling the parametrized Hamiltonian, ``constant`` will always return the input parameter >>> params = [5] >>> H(params, t=8) 5 * X(0) >>> H(params, t=5) 5 * X(0) We can differentiate the parametrized Hamiltonian with respect to the constant parameter: .. code-block:: python import jax jax.config.update("jax_enable_x64", True) dev = qml.device("default.qubit") @qml.qnode(dev, interface="jax") def circuit(params): qml.evolve(H)(params, t=2) return qml.expval(qml.Z(0)) >>> params = jnp.array([5.0]) >>> circuit(params) Array(0.40808193, dtype=float64) >>> jax.grad(circuit)(params) Array([-3.65178003], dtype=float64) """ return scalar
[docs]def rect( x: Union[float, Callable], windows: Optional[Union[tuple[float], list[tuple[float]]]] = None ): """Takes a scalar or a scalar-valued function, x, and applies a rectangular window to it, such that the returned function is x inside the window and 0 outside it. Creates a callable for defining a :class:`~.ParametrizedHamiltonian`. Args: x (Union[float, Callable]): either a scalar, or a function that accepts two arguments: the trainable parameters and time windows (Union[Tuple[float], List[Tuple[float]]]): List of tuples containing time windows where ``x`` is evaluated. If ``None`` it is always evaluated. Defaults to ``None``. Returns: callable: A callable ``f(p, t)`` which evaluates the given function/scalar ``x`` inside the time windows defined in ``windows``, and otherwise returns 0. .. note:: If ``x`` is a function, it must accept two arguments: the trainable parameters and time. The primary use of ``rect`` is for numerical simulations via :class:`ParametrizedEvolution`, which assumes ``t`` to be a single scalar argument. If you need to efficiently compute multiple times, you need to broadcast over ``t`` via `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`_ (see examples below). **Example** Here we use :func:`~.rect` to create a parametrized coefficient that has a value of ``0`` outside the time interval ``t=(1, 7)``, and is defined by ``jnp.polyval(p, t)`` within the interval: .. code-block:: python3 def f(p, t): return jnp.polyval(p, t) p = jnp.array([1, 2, 3]) time = jnp.linspace(0, 10, 1000) windows = [(1, 7)] windowed_f = qml.pulse.rect(f, windows=windows) y1 = f(p, time) y2 = jax.vmap(windowed_f, (None, 0))(p, time) plt.plot(time, y1, label=f"polyval(p={p}, t)") plt.plot(time, y2, label=f"rect(polyval, windows={windows})(p={p}, t)") plt.legend() plt.xlabel("t") plt.show() .. figure:: ../../_static/pulse/rect_example.png :align: center :width: 60% :target: javascript:void(0); Note that in order to efficiently create ``y2``, we broadcasted ``windowed_f`` over the time argument using `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`_. ``rect`` can be used to create a :class:`~.ParametrizedHamiltonian` in the following way: >>> H = qml.pulse.rect(jnp.polyval, windows=[(1, 7)]) * qml.X(0) The resulting Hamiltonian will be non-zero only inside the window. >>> H([[1, 3]], t=2) # inside the window 5.0 * X(0) >>> H([[1, 3]], t=0.5 ) # outside the window 0.0 * X(0) It is also possible to define multiple windows for the same function: .. code-block:: python windows = [(1, 7), (9, 14)] H = qml.pulse.rect(jnp.polyval, windows) * qml.X(0) When calling the :class:`.ParametrizedHamiltonian`, ``rect`` will evaluate the given function only inside the time windows, and otherwise return 0. One can also pass a scalar to the ``rect`` function >>> H = qml.pulse.rect(10, (1, 7)) * qml.X(0) In this case, ``rect`` will return the given scalar only when the time is inside the provided time windows >>> params = [None] # the parameter value won't be used! >>> H(params, t=8) 0.0 * X(0) >>> H(params, t=5) 10.0 * X(0) """ if not has_jax: raise ImportError( "Module jax is required for any pulse-related convenience function. " "You can install jax via: pip install jax==0.4.10 jaxlib==0.4.10" ) if windows is not None: is_nested = any(hasattr(w, "__len__") for w in windows) single_window = len(windows) == 2 and not is_nested if single_window: windows = [windows] elif not all(hasattr(w, "__len__") and len(w) == 2 for w in windows): raise ValueError("At least one provided window is not a two-element sequence.") if not callable(x): def _f(_, __): return jnp.array(x, dtype=float) else: _f = x def f(p, t): p = jnp.array(p, dtype=float) # if p is an integer, f(p, t) will be cast to an integer if windows is not None: ti, tf = zip(*windows) ti, tf = jnp.array(ti), jnp.array(tf) return jnp.where(jnp.any((t >= ti) & (t <= tf)), _f(p, t), 0) return _f(p, t) return f
[docs]def pwc(timespan): """Takes a time span and returns a callable for creating a function that is piece-wise constant in time. The returned function takes arguments ``(p, t)``, where ``p`` is an array that defines the bin values for the function. Creates a callable for defining a :class:`~.ParametrizedHamiltonian`. Args: timespan(Union[float, tuple(float, float)]): The time span defining the region where the function is non-zero. If an integer is provided, the time span is defined as ``(0, timespan)``. Returns: callable: a function that takes two arguments: an ``array`` of trainable parameters, and a ``float`` defining the time at which the function is evaluated. The convenience function ``pwc`` essentially implements .. code-block:: python3 def pwc(timespan): def wrapped(p, t): return p[int(t/len(p))] return wrapped This function can be used to create a parametrized coefficient function that is piece-wise constant within the interval ``t``, and 0 outside it. When creating the callable, only the time span is passed. The number of bins and values for the parameters are set when ``params`` is passed to the callable. Each bin value is set by an element of the ``params`` array. The variable ``t`` is used to select the value of the parameter array corresponding to the specified time, based on the assigned binning. .. code-block:: python3 params = jnp.array([1, 2, 3, 4, 5]) time = jnp.linspace(0, 10, 1000) timespan=(2, 7) y = qml.pulse.pwc(timespan)(params, time) plt.plot(time, y, label=f"params={params}, timespan={timespan}") plt.legend() plt.show() .. figure:: ../../_static/pulse/pwc_example.png :align: center :width: 60% :target: javascript:void(0); .. warning:: The final time in the time span indicates the time at which the function output switches from params[-1] to 0. As such, the above function returns ``5`` for a time slightly smaller than the final time in ``timespan``, but it returns ``0`` for the final time itself: >>> qml.pulse.pwc(timespan)(params, 6.999999) Array(5., dtype=float32) >>> qml.pulse.pwc(timespan)(params, 7.) Array(0., dtype=float32) **Example** >>> timespan = (2, 7) >>> f1 = qml.pulse.pwc(timespan) >>> H = f1 * qml.X(0) The resulting function ``f1`` has the call signature ``f1(params, t)``. If passed an array of parameters and a time, it will assign the array as the constants in the piece-wise function, and select the constant corresponding to the specified time, based on the time interval defined by ``timespan``. In the following example, passing an array to ``pwc((2, 7))`` evenly distributes the array values in the interval ``t=2`` to ``t=7``. The time ``t`` is then used to select one of the array values based on this distribution. >>> H(params=[[11, 12, 13, 14, 15]], t=2.3) 11.0 * X(0) >>> H(params=[[11, 12, 13, 14, 15]], t=2.5) # different time, same bin, same result 11.0 * X(0) >>> H(params=[[11, 12, 13, 14, 15]], t=3.1) # next bin 12.0 * X(0) >>> H(params=[[11, 12, 13, 14, 15]], t=8) # outside the window returns 0 0.0 * X(0) """ if not has_jax: raise ImportError( "Module jax is required for any pulse-related convenience function. " "You can install jax via: pip install jax==0.4.3 jaxlib==0.4.3" ) if isinstance(timespan, (tuple, list)): t0, t1 = timespan else: t0 = 0 t1 = timespan def func(params, t): num_bins = len(params) params = jnp.concatenate([jnp.array(params), jnp.zeros(1)]) # get idx from timestamp, then set idx=0 if idx is out of bounds for the array idx = num_bins / (t1 - t0) * (t - t0) idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1) return params[idx] return func
[docs]def pwc_from_function(timespan, num_bins): """ Decorates a smooth function, creating a piece-wise constant function that approximates it. Creates a callable for defining a :class:`~.ParametrizedHamiltonian`. Args: timespan(Union[float, tuple(float)]): The time span defining the region where the function is non-zero. If a ``float`` is provided, the time span is defined as ``(0, timespan)``. num_bins(int): number of bins for time-binning the function Returns: callable: a function that takes some smooth function ``f(params, t)`` and converts it to a piece-wise constant function spanning time ``t`` in ``num_bins`` bins. **Example** .. code-block:: python3 def smooth_function(params, t): return params[0] * t + params[1] timespan = 10 num_bins = 10 binned_function = qml.pulse.pwc_from_function(timespan, num_bins)(smooth_function) >>> binned_function([2, 4], 3), smooth_function([2, 4], 3) # t = 3 (Array(10.666667, dtype=float32), 10) >>> binned_function([2, 4], 3.2), smooth_function([2, 4], 3.2) # t = 3.2 (Array(10.666667, dtype=float32), 10.4) >>> binned_function([2, 4], 4.5), smooth_function([2, 4], 4.5) # t = 4.5 (Array(12.888889, dtype=float32), 13.0) The same effect can be achieved by decorating the smooth function: .. code-block:: python from pennylane.pulse.convenience_functions import pwc_from_function @pwc_from_function(timespan, num_bins) def fn(params, t): return params[0] * t + params[1] >>> fn([2, 4], 3) Array(10.666667, dtype=float32) """ if not has_jax: raise ImportError( "Module jax is required for any pulse-related convenience function. " "You can install jax via: pip install jax==0.4.3 jaxlib==0.4.3" ) if isinstance(timespan, tuple): t0, t1 = timespan else: t0 = 0 t1 = timespan def inner(fn): time_bins = np.linspace(t0, t1, num_bins) def wrapper(params, t): constants = jnp.array(list(fn(params, time_bins)) + [0]) idx = num_bins / (t1 - t0) * (t - t0) # check interval is within 0 to num_bins, then cast to int, to avoid casting outcomes between -1 and 0 as 0 idx = jnp.where((idx >= 0) & (idx <= num_bins), jnp.array(idx, dtype=int), -1) return constants[idx] return wrapper return inner