Source code for pennylane.templates.state_preparations.state_prep_mps

# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the MPSPrep template.
"""

import numpy as np

import pennylane as qml
from pennylane.operation import Operation
from pennylane.wires import Wires


def _validate_mps_shape(mps):
    r"""Validate that the MPS dimensions are correct.

    Args:
        mps (list[TensorLike]): List of tensors representing the MPS.
    """

    # Validate the shape and dimensions of the first tensor
    assert qml.math.isclose(
        len(qml.math.shape(mps[0])), 2
    ), "The first tensor must have exactly 2 dimensions."
    dj0, dj2 = qml.math.shape(mps[0])
    assert qml.math.isclose(dj0, 2), "The first dimension of the first tensor must be exactly 2."
    assert qml.math.log2(
        dj2
    ).is_integer(), "The second dimension of the first tensor must be a power of 2."

    # Validate the shapes of the intermediate tensors
    for i, array in enumerate(mps[1:-1], start=1):
        shape = qml.math.shape(array)
        assert qml.math.isclose(len(shape), 3), f"Tensor {i} must have exactly 3 dimensions."
        new_dj0, new_dj1, new_dj2 = shape
        assert qml.math.isclose(
            new_dj1, 2
        ), f"The second dimension of tensor {i} must be exactly 2."
        assert qml.math.log2(
            new_dj0
        ).is_integer(), f"The first dimension of tensor {i} must be a power of 2."
        assert qml.math.isclose(
            new_dj1, 2
        ), f"The second dimension of tensor {i} must be exactly 2."
        assert qml.math.log2(
            new_dj2
        ).is_integer(), f"The third dimension of tensor {i} must be a power of 2."
        assert qml.math.isclose(
            new_dj0, dj2
        ), f"Dimension mismatch: tensor {i}'s first dimension does not match the previous third dimension."
        dj2 = new_dj2

    # Validate the shape and dimensions of the last tensor
    assert qml.math.isclose(
        len(qml.math.shape(mps[-1])), 2
    ), "The last tensor must have exactly 2 dimensions."
    new_dj0, new_dj1 = qml.math.shape(mps[-1])
    assert new_dj1 == 2, "The second dimension of the last tensor must be exactly 2."
    assert qml.math.log2(
        new_dj0
    ).is_integer(), "The first dimension of the last tensor must be a power of 2."
    assert qml.math.isclose(
        new_dj0, dj2
    ), "Dimension mismatch: the last tensor's first dimension does not match the previous third dimension."


[docs]def right_canonicalize_mps(mps): r"""Transform a matrix product state (MPS) into its right-canonical form. A right-canonicalized MPS is a matrix product state in which the constituent tensors, :math:`A^{(j)}`, satisfy the following orthonormality condition [Eq. (21) of `arXiv:2310.18410 <https://arxiv.org/pdf/2310.18410>`_]: .. math:: \sum_{d_{j,1}, d_{j,2}} A^{(j)}_{d_{j, 0}, d_{j, 1}, d_{j, 2}} \left( A^{(j)}_{d'_{j, 0}, d_{j, 1}, d_{j, 2}} \right)^* = \delta_{d_{j, 0}, d'_{j, 0}}, where :math:`d_{i,j}` denotes the :math:`j` dimension of the :math:`i` tensor and :math:`\delta` is the Kronecker delta. Args: mps (list[TensorLike]): List of tensors representing the MPS. Returns: List of tensors representing the MPS in right-canonical form with the same dimensions as the initial MPS. .. seealso:: :class:`~.MPSPrep`. **Example** .. code-block:: n_sites = 4 mps = ([np.ones((2, 4))] + [np.ones((4, 2, 4)) for _ in range(1, n_sites - 1)] + [np.ones((4, 2))]) mps_rc = qml.right_canonicalize_mps(mps) # Check that the right-canonical definition is fulfilled for i in range(1, n_sites - 1): tensor = mps_rc[i] contraction_matrix = np.tensordot(tensor, tensor.conj(), axes=([1, 2], [1, 2])) assert np.allclose(contraction_matrix, np.eye(tensor.shape[0])) .. details:: :title: Usage Details The input MPS must be a list of :math:`n` tensors :math:`[A^{(0)}, ..., A^{(n-1)}]` with shapes :math:`d_0, ..., d_{n-1}`, respectively. The first and last tensors have rank :math:`2` while the intermediate tensors have rank :math:`3`. The first tensor must have the shape :math:`d_0 = (d_{0,0}, d_{0,1})` where :math:`d_{0,0}` and :math:`d_{0,1}` correspond to the physical dimension of the site and an auxiliary bond dimension connecting it to the next tensor, respectively. The last tensor must have the shape :math:`d_{n-1} = (d_{n-1,0}, d_{n-1,1})` where :math:`d_{n-1,0}` and :math:`d_{n-1,1}` represent the auxiliary dimension from the previous site and the physical dimension of the site, respectively. The intermediate tensors must have the shape :math:`d_j = (d_{j,0}, d_{j,1}, d_{j,2})`, where: - :math:`d_{j,0}` is the bond dimension connecting to the previous tensor - :math:`d_{j,1}` is the physical dimension for the site - :math:`d_{j,2}` is the bond dimension connecting to the next tensor Note that the bond dimensions must match between adjacent tensors such that :math:`d_{j-1,2} = d_{j,0}`. Additionally, the physical dimension of the site should always be fixed at :math:`2` (since the dimension of a qubit is :math:`2`), while the other dimensions must be powers of two. The following example shows a valid MPS input containing four tensors with dimensions :math:`[(2,2), (2,2,4), (4,2,2), (2,2)]` which satisfy the criteria described above. .. code-block:: mps = [ np.array([[0.0, 0.107], [0.994, 0.0]]), np.array( [ [[0.0, 0.0, 0.0, -0.0], [1.0, 0.0, 0.0, -0.0]], [[0.0, 1.0, 0.0, -0.0], [0.0, 0.0, 0.0, -0.0]], ] ), np.array( [ [[-1.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]], [[0.0, -1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]], ] ), np.array([[-1.0, -0.0], [-0.0, -1.0]]), ] """ _validate_mps_shape(mps) mps = mps.copy() mps[0] = mps[0].reshape((1, *mps[0].shape)) mps[-1] = mps[-1].reshape((*mps[-1].shape, 1)) if not qml.math.is_abstract(mps[0]): is_right_canonical = True for tensor in mps[1:-1]: # Right-canonical definition input_matrix = qml.math.tensordot(tensor, tensor.conj(), axes=([1, 2], [1, 2])) if not qml.math.allclose(input_matrix, qml.math.eye(tensor.shape[0])): is_right_canonical = False break if is_right_canonical: mps[0] = mps[0][0] mps[-1] = mps[-1][:, :, 0] return mps d_shapes = [] for tensor in mps[1:-1]: d_shapes += tensor.shape max_bond_dim = qml.math.max(d_shapes) n_sites = len(mps) output_mps = [None] * n_sites # Procedure analogous to the left-canonical conversion but starting from the right and storing the Vd, # where Vd is the right matrix in the Singular Value Decomposition (SVD) for i in range(n_sites - 1, 0, -1): chi_left, d, chi_right = mps[i].shape input_matrix = mps[i].reshape(chi_left, d * chi_right) u_matrix, s_diag, vd_matrix = qml.math.linalg.svd(input_matrix, full_matrices=False) # Truncate SVD components if needed chi_new = min(int(max_bond_dim), len(s_diag)) u_matrix = u_matrix[:, :chi_new] s_diag = s_diag[:chi_new] vd_matrix = vd_matrix[:chi_new, :] # Store Vd reshaped as an MPS tensor in the output MPS output_mps[i] = vd_matrix.reshape(chi_new, d, chi_right) # Contract U with diag(S) and merge it with the preceding MPS tensor, preserving the canonical structure mps[i - 1] = qml.math.tensordot( mps[i - 1], u_matrix @ qml.math.diag(s_diag), axes=([2], [0]) ) output_mps[0] = mps[0][0] output_mps[-1] = output_mps[-1][:, :, 0] return output_mps
[docs]class MPSPrep(Operation): r"""Prepares an initial state from a matrix product state (MPS) representation. .. note:: This operator is natively supported on the ``lightning.tensor`` device, which is designed to run MPS structures efficiently. For other devices, this operation prepares the state vector represented by the MPS using a gate-based decomposition from Eq. (23) in `arXiv:2310.18410 <https://arxiv.org/pdf/2310.18410>`_, which requires the right canonicalization of the MPS using the :func:`~.right_canonicalize_mps` function and defining auxiliary qubits with ``work_wires``. Args: mps (list[TensorLike]): list of arrays of rank-3 and rank-2 tensors representing an MPS state as a product of site matrices. See the usage details section for more information. wires (Sequence[int]): wires that the template acts on. It should match the number of MPS tensors. work_wires (Sequence[int]): list of extra qubits needed in the decomposition. If the maximum dimension of the MPS tensors is :math:`2^k`, then :math:`k` ``work_wires`` will be needed. If no ``work_wires`` are given, this operator can only be executed on the ``lightning.tensor`` device. Default is ``None``. right_canonicalize (bool): indicates whether a conversion to right-canonical form should be performed to the MPS. Default is ``False``. .. seealso:: :func:`~.right_canonicalize_mps`. **Example** Example using the ``lightning.tensor`` device: .. code-block:: mps = [ np.array([[0.0, 0.107], [0.994, 0.0]]), np.array( [ [[0.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], ] ), np.array([[-1.0, -0.0], [-0.0, -1.0]]), ] dev = qml.device("lightning.tensor", wires=3) @qml.qnode(dev) def circuit(): qml.MPSPrep(mps, wires = [0,1,2]) return qml.state() .. code-block:: pycon >>> print(circuit()) [ 0. +0.j -0.10705513+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j -0.99451217+0.j 0. +0.j] Example using the ``default.qubit`` device: .. code-block:: dev = qml.device("default.qubit", wires=4) @qml.qnode(dev) def circuit(): qml.MPSPrep(mps, wires = [1,2,3], work_wires = [0]) return qml.state() .. code-block:: pycon >>> print(circuit()[:8]) [ 0. +0.j -0.10705513+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j -0.99451217+0.j 0. +0.j] .. details:: :title: Usage Details The input MPS must be a list of :math:`n` tensors :math:`[A^{(0)}, ..., A^{(n-1)}]` with shapes :math:`d_0, ..., d_{n-1}`, respectively. The first and last tensors have rank :math:`2` while the intermediate tensors have rank :math:`3`. The first tensor must have the shape :math:`d_0 = (d_{0,0}, d_{0,1})` where :math:`d_{0,0}` and :math:`d_{0,1}` correspond to the physical dimension of the site and an auxiliary bond dimension connecting it to the next tensor, respectively. The last tensor must have the shape :math:`d_{n-1} = (d_{n-1,0}, d_{n-1,1})` where :math:`d_{n-1,0}` and :math:`d_{n-1,1}` represent the auxiliary dimension from the previous site and the physical dimension of the site, respectively. The intermediate tensors must have the shape :math:`d_j = (d_{j,0}, d_{j,1}, d_{j,2})`, where: - :math:`d_{j,0}` is the bond dimension connecting to the previous tensor - :math:`d_{j,1}` is the physical dimension for the site - :math:`d_{j,2}` is the bond dimension connecting to the next tensor Note that the bond dimensions must match between adjacent tensors such that :math:`d_{j-1,2} = d_{j,0}`. Additionally, the physical dimension of the site should always be fixed at :math:`2` (since the dimension of a qubit is :math:`2`), while the other dimensions must be powers of two. The following example shows a valid MPS input containing four tensors with dimensions :math:`[(2,2), (2,2,4), (4,2,2), (2,2)]` which satisfy the criteria described above. .. code-block:: mps = [ np.array([[0.0, 0.107], [0.994, 0.0]]), np.array( [ [[0.0, 0.0, 0.0, -0.0], [1.0, 0.0, 0.0, -0.0]], [[0.0, 1.0, 0.0, -0.0], [0.0, 0.0, 0.0, -0.0]], ] ), np.array( [ [[-1.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]], [[0.0, -1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]], ] ), np.array([[-1.0, -0.0], [-0.0, -1.0]]), ] """ def __init__( self, mps, wires, work_wires=None, right_canonicalize=False, id=None ): # pylint: disable=too-many-arguments,too-many-positional-arguments _validate_mps_shape(mps) self.hyperparameters["input_wires"] = qml.wires.Wires(wires) self.hyperparameters["right_canonicalize"] = right_canonicalize if work_wires: self.hyperparameters["work_wires"] = qml.wires.Wires(work_wires) all_wires = self.hyperparameters["input_wires"] + self.hyperparameters["work_wires"] else: self.hyperparameters["work_wires"] = None all_wires = self.hyperparameters["input_wires"] super().__init__(*mps, wires=all_wires, id=id) @property def mps(self): """list representing the MPS input""" return self.data def _flatten(self): hyperparameters = ( ("wires", self.hyperparameters["input_wires"]), ("work_wires", self.hyperparameters["work_wires"]), ("right_canonicalize", self.hyperparameters["right_canonicalize"]), ) return self.mps, hyperparameters @classmethod def _unflatten(cls, data, metadata): hyperparams_dict = dict(metadata) return cls(data, **hyperparams_dict)
[docs] def map_wires(self, wire_map): new_wires = Wires( [wire_map.get(wire, wire) for wire in self.hyperparameters["input_wires"]] ) new_work_wires = Wires( [wire_map.get(wire, wire) for wire in self.hyperparameters["work_wires"]] ) return MPSPrep( self.mps, new_wires, new_work_wires, self.hyperparameters["right_canonicalize"] )
@classmethod def _primitive_bind_call(cls, mps, wires, id=None): # pylint: disable=arguments-differ if cls._primitive is None: # guard against this being called when primitive is not defined. return type.__call__(cls, mps=mps, wires=wires, id=id) # pragma: no cover return cls._primitive.bind(*mps, wires=wires, id=id)
[docs] def decomposition(self): # pylint: disable=arguments-differ filtered_hyperparameters = { key: value for key, value in self.hyperparameters.items() if key != "input_wires" } return self.compute_decomposition( self.parameters, wires=self.hyperparameters["input_wires"], **filtered_hyperparameters )
[docs] @staticmethod def compute_decomposition( mps, wires, work_wires, right_canonicalize=False ): # pylint: disable=arguments-differ r"""Representation of the operator as a product of other operators. The decomposition follows Eq. (23) in `arXiv:2310.18410 <https://arxiv.org/pdf/2310.18410>`_. Args: mps (list[Array]): list of arrays of rank-3 and rank-2 tensors representing an MPS state as a product of site matrices. wires (Sequence[int]): wires that the template acts on. It should match the number of MPS tensors. work_wires (Sequence[int]): list of extra qubits needed in the decomposition. If the maximum dimension of the MPS tensors is `2^k``, then k ``work_wires`` will be needed. If no ``work_wires`` are given, this operator can only be executed on the ``lightning.tensor`` device. Default is ``None``. right_canonicalize (bool): Indicates whether a conversion to right-canonical form should be performed to the mps. Default is ``False``. Returns: list[.Operator]: Decomposition of the operator """ if work_wires is None: raise ValueError("The qml.MPSPrep decomposition requires `work_wires` to be specified.") max_bond_dimension = 0 for i in range(len(mps) - 1): bond_dim = mps[i].shape[-1] max_bond_dimension = max(max_bond_dimension, bond_dim) if max_bond_dimension > 2 ** len(work_wires): raise ValueError( f"Incorrect number of `work_wires`. At least {int(qml.math.ceil(qml.math.log2(max_bond_dimension)))} `work_wires` must be provided." ) ops = [] n_wires = len(work_wires) + 1 mps = mps.copy() # Transform the MPS to ensure that the generated matrix is unitary if right_canonicalize: mps = right_canonicalize_mps(mps) mps[0] = mps[0].reshape((1, *mps[0].shape)) mps[-1] = mps[-1].reshape((*mps[-1].shape, 1)) interface, dtype = qml.math.get_interface(mps[0]), mps[0].dtype for i, Ai in enumerate(mps): # Encode the tensor Ai in a unitary matrix following Eq.23 in https://arxiv.org/pdf/2310.18410 vectors = [] for column in Ai: vector = qml.math.zeros(2**n_wires, like=interface, dtype=dtype) if interface == "jax": vector = vector.at[: len(column[0])].set(column[0]) vector = vector.at[ 2 ** (n_wires - 1) : 2 ** (n_wires - 1) + len(column[1]) ].set(column[1]) else: vector[: len(column[0])] = column[0] vector[2 ** (n_wires - 1) : 2 ** (n_wires - 1) + len(column[1])] = column[1] vectors.append(vector) vectors = qml.math.stack(vectors).T # The unitary is completed using QR decomposition d, k = vectors.shape new_columns = qml.math.array(np.random.RandomState(42).random((d, d - k))) unitary_matrix, R = qml.math.linalg.qr(qml.math.hstack([vectors, new_columns])) unitary_matrix *= qml.math.sign( qml.math.diag(R) ) # Enforce uniqueness for QR decomposition ops.append(qml.QubitUnitary(unitary_matrix, wires=[wires[i]] + work_wires)) return ops
if MPSPrep._primitive is not None: # pylint: disable=protected-access @MPSPrep._primitive.def_impl # pylint: disable=protected-access def _(*args, **kwargs): return type.__call__(MPSPrep, args, **kwargs)