Source code for pennylane.labs.dla.variational_kak
# Copyright 2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper Functionality to compute the kak decomposition variationally, as outlined in https://arxiv.org/abs/2104.00728"""
# pylint: disable=too-many-arguments, too-many-positional-arguments
import warnings
from datetime import datetime
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import pennylane as qml
from pennylane.operation import Operator
from pennylane.pauli import PauliSentence
from .cartan_subalgebra import adjvec_to_op, op_to_adjvec
has_jax = True
try:
import jax
import jax.numpy as jnp
import optax
jax.config.update("jax_enable_x64", True)
except ImportError:
has_jax = False
[docs]def variational_kak_adj(H, g, dims, adj, verbose=False, opt_kwargs=None, pick_min=False):
r"""
Variational KaK decomposition of Hermitian ``H`` using the adjoint representation.
Given a Cartan decomposition (:func:`~cartan_decomp`) :math:`\mathfrak{g} = \mathfrak{k} \oplus \mathfrak{m}`,
a Hermitian operator :math:`H \in \mathfrak{m}`,
and a horizontal Cartan subalgebra (:func:`~cartan_subalgebra`) :math:`\mathfrak{a} \subset \mathfrak{m}`,
this function computes
:math:`a \in \mathfrak{a}` and :math:`K_c \in e^{i\mathfrak{k}}` such that
.. math:: H = K_c a K_c^\dagger.
In particular, :math:`a = \sum_j c_j a_j` is decomposed in terms of commuting operators :math:`a_j \in \mathfrak{a}`.
This allows for the immediate decomposition
.. math:: e^{-i t H} = K_c e^{-i t a} K_c^\dagger = K_c \left(\prod_j e^{-i t c_j a_j} \right) K_c^\dagger.
The result is provided in terms of the adjoint vector representation of :math:`a \in \mathfrak{a}`
(see :func:`adjvec_to_op`), i.e. the ordered coefficients :math:`c_j` in :math:`a = \sum_j c_j m_j`
with the basis elements :math:`m_j \in (\tilde{\mathfrak{m}} \oplus \mathfrak{a})` and
the optimal parameters :math:`\theta` such that
.. math:: K_c = \prod_{j=|\mathfrak{k}|}^{1} e^{-i \theta_j k_j}
for the ordered basis of :math:`\mathfrak{k}` given by the first :math:`|\mathfrak{k}|` elements of ``g``.
Note that we define :math:`K_c` mathematically with the descending order of basis elements :math:`k_j \in \mathfrak{k}` such that
the resulting circuit has the canonical ascending order. In particular, a PennyLane quantum function that describes the circuit given
the optimal parameters ``theta_opt`` and the basis ``k`` containing the operators, is given by the following.
.. code-block:: python
def Kc(theta_opt: Iterable[float], k: Iterable[Operator]):
assert len(theta_opt) == len(k)
for theta_j, k_j in zip(theta_opt, k):
qml.exp(-1j * theta_j * k_j)
Internally, this function performs a modified version of `2104.00728 <https://arxiv.org/abs/2104.00728>`__,
in particular minimizing the cost function
.. math:: f(\theta) = \langle H, K(\theta) e^{-i \sum_{j=1}^{|\mathfrak{a}|} \pi^j a_j} K(\theta)^\dagger \rangle,
see eq. (6) therein and our `demo <https://pennylane.ai/qml/demos/tutorial_fixed_depth_hamiltonian_simulation_via_cartan_decomposition>`__ for more details.
Instead of relying on having Pauli words, we use the adjoint representation
for a more general evaluation of the cost function. The rest is the same.
.. seealso:: `The KAK decomposition theory(demo) <https://pennylane.ai/qml/demos/tutorial_kak_decomposition>`__, `The KAK decomposition in practice (demo) <https://pennylane.ai/qml/demos/tutorial_fixed_depth_hamiltonian_simulation_via_cartan_decomposition>`__.
Args:
H (Union[Operator, PauliSentence, np.ndarray]): Hamiltonian to decompose
g (List[Union[Operator, PauliSentence, np.ndarray]]): DLA of the Hamiltonian
dims (Tuple[int]): Tuple of dimensions ``(dim_k, dim_mtilde, dim_a)`` of
Cartan decomposition :math:`\mathfrak{g} = \mathfrak{k} \oplus (\tilde{\mathfrak{m}} \oplus \mathfrak{a})`
adj (np.ndarray): Adjoint representation of dimension ``(dim_g, dim_g, dim_g)``,
with the implicit ordering ``(k, mtilde, a)``.
verbose (bool): Plot the optimization
opt_kwargs (dict): Keyword arguments for the optimization like initial starting values
for :math:`\theta` of dimension ``(dim_k,)``, given as ``theta0``.
Also includes ``n_epochs``, ``lr``, ``b1``, ``b2``, ``verbose``, ``interrupt_tol``, see :func:`~run_opt`
pick_min (bool): Whether to pick the parameter set with lowest cost function value during the optimization
as optimal parameters. Otherwise picks the last parameter set.
Returns:
Tuple(np.ndarray, np.ndarray): ``(adjvec_a, theta_opt)``: The adjoint vector representation
``adjvec_a`` of dimension ``(dim_mtilde + dim_a,)``, with respect to the basis of
:math:`\mathfrak{m} = \tilde{\mathfrak{m}} + \mathfrak{a}` of the CSA element
:math:`a \in \mathfrak{a}` s.t. :math:`H = K a K^\dagger`. For a successful optimization, the entries
corresponding to :math:`\tilde{\mathfrak{m}}` should be close to zero.
The second return value, ``theta_opt``, are the optimal coefficients :math:`\theta` of the
decomposition :math:`K = \prod_{j=|\mathfrak{k}|}^{1} e^{-i \theta_j k_j}` for the basis :math:`k_j \in \mathfrak{k}`.
**Example**
Let us perform a KaK decomposition for the transverse field Ising model Hamiltonian, exemplarily for :math:`n=3` qubits on a chain.
We start with some boilerplate code to perform a Cartan decomposition using the :func:`~concurrence_involution`, which places the Hamiltonian
in the horizontal subspace :math:`\mathfrak{m}`. From this we re-order :math:`\mathfrak{g} = \mathfrak{k} + \mathfrak{m}` and finally compute a
:func:`~cartan_subalgebra` :math:`\mathfrak{a}` in :math:`\mathfrak{m} = \tilde{\mathfrak{m}} \oplus \mathfrak{a}`.
.. code-block:: python
import pennylane as qml
import numpy as np
import jax.numpy as jnp
import jax
from pennylane import X, Z
from pennylane.labs.dla import (
cartan_decomp,
cartan_subalgebra,
check_cartan_decomp,
concurrence_involution,
validate_kak,
variational_kak_adj,
adjvec_to_op,
)
n = 3
gens = [X(i) @ X(i + 1) for i in range(n - 1)]
gens += [Z(i) for i in range(n)]
H = qml.sum(*gens)
g = qml.lie_closure(gens)
g = [op.pauli_rep for op in g]
involution = concurrence_involution
assert not involution(H)
k, m = cartan_decomp(g, involution=involution)
assert check_cartan_decomp(k, m)
g = k + m
adj = qml.structure_constants(g)
g, k, mtilde, a, adj = cartan_subalgebra(g, k, m, adj, tol=1e-14, start_idx=0)
Due to the canonical ordering of all constituents, it suffices to tell ``variational_kak_adj`` the dimensions of ``dims = (len(k), len(mtilde), len(a))``,
alongside the Hamiltonian ``H``, the Lie algebra ``g`` and its adjoint representation ``adj``. Internally, the function is performing a variational
optimization to find a local extremum of a suitably constructed loss function that finds as its extremum the decomposition
.. math:: K_c = \prod_{j=1}^{|\mathfrak{k}|} e^{-i \theta_j k_j}
in form of the optimal parameters :math:`\{\theta_j\}` for the respective :math:`k_j \in \mathfrak{k}`.
The resulting :math:`K` then informs the CSA element ``a``
of the KaK decomposition via :math:`a = K_c H K_c^\dagger`. This is detailed in `2104.00728 <https://arxiv.org/abs/2104.00728>`__.
>>> dims = (len(k), len(mtilde), len(a))
>>> adjvec_a, theta_opt = variational_kak_adj(H, g, dims, adj, opt_kwargs={"n_epochs": 3000})
As a result, we are provided with the adjoint vector representation of the CSA element
:math:`a \in \mathfrak{a}` with respect to the basis ``mtilde+a`` and the optimal parameters of dimension :math:`|\mathfrak{k}|`
Let us perform some sanity checks to better understand the resulting outputs.
We can turn that element back to an operator using :func:`adjvec_to_op` and from that to a matrix for which we can check Hermiticity.
.. code-block:: python
m = mtilde + a
[a_op] = adjvec_to_op([adjvec_a], m)
a_m = qml.matrix(a_op, wire_order=range(n))
assert np.allclose(a_m, a_m.conj().T)
Let us now confirm that we get back the original Hamiltonian from the resulting :math:`K_c` and :math:`a`.
In particular, we want to confirm :math:`H = K_c a K_c^\dagger` for :math:`K_c = \prod_{j=1}^{|\mathfrak{k}|} e^{-i \theta_j k_j}`.
.. code-block:: python
assert len(theta_opt) == len(k)
def Kc(theta_opt):
for th, op in zip(theta_opt, k):
qml.exp(-1j * th * op.operation())
Kc_m = qml.matrix(Kc, wire_order=range(n))(theta_opt)
# check Unitary property of Kc
assert np.allclose(Kc_m.conj().T @ Kc_m, np.eye(2**n))
H_reconstructed = Kc_m @ a_m @ Kc_m.conj().T
H_m = qml.matrix(H, wire_order=range(len(H.wires)))
# check Hermitian property of reconstructed Hamiltonian
assert np.allclose(
H_reconstructed, H_reconstructed.conj().T
)
# confirm reconstruction was successful to some given numerical tolerance
assert np.allclose(H_m, H_reconstructed, atol=1e-6)
Instead of performing these checks by hand, we can use the helper function :func:`~validate_kak`.
>>> assert validate_kak(H, g, k, (adjvec_a, theta_opt), n, 1e-6)
"""
if not has_jax: # pragma: no cover
raise ImportError(
"jax and optax are required for variational_kak_adj. You can install them with pip install jax jaxlib optax."
) # pragma: no cover
if opt_kwargs is None:
opt_kwargs = {}
if not isinstance(H, PauliSentence):
H = H.pauli_rep
dim_k, dim_mtilde, dim_h = dims
dim_m = dim_mtilde + dim_h
adj_cropped = adj[-dim_m:, :dim_k, -dim_m:]
## creating the gamma vector expanded on the whole m
gammavec = jnp.zeros(dim_m)
gammavec = gammavec.at[-dim_h:].set([np.pi**i for i in range(dim_h)])
def loss(theta, vec_H, adj):
# this is different to Appendix F 1 in https://arxiv.org/pdf/2104.00728
# Making use of adjoint representation
# should be faster, and most importantly allow for treatment of sums of paulis
assert adj.shape == (len(vec_H), len(theta), len(vec_H))
# Implement Ad_(K_1 .. K_|k|) (vec_H), so that we get K_1 .. K_|k| H K^†_|k| .. K^†_1
for i in range(dim_k - 1, -1, -1):
vec_H = jax.scipy.linalg.expm(theta[i] * adj[:, i]) @ vec_H
return (gammavec @ vec_H).real
value_and_grad = jax.jit(jax.value_and_grad(loss))
[vec_H] = op_to_adjvec([H], g[-dim_m:], is_orthogonal=False)
theta0 = opt_kwargs.pop("theta0", None)
if theta0 is None:
theta0 = jax.random.normal(jax.random.PRNGKey(0), (dim_k,))
opt_kwargs["verbose"] = verbose
thetas, energy, _ = run_opt(
partial(value_and_grad, vec_H=vec_H, adj=adj_cropped), theta0, **opt_kwargs
)
if verbose >= 1:
plt.plot(energy - np.min(energy))
plt.xlabel("epochs")
plt.ylabel("loss")
plt.yscale("log")
plt.show()
idx = np.argmin(energy) if pick_min else -1
if verbose:
n_epochs = opt_kwargs.get("n_epochs", 500)
print(f"Picking entry with index {idx} out of {n_epochs-1} ({pick_min=}).")
theta_opt = thetas[idx]
# Implement Ad_(K_1 .. K_|k|) (vec_H) like in the loss, with optimized parameters now.
for i in range(dim_k - 1, -1, -1):
vec_H = jax.scipy.linalg.expm(theta_opt[i] * adj_cropped[:, i]) @ vec_H
return vec_H, theta_opt
[docs]def validate_kak(H, g, k, kak_res, n, error_tol, verbose=False):
"""Helper function to validate a khk decomposition"""
# validate h_elem is Hermitian
_is_dense = all(isinstance(op, np.ndarray) for op in k) and all(
isinstance(op, np.ndarray) for op in k
)
vec_a, theta_opt = kak_res
[a_elem] = adjvec_to_op([vec_a], g[len(k) :]) # sum(c * op for c, op in zip(vec_h, m))
if isinstance(a_elem, Operator):
a_elem_m = qml.matrix(a_elem, wire_order=range(n))
elif isinstance(a_elem, PauliSentence):
a_elem_m = a_elem.to_mat(wire_order=range(n))
else:
a_elem_m = a_elem
assert np.allclose(a_elem_m, a_elem_m.conj().T), "CSA element `a` not Hermitian"
# validate K_c a K_c^† reproduces H
# Compute the ansatz K_c = K(theta_c) = K_1(theta_1) .. K_|k|(theta_|k|)
Km = jnp.eye(2**n)
assert len(theta_opt) == len(k)
for th, op in zip(theta_opt, k):
opm = qml.matrix(op.operation(), wire_order=range(n)) if not _is_dense else op
Km @= jax.scipy.linalg.expm(1j * th * opm)
assert np.allclose(Km @ Km.conj().T, np.eye(2**n))
# Compute K_c^† a K_c
H_reconstructed = Km.conj().T @ a_elem_m @ Km
H_m = qml.matrix(H, wire_order=range(len(H.wires)))
if verbose:
print(f"Original matrix: {H_m}")
print(f"Reconstructed matrix: {H_reconstructed}")
assert np.allclose(
H_reconstructed, H_reconstructed.conj().T
), "Reconstructed Hamiltonian not Hermitian"
success = np.allclose(H_m, H_reconstructed, atol=error_tol)
if not success:
error = np.linalg.norm(H_m - H_reconstructed, ord="fro")
warnings.warn(
"The reconstructed H is not numerical identical to the original H.\n"
f"We can still check for unitary equivalence: {error}",
UserWarning,
)
return success
[docs]def run_opt(
value_and_grad,
theta,
n_epochs=500,
lr=0.1,
b1=0.99,
b2=0.999,
verbose=False,
interrupt_tol=None,
):
"""Boilerplate jax optimization"""
if not has_jax: # pragma: no cover
raise ImportError(
"jax and optax are required for run_opt. You can install them with pip install jax jaxlib optax."
) # pragma: no cover
optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2)
opt_state = optimizer.init(theta)
energy, gradients, thetas = [], [], []
@jax.jit
def step(opt_state, theta):
val, grad_circuit = value_and_grad(theta)
updates, opt_state = optimizer.update(grad_circuit, opt_state)
theta = optax.apply_updates(theta, updates)
return opt_state, theta, val, grad_circuit
t0 = datetime.now()
## Optimization loop
try:
for n in range(n_epochs):
opt_state, theta, val, grad_circuit = step(opt_state, theta)
energy.append(val)
gradients.append(grad_circuit)
thetas.append(theta)
if (
interrupt_tol is not None
and (norm := np.linalg.norm(gradients[-1])) < interrupt_tol
):
print(
f"Interrupting after {n} epochs because gradient norm is {norm} < {interrupt_tol}"
)
break
if verbose:
if n == 0:
print("First optimization round performed")
if n % (n_epochs // 20) == 0:
print(f"Epoch {n:5d}: {val:.8f}")
except KeyboardInterrupt:
print(
"KeyboardInterrupt received. Cancelled the optimization and will return intermediate result."
)
t1 = datetime.now()
if verbose:
print(f"final loss: {val}; min loss: {np.min(energy)}; after {t1 - t0}")
return thetas, energy, gradients
_modules/pennylane/labs/dla/variational_kak
Download Python script
Download Notebook
View on GitHub