Source code for pennylane.pulse.parametrized_evolution
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=too-few-public-methods,function-redefined
"""
This file contains the ``ParametrizedEvolution`` operator.
"""
import warnings
from collections.abc import Sequence
from typing import Union
import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.ops import functions
from pennylane.typing import TensorLike
from .hardware_hamiltonian import HardwareHamiltonian
from .parametrized_hamiltonian import ParametrizedHamiltonian
has_jax = True
try:
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from .parametrized_hamiltonian_pytree import ParametrizedHamiltonianPytree
except ImportError as e:
has_jax = False
[docs]class ParametrizedEvolution(Operation):
r"""
ParametrizedEvolution(H, params=None, t=None, return_intermediate=False, complementary=False, id=None, **odeint_kwargs)
Parametrized evolution gate, created by passing a :class:`~.ParametrizedHamiltonian` to
the :func:`~.pennylane.evolve` function
For a time-dependent Hamiltonian of the form
.. math:: H(\{v_j\}, t) = H_\text{drift} + \sum_j f_j(v_j, t) H_j
it implements the corresponding time-evolution operator :math:`U(t_0, t_1)`, which is the
solution to the time-dependent Schrodinger equation.
.. math:: \frac{d}{dt}U(t) = -i H(\{v_j\}, t) U(t).
Under the hood, it is using a numerical ordinary differential equation (ODE) solver. It requires ``jax``,
and will not work with other machine learning frameworks typically encountered in PennyLane.
Args:
H (ParametrizedHamiltonian): Hamiltonian to evolve
params (Optional[list]): trainable parameters, passed as list where each element corresponds to
the parameters of a scalar-valued function of the Hamiltonian being evolved.
t (Union[float, List[float]]): If a float, it corresponds to the duration of the evolution.
If a list of floats, the ODE solver will use all the provided time values, and
perform intermediate steps if necessary. It is recommended to just provide a start
and end time unless matrices of the time evolution at intermediate times need
to be computed. Note that such absolute times only have meaning within an instance of
``ParametrizedEvolution`` and will not affect other gates.
To return the matrix at intermediate evolution times, activate ``return_intermediate``
(see below).
id (str or None): id for the scalar product operator. Default is None.
Keyword Args:
atol (float, optional): Absolute error tolerance for the ODE solver. Defaults to ``1.4e-8``.
rtol (float, optional): Relative error tolerance for the ODE solver. The error is estimated
from comparing a 4th and 5th order Runge-Kutta step in the Dopri5 algorithm. This error
is guaranteed to stay below ``tol = atol + rtol * abs(y)`` through adaptive step size
selection. Defaults to 1.4e-8.
mxstep (int, optional): maximum number of steps to take for each timepoint for the
ODE solver. Defaults to ``jnp.inf``.
hmax (float, optional): maximum step size allowed for the ODE solver. Defaults to ``jnp.inf``.
return_intermediate (bool): Whether or not the ``matrix`` method returns all intermediate
solutions of the time evolution at the times provided in ``t = [t_0,...,t_f]``.
If ``False`` (the default), only the matrix for the full time evolution is returned.
If ``True``, all solutions including the initial condition are returned;
when used in a circuit, this results in ``ParametrizedEvolution`` being a broadcasted
operation, see the usage details ("Computing intermediate time evolution") below.
complementary (bool): Whether or not to compute the complementary time evolution when using
``return_intermediate=True`` (ignored otherwise).
If ``False`` (the default), the usual solutions to the Schrodinger equation
:math:`\{U(t_0, t_0), U(t_0, t_1),\dots, U(t_0, t_f)\}` are computed,
where :math:`t_i` are the additional times provided in ``t``.
If ``True``, the *remaining* time evolution to :math:`t_f` is computed instead, returning
:math:`\{U(t_0, t_f), U(t_1, t_f),\dots, U(t_{f-1}, t_f), U(t_f, t_f)\}`.
dense (bool): Whether the evolution should use dense matrices. Per default, this is decided by
the number of wires, i.e. ``dense = len(wires) < 3``.
.. warning::
The :class:`~.ParametrizedHamiltonian` must be Hermitian at all times. This is not explicitly checked
when creating a :class:`~.ParametrizedEvolution` from the :class:`~.ParametrizedHamiltonian`.
**Example**
To create a :class:`~.ParametrizedEvolution`, we first define a :class:`~.ParametrizedHamiltonian`
describing the system, and then pass it to :func:`~pennylane.evolve`:
.. code-block:: python
from jax import numpy as jnp
f1 = lambda p, t: jnp.sin(p * t)
H = f1 * qml.Y(0)
ev = qml.evolve(H)
The initial :class:`~.ParametrizedEvolution` does not have set parameters, and so will not
have a matrix defined. To obtain an Operator with a matrix, it must be passed parameters and
a time interval:
>>> qml.matrix(ev([1.2], t=[0, 4]))
Array([[ 0.72454906+0.j, -0.6892243 +0.j],
[ 0.6892243 +0.j, 0.72454906+0.j]], dtype=complex64)
The parameters can be updated by calling the :class:`~.ParametrizedEvolution` again with different inputs.
When calling the :class:`~.ParametrizedEvolution`, keyword arguments can be passed to specify
behaviour of the ODE solver.
The :class:`~.ParametrizedEvolution` can be implemented in a QNode:
.. code-block:: python
import jax
jax.config.update("jax_enable_x64", True)
dev = qml.device("default.qubit", wires=1)
@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(params):
qml.evolve(H)(params, t=[0, 10])
return qml.expval(qml.Z(0))
>>> params = [1.2]
>>> circuit(params)
Array(0.96632722, dtype=float64)
>>> jax.grad(circuit)(params)
[Array(2.35694829, dtype=float64)]
.. note::
In the example above, the decorator ``@jax.jit`` is used to compile this execution just-in-time. This means
the first execution will typically take a little longer with the benefit that all following executions
will be significantly faster, see the jax docs on jitting. JIT-compiling is optional, and one can remove
the decorator when only single executions are of interest.
.. warning::
The time argument ``t`` corresponds to the time window used to compute the scalar-valued
functions present in the :class:`ParametrizedHamiltonian` class. Consequently, executing
two ``ParametrizedEvolution`` operators using the same time window does not mean both
operators are executed simultaneously, but rather that both evaluate their respective
scalar-valued functions using the same time window. See Usage Details.
.. note::
Using ``return_intermediate`` in a quantum circuit leads to broadcasted execution,
which can lead to unintended additional computational cost.
Also consider the usage details below.
.. details::
:title: Usage Details
The parameters used when calling the ``ParametrizedEvolution`` are expected to have the same order
as the functions used to define the :class:`~.ParametrizedHamiltonian`. For example:
.. code-block:: python3
def f1(p, t):
return jnp.sin(p[0] * t**2) + p[1]
def f2(p, t):
return p * jnp.cos(t)
H = 2 * qml.X(0) + f1 * qml.Y(0) + f2 * qml.Z(0)
ev = qml.evolve(H)
>>> params = [[4.6, 2.3], 1.2]
>>> qml.matrix(ev(params, t=0.5))
Array([[-0.18354285-0.26303384j, -0.7271658 -0.606923j ],
[ 0.7271658 -0.606923j , -0.18354285+0.26303384j]], dtype=complex64)
Internally the solver is using ``f1([4.6, 2.3], t)`` and ``f2(1.2, t)`` at each timestep when
finding the matrix.
In the case where we have defined two Hamiltonians, ``H1`` and ``H2``, and we want to find a time evolution
where the two are driven simultaneously for some period of time, it is important that both are included in
the same call of :func:`~.pennylane.evolve`.
For non-commuting operations, applying ``qml.evolve(H1)(params, t=[0, 10])`` followed by
``qml.evolve(H2)(params, t=[0, 10])`` will **not** apply the two pulses simultaneously, despite the overlapping
time window. Instead, it will execute ``H1`` in the ``[0, 10]`` time window, and then subsequently execute
``H2`` using the same time window to calculate the evolution, but without taking into account how the time
evolution of ``H1`` affects the evolution of ``H2`` and vice versa.
Consider two non-commuting :class:`ParametrizedHamiltonian` objects:
.. code-block:: python
from jax import numpy as jnp
ops = [qml.X(0), qml.Y(1), qml.Z(2)]
coeffs = [lambda p, t: p for _ in range(3)]
H1 = qml.dot(coeffs, ops) # time-independent parametrized Hamiltonian
ops = [qml.Z(0), qml.Y(1), qml.X(2)]
coeffs = [lambda p, t: p * jnp.sin(t) for _ in range(3)]
H2 = qml.dot(coeffs, ops) # time-dependent parametrized Hamiltonian
The evolutions of the :class:`ParametrizedHamiltonian` can be used in a QNode.
.. code-block:: python
dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev, interface="jax")
def circuit1(params):
qml.evolve(H1)(params, t=[0, 10])
qml.evolve(H2)(params, t=[0, 10])
return qml.expval(qml.Z(0) @ qml.Z(1) @ qml.Z(2))
@qml.qnode(dev, interface="jax")
def circuit2(params):
qml.evolve(H1 + H2)(params, t=[0, 10])
return qml.expval(qml.Z(0) @ qml.Z(1) @ qml.Z(2))
In ``circuit1``, the two Hamiltonians are evolved over the same time window, but inside different operators.
In ``circuit2``, we add the two to form a single :class:`~.ParametrizedHamiltonian`. This will combine the
two so that the expected parameters will be ``params1 + params2`` (as an addition of ``list``).
They can then be included inside a single :class:`~.ParametrizedEvolution`.
The resulting evolutions of ``circuit1`` and ``circuit2`` are **not** identical:
>>> params = jnp.array([1., 2., 3.])
>>> circuit1(params)
Array(-0.01542578, dtype=float64)
>>> params = jnp.concatenate([params, params]) # H1 + H2 requires 6 parameters!
>>> circuit2(params)
Array(-0.78235162, dtype=float64)
Here, ``circuit1`` is not executing the evolution of ``H1`` and ``H2`` simultaneously, but rather
executing ``H1`` in the ``[0, 10]`` time window and then executing ``H2`` with the same time window,
without taking into account how the time evolution of ``H1`` affects the evolution of ``H2`` and vice versa!
One can also provide a list of time values that the ODE solver will use to calculate the evolution of the
``ParametrizedHamiltonian``. Keep in mind that the ODE solver uses an adaptive step size, thus
it might use additional intermediate time values.
.. code-block:: python
t = jnp.arange(0., 10.1, 0.1)
@qml.qnode(dev, interface="jax")
def circuit(params):
qml.evolve(H1 + H2)(params, t=t)
return qml.expval(qml.Z(0) @ qml.Z(1) @ qml.Z(2))
>>> circuit(params)
Array(-0.78235162, dtype=float64)
>>> jax.grad(circuit)(params)
Array([-4.80708632, 3.70323783, -1.32958799, -2.40642477, 0.68105214,
-0.52269657], dtype=float64)
Given that we used the same time window (``[0, 10]``), the results are the same as before.
**Computing intermediate time evolution**
As discussed above, the ODE solver will evaluate the Schrodinger equation at
intermediate times in any case. By passing additional time values explicitly in the time
window ``t`` and setting ``return_intermediate=True``, the ``matrix`` method will
return the matrices for the intermediate time evolutions as well:
.. math::
\{U(t_0, t_0), U(t_0, t_1), \dots, U(t_0, t_{f-1}), U(t_0, t_f)\}.
The first entry here is the initial condition :math:`U(t_0, t_0)=1`. For a simple
time-dependent single-qubit Hamiltonian, this feature looks like the following:
.. code-block:: python
ops = [qml.Z(0), qml.Y(0), qml.X(0)]
coeffs = [lambda p, t: p * jnp.cos(t) for _ in range(3)]
H = qml.dot(coeffs, ops) # time-dependent parametrized Hamiltonian
param = [jnp.array(0.2), jnp.array(1.1), jnp.array(-1.3)]
time = jnp.linspace(0.1, 0.4, 6) # Six time points from 0.1 to 0.4
ev = qml.evolve(H)(param, time, return_intermediate=True)
>>> ev_mats = ev.matrix()
>>> ev_mats.shape
(6, 2, 2)
Note that the broadcasting axis has length ``len(time)`` and is the first axis of the
returned tensor.
We may use this feature within QNodes executed on a simulator, returning the
measurements for all intermediate time steps:
.. code-block:: python
dev = qml.device("default.qubit", wires=1)
@qml.qnode(dev, interface="jax")
def circuit(param, time):
qml.evolve(H)(param, time, return_intermediate=True)
return qml.probs(wires=[0])
>>> circuit(param, time)
Array([[1. , 0. ],
[0.98977406, 0.01022594],
[0.95990416, 0.04009584],
[0.91236167, 0.08763833],
[0.84996865, 0.15003133],
[0.77614817, 0.22385181]], dtype=float64)
**Computing complementary time evolution**
When using ``return_intermediate=True``, the partial time evolutions share the *initial*
time :math:`t_0`. For some applications, however, it may be useful to compute the
complementary time evolutions, i.e. the partial evolutions that share the *final* time
:math:`t_f`. This can be activated by setting ``complementary=True``, which will make
``ParametrizedEvolution.matrix`` return the matrices
.. math::
\{U(t_0, t_f), U(t_1, t_f), \dots, U(t_f, t_f)\}.
Using the Hamiltonian from the example above:
>>> complementary_ev = ev(param, time, return_intermediate=True, complementary=True)
>>> comp_ev_mats = complementary_ev.matrix()
>>> comp_ev_mats.shape
(6, 2, 2)
If we multiply the matrices computed before with ``complementary=False`` with these
complementary evolution matrices from the left, we obtain the full time evolution,
which we can check by comparing to the last entry of ``ev_mats``:
>>> for mat, c_mat in zip(ev_mats, comp_ev_mats):
... print(qml.math.allclose(c_mat @ mat, ev_mats[-1]))
True
True
True
True
True
True
"""
_name = "ParametrizedEvolution"
num_wires = AnyWires
grad_method = "A"
# pylint: disable=too-many-arguments
def __init__(
self,
H: ParametrizedHamiltonian,
params: list = None,
t: Union[float, list[float]] = None,
return_intermediate: bool = False,
complementary: bool = False,
dense: bool = None,
id=None,
**odeint_kwargs,
):
if not all(op.has_matrix or isinstance(op, qml.ops.Hamiltonian) for op in H.ops):
raise ValueError(
"All operators inside the parametrized hamiltonian must have a matrix defined."
)
self._has_matrix = params is not None and t is not None
self.H = H
self.odeint_kwargs = odeint_kwargs
if t is None:
self.t = None
else:
if isinstance(t, (list, tuple)):
t = qml.math.stack(t)
self.t = qml.math.cast(qml.math.stack([0.0, t]) if qml.math.ndim(t) == 0 else t, float)
if complementary and not return_intermediate:
warnings.warn(
"The keyword argument complementary does not have any effect if "
"return_intermediate is set to False."
)
if params is None:
params = []
else:
if not isinstance(H, HardwareHamiltonian) and len(params) != len(H.coeffs_parametrized):
raise ValueError(
"The length of the params argument and the number of scalar-valued functions "
f"in the Hamiltonian must be the same. Received {len(params)=} parameters but "
f"expected {len(H.coeffs_parametrized)} parameters."
)
super().__init__(*params, wires=H.wires, id=id)
self.hyperparameters["return_intermediate"] = return_intermediate
self.hyperparameters["complementary"] = complementary
self._check_time_batching()
self.dense = len(self.wires) < 3 if dense is None else dense
def __call__(
self, params, t, return_intermediate=None, complementary=None, dense=None, **odeint_kwargs
):
if not has_jax:
raise ImportError(
"Module jax is required for the ``ParametrizedEvolution`` class. "
"You can install jax via: pip install jax"
)
# Need to cast all elements inside params to `jnp.arrays` to make sure they are not cast
# to `np.arrays` inside `Operator.__init__`
params = [jnp.array(p) for p in params]
# Inherit return_intermediate and complementary from self if not provided.
if return_intermediate is None:
return_intermediate = self.hyperparameters["return_intermediate"]
if complementary is None:
complementary = self.hyperparameters["complementary"]
if dense is None:
dense = self.dense
odeint_kwargs = {**self.odeint_kwargs, **odeint_kwargs}
if qml.QueuingManager.recording():
qml.QueuingManager.remove(self)
return ParametrizedEvolution(
H=self.H,
params=params,
t=t,
return_intermediate=return_intermediate,
complementary=complementary,
dense=dense,
id=self.id,
**odeint_kwargs,
)
def _check_time_batching(self):
"""Check whether the time argument is broadcasted/batched."""
if not self.hyperparameters["return_intermediate"] or self.t is None:
return
# Subtract 1 because the identity is never returned by `matrix`. If `complementary=True`,
# subtract an additional 1 because the full time evolution is not being returned.
self._batch_size = self.t.shape[0]
[docs] def map_wires(self, wire_map):
mapped_op = super().map_wires(wire_map)
mapped_op.H = self.H.map_wires(wire_map)
return mapped_op
@property
def hash(self):
"""int: Integer hash that uniquely represents the operator."""
return hash(
(
str(self.name),
tuple(self.wires.tolist()),
str(self.hyperparameters.values()),
str(self.t),
str(self.data),
self.H,
str(self.odeint_kwargs.values()),
)
)
def _flatten(self):
data = self.data
odeint_kwargs_tuples = tuple((key, value) for key, value in self.odeint_kwargs.items())
t = self.t if self.t is None else tuple(self.t)
metadata = (
t,
self.H,
self.hyperparameters["return_intermediate"],
self.hyperparameters["complementary"],
self.dense,
odeint_kwargs_tuples,
)
return data, metadata
@classmethod
def _unflatten(cls, data, metadata):
t, H, return_intermediate, complementary, dense, odeint_kwargs = metadata
return cls(
H,
None if len(data) == 0 else data,
t,
return_intermediate=return_intermediate,
complementary=complementary,
dense=dense,
**dict(odeint_kwargs),
)
# pylint: disable=arguments-renamed, invalid-overridden-method
@property
def has_matrix(self):
return self._has_matrix
# pylint: disable=import-outside-toplevel
[docs] def matrix(self, wire_order=None):
if not has_jax:
raise ImportError(
"Module jax is required for the ``ParametrizedEvolution`` class. "
"You can install jax via: pip install jax"
)
if not self.has_matrix:
raise ValueError(
"The parameters and the time window are required to compute the matrix. "
"You can update its values by calling the class: EV(params, t)."
)
y0 = jnp.eye(2 ** len(self.wires), dtype=complex)
with jax.ensure_compile_time_eval():
H_jax = ParametrizedHamiltonianPytree.from_hamiltonian(
self.H, dense=self.dense, wire_order=self.wires
)
def fun(y, t):
"""dy/dt = -i H(t) y"""
return (-1j * H_jax(self.data, t=t)) @ y
mat = odeint(fun, y0, self.t, **self.odeint_kwargs)
if self.hyperparameters["return_intermediate"] and self.hyperparameters["complementary"]:
# Compute U(t_0, t_f)@U(t_0, t_i)^\dagger, where i indexes the first axis of mat
mat = qml.math.tensordot(mat[-1], qml.math.conj(mat), axes=[[1], [-1]])
# The previous line leaves the axis indexing the t_i as second, so we move it up
mat = qml.math.moveaxis(mat, 1, 0)
elif not self.hyperparameters["return_intermediate"]:
mat = mat[-1]
return qml.math.expand_matrix(mat, wires=self.wires, wire_order=wire_order)
[docs] def label(self, decimals=None, base_label=None, cache=None):
r"""A customizable string representation of the operator.
Args:
decimals=None (int): If ``None``, no parameters are included. Else,
specifies how to round the parameters.
base_label=None (str): overwrite the non-parameter component of the label
cache=None (dict): dictionary that carries information between label calls
in the same drawing
Returns:
str: label to use in drawings
**Example:**
>>> H = qml.X(1) + qml.pulse.constant * qml.Y(0) + jnp.polyval * qml.Y(1)
>>> params = [0.2, [1, 2, 3]]
>>> op = qml.evolve(H)(params, t=2)
>>> cache = {'matrices': []}
>>> op.label()
"Parametrized\nEvolution"
>>> op.label(decimals=2, cache=cache)
"Parametrized\nEvolution\n(p=[0.20,M0], t=[0. 2.])"
>>> op.label(base_label="my_label")
"my_label"
>>> op.label(decimals=2, base_label="my_label", cache=cache)
"my_label\n(p=[0.20,M0], t=[0. 2.])"
Array-like parameters are stored in ``cache['matrices']``.
"""
op_label = base_label or "Parametrized\nEvolution"
if self.num_params == 0:
return op_label
if decimals is None:
return op_label
params = self.parameters
has_cache = cache and isinstance(cache.get("matrices", None), list)
if any(qml.math.ndim(p) for p in params) and not has_cache:
return op_label
def _format_number(x):
return format(qml.math.toarray(x), f".{decimals}f")
def _format_arraylike(x):
for i, mat in enumerate(cache["matrices"]):
if qml.math.shape(x) == qml.math.shape(mat) and qml.math.allclose(x, mat):
return f"M{i}"
mat_num = len(cache["matrices"])
cache["matrices"].append(x)
return f"M{mat_num}"
param_strings = [_format_arraylike(p) if p.shape else _format_number(p) for p in params]
p = ",".join(s for s in param_strings)
return f"{op_label}\n(p=[{p}], t={self.t})"
@functions.bind_new_parameters.register
def _bind_new_parameters_parametrized_evol(op: ParametrizedEvolution, params: Sequence[TensorLike]):
return ParametrizedEvolution(
op.H,
params=params,
t=op.t,
return_intermediate=op.hyperparameters["return_intermediate"],
complementary=op.hyperparameters["complementary"],
dense=op.dense,
**op.odeint_kwargs,
)
_modules/pennylane/pulse/parametrized_evolution
Download Python script
Download Notebook
View on GitHub