Source code for pennylane.templates.state_preparations.state_prep_mps
# Copyright 2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the MPSPrep template.
"""
import pennylane as qml
from pennylane.operation import Operation
from pennylane.wires import Wires
[docs]class MPSPrep(Operation):
r"""Prepares an initial state from a matrix product state (MPS) representation.
.. note::
Currently, this operator can only be used with ``qml.device(“lightning.tensor”)``.
Args:
mps (List[Array]): list of arrays of rank-3 and rank-2 tensors representing an MPS state as a
product of site matrices. See the usage details section for more information.
wires (Sequence[int]): wires that the template acts on
**Example**
.. code-block::
mps = [
np.array([[0.0, 0.107], [0.994, 0.0]]),
np.array(
[
[[0.0, 0.0], [1.0, 0.0]],
[[0.0, 1.0], [0.0, 0.0]],
]
),
np.array([[-1.0, -0.0], [-0.0, -1.0]]),
]
dev = qml.device("lightning.tensor", wires=3)
@qml.qnode(dev)
def circuit():
qml.MPSPrep(mps, wires = [0,1,2])
return qml.state()
.. code-block:: pycon
>>> print(circuit())
[ 0. +0.j -0.10705513+0.j 0. +0.j 0. +0.j
0. +0.j 0. +0.j -0.99451217+0.j 0. +0.j]
.. details::
:title: Usage Details
The input MPS must be a list of :math:`n` tensors :math:`[A^{(0)}, ..., A^{(n-1)}]`
with shapes :math:`d_0, ..., d_{n-1}`, respectively. The first and last tensors have rank :math:`2`
while the intermediate tensors have rank :math:`3`.
The first tensor must have the shape :math:`d_0 = (d_{0,0}, d_{0,1})` where :math:`d_{0,0}`
and :math:`d_{0,1}` correspond to the physical dimension of the site and an auxiliary bond
dimension connecting it to the next tensor, respectively.
The last tensor must have the shape :math:`d_{n-1} = (d_{n-1,0}, d_{n-1,1})` where :math:`d_{n-1,0}`
and :math:`d_{n-1,1}` represent the auxiliary dimension from the previous site and the physical
dimension of the site, respectively.
The intermediate tensors must have the shape :math:`d_j = (d_{j,0}, d_{j,1}, d_{j,2})`, where:
- :math:`d_{j,0}` is the bond dimension connecting to the previous tensor
- :math:`d_{j,1}` is the physical dimension for the site
- :math:`d_{j,2}` is the bond dimension connecting to the next tensor
Note that the bond dimensions must match between adjacent tensors such that :math:`d_{j-1,2} = d_{j,0}`.
Additionally, the physical dimension of the site should always be fixed at :math:`2`
(since the dimension of a qubit is :math:`2`), while the other dimensions must be powers of two.
The following example shows a valid MPS input containing four tensors with
dimensions :math:`[(2,2), (2,2,4), (4,2,2), (2,2)]` which satisfy the criteria described above.
.. code-block::
mps = [
np.array([[0.0, 0.107], [0.994, 0.0]]),
np.array(
[
[[0.0, 0.0, 0.0, -0.0], [1.0, 0.0, 0.0, -0.0]],
[[0.0, 1.0, 0.0, -0.0], [0.0, 0.0, 0.0, -0.0]],
]
),
np.array(
[
[[-1.0, 0.0], [0.0, 0.0]],
[[0.0, 0.0], [0.0, 1.0]],
[[0.0, -1.0], [0.0, 0.0]],
[[0.0, 0.0], [1.0, 0.0]],
]
),
np.array([[-1.0, -0.0], [-0.0, -1.0]]),
]
"""
def __init__(self, mps, wires, id=None):
# Validate the shape and dimensions of the first tensor
assert qml.math.isclose(
len(qml.math.shape(mps[0])), 2
), "The first tensor must have exactly 2 dimensions."
dj0, dj2 = qml.math.shape(mps[0])
assert qml.math.isclose(
dj0, 2
), "The first dimension of the first tensor must be exactly 2."
assert qml.math.log2(
dj2
).is_integer(), "The second dimension of the first tensor must be a power of 2."
# Validate the shapes of the intermediate tensors
for i, array in enumerate(mps[1:-1], start=1):
shape = qml.math.shape(array)
assert qml.math.isclose(len(shape), 3), f"Tensor {i} must have exactly 3 dimensions."
new_dj0, new_dj1, new_dj2 = shape
assert qml.math.isclose(
new_dj1, 2
), f"The second dimension of tensor {i} must be exactly 2."
assert qml.math.log2(
new_dj0
).is_integer(), f"The first dimension of tensor {i} must be a power of 2."
assert qml.math.isclose(
new_dj1, 2
), f"The second dimension of tensor {i} must be exactly 2."
assert qml.math.log2(
new_dj2
).is_integer(), f"The third dimension of tensor {i} must be a power of 2."
assert qml.math.isclose(
new_dj0, dj2
), f"Dimension mismatch: tensor {i}'s first dimension does not match the previous third dimension."
dj2 = new_dj2
# Validate the shape and dimensions of the last tensor
assert qml.math.isclose(
len(qml.math.shape(mps[-1])), 2
), "The last tensor must have exactly 2 dimensions."
new_dj0, new_dj1 = qml.math.shape(mps[-1])
assert new_dj1 == 2, "The second dimension of the last tensor must be exactly 2."
assert qml.math.log2(
new_dj0
).is_integer(), "The first dimension of the last tensor must be a power of 2."
assert qml.math.isclose(
new_dj0, dj2
), "Dimension mismatch: the last tensor's first dimension does not match the previous third dimension."
super().__init__(*mps, wires=wires, id=id)
@property
def mps(self):
"""list representing the MPS input"""
return self.data
def _flatten(self):
hyperparameters = (("wires", self.wires),)
return self.mps, hyperparameters
@classmethod
def _unflatten(cls, data, metadata):
hyperparams_dict = dict(metadata)
return cls(data, **hyperparams_dict)
[docs] def map_wires(self, wire_map):
new_wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
return MPSPrep(self.mps, new_wires)
@classmethod
def _primitive_bind_call(cls, mps, wires, id=None):
# pylint: disable=arguments-differ
if cls._primitive is None:
# guard against this being called when primitive is not defined.
return type.__call__(cls, mps=mps, wires=wires, id=id) # pragma: no cover
return cls._primitive.bind(*mps, wires=wires, id=id)
if MPSPrep._primitive is not None: # pylint: disable=protected-access
@MPSPrep._primitive.def_impl # pylint: disable=protected-access
def _(*args, **kwargs):
return type.__call__(MPSPrep, args, **kwargs)
_modules/pennylane/templates/state_preparations/state_prep_mps
Download Python script
Download Notebook
View on GitHub