# Copyright 2018-2022 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This submodule defines a base class for composite operations."""# pylint: disable=too-many-instance-attributes,invalid-sequence-indeximportabcimportcopyfromcollections.abcimportCallablefromfunctoolsimportwrapsimportpennylaneasqmlfrompennylaneimportmathfrompennylane.operationimport_UNSET_BATCH_SIZE,Operatorfrompennylane.wiresimportWires# pylint: disable=too-many-instance-attributesdefhandle_recursion_error(func):"""Handles any recursion errors raised from too many levels of nesting."""@wraps(func)defwrapper(*args,**kwargs):try:returnfunc(*args,**kwargs)exceptRecursionErrorase:raiseRuntimeError("Maximum recursion depth reached! This is likely due to nesting too many levels ""of composite operators. Try setting lazy=False when calling qml.sum, qml.prod, ""and qml.s_prod, or use the +, @, and * operators instead. Alternatively, you ""can periodically call qml.simplify on your operators.")fromereturnwrapper
[docs]classCompositeOp(Operator):"""A base class for operators that are composed of other operators. Args: operands: (tuple[~.operation.Operator]): a tuple of operators which will be combined. Keyword Args: id (str or None): id for the operator. Default is None. The child composite operator should define the `_op_symbol` property during initialization and define any relevant representations, such as :meth:`~.operation.Operator.matrix` and :meth:`~.operation.Operator.decomposition`. """@classmethoddef_primitive_bind_call(cls,*args,**kwargs):# needs to be overwritten because it doesnt take wiresreturncls._primitive.bind(*args,**kwargs)def_flatten(self):returntuple(self.operands),tuple()@classmethoddef_unflatten(cls,data,metadata):returncls(*data)_eigs={}# cache eigen vectors and values like in qml.Hermitiandef__init__(self,*operands:Operator,id=None,_pauli_rep=None):# pylint: disable=super-init-not-calledself._id=idself._name=self.__class__.__name__self.operands=operandsself._wires=qml.wires.Wires.all_wires([op.wiresforopinoperands])self._hash=Noneself._has_overlapping_wires=Noneself._overlapping_ops=Noneself._pauli_rep=self._build_pauli_rep()if_pauli_repisNoneelse_pauli_repself.queue()self._batch_size=_UNSET_BATCH_SIZE@handle_recursion_errordef_check_batching(self):batch_sizes={op.batch_sizeforopinselfifop.batch_sizeisnotNone}iflen(batch_sizes)>1:raiseValueError("Broadcasting was attempted but the broadcasted dimensions "f"do not match: {batch_sizes}.")self._batch_size=batch_sizes.pop()ifbatch_sizeselseNonedef__repr__(self):returnf" {self._op_symbol} ".join([f"({op})"ifop.arithmetic_depth>0elsef"{op}"foropinself])@handle_recursion_errordef__copy__(self):cls=self.__class__copied_op=cls.__new__(cls)copied_op.operands=tuple(s.__copy__()forsinself)forattr,valueinvars(self).items():ifattrnotin{"operands"}:setattr(copied_op,attr,value)# TODO: exclude data?returncopied_opdef__iter__(self):"""Return the iterator over the underlying operands."""returniter(self.operands)def__getitem__(self,idx):"""Return the operand at position ``idx`` of the composition."""returnself.operands[idx]def__len__(self):"""Return the number of operators in this composite operator"""returnlen(self.operands)@property@abc.abstractmethoddef_op_symbol(self)->str:"""The symbol used when visualizing the composite operator"""@property@handle_recursion_errordefdata(self):"""Create data property"""returntuple(dforopinselffordinop.data)@data.setterdefdata(self,new_data):"""Set the data property"""foropinself:op_num_params=op.num_paramsifop_num_params>0:op.data=new_data[:op_num_params]new_data=new_data[op_num_params:]@propertydefnum_wires(self):"""Number of wires the operator acts on."""returnlen(self.wires)@property@handle_recursion_errordefnum_params(self):returnsum(op.num_paramsforopinself)@propertydefhas_overlapping_wires(self)->bool:"""Boolean expression that indicates if the factors have overlapping wires."""ifself._has_overlapping_wiresisNone:wires=[]foropinself:wires.extend(list(op.wires))self._has_overlapping_wires=len(wires)!=len(set(wires))returnself._has_overlapping_wires@property@abc.abstractmethoddefis_hermitian(self):"""This property determines if the composite operator is hermitian."""# pylint: disable=arguments-renamed, invalid-overridden-method@property@handle_recursion_errordefhas_matrix(self):returnall(op.has_matrixforopinself)
[docs]@handle_recursion_errordefeigvals(self):"""Return the eigenvalues of the specified operator. This method uses pre-stored eigenvalues for standard observables where possible and stores the corresponding eigenvectors from the eigendecomposition. Returns: array: array containing the eigenvalues of the operator """eigvals=[]foropsinself.overlapping_ops:iflen(ops)==1:eigvals.append(math.expand_vector(ops[0].eigvals(),list(ops[0].wires),list(self.wires)))else:tmp_composite=self.__class__(*ops)eigvals.append(math.expand_vector(tmp_composite.eigendecomposition["eigval"],list(tmp_composite.wires),list(self.wires),))framework=math.get_deep_interface(eigvals)eigvals=[math.asarray(ei,like=framework)foreiineigvals]returnself._math_op(math.vstack(eigvals),axis=0)
[docs]@abc.abstractmethoddefmatrix(self,wire_order=None):"""Representation of the operator as a matrix in the computational basis."""
@propertydefoverlapping_ops(self)->list[list[Operator]]:"""Groups all operands of the composite operator that act on overlapping wires. Returns: List[List[Operator]]: List of lists of operators that act on overlapping wires. All the inner lists commute with each other. """ifself._overlapping_opsisnotNone:returnself._overlapping_opsgroups=[]foropinself:# For every op, find all groups that have overlapping wires with it.i=0first_group_idx=Nonewhilei<len(groups):iffirst_group_idxisNoneandany(wireinop.wiresforwireingroups[i][1]):# Found the first group that has overlapping wires with this opgroups[i][1]=groups[i][1]+op.wiresfirst_group_idx=i# record the index of this groupi+=1eliffirst_group_idxisnotNoneandany(wireinop.wiresforwireingroups[i][1]):# If the op has already been added to the first group, every subsequent# group that overlaps with this op is merged into the first groupops,wires=groups.pop(i)groups[first_group_idx][0].extend(ops)groups[first_group_idx][1]=groups[first_group_idx][1]+wireselse:i+=1iffirst_group_idxisnotNone:groups[first_group_idx][0].append(op)else:# Create new groupgroups.append([[op],op.wires])self._overlapping_ops=[group[0]forgroupingroups]returnself._overlapping_ops@propertydefeigendecomposition(self):r"""Return the eigendecomposition of the matrix specified by the operator. This method uses pre-stored eigenvalues for standard observables where possible and stores the corresponding eigenvectors from the eigendecomposition. It transforms the input operator according to the wires specified. Returns: dict[str, array]: dictionary containing the eigenvalues and the eigenvectors of the operator. """eigen_func=math.linalg.eighifself.is_hermitianelsemath.linalg.eigifself.hashnotinself._eigs:mat=self.matrix()w,U=eigen_func(mat)self._eigs[self.hash]={"eigvec":U,"eigval":w}returnself._eigs[self.hash]@propertydefhas_diagonalizing_gates(self):ifself.has_overlapping_wires:foropsinself.overlapping_ops:# if any of the single ops doesn't have diagonalizing gates, the overall operator doesn't eitheriflen(ops)==1andnotops[0].has_diagonalizing_gates:returnFalse# the lists of ops with multiple operators can be handled if there is a matrixreturnself.has_matrixreturnall(op.has_diagonalizing_gatesforopinself)
[docs]defdiagonalizing_gates(self):r"""Sequence of gates that diagonalize the operator in the computational basis. Given the eigendecomposition :math:`O = U \Sigma U^{\dagger}` where :math:`\Sigma` is a diagonal matrix containing the eigenvalues, the sequence of diagonalizing gates implements the unitary :math:`U^{\dagger}`. The diagonalizing gates rotate the state into the eigenbasis of the operator. A ``DiagGatesUndefinedError`` is raised if no representation by decomposition is defined. .. seealso:: :meth:`~.Operator.compute_diagonalizing_gates`. Returns: list[.Operator] or None: a list of operators """diag_gates=[]foropsinself.overlapping_ops:iflen(ops)==1:diag_gates.extend(ops[0].diagonalizing_gates())else:tmp_sum=self.__class__(*ops)eigvecs=tmp_sum.eigendecomposition["eigvec"]diag_gates.append(qml.QubitUnitary(math.transpose(math.conj(eigvecs)),wires=tmp_sum.wires))returndiag_gates
[docs]@handle_recursion_errordeflabel(self,decimals=None,base_label=None,cache=None):r"""How the composite operator is represented in diagrams and drawings. Args: decimals (int): If ``None``, no parameters are included. Else, how to round the parameters. Defaults to ``None``. base_label (Iterable[str]): Overwrite the non-parameter component of the label. Must be same length as ``operands`` attribute. Defaults to ``None``. cache (dict): Dictionary that carries information between label calls in the same drawing. Defaults to ``None``. Returns: str: label to use in drawings **Example (using the Sum composite operator)** >>> op = qml.S(0) + qml.X(0) + qml.Rot(1,2,3, wires=[1]) >>> op.label() '(S+X)+Rot' >>> op.label(decimals=2, base_label=[["my_s", "my_x"], "inc_rot"]) '(my_s+my_x)+inc_rot\n(1.00,\n2.00,\n3.00)' """def_label(op,decimals,base_label,cache):sub_label=op.label(decimals,base_label,cache)returnf"({sub_label})"ifop.arithmetic_depth>0elsesub_labelifbase_labelisnotNone:ifisinstance(base_label,str)orlen(base_label)!=len(self):raiseValueError("Composite operator labels require ``base_label`` keyword to be same length as operands.")returnself._op_symbol.join(_label(op,decimals,lbl,cache)forop,lblinzip(self,base_label))returnself._op_symbol.join(_label(op,decimals,None,cache)foropinself)
[docs]defqueue(self,context=qml.QueuingManager):"""Updates each operator's owner to self, this ensures that the operators are not applied to the circuit repeatedly."""ifqml.QueuingManager.recording():foropinself:context.remove(op)context.append(self)returnself
@classmethod@abc.abstractmethoddef_sort(cls,op_list,wire_map:dict=None)->list[Operator]:"""Sort composite operands by their wire indices."""@property@handle_recursion_errordefhash(self):ifself._hashisNone:self._hash=hash((str(self.name),str([factor.hashforfactorinself._sort(self.operands)])))returnself._hash# pylint:disable = missing-function-docstring@propertydefbasis(self):returnNone@property@handle_recursion_errordefarithmetic_depth(self)->int:return1+max(op.arithmetic_depthforopinself)@property@abc.abstractmethoddef_math_op(self)->Callable:"""The function used when combining the operands of the composite operator"""