Source code for pennylane.ops.op_math.composite
# Copyright 2018-2022 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule defines a base class for composite operations.
"""
# pylint: disable=too-many-instance-attributes,invalid-sequence-index
import abc
import copy
from collections.abc import Callable
from functools import wraps
import pennylane as qml
from pennylane import math
from pennylane.operation import _UNSET_BATCH_SIZE, Operator
from pennylane.wires import Wires
# pylint: disable=too-many-instance-attributes
def handle_recursion_error(func):
"""Handles any recursion errors raised from too many levels of nesting."""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RecursionError as e:
raise RuntimeError(
"Maximum recursion depth reached! This is likely due to nesting too many levels "
"of composite operators. Try setting lazy=False when calling qml.sum, qml.prod, "
"and qml.s_prod, or use the +, @, and * operators instead. Alternatively, you "
"can periodically call qml.simplify on your operators."
) from e
return wrapper
[docs]class CompositeOp(Operator):
"""A base class for operators that are composed of other operators.
Args:
operands: (tuple[~.operation.Operator]): a tuple of operators which will be combined.
Keyword Args:
id (str or None): id for the operator. Default is None.
The child composite operator should define the `_op_symbol` property
during initialization and define any relevant representations, such as
:meth:`~.operation.Operator.matrix` and :meth:`~.operation.Operator.decomposition`.
"""
@classmethod
def _primitive_bind_call(cls, *args, **kwargs):
# needs to be overwritten because it doesnt take wires
return cls._primitive.bind(*args, **kwargs)
def _flatten(self):
return tuple(self.operands), tuple()
@classmethod
def _unflatten(cls, data, metadata):
return cls(*data)
_eigs = {} # cache eigen vectors and values like in qml.Hermitian
def __init__(
self, *operands: Operator, id=None, _pauli_rep=None
): # pylint: disable=super-init-not-called
self._id = id
self._name = self.__class__.__name__
self.operands = operands
self._wires = qml.wires.Wires.all_wires([op.wires for op in operands])
self._hash = None
self._has_overlapping_wires = None
self._overlapping_ops = None
self._pauli_rep = self._build_pauli_rep() if _pauli_rep is None else _pauli_rep
self.queue()
self._batch_size = _UNSET_BATCH_SIZE
@handle_recursion_error
def _check_batching(self):
batch_sizes = {op.batch_size for op in self if op.batch_size is not None}
if len(batch_sizes) > 1:
raise ValueError(
"Broadcasting was attempted but the broadcasted dimensions "
f"do not match: {batch_sizes}."
)
self._batch_size = batch_sizes.pop() if batch_sizes else None
def __repr__(self):
return f" {self._op_symbol} ".join(
[f"({op})" if op.arithmetic_depth > 0 else f"{op}" for op in self]
)
@handle_recursion_error
def __copy__(self):
cls = self.__class__
copied_op = cls.__new__(cls)
copied_op.operands = tuple(s.__copy__() for s in self)
for attr, value in vars(self).items():
if attr not in {"operands"}:
setattr(copied_op, attr, value) # TODO: exclude data?
return copied_op
def __iter__(self):
"""Return the iterator over the underlying operands."""
return iter(self.operands)
def __getitem__(self, idx):
"""Return the operand at position ``idx`` of the composition."""
return self.operands[idx]
def __len__(self):
"""Return the number of operators in this composite operator"""
return len(self.operands)
@property
@abc.abstractmethod
def _op_symbol(self) -> str:
"""The symbol used when visualizing the composite operator"""
@property
@handle_recursion_error
def data(self):
"""Create data property"""
return tuple(d for op in self for d in op.data)
@data.setter
def data(self, new_data):
"""Set the data property"""
for op in self:
op_num_params = op.num_params
if op_num_params > 0:
op.data = new_data[:op_num_params]
new_data = new_data[op_num_params:]
@property
def num_wires(self):
"""Number of wires the operator acts on."""
return len(self.wires)
@property
@handle_recursion_error
def num_params(self):
return sum(op.num_params for op in self)
@property
def has_overlapping_wires(self) -> bool:
"""Boolean expression that indicates if the factors have overlapping wires."""
if self._has_overlapping_wires is None:
wires = []
for op in self:
wires.extend(list(op.wires))
self._has_overlapping_wires = len(wires) != len(set(wires))
return self._has_overlapping_wires
@property
@abc.abstractmethod
def is_hermitian(self):
"""This property determines if the composite operator is hermitian."""
# pylint: disable=arguments-renamed, invalid-overridden-method
@property
@handle_recursion_error
def has_matrix(self):
return all(op.has_matrix or isinstance(op, qml.ops.Hamiltonian) for op in self)
[docs] @handle_recursion_error
def eigvals(self):
"""Return the eigenvalues of the specified operator.
This method uses pre-stored eigenvalues for standard observables where
possible and stores the corresponding eigenvectors from the eigendecomposition.
Returns:
array: array containing the eigenvalues of the operator
"""
eigvals = []
for ops in self.overlapping_ops:
if len(ops) == 1:
eigvals.append(
qml.utils.expand_vector(ops[0].eigvals(), list(ops[0].wires), list(self.wires))
)
else:
tmp_composite = self.__class__(*ops)
eigvals.append(
qml.utils.expand_vector(
tmp_composite.eigendecomposition["eigval"],
list(tmp_composite.wires),
list(self.wires),
)
)
framework = math.get_deep_interface(eigvals)
eigvals = [math.asarray(ei, like=framework) for ei in eigvals]
return self._math_op(math.vstack(eigvals), axis=0)
[docs] @abc.abstractmethod
def matrix(self, wire_order=None):
"""Representation of the operator as a matrix in the computational basis."""
@property
def overlapping_ops(self) -> list[list[Operator]]:
"""Groups all operands of the composite operator that act on overlapping wires.
Returns:
List[List[Operator]]: List of lists of operators that act on overlapping wires. All the
inner lists commute with each other.
"""
if self._overlapping_ops is not None:
return self._overlapping_ops
groups = []
for op in self:
# For every op, find all groups that have overlapping wires with it.
i = 0
first_group_idx = None
while i < len(groups):
if first_group_idx is None and any(wire in op.wires for wire in groups[i][1]):
# Found the first group that has overlapping wires with this op
groups[i][1] = groups[i][1] + op.wires
first_group_idx = i # record the index of this group
i += 1
elif first_group_idx is not None and any(wire in op.wires for wire in groups[i][1]):
# If the op has already been added to the first group, every subsequent
# group that overlaps with this op is merged into the first group
ops, wires = groups.pop(i)
groups[first_group_idx][0].extend(ops)
groups[first_group_idx][1] = groups[first_group_idx][1] + wires
else:
i += 1
if first_group_idx is not None:
groups[first_group_idx][0].append(op)
else:
# Create new group
groups.append([[op], op.wires])
self._overlapping_ops = [group[0] for group in groups]
return self._overlapping_ops
@property
def eigendecomposition(self):
r"""Return the eigendecomposition of the matrix specified by the operator.
This method uses pre-stored eigenvalues for standard observables where
possible and stores the corresponding eigenvectors from the eigendecomposition.
It transforms the input operator according to the wires specified.
Returns:
dict[str, array]: dictionary containing the eigenvalues and the
eigenvectors of the operator.
"""
eigen_func = math.linalg.eigh if self.is_hermitian else math.linalg.eig
if self.hash not in self._eigs:
mat = self.matrix()
w, U = eigen_func(mat)
self._eigs[self.hash] = {"eigvec": U, "eigval": w}
return self._eigs[self.hash]
@property
def has_diagonalizing_gates(self):
if self.has_overlapping_wires:
for ops in self.overlapping_ops:
# if any of the single ops doesn't have diagonalizing gates, the overall operator doesn't either
if len(ops) == 1 and not ops[0].has_diagonalizing_gates:
return False
# the lists of ops with multiple operators can be handled if there is a matrix
return self.has_matrix
return all(op.has_diagonalizing_gates for op in self)
[docs] def diagonalizing_gates(self):
r"""Sequence of gates that diagonalize the operator in the computational basis.
Given the eigendecomposition :math:`O = U \Sigma U^{\dagger}` where
:math:`\Sigma` is a diagonal matrix containing the eigenvalues,
the sequence of diagonalizing gates implements the unitary :math:`U^{\dagger}`.
The diagonalizing gates rotate the state into the eigenbasis
of the operator.
A ``DiagGatesUndefinedError`` is raised if no representation by decomposition is defined.
.. seealso:: :meth:`~.Operator.compute_diagonalizing_gates`.
Returns:
list[.Operator] or None: a list of operators
"""
diag_gates = []
for ops in self.overlapping_ops:
if len(ops) == 1:
diag_gates.extend(ops[0].diagonalizing_gates())
else:
tmp_sum = self.__class__(*ops)
eigvecs = tmp_sum.eigendecomposition["eigvec"]
diag_gates.append(
qml.QubitUnitary(math.transpose(math.conj(eigvecs)), wires=tmp_sum.wires)
)
return diag_gates
[docs] @handle_recursion_error
def label(self, decimals=None, base_label=None, cache=None):
r"""How the composite operator is represented in diagrams and drawings.
Args:
decimals (int): If ``None``, no parameters are included. Else,
how to round the parameters. Defaults to ``None``.
base_label (Iterable[str]): Overwrite the non-parameter component of the label.
Must be same length as ``operands`` attribute. Defaults to ``None``.
cache (dict): Dictionary that carries information between label calls
in the same drawing. Defaults to ``None``.
Returns:
str: label to use in drawings
**Example (using the Sum composite operator)**
>>> op = qml.S(0) + qml.X(0) + qml.Rot(1,2,3, wires=[1])
>>> op.label()
'(S+X)+Rot'
>>> op.label(decimals=2, base_label=[["my_s", "my_x"], "inc_rot"])
'(my_s+my_x)+inc_rot\n(1.00,\n2.00,\n3.00)'
"""
def _label(op, decimals, base_label, cache):
sub_label = op.label(decimals, base_label, cache)
return f"({sub_label})" if op.arithmetic_depth > 0 else sub_label
if base_label is not None:
if isinstance(base_label, str) or len(base_label) != len(self):
raise ValueError(
"Composite operator labels require ``base_label`` keyword to be same length as operands."
)
return self._op_symbol.join(
_label(op, decimals, lbl, cache) for op, lbl in zip(self, base_label)
)
return self._op_symbol.join(_label(op, decimals, None, cache) for op in self)
[docs] def queue(self, context=qml.QueuingManager):
"""Updates each operator's owner to self, this ensures
that the operators are not applied to the circuit repeatedly."""
if qml.QueuingManager.recording():
for op in self:
context.remove(op)
context.append(self)
return self
@classmethod
@abc.abstractmethod
def _sort(cls, op_list, wire_map: dict = None) -> list[Operator]:
"""Sort composite operands by their wire indices."""
@property
@handle_recursion_error
def hash(self):
if self._hash is None:
self._hash = hash(
(str(self.name), str([factor.hash for factor in self._sort(self.operands)]))
)
return self._hash
# pylint:disable = missing-function-docstring
@property
def basis(self):
return None
@property
@handle_recursion_error
def arithmetic_depth(self) -> int:
return 1 + max(op.arithmetic_depth for op in self)
@property
@abc.abstractmethod
def _math_op(self) -> Callable:
"""The function used when combining the operands of the composite operator"""
[docs] @handle_recursion_error
def map_wires(self, wire_map: dict):
# pylint:disable=protected-access
cls = self.__class__
new_op = cls.__new__(cls)
new_op.operands = tuple(op.map_wires(wire_map=wire_map) for op in self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op.data = copy.copy(self.data)
if self._overlapping_ops is not None:
new_op._overlapping_ops = [
[o.map_wires(wire_map) for o in _ops] for _ops in self._overlapping_ops
]
else:
new_op._overlapping_ops = None
for attr, value in vars(self).items():
if attr not in {"data", "operands", "_wires", "_overlapping_ops"}:
setattr(new_op, attr, value)
if (p_rep := new_op.pauli_rep) is not None:
new_op._pauli_rep = p_rep.map_wires(wire_map)
return new_op
@abc.abstractmethod
def _build_pauli_rep(self):
"""The function to generate the pauli representation for the composite operator."""
_modules/pennylane/ops/op_math/composite
Download Python script
Download Notebook
View on GitHub