Source code for pennylane.ops.functions.assert_valid
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the qml.ops.functions.check_validity function for determining whether or not an
Operator class is correctly defined.
"""
import copy
import pickle
from string import ascii_lowercase
import numpy as np
import scipy.sparse
import pennylane as qml
from pennylane.operation import EigvalsUndefinedError
def _assert_error_raised(func, error, failure_comment):
def inner_func(*args, **kwargs):
error_raised = False
try:
func(*args, **kwargs)
except error:
error_raised = True
assert error_raised, failure_comment
return inner_func
def _check_decomposition(op, skip_wire_mapping):
"""Checks involving the decomposition."""
if op.has_decomposition:
decomp = op.decomposition()
try:
compute_decomp = type(op).compute_decomposition(
*op.data, wires=op.wires, **op.hyperparameters
)
except (qml.operation.DecompositionUndefinedError, TypeError):
# sometimes decomposition is defined but not compute_decomposition
# Also sometimes compute_decomposition can have a different signature
compute_decomp = decomp
with qml.queuing.AnnotatedQueue() as queued_decomp:
op.decomposition()
processed_queue = qml.tape.QuantumTape.from_queue(queued_decomp)
assert isinstance(decomp, list), "decomposition must be a list"
assert isinstance(compute_decomp, list), "decomposition must be a list"
assert op.__class__ not in [
decomp_op.__class__ for decomp_op in decomp
], "an operator should not be included in its own decomposition"
for o1, o2, o3 in zip(decomp, compute_decomp, processed_queue):
assert o1 == o2, "decomposition must match compute_decomposition"
assert o1 == o3, "decomposition must match queued operations"
assert isinstance(o1, qml.operation.Operator), "decomposition must contain operators"
if skip_wire_mapping:
return
# Check that mapping wires transitions to the decomposition
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
mapped_op = op.map_wires(wire_map)
# calling `map_wires` on a Controlled operator generates a new `op` from the controls and
# base, so may return a different class of operator. We only compare decomps of `op` and
# `mapped_op` if `mapped_op` **has** a decomposition.
# see MultiControlledX([0, 1]) and CNOT([0, 1]) as an example
if mapped_op.has_decomposition:
mapped_decomp = mapped_op.decomposition()
orig_decomp = op.decomposition()
for mapped_op, orig_op in zip(mapped_decomp, orig_decomp):
assert (
mapped_op.wires == qml.map_wires(orig_op, wire_map).wires
), "Operators in decomposition of wire-mapped operator must have mapped wires."
else:
failure_comment = "If has_decomposition is False, then decomposition must raise a ``DecompositionUndefinedError``."
_assert_error_raised(
op.decomposition,
qml.operation.DecompositionUndefinedError,
failure_comment=failure_comment,
)()
_assert_error_raised(
op.compute_decomposition,
qml.operation.DecompositionUndefinedError,
failure_comment=failure_comment,
)(*op.data, wires=op.wires, **op.hyperparameters)
def _check_matrix(op):
"""Check that if the operation says it has a matrix, it does. Otherwise a ``MatrixUndefinedError`` should be raised."""
if op.has_matrix:
mat = op.matrix()
assert isinstance(mat, qml.typing.TensorLike), "matrix must be a TensorLike"
l = 2 ** len(op.wires)
failure_comment = f"matrix must be two dimensional with shape ({l}, {l})"
assert qml.math.shape(mat) == (l, l), failure_comment
else:
failure_comment = (
"If has_matrix is False, the matrix method must raise a ``MatrixUndefinedError``."
)
_assert_error_raised(
op.matrix, qml.operation.MatrixUndefinedError, failure_comment=failure_comment
)()
def _check_sparse_matrix(op):
"""Check that if the operation says it has a sparse matrix, it does. Otherwise a ``SparseMatrixUndefinedError`` should be raised."""
if op.has_sparse_matrix:
mat = op.sparse_matrix()
assert isinstance(mat, scipy.sparse.csr_matrix), "matrix must be a TensorLike"
l = 2 ** len(op.wires)
failure_comment = f"matrix must be two dimensional with shape ({l}, {l})"
assert qml.math.shape(mat) == (l, l), failure_comment
else:
failure_comment = "If has_sparse_matrix is False, the matrix method must raise a ``SparseMatrixUndefinedError``."
_assert_error_raised(
op.sparse_matrix,
qml.operation.SparseMatrixUndefinedError,
failure_comment=failure_comment,
)()
def _check_matrix_matches_decomp(op):
"""Check that if both the matrix and decomposition are defined, they match."""
if op.has_matrix and op.has_decomposition:
mat = op.matrix()
decomp_mat = qml.matrix(qml.tape.QuantumScript(op.decomposition()), wire_order=op.wires)
failure_comment = (
f"matrix and matrix from decomposition must match. Got \n{mat}\n\n {decomp_mat}"
)
assert qml.math.allclose(mat, decomp_mat), failure_comment
def _check_eigendecomposition(op):
"""Checks involving diagonalizing gates and eigenvalues."""
if op.has_diagonalizing_gates:
dg = op.diagonalizing_gates()
try:
compute_dg = type(op).compute_diagonalizing_gates(
*op.data, wires=op.wires, **op.hyperparameters
)
except (qml.operation.DiagGatesUndefinedError, TypeError):
# sometimes diagonalizing gates is defined but not compute_diagonalizing_gates
# compute_diagonalizing_gates might also have a different call signature
compute_dg = dg
for op1, op2 in zip(dg, compute_dg):
assert op1 == op2, "diagonalizing_gates and compute_diagonalizing_gates must match"
else:
failure_comment = "If has_diagonalizing_gates is False, diagonalizing_gates must raise a DiagGatesUndefinedError"
_assert_error_raised(
op.diagonalizing_gates, qml.operation.DiagGatesUndefinedError, failure_comment
)()
try:
eg = op.eigvals()
except EigvalsUndefinedError:
eg = None
has_eigvals = True
try:
compute_eg = type(op).compute_eigvals(*op.data, **op.hyperparameters)
except EigvalsUndefinedError:
compute_eg = eg
has_eigvals = False
if has_eigvals:
assert qml.math.allclose(eg, compute_eg), "eigvals and compute_eigvals must match"
if has_eigvals and op.has_diagonalizing_gates:
dg = qml.prod(*dg[::-1]) if len(dg) > 0 else qml.Identity(op.wires)
eg = qml.QubitUnitary(np.diag(eg), wires=op.wires)
decomp = qml.prod(qml.adjoint(dg), eg, dg)
decomp_mat = qml.matrix(decomp)
original_mat = qml.matrix(op)
failure_comment = f"eigenvalues and diagonalizing gates must be able to reproduce the original operator. Got \n{decomp_mat}\n\n{original_mat}"
assert qml.math.allclose(decomp_mat, original_mat), failure_comment
def _check_generator(op):
"""Checks that if an operator's has_generator property is True, it has a generator."""
if op.has_generator:
gen = op.generator()
assert isinstance(gen, qml.operation.Operator)
new_op = qml.exp(gen, 1j * op.data[0])
assert qml.math.allclose(
qml.matrix(op, wire_order=op.wires), qml.matrix(new_op, wire_order=op.wires)
)
else:
failure_comment = (
"If has_generator is False, the matrix method must raise a ``GeneratorUndefinedError``."
)
_assert_error_raised(
op.generator, qml.operation.GeneratorUndefinedError, failure_comment=failure_comment
)()
def _check_copy(op):
"""Check that copies and deep copies give identical objects."""
copied_op = copy.copy(op)
assert qml.equal(copied_op, op), "copied op must be equal with qml.equal"
assert copied_op == op, "copied op must be equivalent to original operation"
assert copied_op is not op, "copied op must be a separate instance from original operaiton"
assert qml.equal(copy.deepcopy(op), op), "deep copied op must also be equal"
# pylint: disable=import-outside-toplevel, protected-access
def _check_pytree(op):
"""Check that the operator is a pytree."""
data, metadata = op._flatten()
try:
assert hash(metadata), "metadata must be hashable"
except Exception as e:
raise AssertionError(
f"metadata output from _flatten must be hashable. Got metadata {metadata}"
) from e
try:
new_op = type(op)._unflatten(data, metadata)
except Exception as e:
message = (
f"{type(op).__name__}._unflatten must be able to reproduce the original operation from "
f"{data} and {metadata}. You may need to override either the _unflatten or _flatten method. "
f"\nFor local testing, try type(op)._unflatten(*op._flatten())"
)
raise AssertionError(message) from e
assert op == new_op, "metadata and data must be able to reproduce the original operation"
try:
import jax
except ImportError:
return
leaves, struct = jax.tree_util.tree_flatten(op)
unflattened_op = jax.tree_util.tree_unflatten(struct, leaves)
assert unflattened_op == op, f"op must be a valid pytree. Got {unflattened_op} instead of {op}."
for d1, d2 in zip(op.data, leaves):
assert qml.math.allclose(
d1, d2
), f"data must be the terminal leaves of the pytree. Got {d1}, {d2}"
def _check_capture(op):
try:
import jax
except ImportError:
return
if not all(isinstance(w, int) for w in op.wires):
return
qml.capture.enable()
try:
jaxpr = jax.make_jaxpr(lambda obj: obj)(op)
data, _ = jax.tree_util.tree_flatten(op)
new_op = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *data)[0]
assert op == new_op
except Exception as e:
raise ValueError(
(
"The capture of the operation into jaxpr failed somehow."
" This capture mechanism is currently experimental and not a core"
" requirement, but will be necessary in the future."
" Please see the capture module documentation for more information."
)
) from e
finally:
qml.capture.disable()
def _check_pickle(op):
"""Check that an operation can be dumped and reloaded with pickle."""
pickled = pickle.dumps(op)
unpickled = pickle.loads(pickled)
assert unpickled == op, "operation must be able to be pickled and unpickled"
# pylint: disable=no-member
def _check_bind_new_parameters(op):
"""Check that bind new parameters can create a new op with different data."""
new_data = [d * 0.0 for d in op.data]
new_data_op = qml.ops.functions.bind_new_parameters(op, new_data)
failure_comment = "bind_new_parameters must be able to update the operator with new data."
for d1, d2 in zip(new_data_op.data, new_data):
assert qml.math.allclose(d1, d2), failure_comment
def _check_differentiation(op):
"""Checks that the operator can be executed and differentiated correctly."""
if op.num_params == 0:
return
data, struct = qml.pytrees.flatten(op)
def circuit(*args):
qml.apply(qml.pytrees.unflatten(args, struct))
return qml.probs(wires=op.wires)
qnode_ref = qml.QNode(circuit, qml.device("default.qubit"), diff_method="backprop")
qnode_ps = qml.QNode(circuit, qml.device("default.qubit"), diff_method="parameter-shift")
params = [x if isinstance(x, int) else qml.numpy.array(x) for x in data]
ps = qml.jacobian(qnode_ps)(*params)
expected_bp = qml.jacobian(qnode_ref)(*params)
error_msg = (
"Parameter-shift does not produce the same Jacobian as with backpropagation. "
"This might be a bug, or it might be expected due to the mathematical nature "
"of backpropagation, in which case, this test can be skipped for this operator."
)
if isinstance(ps, tuple):
for actual, expected in zip(ps, expected_bp):
assert qml.math.allclose(actual, expected), error_msg
else:
assert qml.math.allclose(ps, expected_bp), error_msg
def _check_wires(op, skip_wire_mapping):
"""Check that wires are a ``Wires`` class and can be mapped."""
assert isinstance(op.wires, qml.wires.Wires), "wires must be a wires instance"
if skip_wire_mapping:
return
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
mapped_op = op.map_wires(wire_map)
new_wires = qml.wires.Wires(list(ascii_lowercase[: len(op.wires)]))
assert mapped_op.wires == new_wires, "wires must be mappable with map_wires"
[docs]def assert_valid(
op: qml.operation.Operator,
skip_pickle=False,
skip_wire_mapping=False,
skip_differentiation=False,
) -> None:
"""Runs basic validation checks on an :class:`~.operation.Operator` to make
sure it has been correctly defined.
Args:
op (.Operator): an operator instance to validate
Keyword Args:
skip_pickle=False : If ``True``, pickling tests are not run. Set to ``True`` when
testing a locally defined operator, as pickle cannot handle local objects
skip_differentiation: If ``True``, differentiation tests are not run. Set to `True` when
the operator is parametrized but not differentiable.
**Examples:**
.. code-block:: python
class MyOp(qml.operation.Operator):
def __init__(self, data, wires):
self.data = data
super().__init__(wires=wires)
op = MyOp(qml.numpy.array(0.5), wires=0)
.. code-block::
>>> assert_valid(op)
AssertionError: op.data must be a tuple
.. code-block:: python
class MyOp(qml.operation.Operator):
def __init__(self, wires):
self.hyperparameters["unhashable_list"] = []
super().__init__(wires=wires)
op = MyOp(wires = 0)
assert_valid(op)
.. code-block::
ValueError: metadata output from _flatten must be hashable. This also applies to hyperparameters
"""
assert isinstance(op.data, tuple), "op.data must be a tuple"
assert isinstance(op.parameters, list), "op.parameters must be a list"
for d, p in zip(op.data, op.parameters):
assert isinstance(d, qml.typing.TensorLike), "each data element must be tensorlike"
assert qml.math.allclose(d, p), "data and parameters must match."
if len(op.wires) <= 26:
_check_wires(op, skip_wire_mapping)
_check_copy(op)
_check_pytree(op)
if not skip_pickle:
_check_pickle(op)
_check_bind_new_parameters(op)
_check_decomposition(op, skip_wire_mapping)
_check_matrix(op)
_check_matrix_matches_decomp(op)
_check_sparse_matrix(op)
_check_eigendecomposition(op)
_check_generator(op)
if not skip_differentiation:
_check_differentiation(op)
_check_capture(op)
_modules/pennylane/ops/functions/assert_valid
Download Python script
Download Notebook
View on GitHub