Source code for pennylane.gradients.pulse_gradient
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains functions for computing the stochastic parameter-shift gradient
of pulse sequences in a qubit-based quantum tape.
"""
import warnings
from functools import partial
import numpy as np
import pennylane as qml
from pennylane import transform
from pennylane.pulse import HardwareHamiltonian, ParametrizedEvolution
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.typing import PostprocessingFn
from .general_shift_rules import eigvals_to_frequencies, generate_shift_rule
from .gradient_transform import (
_all_zero_grad,
_no_trainable_grad,
assert_no_state_returns,
assert_no_trainable_tape_batching,
assert_no_variance,
choose_trainable_params,
find_and_validate_gradient_methods,
reorder_grads,
)
from .parameter_shift import _make_zero_rep
has_jax = True
try:
import jax
import jax.numpy as jnp
except ImportError:
has_jax = False
def _assert_has_jax(transform_name):
"""Check that JAX is installed and imported correctly, otherwise raise an error.
Args:
transform_name (str): Name of the gradient transform that queries the return system
"""
if not has_jax: # pragma: no cover
raise ImportError(
f"Module jax is required for the {transform_name} gradient transform. "
"You can install jax via: pip install jax jaxlib"
)
def raise_pulse_diff_on_qnode(transform_name):
"""Raises an error as the gradient transform with the provided name does
not support direct application to QNodes.
"""
msg = (
f"Applying the {transform_name} gradient transform to a QNode directly is currently "
"not supported. Please use differentiation via a JAX entry point "
"(jax.grad, jax.jacobian, ...) instead.",
UserWarning,
)
raise NotImplementedError(msg)
def _split_evol_ops(op, ob, tau):
r"""Randomly split a ``ParametrizedEvolution`` with respect to time into two operations and
insert a Pauli rotation using a given Pauli word and rotation angles :math:`\pm\pi/2`.
This yields two groups of three operations each.
Args:
op (ParametrizedEvolution): operation to split up.
ob (`~.Operator`): generating Hamiltonian term to insert the parameter-shift rule for.
tau (float or tensor_like): split-up time(s). If multiple times are passed, the split-up
operations are set up to return intermediate time evolution results, leading to
broadcasting effectively.
Returns:
tuple[list[`~.Operation`]]: The split-time evolution, expressed as three operations in the
inner lists. The number of tuples is given by the number of shifted terms in the
parameter-shift rule of the generating Hamiltonian term ``ob``.
tensor_like: Coefficients of the parameter-shift rule of the provided generating Hamiltonian
term ``ob``.
"""
t0, *_, t1 = op.t
# If there are multiple values for tau, use broadcasting
if bcast := qml.math.ndim(tau) > 0:
# With broadcasting, create a sorted array of [t_0, *sorted(taus), t_1]
# Use this array for both, the pulse before and after the inserted operation.
# The way we slice the resulting tape results later on accomodates for the additional
# time points t_0 and t_1 in the array.
tau = jnp.sort(tau)
before_t = jnp.concatenate([jnp.array([t0]), tau, jnp.array([t1])])
after_t = before_t.copy()
else:
# Create a time interval from start to split and one from split to end
before_t = jax.numpy.array([t0, tau])
after_t = jax.numpy.array([tau, t1])
if qml.pauli.is_pauli_word(ob):
prefactor = qml.pauli.pauli_word_prefactor(ob)
word = qml.pauli.pauli_word_to_string(ob)
insert_ops = [qml.PauliRot(shift, word, ob.wires) for shift in [np.pi / 2, -np.pi / 2]]
coeffs = [prefactor, -prefactor]
else:
with warnings.catch_warnings():
if len(ob.wires) <= 4:
warnings.filterwarnings(
"ignore", ".*the eigenvalues will be computed numerically.*"
)
eigvals = qml.eigvals(ob)
coeffs, shifts = zip(*generate_shift_rule(eigvals_to_frequencies(tuple(eigvals))))
insert_ops = [qml.exp(qml.dot([-1j * shift], [ob])) for shift in shifts]
# Create Pauli rotations to be inserted at tau
ode_kwargs = op.odeint_kwargs
# If we are broadcasting, make use of the `return_intermediate` and `complementary` features
ops = tuple(
[
op(op.data, before_t, return_intermediate=bcast, **ode_kwargs),
insert_op,
op(op.data, after_t, return_intermediate=bcast, complementary=bcast, **ode_kwargs),
]
for insert_op in insert_ops
)
return ops, jnp.array(coeffs)
def _split_evol_tape(tape, split_evolve_ops, op_idx):
"""Replace a marked ``ParametrizedEvolution`` in a given tape by provided operations, creating
one tape per group of operations.
Args:
tape (QuantumTape): original tape
split_evolve_ops (tuple[list[qml.Operation]]): The time-split evolution operations as
created by ``_split_evol_ops``. For each group of operations, a new tape
is created.
op_idx (int): index of the operation to replace within the tape
Returns:
list[QuantumTape]: new tapes with replaced operation, one tape per group of operations in
``split_evolve_ops``.
"""
ops_pre = tape.operations[:op_idx]
ops_post = tape.operations[op_idx + 1 :]
return [
qml.tape.QuantumScript(ops_pre + split + ops_post, tape.measurements, shots=tape.shots)
for split in split_evolve_ops
]
# pylint: disable=too-many-arguments
def _parshift_and_integrate(
results,
cjacs,
int_prefactor,
psr_coeffs,
single_measure,
has_partitioned_shots,
use_broadcasting,
):
"""Apply the parameter-shift rule post-processing to tape results and contract
with classical Jacobians, effectively evaluating the numerical integral of the stochastic
parameter-shift rule.
Args:
results (list): Tape evaluation results, corresponding to the modified quantum
circuit result when using the applicable parameter shifts and the sample splitting
times. Results should be ordered such that the different shifted circuits for a given
splitting time are grouped together
cjacs (tensor_like): classical Jacobian evaluated at the splitting times
int_prefactor (float): prefactor of the numerical integration, corresponding to the size
of the time range divided by the number of splitting time samples
psr_coeffs (tensor_like or tuple[tensor_like]): Coefficients of the parameter-shift
rule to contract the results with before integrating numerically.
single_measure (bool): Whether the results contain a single measurement per shot setting
has_partitioned_shots (bool): Whether the results have a shot vector axis
use_broadcasting (bool): Whether broadcasting was used in the tapes that returned the
``results``.
Returns:
tensor_like or tuple[tensor_like] or tuple[tuple[tensor_like]]: Gradient entry
"""
def _contract(coeffs, res, cjac):
"""Contract three tensors, the first two like a standard matrix multiplication
and the result with the third tensor along the first axes."""
return jnp.tensordot(jnp.tensordot(coeffs, res, axes=1), cjac, axes=[[0], [0]])
if isinstance(psr_coeffs, tuple):
num_shifts = [len(c) for c in psr_coeffs]
def _psr_and_contract(res_list, cjacs, int_prefactor):
"""Execute the parameter-shift rule and contract with classical Jacobians.
This function assumes multiple generating terms for the pulse parameter
of interest"""
res = jnp.stack(res_list)
idx = 0
# Preprocess the results: Reshape, create slices for different generating terms
if use_broadcasting:
# Slice the results according to the different generating terms. Slice away the
# first and last value for each term, which correspond to the initial condition
# and the final value of the time evolution, but not to splitting times
res = tuple(res[idx : (idx := idx + n), 1:-1] for n in num_shifts)
else:
shape = jnp.shape(res)
num_taus = shape[0] // sum(num_shifts)
# Reshape the slices of the results corresponding to different generating terms.
# Afterwards the first axis corresponds to the splitting times and the second axis
# corresponds to the different shifts of the respective term.
# Finally move the shifts-axis to the first position of each term.
res = tuple(
jnp.moveaxis(
jnp.reshape(
res[idx : (idx := idx + n * num_taus)], (num_taus, n) + shape[1:]
),
1,
0,
)
for n in num_shifts
)
# Contract the results, parameter-shift rule coefficients and (classical) Jacobians,
# and include the rescaling factor from the Monte Carlo integral and from global
# prefactors of Pauli word generators.
diff_per_term = jnp.array(
[_contract(c, r, cjac) for c, r, cjac in zip(psr_coeffs, res, cjacs)]
)
return qml.math.sum(diff_per_term, axis=0) * int_prefactor
else:
num_shifts = len(psr_coeffs)
def _psr_and_contract(res_list, cjacs, int_prefactor):
"""Execute the parameter-shift rule and contract with classical Jacobians.
This function assumes a single generating term for the pulse parameter
of interest"""
res = jnp.stack(res_list)
# Preprocess the results: Reshape, create slices for different generating terms
if use_broadcasting:
# Slice away the first and last values, corresponding to the initial condition
# and the final value of the time evolution, but not to splitting times
res = res[:, 1:-1]
else:
# Reshape the results such that the first axis corresponds to the splitting times
# and the second axis corresponds to different shifts. All other axes are untouched.
# Afterwards move the shifts-axis to the first position.
shape = jnp.shape(res)
new_shape = (shape[0] // num_shifts, num_shifts) + shape[1:]
res = jnp.moveaxis(jnp.reshape(res, new_shape), 1, 0)
# Contract the results, parameter-shift rule coefficients and (classical) Jacobians,
# and include the rescaling factor from the Monte Carlo integral and from global
# prefactors of Pauli word generators.
return _contract(psr_coeffs, res, cjacs) * int_prefactor
nesting_layers = (not single_measure) + has_partitioned_shots
if nesting_layers == 1:
return tuple(_psr_and_contract(r, cjacs, int_prefactor) for r in zip(*results))
if nesting_layers == 0:
# Single measurement without shot vector
return _psr_and_contract(results, cjacs, int_prefactor)
# Multiple measurements with shot vector. Not supported with broadcasting yet.
if use_broadcasting:
# TODO: Remove once #2690 is resolved
raise NotImplementedError(
"Broadcasting, multiple measurements and shot vectors are currently not "
"supported all simultaneously by stoch_pulse_grad."
)
return tuple(
tuple(_psr_and_contract(_r, cjacs, int_prefactor) for _r in zip(*r)) for r in zip(*results)
)
# pylint: disable=too-many-arguments
[docs]@partial(transform, final_transform=True)
def stoch_pulse_grad(
tape: QuantumScript,
argnum=None,
num_split_times=1,
sampler_seed=None,
use_broadcasting=False,
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
r"""Compute the gradient of a quantum circuit composed of pulse sequences by applying the
stochastic parameter shift rule.
For a pulse-based cost function :math:`C(\boldsymbol{v}, T)`
with variational parameters :math:`\boldsymbol{v}` and evolution time :math:`T`, it is given by
(c.f. Eqn. (6) in `Leng et al. (2022) <https://arxiv.org/abs/2210.15812>`__ with altered
notation):
.. math::
\frac{\partial C}{\partial v_k}
= \int_{0}^{T} \mathrm{d}\tau \sum_{j=1}^m
\frac{\partial f_j}{\partial v_k}(\boldsymbol{v}, \tau)
\left[C_j^{(+)}(\boldsymbol{v}, \tau) - C_j^{(-)}(\boldsymbol{v}, \tau)\right]
Here, :math:`f_j` are the pulse envelopes that capture the time dependence of the pulse
Hamiltonian:
.. math::
H(\boldsymbol{v}, t) = H_\text{drift} + \sum_j f_j(\boldsymbol{v}, t) H_j,
and :math:`C_j^{(\pm)}` are modified cost functions:
.. math::
C_j^{(\pm)}(\boldsymbol{v}, \tau)&=
\bra{\psi^{(\pm)}_{j}(\boldsymbol{v}, \tau)} B
\ket{\psi^{(\pm)}_{j}(\boldsymbol{v}, \tau)} \\
\ket{\psi^{(\pm)}_{j}(\boldsymbol{v}, \tau)}
&= U_{\boldsymbol{v}}(T, \tau) e^{-i (\pm \frac{\pi}{4}) H_j}
U_{\boldsymbol{v}}(\tau, 0)\ket{\psi_0}.
That is, the :math:`j`\ th modified time evolution in these circuit interrupts the
evolution generated by the pulse Hamiltonian by inserting a rotation gate generated by
the corresponding Hamiltonian term :math:`H_j` with a rotation angle of
:math:`\pm\frac{\pi}{4}`.
See below for a more detailed description. The integral in the first equation above
is estimated numerically in the stochastic parameter-shift rule. For this, it samples
split times :math:`\tau` and averages the modified cost functions and the Jacobians
of the envelopes :math:`\partial f_j / \partial v_k` at the sampled times suitably.
Args:
tape (QuantumTape): quantum circuit to differentiate
argnum (int or list[int] or None): Trainable tape parameter indices to differentiate
with respect to. If not provided, the derivatives with respect to all
trainable parameters are returned. Note that the indices are with respect to
the list of trainable parameters.
num_split_times (int): number of time samples to use in the stochastic parameter-shift
rule underlying the differentiation; also see details
sample_seed (int): randomness seed to be used for the time samples in the stochastic
parameter-shift rule
use_broadcasting (bool): Whether to use broadcasting across the different sampled
splitting times. If ``False`` (the default), one set of modified tapes per
splitting time is created, if ``True`` only a single set of broadcasted, modified
tapes is created, increasing performance on simulators.
Returns:
tuple[List[QuantumTape], function]:
The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. Executing this circuit
will provide the Jacobian in the form of a tensor, a tuple, or a nested tuple depending upon the nesting
structure of measurements in the original circuit.
This transform realizes the stochastic parameter-shift rule for pulse sequences, as introduced
in `Banchi and Crooks (2018) <https://quantum-journal.org/papers/q-2021-01-25-386/>`_ and
`Leng et al. (2022) <https://arxiv.org/abs/2210.15812>`_.
.. note::
This function requires the JAX interface and does not work with other autodiff interfaces
commonly encountered with PennyLane.
Finally, this transform is not JIT-compatible yet.
.. note::
This function uses a basic sampling approach with a uniform distribution to estimate the
integral appearing in the stochastic parameter-shift rule. In many cases, there are
probability distributions that lead to smaller variances of the estimator.
In addition, the sampling approach will not reduce trivially to simpler parameter-shift
rules when used with simple pulses (see details and examples below), potentially leading
to imprecise results and/or unnecessarily large computational efforts.
.. warning::
This transform may not be applied directly to QNodes. Use JAX entrypoints
(``jax.grad``, ``jax.jacobian``, ...) instead or apply the transform on the tape level.
Also see the examples below.
**Examples**
Consider a pulse program with a single two-qubit pulse, generated by a Hamiltonian
with three terms: the non-trainable term :math:`\frac{1}{2}X_0`, the trainable
constant (over time) term :math:`v_1 Z_0 Z_1` and the trainable sinoidal term
:math:`\sin(v_2 t) (\frac{1}{5} Y_0 + \frac{7}{10} X_1)`.
.. code-block:: python
jax.config.update("jax_enable_x64", True)
dev = qml.device("default.qubit")
def sin(p, t):
return jax.numpy.sin(p * t)
ZZ = qml.Z(0) @ qml.Z(1)
Y_plus_X = qml.dot([1/5, 3/5], [qml.Y(0), qml.X(1)])
H = 0.5 * qml.X(0) + qml.pulse.constant * ZZ + sin * Y_plus_X
def ansatz(params):
qml.evolve(H)(params, (0.2, 0.4))
return qml.expval(qml.Y(1))
qnode = qml.QNode(ansatz, dev, interface="jax", diff_method=qml.gradients.stoch_pulse_grad)
The program takes the two parameters :math:`v_1, v_2` for the two trainable terms:
>>> params = [jax.numpy.array(0.4), jax.numpy.array(1.3)]
>>> qnode(params)
Array(-0.0905377, dtype=float64)
And as we registered the differentiation method :func:`~.stoch_pulse_grad`,
we can compute its gradient in a hardware compatible manner:
>>> jax.grad(qnode)(params)
[Array(0.00109782, dtype=float64, weak_type=True),
Array(-0.05833371, dtype=float64, weak_type=True)] # results may differ
Note that the derivative is computed using a stochastic parameter-shift rule,
which is based on a sampled approximation of an integral expression (see theoretical
background below). This makes the computed derivative an approximate quantity subject
to statistical fluctuations with notable variance. The number of samples used to
approximate the integral can be chosen with ``num_split_times``, the seed for the
sampling can be fixed with ``sampler_seed``:
.. code-block:: python
qnode = qml.QNode(
ansatz,
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
num_split_times=5, # Use 5 samples for the approximation
sampler_seed=18, # Fix randomness seed
)
>>> jax.grad(qnode)(params)
[Array(0.00207256, dtype=float64, weak_type=True),
Array(-0.05989856, dtype=float64, weak_type=True)]
We may activate the option ``use_broadcasting`` to improve the performance when running
on classical simulators. Internally, it reuses intermediate results of the time evolution.
We can compare the performance with a simple test:
.. code-block:: python
from time import process_time
faster_grad_qnode = qml.QNode(
ansatz,
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
num_split_times=5, # Use 5 samples for the approximation
sampler_seed=18, # Fix randomness seed
use_broadcasting=True, # Activate broadcasting
)
times = []
for node in [qnode, faster_grad_qnode]:
start = process_time()
jax.grad(node)(params)
times.append(process_time() - start)
>>> print(times) # Show the gradient computation times in seconds.
[55.75785480000002, 12.297400500000009]
.. warning::
As the option ``use_broadcasting=True`` adds a broadcasting dimension to the modified
circuits, it is not compatible with circuits that already are broadcasted.
.. details::
:title: Theoretical background
:href: theory
Consider a pulse generated by a time-dependent Hamiltonian
.. math::
H(\boldsymbol{v}, t) = H_\text{drift} + \sum_j f_j(v_j, t) H_j,
where :math:`\boldsymbol{v}=\{v_j\}` are variational parameters and :math:`t` is the time.
In addition, consider a cost function that is based on using this pulse for
a duration :math:`T`
in a pulse sequence and measuring the expectation value of an observable.
For simplicity we absorb the parts of the sequence
before and after the considered pulse into the initial state and the observable,
respectively:
.. math::
C(\boldsymbol{v}, t) =
\bra{\psi_0} U_{\boldsymbol{v}}(T, 0)^\dagger B U_{\boldsymbol{v}}(T, 0)\ket{\psi_0}.
Here, we denoted the unitary evolution under :math:`H(\boldsymbol{v}, t)` from time
:math:`t_1` to :math:`t_2` as :math:`U_{\boldsymbol{v}(t_2, t_1)}`.
Then the derivative of :math:`C` with respect to a specific parameter :math:`v_k`
is given by (see Eqn. (6) of `Leng et al. (2022) <https://arxiv.org/abs/2210.15812>`_)
.. math::
\frac{\partial C}{\partial v_k}
= \int_{0}^{T} \mathrm{d}\tau \sum_{j=1}^m
\frac{\partial f_j}{\partial v_k}(\boldsymbol{v}, \tau)
\widetilde{C_j}(\boldsymbol{v}, \tau).
Here, the integral ranges over the duration of the pulse, the partial derivatives of
the coefficient functions, :math:`\partial f_j / \partial v_k`, are computed classically,
and :math:`\widetilde{C_j}` is a linear combination of the results from modified pulse
sequence executions based on generalized parameter-shift rules
(see e.g. `Kyriienko and Elfving (2022) <https://arxiv.org/abs/2108.01218>`_ or
`Wierichs et al. (2022) <https://doi.org/10.22331/q-2022-03-30-677>`_ for more details
and :func:`~.param_shift` for an implementation of the non-stochastic generalized shift
rules)
Given the parameter shift rule with coefficients :math:`\{y_\ell\}` and shifts
:math:`\{x_\ell\}` for the single-parameter pulse :math:`\exp(-i \theta H_j)`,
the linear combination is given by
.. math::
\widetilde{C_j}(\boldsymbol{v}, \tau)&=\sum_{\ell=1} y_\ell
\bra{\psi_{j}(\boldsymbol{v}, x_\ell, \tau)} B
\ket{\psi_{j}(\boldsymbol{v}, x_\ell, \tau)} \\
\ket{\psi_{j}(\boldsymbol{v}, x_\ell, \tau)}
&= U_{\boldsymbol{v}}(T, \tau) e^{-i x_\ell H_j}
U_{\boldsymbol{v}}(\tau, 0)\ket{\psi_0}.
In practice, the time integral over :math:`\tau` is computed by sampling values for
the time, evaluating the integrand, and averaging appropriately. The probability
distribution used for the sampling may have a significant impact on the quality of the
obtained estimates, in particular with regards to their variance.
In this function, a uniform distribution over the interval :math:`[0, t]` is used,
which often can be improved upon.
**Examples**
Consider the pulse generated by
.. math::
H(\boldsymbol{v}, t) = \frac{1}{2} X_0 + v_1 Z_0 Z_1 + \sin(v_2 t) X_1
and the observable :math:`B=Y_1`. There are two variational parameters, :math:`v_1`
and :math:`v_2`, for which we may compute the derivative of the cost function:
.. math::
\frac{\partial C}{\partial v_1}
&= \int_{0}^{T} \mathrm{d}\tau \ \widetilde{C_1}((v_1, v_2), \tau)\\
\frac{\partial C}{\partial v_2}
&= \int_{0}^{T} \mathrm{d}\tau \cos(v_2 \tau) \tau \ \widetilde{C_2}((v_1, v_2), \tau)\\
\widetilde{C_j}((v_1, v_2), \tau)&=
\bra{\psi_{j}((v_1, v_2), \pi/4, \tau)} B
\ket{\psi_{j}((v_1, v_2), \pi/4, \tau)}\\
&-\bra{\psi_{j}((v_1, v_2), -\pi/4, \tau)} B
\ket{\psi_{j}((v_1, v_2), -\pi/4, \tau)} \\
\ket{\psi_{j}((v_1, v_2), x, \tau)}
&= U_{(v_1, v_2)}(T, \tau) e^{-i x H_j}U_{(v_1, v_2)}(\tau, 0)\ket{0}.
Here we used the partial derivatives
.. math::
\frac{\partial f_1}{\partial v_1}&= 1\\
\frac{\partial f_2}{\partial v_2}&= \cos(v_2 t) t \\
\frac{\partial f_1}{\partial v_2}=
\frac{\partial f_2}{\partial v_1}&= 0
and the fact that both :math:`H_1=Z_0 Z_1` and :math:`H_2=X_1`
have two unique eigenvalues and therefore admit a two-term parameter-shift rule
(see e.g. `Schuld et al. (2018) <https://arxiv.org/abs/1811.11184>`_).
As a second scenario, consider the single-qubit pulse generated by
.. math::
H((v_1, v_2), t) = v_1 \sin(v_2 t) X
together with the observable :math:`B=Z`.
You may already notice that this pulse can be rewritten as a :class:`~.RX` rotation,
because we have a single Hamiltonian term and the spectrum of :math:`H` consequently
will be constant up to rescaling.
In particular, the unitary time evolution under the Schrödinger equation is given by
.. math::
U_{(v_1, v_2)}(t_2, t_1) &=
\exp\left(-i\int_{t_1}^{t_2} \mathrm{d}\tau v_1 \sin(v_2 \tau) X\right)\\
&=\exp(-i\theta(v_1, v_2) X)\\
\theta(v_1, v_2) &= \int_{t_1}^{t_2} \mathrm{d}\tau v_1 \sin(v_2 \tau)\\
&=-\frac{v_1}{v_2}(\cos(v_2 t_2) - \cos(v_2 t_1)).
As the ``RX`` rotation satisfies a (non-stochastic) two-term parameter-shift rule,
we could compute the derivatives with respect to :math:`v_1` and :math:`v_2` by
implementing :math:`\exp(-i\theta(v_1, v_2) X)`, applying the two-term shift rule
and evaluating the classical Jacobian of the mapping :math:`\theta(v_1, v_2)`.
Using the stochastic parameter-shift rule instead will lead to approximation errors.
This is because the approximated integral not only includes the shifted circuit
evaluations, which do not depend on :math:`\tau` in this example, but also on the
classical Jacobian, which is *not* constant over :math:`\tau`.
Therefore, it is important to implement pulses in the simplest way possible.
"""
# pylint:disable=unused-argument
transform_name = "stochastic pulse parameter-shift"
_assert_has_jax(transform_name)
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
assert_no_trainable_tape_batching(tape, transform_name)
if num_split_times < 1:
raise ValueError(
"Expected a positive number of samples for the stochastic pulse "
f"parameter-shift gradient, got {num_split_times}."
)
if argnum is None and not tape.trainable_params:
return _no_trainable_grad(tape)
if use_broadcasting and tape.batch_size is not None:
raise ValueError("Broadcasting is not supported for tapes that already are broadcasted.")
trainable_params = choose_trainable_params(tape, argnum)
diff_methods = find_and_validate_gradient_methods(tape, "analytic", trainable_params)
if all(g == "0" for g in diff_methods.values()):
return _all_zero_grad(tape)
argnum = [i for i, dm in diff_methods.items() if dm == "A"]
sampler_seed = sampler_seed or np.random.randint(18421)
key = jax.random.PRNGKey(sampler_seed)
return _expval_stoch_pulse_grad(tape, argnum, num_split_times, key, use_broadcasting)
def _generate_tapes_and_cjacs(
tape, operation, key, num_split_times, use_broadcasting, par_idx=None
):
"""Generate the tapes and compute the classical Jacobians for one given
generating Hamiltonian term of one pulse.
Args:
tape (QuantumScript): Tape for which to compute the stochastic pulse parameter-shift
gradient tapes.
operation (tuple[Operation, int, int]): Information about the pulse operation to be
shifted. The first entry is the operation itself, the second entry is its position
in the ``tape``, and the third entry is the index of the differentiated parameter
(and generating term) within the ``HardwareHamiltonian`` of the operation.
key (tuple[int]): Randomness key to create spliting times.
num_split_times (int): Number of splitting times at which to create shifted tapes for
the stochastic shift rule.
use_broadcasting (bool): Whether to use broadcasting in the shift rule or not.
Returns:
list[QuantumScript]: Gradient tapes for the indicated operation and Hamiltonian term.
list[tensor_like]: Classical Jacobian at the splitting times for the given parameter.
float: Prefactor for the Monte Carlo estimate of the integral in the stochastic shift rule.
tensor_like: Parameter-shift coefficients for the shift rule of the indicated term.
"""
op, op_idx, term_idx = operation
coeff, ob = op.H.coeffs_parametrized[term_idx], op.H.ops_parametrized[term_idx]
if par_idx is None:
cjac_fn = jax.jacobian(coeff, argnums=0)
else:
# For `par_idx is not None`, we need to extract the entry of the coefficient
# Jacobian that belongs to the parameter of interest. This only happens when
# more than one parameter effectively feeds into one coefficient (HardwareHamiltonian)
def cjac_fn(params, t):
return jax.jacobian(coeff, argnums=0)(params, t)[par_idx]
t0, *_, t1 = op.t
taus = jnp.sort(jax.random.uniform(key, shape=(num_split_times,)) * (t1 - t0) + t0)
if isinstance(op.H, HardwareHamiltonian):
op_data = op.H.reorder_fn(op.data, op.H.coeffs_parametrized)
else:
op_data = op.data
cjacs = [cjac_fn(op_data[term_idx], tau) for tau in taus]
if use_broadcasting:
split_evolve_ops, psr_coeffs = _split_evol_ops(op, ob, taus)
tapes = _split_evol_tape(tape, split_evolve_ops, op_idx)
else:
tapes = []
for tau in taus:
split_evolve_ops, psr_coeffs = _split_evol_ops(op, ob, tau)
tapes.extend(_split_evol_tape(tape, split_evolve_ops, op_idx))
int_prefactor = (t1 - t0) / num_split_times
return tapes, cjacs, int_prefactor, psr_coeffs
def _tapes_data_hardware(tape, operation, key, num_split_times, use_broadcasting):
"""Create tapes and gradient data for a trainable parameter of a HardwareHamiltonian,
taking into account its reordering function.
Args:
tape (QuantumScript): Tape for which to compute the stochastic pulse parameter-shift
gradient tapes.
operation (tuple[Operation, int, int]): Information about the pulse operation to be
shifted. The first entry is the operation itself, the second entry is its position
in the ``tape``, and the third entry is the index of the differentiated parameter
within the ``HardwareHamiltonian`` of the operation.
key (tuple[int]): Randomness key to create spliting times in ``_generate_tapes_and_cjacs``
num_split_times (int): Number of splitting times at which to create shifted tapes for
the stochastic shift rule.
use_broadcasting (bool): Whether to use broadcasting in the shift rule or not.
Returns:
list[QuantumScript]: Gradient tapes for the indicated operation and Hamiltonian term.
tuple: Gradient postprocessing data.
See comment below.
This function analyses the ``reorder_fn`` of the ``HardwareHamiltonian`` of the pulse
that is being differentiated. Given a ``term_idx``, the index of the parameter
in the Hamiltonian, stochastic parameter shift tapes are created for all terms in the
Hamiltonian into which the parameter feeds. While this is a one-to-one relation for
standard ``ParametrizedHamiltonian`` objects, the reordering function of
the ``HardwareHamiltonian`` requires to create tapes for multiple Hamiltonian terms,
and for each term ``_generate_tapes_and_cjacs`` is called.
The returned gradient data has four entries:
1. ``int``: Total number of tapes created for all the terms that depend on the indicated
parameter.
2. ``tuple[tensor_like]``: Classical Jacobians for all terms and splitting times
3. ``float``: Prefactor for the Monte Carlo estimate of the integral in the stochastic
shift rule.
4. ``tuple[tensor_like]``: Parameter-shift coefficients for all terms.
The tuple axes in the second and fourth entry correspond to the different terms in the
Hamiltonian.
"""
op, op_idx, term_idx = operation
# Map a simple enumeration of numbers from HardwareHamiltonian input parameters to
# ParametrizedHamiltonian parameters. This is typically a fan-out function.
fake_params, allowed_outputs = np.arange(op.num_params), set(range(op.num_params))
reordered = op.H.reorder_fn(fake_params, op.H.coeffs_parametrized)
def _raise():
raise ValueError(
"Only permutations, fan-out or fan-in functions are allowed as reordering functions "
"in HardwareHamiltonians treated by stoch_pulse_grad. The reordering function of "
f"{op.H} mapped {fake_params} to {reordered}."
)
cjacs, tapes, psr_coeffs = [], [], []
for coeff_idx, x in enumerate(reordered):
# Find out whether the value term_idx, corresponding to the current parameter of interest,
# has been mapped to x (for scalar x) or into x (for 1d x). If so, generate tapes and data
# Also check that only allowed outputs have been produced by the reordering function.
if not hasattr(x, "__len__"):
if x not in allowed_outputs:
_raise()
if x != term_idx:
continue
cjac_idx = None
else:
if not all(_x in list(range(op.num_params)) for _x in x):
_raise()
if term_idx not in x:
continue
cjac_idx = np.argwhere([_x == term_idx for _x in x])[0][0]
_operation = (op, op_idx, coeff_idx)
# Overwriting int_prefactor does not matter, it is equal for all parameters in this op,
# because it only consists of the duration `op.t[-1]-op.t[0]` and `num_split_times`
_tapes, _cjacs, int_prefactor, _psr_coeffs = _generate_tapes_and_cjacs(
tape, _operation, key, num_split_times, use_broadcasting, cjac_idx
)
cjacs.append(qml.math.stack(_cjacs))
tapes.extend(_tapes)
psr_coeffs.append(_psr_coeffs)
# The fact that psr_coeffs are a tuple only for hardware Hamiltonian generators will be
# used in `_parshift_and_integrate`.
data = (len(tapes), tuple(cjacs), int_prefactor, tuple(psr_coeffs))
return tapes, data
# pylint: disable=too-many-arguments
def _expval_stoch_pulse_grad(tape, argnum, num_split_times, key, use_broadcasting):
r"""Compute the gradient of a quantum circuit composed of pulse sequences that measures
an expectation value or probabilities, by applying the stochastic parameter shift rule.
See the main function for the signature.
"""
tapes = []
gradient_data = []
for idx in range(tape.num_params):
if idx not in argnum:
# Only the number of tapes is needed to indicate a zero gradient entry
gradient_data.append((0, None, None, None))
continue
key, _key = jax.random.split(key)
operation = tape.get_operation(idx)
op, *_ = operation
if not isinstance(op, ParametrizedEvolution):
raise ValueError(
"stoch_pulse_grad does not support differentiating parameters of "
"other operations than pulses."
)
if isinstance(op.H, HardwareHamiltonian):
# Treat HardwareHamiltonians separately because they have a reordering function
_tapes, data = _tapes_data_hardware(
tape, operation, key, num_split_times, use_broadcasting
)
else:
_tapes, cjacs, int_prefactor, psr_coeffs = _generate_tapes_and_cjacs(
tape, operation, _key, num_split_times, use_broadcasting
)
data = (len(_tapes), qml.math.stack(cjacs), int_prefactor, psr_coeffs)
tapes.extend(_tapes)
gradient_data.append(data)
num_measurements = len(tape.measurements)
single_measure = num_measurements == 1
num_params = len(tape.trainable_params)
has_partitioned_shots = tape.shots.has_partitioned_shots
tape_specs = (single_measure, num_params, num_measurements, tape.shots)
def processing_fn(results):
start = 0
grads = []
for num_tapes, cjacs, int_prefactor, psr_coeffs in gradient_data:
if num_tapes == 0:
grads.append(None)
continue
res = results[start : start + num_tapes]
start += num_tapes
# Apply the postprocessing of the parameter-shift rule and contract
# with classical Jacobian, effectively computing the integral approximation
g = _parshift_and_integrate(
res,
cjacs,
int_prefactor,
psr_coeffs,
single_measure,
has_partitioned_shots,
use_broadcasting,
)
grads.append(g)
# g will have been defined at least once (because otherwise all gradients would have
# been zero), providing a representative for a zero gradient to emulate its type/shape.
zero_rep = _make_zero_rep(g, single_measure, has_partitioned_shots)
# Fill in zero-valued gradients
grads = [zero_rep if g is None else g for g in grads]
return reorder_grads(grads, tape_specs)
return tapes, processing_fn
@stoch_pulse_grad.custom_qnode_transform
def stoch_pulse_grad_qnode_wrapper(self, qnode, targs, tkwargs):
"""A custom QNode wrapper for the gradient transform :func:`~.stoch_pulse_grad`.
It raises an error, so that applying ``stoch_pulse_grad`` to a ``QNode`` directly
is not supported.
"""
# pylint:disable=unused-argument
transform_name = "stochastic pulse parameter-shift"
raise_pulse_diff_on_qnode(transform_name)
_modules/pennylane/gradients/pulse_gradient
Download Python script
Download Notebook
View on GitHub