Source code for pennylane.ops.functions.equal
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the qml.equal function.
"""
# pylint: disable=too-many-arguments,too-many-return-statements,too-many-branches
from collections.abc import Iterable
from functools import singledispatch
from typing import Union
import pennylane as qml
from pennylane.measurements import MeasurementProcess
from pennylane.measurements.classical_shadow import ShadowExpvalMP
from pennylane.measurements.counts import CountsMP
from pennylane.measurements.mid_measure import MeasurementValue, MidMeasureMP
from pennylane.measurements.mutual_info import MutualInfoMP
from pennylane.measurements.vn_entropy import VnEntropyMP
from pennylane.operation import Observable, Operator
from pennylane.ops import Adjoint, CompositeOp, Conditional, Controlled, Exp, Pow, SProd
from pennylane.pauli import PauliSentence, PauliWord
from pennylane.pulse.parametrized_evolution import ParametrizedEvolution
from pennylane.tape import QuantumScript
from pennylane.templates.subroutines import ControlledSequence, PrepSelPrep
OPERANDS_MISMATCH_ERROR_MESSAGE = "op1 and op2 have different operands because "
BASE_OPERATION_MISMATCH_ERROR_MESSAGE = "op1 and op2 have different base operations because "
[docs]def equal(
op1: Union[Operator, MeasurementProcess, QuantumScript, PauliWord, PauliSentence],
op2: Union[Operator, MeasurementProcess, QuantumScript, PauliWord, PauliSentence],
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
) -> bool:
r"""Function for determining operator, measurement, and tape equality.
.. Warning::
The ``qml.equal`` function is based on a comparison of the types and attributes of the
measurements or operators, not their mathematical representations. Mathematically equivalent
operators defined via different classes may return False when compared via ``qml.equal``.
To be more thorough would require the matrix forms to be calculated, which may drastically
increase runtime.
.. Warning::
The interfaces and trainability of data within some observables including ``Prod`` and
``Sum`` are sometimes ignored, regardless of what the user specifies for ``check_interface``
and ``check_trainability``.
Args:
op1 (.Operator or .MeasurementProcess or .QuantumTape or .PauliWord or .PauliSentence): First object to compare
op2 (.Operator or .MeasurementProcess or .QuantumTape or .PauliWord or .PauliSentence): Second object to compare
check_interface (bool, optional): Whether to compare interfaces. Default: ``True``.
check_trainability (bool, optional): Whether to compare trainability status. Default: ``True``.
rtol (float, optional): Relative tolerance for parameters.
atol (float, optional): Absolute tolerance for parameters.
Returns:
bool: ``True`` if the operators, measurement processes, or tapes are equal, else ``False``
**Example**
Given two operators or measurement processes, ``qml.equal`` determines their equality.
>>> op1 = qml.RX(np.array(.12), wires=0)
>>> op2 = qml.RY(np.array(1.23), wires=0)
>>> qml.equal(op1, op1), qml.equal(op1, op2)
(True, False)
>>> prod1 = qml.X(0) @ qml.Y(1)
>>> prod2 = qml.Y(1) @ qml.X(0)
>>> prod3 = qml.X(1) @ qml.Y(0)
>>> qml.equal(prod1, prod2), qml.equal(prod1, prod3)
(True, False)
>>> prod = qml.X(0) @ qml.Y(1)
>>> ham = qml.Hamiltonian([1], [qml.X(0) @ qml.Y(1)])
>>> qml.equal(prod, ham)
True
>>> H1 = qml.Hamiltonian([0.5, 0.5], [qml.Z(0) @ qml.Y(1), qml.Y(1) @ qml.Z(0) @ qml.Identity("a")])
>>> H2 = qml.Hamiltonian([1], [qml.Z(0) @ qml.Y(1)])
>>> H3 = qml.Hamiltonian([2], [qml.Z(0) @ qml.Y(1)])
>>> qml.equal(H1, H2), qml.equal(H1, H3)
(True, False)
>>> qml.equal(qml.expval(qml.X(0)), qml.expval(qml.X(0)))
True
>>> qml.equal(qml.probs(wires=(0,1)), qml.probs(wires=(1,2)))
False
>>> qml.equal(qml.classical_shadow(wires=[0,1]), qml.classical_shadow(wires=[0,1]))
True
>>> tape1 = qml.tape.QuantumScript([qml.RX(1.2, wires=0)], [qml.expval(qml.Z(0))])
>>> tape2 = qml.tape.QuantumScript([qml.RX(1.2 + 1e-6, wires=0)], [qml.expval(qml.Z(0))])
>>> qml.equal(tape1, tape2, tol=0, atol=1e-7)
False
>>> qml.equal(tape1, tape2, tol=0, atol=1e-5)
True
.. details::
:title: Usage Details
You can use the optional arguments to get more specific results:
>>> op1 = qml.RX(torch.tensor(1.2), wires=0)
>>> op2 = qml.RX(jax.numpy.array(1.2), wires=0)
>>> qml.equal(op1, op2)
False
>>> qml.equal(op1, op2, check_interface=False, check_trainability=False)
True
>>> op3 = qml.RX(np.array(1.2, requires_grad=True), wires=0)
>>> op4 = qml.RX(np.array(1.2, requires_grad=False), wires=0)
>>> qml.equal(op3, op4)
False
>>> qml.equal(op3, op4, check_trainability=False)
True
>>> qml.equal(Controlled(op3, control_wires=1), Controlled(op4, control_wires=1))
False
>>> qml.equal(Controlled(op3, control_wires=1), Controlled(op4, control_wires=1), check_trainability=False)
True
"""
dispatch_result = _equal(
op1,
op2,
check_interface=check_interface,
check_trainability=check_trainability,
atol=atol,
rtol=rtol,
)
if isinstance(dispatch_result, str):
return False
return dispatch_result
[docs]def assert_equal(
op1: Union[Operator, MeasurementProcess, QuantumScript],
op2: Union[Operator, MeasurementProcess, QuantumScript],
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
) -> None:
"""Function to assert that two operators, measurements, or tapes are equal
Args:
op1 (.Operator or .MeasurementProcess or .QuantumTape): First object to compare
op2 (.Operator or .MeasurementProcess or .QuantumTape): Second object to compare
check_interface (bool, optional): Whether to compare interfaces. Default: ``True``.
check_trainability (bool, optional): Whether to compare trainability status. Default: ``True``.
rtol (float, optional): Relative tolerance for parameters.
atol (float, optional): Absolute tolerance for parameters.
Returns:
None
Raises:
AssertionError: An ``AssertionError`` is raised if the two operators are not equal.
.. seealso::
:func:`~.equal`
**Example**
>>> op1 = qml.RX(np.array(0.12), wires=0)
>>> op2 = qml.RX(np.array(1.23), wires=0)
>>> qml.assert_equal(op1, op2)
AssertionError: op1 and op2 have different data.
Got (array(0.12),) and (array(1.23),)
>>> h1 = qml.Hamiltonian([1, 2], [qml.PauliX(0), qml.PauliY(1)])
>>> h2 = qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliY(1)])
>>> qml.assert_equal(h1, h2)
AssertionError: op1 and op2 have different operands because op1 and op2 have different scalars. Got 2 and 1
"""
dispatch_result = _equal(
op1,
op2,
check_interface=check_interface,
check_trainability=check_trainability,
atol=atol,
rtol=rtol,
)
if isinstance(dispatch_result, str):
raise AssertionError(dispatch_result)
if not dispatch_result:
raise AssertionError(f"{op1} and {op2} are not equal for an unspecified reason.")
def _equal(
op1,
op2,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
) -> Union[bool, str]: # pylint: disable=unused-argument
if not isinstance(op2, type(op1)) and not isinstance(op1, Observable):
return f"op1 and op2 are of different types. Got {type(op1)} and {type(op2)}."
return _equal_dispatch(
op1,
op2,
check_interface=check_interface,
check_trainability=check_trainability,
atol=atol,
rtol=rtol,
)
@singledispatch
def _equal_dispatch(
op1,
op2,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
) -> Union[bool, str]: # pylint: disable=unused-argument
raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented")
@_equal_dispatch.register
def _equal_circuit(
op1: qml.tape.QuantumScript,
op2: qml.tape.QuantumScript,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
# operations
if len(op1.operations) != len(op2.operations):
return False
for comparands in zip(op1.operations, op2.operations):
if not qml.equal(
comparands[0],
comparands[1],
check_interface=check_interface,
check_trainability=check_trainability,
rtol=rtol,
atol=atol,
):
return False
# measurements
if len(op1.measurements) != len(op2.measurements):
return False
for comparands in zip(op1.measurements, op2.measurements):
if not qml.equal(
comparands[0],
comparands[1],
check_interface=check_interface,
check_trainability=check_trainability,
rtol=rtol,
atol=atol,
):
return False
if op1.shots != op2.shots:
return False
if op1.trainable_params != op2.trainable_params:
return False
return True
@_equal_dispatch.register
def _equal_operators(
op1: Operator,
op2: Operator,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
"""Default function to determine whether two Operator objects are equal."""
if not isinstance(
op2, type(op1)
): # clarifies cases involving PauliX/Y/Z (Observable/Operation)
return f"op1 and op2 are of different types. Got {type(op1)} and {type(op2)}"
if isinstance(op1, qml.Identity):
# All Identities are equivalent, independent of wires.
# We already know op1 and op2 are of the same type, so no need to check
# that op2 is also an Identity
return True
if op1.arithmetic_depth != op2.arithmetic_depth:
return f"op1 and op2 have different arithmetic depths. Got {op1.arithmetic_depth} and {op2.arithmetic_depth}"
if op1.arithmetic_depth > 0:
# Other dispatches cover cases of operations with arithmetic depth > 0.
# If any new operations are added with arithmetic depth > 0, a new dispatch
# should be created for them.
return f"op1 and op2 have arithmetic depth > 0. Got arithmetic depth {op1.arithmetic_depth}"
if op1.wires != op2.wires:
return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}."
if op1.hyperparameters != op2.hyperparameters:
return (
"The hyperparameters are not equal for op1 and op2.\n"
f"Got {op1.hyperparameters}\n and {op2.hyperparameters}."
)
if any(qml.math.is_abstract(d) for d in op1.data + op2.data):
# assume all tracers are independent
return "Data contains a tracer. Abstract tracers are assumed to be unique."
if not all(
qml.math.allclose(d1, d2, rtol=rtol, atol=atol) for d1, d2 in zip(op1.data, op2.data)
):
return f"op1 and op2 have different data.\nGot {op1.data} and {op2.data}"
if check_trainability:
for params1, params2 in zip(op1.data, op2.data):
params1_train = qml.math.requires_grad(params1)
params2_train = qml.math.requires_grad(params2)
if params1_train != params2_train:
return (
"Parameters have different trainability.\n "
f"{params1} trainability is {params1_train} and {params2} trainability is {params2_train}"
)
if check_interface:
for params1, params2 in zip(op1.data, op2.data):
params1_interface = qml.math.get_interface(params1)
params2_interface = qml.math.get_interface(params2)
if params1_interface != params2_interface:
return (
"Parameters have different interfaces.\n "
f"{params1} interface is {params1_interface} and {params2} interface is {params2_interface}"
)
return True
# pylint: disable=unused-argument
@_equal_dispatch.register
def _equal_pauliword(
op1: PauliWord,
op2: PauliWord,
**kwargs,
):
if op1 != op2:
if set(op1) != set(op2):
err = "Different wires in Pauli words."
diff12 = set(op1).difference(set(op2))
diff21 = set(op2).difference(set(op1))
if diff12:
err += f" op1 has {diff12} not present in op2."
if diff21:
err += f" op2 has {diff21} not present in op1."
return err
pauli_diff = {}
for wire in op1:
if op1[wire] != op2[wire]:
pauli_diff[wire] = f"{op1[wire]} != {op2[wire]}"
return f"Pauli words agree on wires but differ in Paulis: {pauli_diff}"
return True
@_equal_dispatch.register
def _equal_paulisentence(
op1: PauliSentence,
op2: PauliSentence,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
if set(op1) != set(op2):
err = "Different Pauli words in PauliSentences."
diff12 = set(op1).difference(set(op2))
diff21 = set(op2).difference(set(op1))
if diff12:
err += f" op1 has {diff12} not present in op2."
if diff21:
err += f" op2 has {diff21} not present in op1."
return err
for pw in op1:
param1 = op1[pw]
param2 = op2[pw]
if check_trainability:
param1_train = qml.math.requires_grad(param1)
param2_train = qml.math.requires_grad(param2)
if param1_train != param2_train:
return (
"Parameters have different trainability.\n "
f"{param1} trainability is {param1_train} and {param2} trainability is {param2_train}"
)
if check_interface:
param1_interface = qml.math.get_interface(param1)
param2_interface = qml.math.get_interface(param2)
if param1_interface != param2_interface:
return (
"Parameters have different interfaces.\n "
f"{param1} interface is {param1_interface} and {param2} interface is {param2_interface}"
)
if not qml.math.allclose(param1, param2, rtol=rtol, atol=atol):
return f"The coefficients of the PauliSentences for {pw} differ: {param1}; {param2}"
return True
@_equal_dispatch.register
# pylint: disable=unused-argument, protected-access
def _equal_prod_and_sum(op1: CompositeOp, op2: CompositeOp, **kwargs):
"""Determine whether two Prod or Sum objects are equal"""
if op1.pauli_rep is not None and (op1.pauli_rep == op2.pauli_rep): # shortcut check
return True
if len(op1.operands) != len(op2.operands):
return f"op1 and op2 have different number of operands. Got {len(op1.operands)} and {len(op2.operands)}"
# organizes by wire indicies while respecting commutation relations
sorted_ops1 = op1._sort(op1.operands)
sorted_ops2 = op2._sort(op2.operands)
for o1, o2 in zip(sorted_ops1, sorted_ops2):
op_check = _equal(o1, o2, **kwargs)
if isinstance(op_check, str):
return OPERANDS_MISMATCH_ERROR_MESSAGE + op_check
return True
@_equal_dispatch.register
def _equal_controlled(op1: Controlled, op2: Controlled, **kwargs):
"""Determine whether two Controlled or ControlledOp objects are equal"""
if op1.arithmetic_depth != op2.arithmetic_depth:
return f"op1 and op2 have different arithmetic depths. Got {op1.arithmetic_depth} and {op2.arithmetic_depth}"
# op.base.wires compared in return
if op1.work_wires != op2.work_wires:
return f"op1 and op2 have different work wires. Got {op1.work_wires} and {op2.work_wires}"
# work wires and control_wire/control_value combinations compared here
op1_control_dict = dict(zip(op1.control_wires, op1.control_values))
op2_control_dict = dict(zip(op2.control_wires, op2.control_values))
if op1_control_dict != op2_control_dict:
return f"op1 and op2 have different control dictionaries. Got {op1_control_dict} and {op2_control_dict}"
base_equal_check = _equal(op1.base, op2.base, **kwargs)
if isinstance(base_equal_check, str):
return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check
return True
@_equal_dispatch.register
def _equal_controlled_sequence(op1: ControlledSequence, op2: ControlledSequence, **kwargs):
"""Determine whether two ControlledSequences are equal"""
if op1.wires != op2.wires:
return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}."
if op1.arithmetic_depth != op2.arithmetic_depth:
return f"op1 and op2 have different arithmetic depths. Got {op1.arithmetic_depth} and {op2.arithmetic_depth}"
base_equal_check = _equal(op1.base, op2.base, **kwargs)
if isinstance(base_equal_check, str):
return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check
return True
@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_pow(op1: Pow, op2: Pow, **kwargs):
"""Determine whether two Pow objects are equal"""
check_interface, check_trainability = kwargs["check_interface"], kwargs["check_trainability"]
if check_interface:
interface1 = qml.math.get_interface(op1.z)
interface2 = qml.math.get_interface(op2.z)
if interface1 != interface2:
return (
"Exponent have different interfaces.\n"
f"{op1.z} interface is {interface1} and {op2.z} interface is {interface2}"
)
if check_trainability:
grad1 = qml.math.requires_grad(op1.z)
grad2 = qml.math.requires_grad(op2.z)
if grad1 != grad2:
return (
"Exponent have different trainability.\n"
f"{op1.z} interface is {grad1} and {op2.z} interface is {grad2}"
)
if op1.z != op2.z:
return f"Exponent are different. Got {op1.z} and {op2.z}"
base_equal_check = _equal(op1.base, op2.base, **kwargs)
if isinstance(base_equal_check, str):
return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check
return True
@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs):
"""Determine whether two Adjoint objects are equal"""
# first line of top-level equal function already confirms both are Adjoint - only need to compare bases
base_equal_check = _equal(op1.base, op2.base, **kwargs)
if isinstance(base_equal_check, str):
return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check
return True
@_equal_dispatch.register
def _equal_conditional(op1: Conditional, op2: Conditional, **kwargs):
"""Determine whether two Conditional objects are equal"""
# first line of top-level equal function already confirms both are Conditionaly - only need to compare bases and meas_val
return qml.equal(op1.base, op2.base, **kwargs) and qml.equal(
op1.meas_val, op2.meas_val, **kwargs
)
@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_measurement_value(op1: MeasurementValue, op2: MeasurementValue, **kwargs):
"""Determine whether two MeasurementValue objects are equal"""
return op1.measurements == op2.measurements
@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_exp(op1: Exp, op2: Exp, **kwargs):
"""Determine whether two Exp objects are equal"""
check_interface, check_trainability, rtol, atol = (
kwargs["check_interface"],
kwargs["check_trainability"],
kwargs["rtol"],
kwargs["atol"],
)
if check_interface:
for params1, params2 in zip(op1.data, op2.data):
params1_interface = qml.math.get_interface(params1)
params2_interface = qml.math.get_interface(params2)
if params1_interface != params2_interface:
return (
"Parameters have different interfaces.\n"
f"{params1} interface is {params1_interface} and {params2} interface is {params2_interface}"
)
if check_trainability:
for params1, params2 in zip(op1.data, op2.data):
params1_trainability = qml.math.requires_grad(params1)
params2_trainability = qml.math.requires_grad(params2)
if params1_trainability != params2_trainability:
return (
"Parameters have different trainability.\n"
f"{params1} trainability is {params1_trainability} and {params2} trainability is {params2_trainability}"
)
if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol):
return f"op1 and op2 have different coefficients. Got {op1.coeff} and {op2.coeff}"
equal_check = _equal(op1.base, op2.base, **kwargs)
if isinstance(equal_check, str):
return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + equal_check
return True
@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_sprod(op1: SProd, op2: SProd, **kwargs):
"""Determine whether two SProd objects are equal"""
check_interface, check_trainability, rtol, atol = (
kwargs["check_interface"],
kwargs["check_trainability"],
kwargs["rtol"],
kwargs["atol"],
)
if check_interface:
for params1, params2 in zip(op1.data, op2.data):
params1_interface = qml.math.get_interface(params1)
params2_interface = qml.math.get_interface(params2)
if params1_interface != params2_interface:
return (
"Parameters have different interfaces.\n "
f"{params1} interface is {params1_interface} and {params2} interface is {params2_interface}"
)
if check_trainability:
for params1, params2 in zip(op1.data, op2.data):
params1_train = qml.math.requires_grad(params1)
params2_train = qml.math.requires_grad(params2)
if params1_train != params2_train:
return (
"Parameters have different trainability.\n "
f"{params1} trainability is {params1_train} and {params2} trainability is {params2_train}"
)
if op1.pauli_rep is not None and (op1.pauli_rep == op2.pauli_rep): # shortcut check
return True
if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol):
return f"op1 and op2 have different scalars. Got {op1.scalar} and {op2.scalar}"
equal_check = _equal(op1.base, op2.base, **kwargs)
if isinstance(equal_check, str):
return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + equal_check
return True
@_equal_dispatch.register
def _equal_parametrized_evolution(op1: ParametrizedEvolution, op2: ParametrizedEvolution, **kwargs):
# check times match
if op1.t is None or op2.t is None:
if not (op1.t is None and op2.t is None):
return False
elif not qml.math.allclose(op1.t, op2.t):
return False
# check parameters passed to operator match
operator_check = _equal_operators(op1, op2, **kwargs)
if isinstance(operator_check, str):
return False
# check H.coeffs match
if not all(c1 == c2 for c1, c2 in zip(op1.H.coeffs, op2.H.coeffs)):
return False
# check that all the base operators on op1.H and op2.H match
return all(equal(o1, o2, **kwargs) for o1, o2 in zip(op1.H.ops, op2.H.ops))
@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_measurements(
op1: MeasurementProcess,
op2: MeasurementProcess,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
"""Determine whether two MeasurementProcess objects are equal"""
if op1.obs is not None and op2.obs is not None:
return equal(
op1.obs,
op2.obs,
check_interface=check_interface,
check_trainability=check_trainability,
rtol=rtol,
atol=atol,
)
if op1.mv is not None and op2.mv is not None:
if isinstance(op1.mv, MeasurementValue) and isinstance(op2.mv, MeasurementValue):
return qml.equal(op1.mv, op2.mv)
if qml.math.is_abstract(op1.mv) or qml.math.is_abstract(op2.mv):
return op1.mv is op2.mv
if isinstance(op1.mv, Iterable) and isinstance(op2.mv, Iterable):
if len(op1.mv) == len(op2.mv):
return all(mv1.measurements == mv2.measurements for mv1, mv2 in zip(op1.mv, op2.mv))
return False
if op1.wires != op2.wires:
return False
if op1.obs is None and op2.obs is None:
# only compare eigvals if both observables are None.
# Can be expensive to compute for large observables
if op1.eigvals() is not None and op2.eigvals() is not None:
return qml.math.allclose(op1.eigvals(), op2.eigvals(), rtol=rtol, atol=atol)
return op1.eigvals() is None and op2.eigvals() is None
return False
@_equal_dispatch.register
def _equal_mid_measure(op1: MidMeasureMP, op2: MidMeasureMP, **_):
return (
op1.wires == op2.wires
and op1.id == op2.id
and op1.reset == op2.reset
and op1.postselect == op2.postselect
)
@_equal_dispatch.register
# pylint: disable=unused-argument
def _(op1: VnEntropyMP, op2: VnEntropyMP, **kwargs):
"""Determine whether two MeasurementProcess objects are equal"""
eq_m = _equal_measurements(op1, op2, **kwargs)
log_base_match = op1.log_base == op2.log_base
return eq_m and log_base_match
@_equal_dispatch.register
# pylint: disable=unused-argument
def _(op1: MutualInfoMP, op2: MutualInfoMP, **kwargs):
"""Determine whether two MeasurementProcess objects are equal"""
eq_m = _equal_measurements(op1, op2, **kwargs)
log_base_match = op1.log_base == op2.log_base
return eq_m and log_base_match
@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_shadow_measurements(op1: ShadowExpvalMP, op2: ShadowExpvalMP, **_):
"""Determine whether two ShadowExpvalMP objects are equal"""
wires_match = op1.wires == op2.wires
if isinstance(op1.H, Operator) and isinstance(op2.H, Operator):
H_match = equal(op1.H, op2.H)
elif isinstance(op1.H, Iterable) and isinstance(op2.H, Iterable):
H_match = all(equal(o1, o2) for o1, o2 in zip(op1.H, op2.H))
else:
return False
k_match = op1.k == op2.k
return wires_match and H_match and k_match
@_equal_dispatch.register
def _equal_counts(op1: CountsMP, op2: CountsMP, **kwargs):
return _equal_measurements(op1, op2, **kwargs) and op1.all_outcomes == op2.all_outcomes
@_equal_dispatch.register
def _equal_hilbert_schmidt(
op1: qml.HilbertSchmidt,
op2: qml.HilbertSchmidt,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
if not all(
qml.math.allclose(d1, d2, rtol=rtol, atol=atol) for d1, d2 in zip(op1.data, op2.data)
):
return False
if check_trainability:
for params_1, params_2 in zip(op1.data, op2.data):
if qml.math.requires_grad(params_1) != qml.math.requires_grad(params_2):
return False
if check_interface:
for params_1, params_2 in zip(op1.data, op2.data):
if qml.math.get_interface(params_1) != qml.math.get_interface(params_2):
return False
equal_kwargs = {
"check_interface": check_interface,
"check_trainability": check_trainability,
"atol": atol,
"rtol": rtol,
}
# Check hyperparameters using qml.equal rather than == where necessary
if op1.hyperparameters["v_wires"] != op2.hyperparameters["v_wires"]:
return False
if not qml.equal(op1.hyperparameters["u_tape"], op2.hyperparameters["u_tape"], **equal_kwargs):
return False
if not qml.equal(op1.hyperparameters["v_tape"], op2.hyperparameters["v_tape"], **equal_kwargs):
return False
if op1.hyperparameters["v_function"] != op2.hyperparameters["v_function"]:
return False
return True
@_equal_dispatch.register
def _equal_prep_sel_prep(
op1: PrepSelPrep, op2: PrepSelPrep, **kwargs
): # pylint: disable=unused-argument
"""Determine whether two PrepSelPrep are equal"""
if op1.control != op2.control:
return f"op1 and op2 have different control wires. Got {op1.control} and {op2.control}."
if op1.wires != op2.wires:
return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}."
if not qml.equal(op1.lcu, op2.lcu):
return f"op1 and op2 have different lcu. Got {op1.lcu} and {op2.lcu}"
return True
_modules/pennylane/ops/functions/equal
Download Python script
Download Notebook
View on GitHub