Source code for pennylane.labs.dla.variational_kak

# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper Functionality to compute the kak decomposition variationally, as outlined in https://arxiv.org/abs/2104.00728"""
# pylint: disable=too-many-arguments, too-many-positional-arguments
import warnings
from datetime import datetime
from functools import partial

import matplotlib.pyplot as plt
import numpy as np

import pennylane as qml
from pennylane.operation import Operator
from pennylane.pauli import PauliSentence

from .cartan_subalgebra import adjvec_to_op, op_to_adjvec

has_jax = True
try:
    import jax
    import jax.numpy as jnp
    import optax

    jax.config.update("jax_enable_x64", True)
except ImportError:
    has_jax = False


[docs]def variational_kak_adj(H, g, dims, adj, verbose=False, opt_kwargs=None, pick_min=False): r""" Variational KaK decomposition of Hermitian ``H`` using the adjoint representation. Given a Cartan decomposition (:func:`~cartan_decomp`) :math:`\mathfrak{g} = \mathfrak{k} \oplus \mathfrak{m}`, a Hermitian operator :math:`H \in \mathfrak{m}`, and a horizontal Cartan subalgebra (:func:`~cartan_subalgebra`) :math:`\mathfrak{a} \subset \mathfrak{m}`, this function computes :math:`a \in \mathfrak{a}` and :math:`K_c \in e^{i\mathfrak{k}}` such that .. math:: H = K_c a K_c^\dagger. In particular, :math:`a = \sum_j c_j a_j` is decomposed in terms of commuting operators :math:`a_j \in \mathfrak{a}`. This allows for the immediate decomposition .. math:: e^{-i t H} = K_c e^{-i t a} K_c^\dagger = K_c \left(\prod_j e^{-i t c_j a_j} \right) K_c^\dagger. The result is provided in terms of the adjoint vector representation of :math:`a \in \mathfrak{a}` (see :func:`adjvec_to_op`), i.e. the ordered coefficients :math:`c_j` in :math:`a = \sum_j c_j m_j` with the basis elements :math:`m_j \in (\tilde{\mathfrak{m}} \oplus \mathfrak{a})` and the optimal parameters :math:`\theta` such that .. math:: K_c = \prod_{j=|\mathfrak{k}|}^{1} e^{-i \theta_j k_j} for the ordered basis of :math:`\mathfrak{k}` given by the first :math:`|\mathfrak{k}|` elements of ``g``. Note that we define :math:`K_c` mathematically with the descending order of basis elements :math:`k_j \in \mathfrak{k}` such that the resulting circuit has the canonical ascending order. In particular, a PennyLane quantum function that describes the circuit given the optimal parameters ``theta_opt`` and the basis ``k`` containing the operators, is given by the following. .. code-block:: python def Kc(theta_opt: Iterable[float], k: Iterable[Operator]): assert len(theta_opt) == len(k) for theta_j, k_j in zip(theta_opt, k): qml.exp(-1j * theta_j * k_j) Internally, this function performs a modified version of `2104.00728 <https://arxiv.org/abs/2104.00728>`__, in particular minimizing the cost function .. math:: f(\theta) = \langle H, K(\theta) e^{-i \sum_{j=1}^{|\mathfrak{a}|} \pi^j a_j} K(\theta)^\dagger \rangle, see eq. (6) therein and our `demo <https://pennylane.ai/qml/demos/tutorial_fixed_depth_hamiltonian_simulation_via_cartan_decomposition>`__ for more details. Instead of relying on having Pauli words, we use the adjoint representation for a more general evaluation of the cost function. The rest is the same. .. seealso:: `The KAK decomposition theory(demo) <https://pennylane.ai/qml/demos/tutorial_kak_decomposition>`__, `The KAK decomposition in practice (demo) <https://pennylane.ai/qml/demos/tutorial_fixed_depth_hamiltonian_simulation_via_cartan_decomposition>`__. Args: H (Union[Operator, PauliSentence, np.ndarray]): Hamiltonian to decompose g (List[Union[Operator, PauliSentence, np.ndarray]]): DLA of the Hamiltonian dims (Tuple[int]): Tuple of dimensions ``(dim_k, dim_mtilde, dim_a)`` of Cartan decomposition :math:`\mathfrak{g} = \mathfrak{k} \oplus (\tilde{\mathfrak{m}} \oplus \mathfrak{a})` adj (np.ndarray): Adjoint representation of dimension ``(dim_g, dim_g, dim_g)``, with the implicit ordering ``(k, mtilde, a)``. verbose (bool): Plot the optimization opt_kwargs (dict): Keyword arguments for the optimization like initial starting values for :math:`\theta` of dimension ``(dim_k,)``, given as ``theta0``. Also includes ``n_epochs``, ``lr``, ``b1``, ``b2``, ``verbose``, ``interrupt_tol``, see :func:`~run_opt` pick_min (bool): Whether to pick the parameter set with lowest cost function value during the optimization as optimal parameters. Otherwise picks the last parameter set. Returns: Tuple(np.ndarray, np.ndarray): ``(adjvec_a, theta_opt)``: The adjoint vector representation ``adjvec_a`` of dimension ``(dim_mtilde + dim_a,)``, with respect to the basis of :math:`\mathfrak{m} = \tilde{\mathfrak{m}} + \mathfrak{a}` of the CSA element :math:`a \in \mathfrak{a}` s.t. :math:`H = K a K^\dagger`. For a successful optimization, the entries corresponding to :math:`\tilde{\mathfrak{m}}` should be close to zero. The second return value, ``theta_opt``, are the optimal coefficients :math:`\theta` of the decomposition :math:`K = \prod_{j=|\mathfrak{k}|}^{1} e^{-i \theta_j k_j}` for the basis :math:`k_j \in \mathfrak{k}`. **Example** Let us perform a KaK decomposition for the transverse field Ising model Hamiltonian, exemplarily for :math:`n=3` qubits on a chain. We start with some boilerplate code to perform a Cartan decomposition using the :func:`~concurrence_involution`, which places the Hamiltonian in the horizontal subspace :math:`\mathfrak{m}`. From this we re-order :math:`\mathfrak{g} = \mathfrak{k} + \mathfrak{m}` and finally compute a :func:`~cartan_subalgebra` :math:`\mathfrak{a}` in :math:`\mathfrak{m} = \tilde{\mathfrak{m}} \oplus \mathfrak{a}`. .. code-block:: python import pennylane as qml import numpy as np import jax.numpy as jnp import jax from pennylane import X, Z from pennylane.labs.dla import ( cartan_decomp, cartan_subalgebra, check_cartan_decomp, concurrence_involution, validate_kak, variational_kak_adj, adjvec_to_op, ) n = 3 gens = [X(i) @ X(i + 1) for i in range(n - 1)] gens += [Z(i) for i in range(n)] H = qml.sum(*gens) g = qml.lie_closure(gens) g = [op.pauli_rep for op in g] involution = concurrence_involution assert not involution(H) k, m = cartan_decomp(g, involution=involution) assert check_cartan_decomp(k, m) g = k + m adj = qml.structure_constants(g) g, k, mtilde, a, adj = cartan_subalgebra(g, k, m, adj, tol=1e-14, start_idx=0) Due to the canonical ordering of all constituents, it suffices to tell ``variational_kak_adj`` the dimensions of ``dims = (len(k), len(mtilde), len(a))``, alongside the Hamiltonian ``H``, the Lie algebra ``g`` and its adjoint representation ``adj``. Internally, the function is performing a variational optimization to find a local extremum of a suitably constructed loss function that finds as its extremum the decomposition .. math:: K_c = \prod_{j=1}^{|\mathfrak{k}|} e^{-i \theta_j k_j} in form of the optimal parameters :math:`\{\theta_j\}` for the respective :math:`k_j \in \mathfrak{k}`. The resulting :math:`K` then informs the CSA element ``a`` of the KaK decomposition via :math:`a = K_c H K_c^\dagger`. This is detailed in `2104.00728 <https://arxiv.org/abs/2104.00728>`__. >>> dims = (len(k), len(mtilde), len(a)) >>> adjvec_a, theta_opt = variational_kak_adj(H, g, dims, adj, opt_kwargs={"n_epochs": 3000}) As a result, we are provided with the adjoint vector representation of the CSA element :math:`a \in \mathfrak{a}` with respect to the basis ``mtilde+a`` and the optimal parameters of dimension :math:`|\mathfrak{k}|` Let us perform some sanity checks to better understand the resulting outputs. We can turn that element back to an operator using :func:`adjvec_to_op` and from that to a matrix for which we can check Hermiticity. .. code-block:: python m = mtilde + a [a_op] = adjvec_to_op([adjvec_a], m) a_m = qml.matrix(a_op, wire_order=range(n)) assert np.allclose(a_m, a_m.conj().T) Let us now confirm that we get back the original Hamiltonian from the resulting :math:`K_c` and :math:`a`. In particular, we want to confirm :math:`H = K_c a K_c^\dagger` for :math:`K_c = \prod_{j=1}^{|\mathfrak{k}|} e^{-i \theta_j k_j}`. .. code-block:: python assert len(theta_opt) == len(k) def Kc(theta_opt): for th, op in zip(theta_opt, k): qml.exp(-1j * th * op.operation()) Kc_m = qml.matrix(Kc, wire_order=range(n))(theta_opt) # check Unitary property of Kc assert np.allclose(Kc_m.conj().T @ Kc_m, np.eye(2**n)) H_reconstructed = Kc_m @ a_m @ Kc_m.conj().T H_m = qml.matrix(H, wire_order=range(len(H.wires))) # check Hermitian property of reconstructed Hamiltonian assert np.allclose( H_reconstructed, H_reconstructed.conj().T ) # confirm reconstruction was successful to some given numerical tolerance assert np.allclose(H_m, H_reconstructed, atol=1e-6) Instead of performing these checks by hand, we can use the helper function :func:`~validate_kak`. >>> assert validate_kak(H, g, k, (adjvec_a, theta_opt), n, 1e-6) """ if not has_jax: # pragma: no cover raise ImportError( "jax and optax are required for variational_kak_adj. You can install them with pip install jax jaxlib optax." ) # pragma: no cover if opt_kwargs is None: opt_kwargs = {} if not isinstance(H, PauliSentence): H = H.pauli_rep dim_k, dim_mtilde, dim_h = dims dim_m = dim_mtilde + dim_h adj_cropped = adj[-dim_m:, :dim_k, -dim_m:] ## creating the gamma vector expanded on the whole m gammavec = jnp.zeros(dim_m) gammavec = gammavec.at[-dim_h:].set([np.pi**i for i in range(dim_h)]) def loss(theta, vec_H, adj): # this is different to Appendix F 1 in https://arxiv.org/pdf/2104.00728 # Making use of adjoint representation # should be faster, and most importantly allow for treatment of sums of paulis assert adj.shape == (len(vec_H), len(theta), len(vec_H)) # Implement Ad_(K_1 .. K_|k|) (vec_H), so that we get K_1 .. K_|k| H K^†_|k| .. K^†_1 for i in range(dim_k - 1, -1, -1): vec_H = jax.scipy.linalg.expm(theta[i] * adj[:, i]) @ vec_H return (gammavec @ vec_H).real value_and_grad = jax.jit(jax.value_and_grad(loss)) [vec_H] = op_to_adjvec([H], g[-dim_m:], is_orthogonal=False) theta0 = opt_kwargs.pop("theta0", None) if theta0 is None: theta0 = jax.random.normal(jax.random.PRNGKey(0), (dim_k,)) opt_kwargs["verbose"] = verbose thetas, energy, _ = run_opt( partial(value_and_grad, vec_H=vec_H, adj=adj_cropped), theta0, **opt_kwargs ) if verbose >= 1: plt.plot(energy - np.min(energy)) plt.xlabel("epochs") plt.ylabel("loss") plt.yscale("log") plt.show() idx = np.argmin(energy) if pick_min else -1 if verbose: n_epochs = opt_kwargs.get("n_epochs", 500) print(f"Picking entry with index {idx} out of {n_epochs-1} ({pick_min=}).") theta_opt = thetas[idx] # Implement Ad_(K_1 .. K_|k|) (vec_H) like in the loss, with optimized parameters now. for i in range(dim_k - 1, -1, -1): vec_H = jax.scipy.linalg.expm(theta_opt[i] * adj_cropped[:, i]) @ vec_H return vec_H, theta_opt
[docs]def validate_kak(H, g, k, kak_res, n, error_tol, verbose=False): """Helper function to validate a khk decomposition""" # validate h_elem is Hermitian _is_dense = all(isinstance(op, np.ndarray) for op in k) and all( isinstance(op, np.ndarray) for op in k ) vec_a, theta_opt = kak_res [a_elem] = adjvec_to_op([vec_a], g[len(k) :]) # sum(c * op for c, op in zip(vec_h, m)) if isinstance(a_elem, Operator): a_elem_m = qml.matrix(a_elem, wire_order=range(n)) elif isinstance(a_elem, PauliSentence): a_elem_m = a_elem.to_mat(wire_order=range(n)) else: a_elem_m = a_elem assert np.allclose(a_elem_m, a_elem_m.conj().T), "CSA element `a` not Hermitian" # validate K_c a K_c^† reproduces H # Compute the ansatz K_c = K(theta_c) = K_1(theta_1) .. K_|k|(theta_|k|) Km = jnp.eye(2**n) assert len(theta_opt) == len(k) for th, op in zip(theta_opt, k): opm = qml.matrix(op.operation(), wire_order=range(n)) if not _is_dense else op Km @= jax.scipy.linalg.expm(1j * th * opm) assert np.allclose(Km @ Km.conj().T, np.eye(2**n)) # Compute K_c^† a K_c H_reconstructed = Km.conj().T @ a_elem_m @ Km H_m = qml.matrix(H, wire_order=range(len(H.wires))) if verbose: print(f"Original matrix: {H_m}") print(f"Reconstructed matrix: {H_reconstructed}") assert np.allclose( H_reconstructed, H_reconstructed.conj().T ), "Reconstructed Hamiltonian not Hermitian" success = np.allclose(H_m, H_reconstructed, atol=error_tol) if not success: error = np.linalg.norm(H_m - H_reconstructed, ord="fro") warnings.warn( "The reconstructed H is not numerical identical to the original H.\n" f"We can still check for unitary equivalence: {error}", UserWarning, ) return success
[docs]def run_opt( value_and_grad, theta, n_epochs=500, lr=0.1, b1=0.99, b2=0.999, verbose=False, interrupt_tol=None, ): """Boilerplate jax optimization""" if not has_jax: # pragma: no cover raise ImportError( "jax and optax are required for run_opt. You can install them with pip install jax jaxlib optax." ) # pragma: no cover optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2) opt_state = optimizer.init(theta) energy, gradients, thetas = [], [], [] @jax.jit def step(opt_state, theta): val, grad_circuit = value_and_grad(theta) updates, opt_state = optimizer.update(grad_circuit, opt_state) theta = optax.apply_updates(theta, updates) return opt_state, theta, val, grad_circuit t0 = datetime.now() ## Optimization loop try: for n in range(n_epochs): opt_state, theta, val, grad_circuit = step(opt_state, theta) energy.append(val) gradients.append(grad_circuit) thetas.append(theta) if ( interrupt_tol is not None and (norm := np.linalg.norm(gradients[-1])) < interrupt_tol ): print( f"Interrupting after {n} epochs because gradient norm is {norm} < {interrupt_tol}" ) break if verbose: if n == 0: print("First optimization round performed") if n % (n_epochs // 20) == 0: print(f"Epoch {n:5d}: {val:.8f}") except KeyboardInterrupt: print( "KeyboardInterrupt received. Cancelled the optimization and will return intermediate result." ) t1 = datetime.now() if verbose: print(f"final loss: {val}; min loss: {np.min(energy)}; after {t1 - t0}") return thetas, energy, gradients