Source code for pennylane.ops.functions.assert_valid
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This module contains the qml.ops.functions.check_validity function for determining whether or not anOperator class is correctly defined."""importcopyimportpicklefromstringimportascii_lowercaseimportnumpyasnpimportscipy.sparseimportpennylaneasqmlfrompennylane.operationimportEigvalsUndefinedErrordef_assert_error_raised(func,error,failure_comment):definner_func(*args,**kwargs):error_raised=Falsetry:func(*args,**kwargs)excepterror:error_raised=Trueasserterror_raised,failure_commentreturninner_funcdef_check_decomposition(op,skip_wire_mapping):"""Checks involving the decomposition."""ifop.has_decomposition:decomp=op.decomposition()try:compute_decomp=type(op).compute_decomposition(*op.data,wires=op.wires,**op.hyperparameters)except(qml.operation.DecompositionUndefinedError,TypeError):# sometimes decomposition is defined but not compute_decomposition# Also sometimes compute_decomposition can have a different signaturecompute_decomp=decompwithqml.queuing.AnnotatedQueue()asqueued_decomp:op.decomposition()processed_queue=qml.tape.QuantumTape.from_queue(queued_decomp)assertisinstance(decomp,list),"decomposition must be a list"assertisinstance(compute_decomp,list),"decomposition must be a list"assertop.__class__notin[decomp_op.__class__fordecomp_opindecomp],"an operator should not be included in its own decomposition"foro1,o2,o3inzip(decomp,compute_decomp,processed_queue):asserto1==o2,"decomposition must match compute_decomposition"asserto1==o3,"decomposition must match queued operations"assertisinstance(o1,qml.operation.Operator),"decomposition must contain operators"ifskip_wire_mapping:return# Check that mapping wires transitions to the decompositionwire_map={w:ascii_lowercase[i]fori,winenumerate(op.wires)}mapped_op=op.map_wires(wire_map)# calling `map_wires` on a Controlled operator generates a new `op` from the controls and# base, so may return a different class of operator. We only compare decomps of `op` and# `mapped_op` if `mapped_op` **has** a decomposition.# see MultiControlledX([0, 1]) and CNOT([0, 1]) as an exampleifmapped_op.has_decomposition:mapped_decomp=mapped_op.decomposition()orig_decomp=op.decomposition()formapped_op,orig_opinzip(mapped_decomp,orig_decomp):assert(mapped_op.wires==qml.map_wires(orig_op,wire_map).wires),"Operators in decomposition of wire-mapped operator must have mapped wires."else:failure_comment="If has_decomposition is False, then decomposition must raise a ``DecompositionUndefinedError``."_assert_error_raised(op.decomposition,qml.operation.DecompositionUndefinedError,failure_comment=failure_comment,)()_assert_error_raised(op.compute_decomposition,qml.operation.DecompositionUndefinedError,failure_comment=failure_comment,)(*op.data,wires=op.wires,**op.hyperparameters)def_check_matrix(op):"""Check that if the operation says it has a matrix, it does. Otherwise a ``MatrixUndefinedError`` should be raised."""ifop.has_matrix:mat=op.matrix()assertisinstance(mat,qml.typing.TensorLike),"matrix must be a TensorLike"l=2**len(op.wires)failure_comment=f"matrix must be two dimensional with shape ({l}, {l})"assertqml.math.shape(mat)==(l,l),failure_commentelse:failure_comment=("If has_matrix is False, the matrix method must raise a ``MatrixUndefinedError``.")_assert_error_raised(op.matrix,qml.operation.MatrixUndefinedError,failure_comment=failure_comment)()def_check_sparse_matrix(op):"""Check that if the operation says it has a sparse matrix, it does. Otherwise a ``SparseMatrixUndefinedError`` should be raised."""ifop.has_sparse_matrix:mat=op.sparse_matrix()assertisinstance(mat,scipy.sparse.csr_matrix),"matrix must be a TensorLike"l=2**len(op.wires)failure_comment=f"matrix must be two dimensional with shape ({l}, {l})"assertqml.math.shape(mat)==(l,l),failure_commentassertisinstance(op.sparse_matrix(),scipy.sparse.csr_matrix),"sparse matrix should default to csr format"assertisinstance(op.sparse_matrix(format="csc"),scipy.sparse.csc_matrix),"sparse matrix should be formatted as csc"assertisinstance(op.sparse_matrix(format="lil"),scipy.sparse.lil_matrix),"sparse matrix should be formatted as lil"assertisinstance(op.sparse_matrix(format="coo"),scipy.sparse.coo_matrix),"sparse matrix should be formatted as coo"else:failure_comment="If has_sparse_matrix is False, the matrix method must raise a ``SparseMatrixUndefinedError``."_assert_error_raised(op.sparse_matrix,qml.operation.SparseMatrixUndefinedError,failure_comment=failure_comment,)()def_check_matrix_matches_decomp(op):"""Check that if both the matrix and decomposition are defined, they match."""ifop.has_matrixandop.has_decomposition:mat=op.matrix()decomp_mat=qml.matrix(qml.tape.QuantumScript(op.decomposition()),wire_order=op.wires)failure_comment=(f"matrix and matrix from decomposition must match. Got \n{mat}\n\n{decomp_mat}")assertqml.math.allclose(mat,decomp_mat),failure_commentdef_check_eigendecomposition(op):"""Checks involving diagonalizing gates and eigenvalues."""ifop.has_diagonalizing_gates:dg=op.diagonalizing_gates()try:compute_dg=type(op).compute_diagonalizing_gates(*op.data,wires=op.wires,**op.hyperparameters)except(qml.operation.DiagGatesUndefinedError,TypeError):# sometimes diagonalizing gates is defined but not compute_diagonalizing_gates# compute_diagonalizing_gates might also have a different call signaturecompute_dg=dgforop1,op2inzip(dg,compute_dg):assertop1==op2,"diagonalizing_gates and compute_diagonalizing_gates must match"else:failure_comment="If has_diagonalizing_gates is False, diagonalizing_gates must raise a DiagGatesUndefinedError"_assert_error_raised(op.diagonalizing_gates,qml.operation.DiagGatesUndefinedError,failure_comment)()try:eg=op.eigvals()exceptEigvalsUndefinedError:eg=Nonehas_eigvals=Truetry:compute_eg=type(op).compute_eigvals(*op.data,**op.hyperparameters)exceptEigvalsUndefinedError:compute_eg=eghas_eigvals=Falseifhas_eigvals:assertqml.math.allclose(eg,compute_eg),"eigvals and compute_eigvals must match"ifhas_eigvalsandop.has_diagonalizing_gates:dg=qml.prod(*dg[::-1])iflen(dg)>0elseqml.Identity(op.wires)eg=qml.QubitUnitary(np.diag(eg),wires=op.wires)decomp=qml.prod(qml.adjoint(dg),eg,dg)decomp_mat=qml.matrix(decomp)original_mat=qml.matrix(op)failure_comment=f"eigenvalues and diagonalizing gates must be able to reproduce the original operator. Got \n{decomp_mat}\n\n{original_mat}"assertqml.math.allclose(decomp_mat,original_mat),failure_commentdef_check_generator(op):"""Checks that if an operator's has_generator property is True, it has a generator."""ifop.has_generator:gen=op.generator()assertisinstance(gen,qml.operation.Operator)new_op=qml.exp(gen,1j*op.data[0])assertqml.math.allclose(qml.matrix(op,wire_order=op.wires),qml.matrix(new_op,wire_order=op.wires))else:failure_comment=("If has_generator is False, the matrix method must raise a ``GeneratorUndefinedError``.")_assert_error_raised(op.generator,qml.operation.GeneratorUndefinedError,failure_comment=failure_comment)()def_check_copy(op,skip_deepcopy):"""Check that copies and deep copies give identical objects."""copied_op=copy.copy(op)assertqml.equal(copied_op,op),"copied op must be equal with qml.equal"assertcopied_op==op,"copied op must be equivalent to original operation"assertcopied_opisnotop,"copied op must be a separate instance from original operaiton"ifnotskip_deepcopy:assertqml.equal(copy.deepcopy(op),op),"deep copied op must also be equal"# pylint: disable=import-outside-toplevel, protected-accessdef_check_pytree(op):"""Check that the operator is a pytree."""data,metadata=op._flatten()try:asserthash(metadata),"metadata must be hashable"exceptExceptionase:raiseAssertionError(f"metadata output from _flatten must be hashable. Got metadata {metadata}")frometry:new_op=type(op)._unflatten(data,metadata)exceptExceptionase:message=(f"{type(op).__name__}._unflatten must be able to reproduce the original operation from "f"{data} and {metadata}. You may need to override either the _unflatten or _flatten method. "f"\nFor local testing, try type(op)._unflatten(*op._flatten())")raiseAssertionError(message)fromeassertop==new_op,"metadata and data must be able to reproduce the original operation"try:importjaxexceptImportError:returnleaves,struct=jax.tree_util.tree_flatten(op)unflattened_op=jax.tree_util.tree_unflatten(struct,leaves)assertunflattened_op==op,f"op must be a valid pytree. Got {unflattened_op} instead of {op}."ford1,d2inzip(op.data,leaves):assertqml.math.allclose(d1,d2),f"data must be the terminal leaves of the pytree. Got {d1}, {d2}"def_check_capture(op):try:importjaxexceptImportError:returnifnotall(isinstance(w,int)forwinop.wires):returnqml.capture.enable()try:jaxpr=jax.make_jaxpr(lambdaobj:obj)(op)data,_=jax.tree_util.tree_flatten(op)new_op=jax.core.eval_jaxpr(jaxpr.jaxpr,jaxpr.consts,*data)[0]assertop==new_opexceptExceptionase:raiseValueError(("The capture of the operation into jaxpr failed somehow."" This capture mechanism is currently experimental and not a core"" requirement, but will be necessary in the future."" Please see the capture module documentation for more information."))fromefinally:qml.capture.disable()def_check_pickle(op):"""Check that an operation can be dumped and reloaded with pickle."""pickled=pickle.dumps(op)unpickled=pickle.loads(pickled)assertunpickled==op,"operation must be able to be pickled and unpickled"# pylint: disable=no-memberdef_check_bind_new_parameters(op):"""Check that bind new parameters can create a new op with different data."""new_data=[d*0.0fordinop.data]new_data_op=qml.ops.functions.bind_new_parameters(op,new_data)failure_comment="bind_new_parameters must be able to update the operator with new data."ford1,d2inzip(new_data_op.data,new_data):assertqml.math.allclose(d1,d2),failure_commentdef_check_differentiation(op):"""Checks that the operator can be executed and differentiated correctly."""ifop.num_params==0:returndata,struct=qml.pytrees.flatten(op)defcircuit(*args):qml.apply(qml.pytrees.unflatten(args,struct))returnqml.probs(wires=op.wires)qnode_ref=qml.QNode(circuit,qml.device("default.qubit"),diff_method="backprop")qnode_ps=qml.QNode(circuit,qml.device("default.qubit"),diff_method="parameter-shift")params=[xifisinstance(x,int)elseqml.numpy.array(x)forxindata]ps=qml.jacobian(qnode_ps)(*params)expected_bp=qml.jacobian(qnode_ref)(*params)error_msg=("Parameter-shift does not produce the same Jacobian as with backpropagation. ""This might be a bug, or it might be expected due to the mathematical nature ""of backpropagation, in which case, this test can be skipped for this operator.")ifisinstance(ps,tuple):foractual,expectedinzip(ps,expected_bp):assertqml.math.allclose(actual,expected),error_msgelse:assertqml.math.allclose(ps,expected_bp),error_msgdef_check_wires(op,skip_wire_mapping):"""Check that wires are a ``Wires`` class and can be mapped."""assertisinstance(op.wires,qml.wires.Wires),"wires must be a wires instance"ifskip_wire_mapping:returnwire_map={w:ascii_lowercase[i]fori,winenumerate(op.wires)}mapped_op=op.map_wires(wire_map)new_wires=qml.wires.Wires(list(ascii_lowercase[:len(op.wires)]))assertmapped_op.wires==new_wires,"wires must be mappable with map_wires"
[docs]defassert_valid(op:qml.operation.Operator,skip_deepcopy=False,skip_pickle=False,skip_wire_mapping=False,skip_differentiation=False,)->None:"""Runs basic validation checks on an :class:`~.operation.Operator` to make sure it has been correctly defined. Args: op (.Operator): an operator instance to validate Keyword Args: skip_deepcopy=False: If `True`, deepcopy tests are not run. skip_pickle=False : If ``True``, pickling tests are not run. Set to ``True`` when testing a locally defined operator, as pickle cannot handle local objects skip_differentiation: If ``True``, differentiation tests are not run. Set to `True` when the operator is parametrized but not differentiable. **Examples:** .. code-block:: python class MyOp(qml.operation.Operator): def __init__(self, data, wires): self.data = data super().__init__(wires=wires) op = MyOp(qml.numpy.array(0.5), wires=0) .. code-block:: >>> assert_valid(op) AssertionError: op.data must be a tuple .. code-block:: python class MyOp(qml.operation.Operator): def __init__(self, wires): self.hyperparameters["unhashable_list"] = [] super().__init__(wires=wires) op = MyOp(wires = 0) assert_valid(op) .. code-block:: ValueError: metadata output from _flatten must be hashable. This also applies to hyperparameters """assertisinstance(op.data,tuple),"op.data must be a tuple"assertisinstance(op.parameters,list),"op.parameters must be a list"ford,pinzip(op.data,op.parameters):assertisinstance(d,qml.typing.TensorLike),"each data element must be tensorlike"assertqml.math.allclose(d,p),"data and parameters must match."iflen(op.wires)<=26:_check_wires(op,skip_wire_mapping)_check_copy(op,skip_deepcopy)_check_pytree(op)ifnotskip_pickle:_check_pickle(op)_check_bind_new_parameters(op)_check_decomposition(op,skip_wire_mapping)_check_matrix(op)_check_matrix_matches_decomp(op)_check_sparse_matrix(op)_check_eigendecomposition(op)_check_generator(op)ifnotskip_differentiation:_check_differentiation(op)_check_capture(op)