Source code for pennylane.ops.op_math.controlled
# Copyright 2018-2022 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule defines the symbolic operation that indicates the control of an operator.
"""
import functools
import warnings
from collections.abc import Callable, Sequence
from copy import copy
from functools import wraps
from inspect import signature
from typing import Any, Optional, overload
import numpy as np
from scipy import sparse
import pennylane as qml
from pennylane import math as qmlmath
from pennylane import operation
from pennylane.capture.capture_diff import create_non_jvp_primitive
from pennylane.compiler import compiler
from pennylane.operation import Operator
from pennylane.wires import Wires
from .controlled_decompositions import ctrl_decomp_bisect, ctrl_decomp_zyz
from .symbolicop import SymbolicOp
@overload
def ctrl(
op: Operator,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = None,
) -> Operator: ...
@overload
def ctrl(
op: Callable,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = None,
) -> Callable: ...
[docs]def ctrl(op, control: Any, control_values=None, work_wires=None):
"""Create a method that applies a controlled version of the provided op.
:func:`~.qjit` compatible.
.. note::
When used with :func:`~.qjit`, this function only supports the Catalyst compiler.
See :func:`catalyst.ctrl` for more details.
Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`,
as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>`
page for an overview of the differences between Catalyst and PennyLane.
Args:
op (function or :class:`~.operation.Operator`): A single operator or a function that applies pennylane operators.
control (Wires): The control wire(s).
control_values (bool or list[bool]): The value(s) the control wire(s) should take.
Integers other than 0 or 1 will be treated as ``int(bool(x))``.
work_wires (Any): Any auxiliary wires that can be used in the decomposition
Returns:
function or :class:`~.operation.Operator`: If an Operator is provided, returns a Controlled version of the Operator.
If a function is provided, returns a function with the same call signature that creates a controlled version of the
provided function.
.. seealso:: :class:`~.Controlled`.
**Example**
.. code-block:: python3
@qml.qnode(qml.device('default.qubit', wires=range(4)))
def circuit(x):
qml.X(2)
qml.ctrl(qml.RX, (1,2,3), control_values=(0,1,0))(x, wires=0)
return qml.expval(qml.Z(0))
>>> print(qml.draw(circuit)("x"))
0: ────╭RX(x)─┤ <Z>
1: ────├○─────┤
2: ──X─├●─────┤
3: ────╰○─────┤
>>> x = np.array(1.2)
>>> circuit(x)
tensor(0.36235775, requires_grad=True)
>>> qml.grad(circuit)(x)
tensor(-0.93203909, requires_grad=True)
:func:`~.ctrl` works on both callables like ``qml.RX`` or a quantum function
and individual :class:`~.operation.Operator`'s.
>>> qml.ctrl(qml.Hadamard(0), (1,2))
Controlled(Hadamard(wires=[0]), control_wires=[1, 2])
Controlled operations work with all other forms of operator math and simplification:
>>> op = qml.ctrl(qml.RX(1.2, wires=0) ** 2 @ qml.RY(0.1, wires=0), control=1)
>>> qml.simplify(qml.adjoint(op))
Controlled(RY(12.466370614359173, wires=[0]) @ RX(10.166370614359172, wires=[0]), control_wires=[1])
**Example with compiler**
.. code-block:: python
dev = qml.device("lightning.qubit", wires=2)
@qml.qjit
@qml.qnode(dev)
def workflow(theta, w, cw):
qml.Hadamard(wires=[0])
qml.Hadamard(wires=[1])
def func(arg):
qml.RX(theta, wires=arg)
def cond_fn():
qml.RY(theta, wires=w)
qml.ctrl(func, control=[cw])(w)
qml.ctrl(qml.cond(theta > 0.0, cond_fn), control=[cw])()
qml.ctrl(qml.RZ, control=[cw])(theta, wires=w)
qml.ctrl(qml.RY(theta, wires=w), control=[cw])
return qml.probs()
>>> workflow(jnp.pi/4, 1, 0)
array([0.25, 0.25, 0.03661165, 0.46338835])
"""
if active_jit := compiler.active_compiler():
available_eps = compiler.AvailableCompilers.names_entrypoints
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.ctrl(op, control, control_values=control_values, work_wires=work_wires)
if qml.math.is_abstract(op):
return Controlled(op, control, control_values=control_values, work_wires=work_wires)
return create_controlled_op(op, control, control_values=control_values, work_wires=work_wires)
def create_controlled_op(op, control, control_values=None, work_wires=None):
"""Default ``qml.ctrl`` implementation, allowing other implementations to call it when needed."""
control = qml.wires.Wires(control)
if isinstance(control_values, (int, bool)):
control_values = [control_values]
elif control_values is None:
control_values = [True] * len(control)
elif isinstance(control_values, tuple):
control_values = list(control_values)
ctrl_op = _try_wrap_in_custom_ctrl_op(
op, control, control_values=control_values, work_wires=work_wires
)
if ctrl_op is not None:
return ctrl_op
pauli_x_based_ctrl_ops = _get_pauli_x_based_ops()
# Special handling for PauliX-based controlled operations
if isinstance(op, pauli_x_based_ctrl_ops):
qml.QueuingManager.remove(op)
return _handle_pauli_x_based_controlled_ops(op, control, control_values, work_wires)
# Flatten nested controlled operations to a multi-controlled operation for better
# decomposition algorithms. This includes special cases like CRX, CRot, etc.
if isinstance(op, Controlled):
work_wires = work_wires or []
return ctrl(
op.base,
control=control + op.control_wires,
control_values=control_values + op.control_values,
work_wires=work_wires + op.work_wires,
)
if isinstance(op, Operator):
return Controlled(
op, control_wires=control, control_values=control_values, work_wires=work_wires
)
if not callable(op):
raise ValueError(
f"The object {op} of type {type(op)} is not an Operator or callable. "
"This error might occur if you apply ctrl to a list "
"of operations instead of a function or Operator."
)
if qml.capture.enabled():
return _capture_ctrl_transform(op, control, control_values, work_wires)
return _ctrl_transform(op, control, control_values, work_wires)
def _ctrl_transform(op, control, control_values, work_wires):
@wraps(op)
def wrapper(*args, **kwargs):
qscript = qml.tape.make_qscript(op)(*args, **kwargs)
leaves, _ = qml.pytrees.flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
_ = [qml.QueuingManager.remove(l) for l in leaves if isinstance(l, Operator)]
# flip control_values == 0 wires here, so we don't have to do it for each individual op.
flip_control_on_zero = (len(qscript) > 1) and (control_values is not None)
op_control_values = None if flip_control_on_zero else control_values
if flip_control_on_zero:
_ = [qml.X(w) for w, val in zip(control, control_values) if not val]
_ = [
ctrl(op, control=control, control_values=op_control_values, work_wires=work_wires)
for op in qscript.operations
]
if flip_control_on_zero:
_ = [qml.X(w) for w, val in zip(control, control_values) if not val]
if qml.QueuingManager.recording():
_ = [qml.apply(m) for m in qscript.measurements]
return qscript.measurements
return wrapper
@functools.lru_cache # only create the first time requested
def _get_ctrl_qfunc_prim():
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
# if capture is enabled, jax should be installed
import jax # pylint: disable=import-outside-toplevel
ctrl_prim = create_non_jvp_primitive()("ctrl_transform")
ctrl_prim.multiple_results = True
@ctrl_prim.def_impl
def _(*args, n_control, jaxpr, control_values, work_wires, n_consts):
consts = args[:n_consts]
control_wires = args[-n_control:]
args = args[n_consts:-n_control]
with qml.queuing.AnnotatedQueue() as q:
jax.core.eval_jaxpr(jaxpr, consts, *args)
ops, _ = qml.queuing.process_queue(q)
for op in ops:
ctrl(op, control_wires, control_values, work_wires)
return []
@ctrl_prim.def_abstract_eval
def _(*_, **__):
return []
return ctrl_prim
def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires) -> Callable:
"""Capture compatible way of performing an ctrl transform."""
# note that this logic is tested in `tests/capture/test_nested_plxpr.py`
import jax # pylint: disable=import-outside-toplevel
ctrl_prim = _get_ctrl_qfunc_prim()
@wraps(qfunc)
def new_qfunc(*args, **kwargs):
jaxpr = jax.make_jaxpr(functools.partial(qfunc, **kwargs))(*args)
control_wires = qml.wires.Wires(control) # make sure is iterable
ctrl_prim.bind(
*jaxpr.consts,
*args,
*control_wires,
jaxpr=jaxpr.jaxpr,
n_control=len(control_wires),
control_values=control_values,
work_wires=work_wires,
n_consts=len(jaxpr.consts),
)
return new_qfunc
@functools.lru_cache()
def _get_special_ops():
"""Gets a list of special operations with custom controlled versions.
This is placed inside a function to avoid circular imports.
"""
ops_with_custom_ctrl_ops = {
(qml.PauliZ, 1): qml.CZ,
(qml.PauliZ, 2): qml.CCZ,
(qml.PauliY, 1): qml.CY,
(qml.CZ, 1): qml.CCZ,
(qml.SWAP, 1): qml.CSWAP,
(qml.Hadamard, 1): qml.CH,
(qml.RX, 1): qml.CRX,
(qml.RY, 1): qml.CRY,
(qml.RZ, 1): qml.CRZ,
(qml.Rot, 1): qml.CRot,
(qml.PhaseShift, 1): qml.ControlledPhaseShift,
}
return ops_with_custom_ctrl_ops
@functools.lru_cache()
def _get_pauli_x_based_ops():
"""Gets a list of pauli-x based operations
This is placed inside a function to avoid circular imports.
"""
return qml.X, qml.CNOT, qml.Toffoli, qml.MultiControlledX
def _try_wrap_in_custom_ctrl_op(op, control, control_values=None, work_wires=None):
"""Wraps a controlled operation in custom ControlledOp, returns None if not applicable."""
ops_with_custom_ctrl_ops = _get_special_ops()
custom_key = (type(op), len(control))
if custom_key in ops_with_custom_ctrl_ops and all(control_values):
qml.QueuingManager.remove(op)
return ops_with_custom_ctrl_ops[custom_key](*op.data, control + op.wires)
if isinstance(op, qml.QubitUnitary):
return qml.ControlledQubitUnitary(
op, control_wires=control, control_values=control_values, work_wires=work_wires
)
return None
def _handle_pauli_x_based_controlled_ops(op, control, control_values, work_wires):
"""Handles PauliX-based controlled operations."""
op_map = {
(qml.PauliX, 1): qml.CNOT,
(qml.PauliX, 2): qml.Toffoli,
(qml.CNOT, 1): qml.Toffoli,
}
custom_key = (type(op), len(control))
if custom_key in op_map and all(control_values):
qml.QueuingManager.remove(op)
return op_map[custom_key](wires=control + op.wires)
if isinstance(op, qml.PauliX):
return qml.MultiControlledX(
wires=control + op.wires, control_values=control_values, work_wires=work_wires
)
work_wires = work_wires or []
return qml.MultiControlledX(
wires=control + op.wires,
control_values=control_values + op.control_values,
work_wires=work_wires + op.work_wires,
)
# pylint: disable=too-many-arguments, too-many-public-methods
[docs]class Controlled(SymbolicOp):
"""Symbolic operator denoting a controlled operator.
Args:
base (~.operation.Operator): the operator that is controlled
control_wires (Any): The wires to control on.
Keyword Args:
control_values (Iterable[Bool]): The values to control on. Must be the same
length as ``control_wires``. Defaults to ``True`` for all control wires.
Provided values are converted to `Bool` internally.
work_wires (Any): Any auxiliary wires that can be used in the decomposition
.. note::
This class, ``Controlled``, denotes a controlled version of any individual operation.
:class:`~.ControlledOp` adds :class:`~.Operation` specific methods and properties to the
more general ``Controlled`` class.
.. seealso:: :class:`~.ControlledOp`, and :func:`~.ctrl`
**Example**
>>> base = qml.RX(1.234, 1)
>>> Controlled(base, (0, 2, 3), control_values=[True, False, True])
Controlled(RX(1.234, wires=[1]), control_wires=[0, 2, 3], control_values=[True, False, True])
>>> op = Controlled(base, 0, control_values=[0])
>>> op
Controlled(RX(1.234, wires=[1]), control_wires=[0], control_values=[0])
The operation has both standard :class:`~.operation.Operator` properties
and ``Controlled`` specific properties:
>>> op.base
RX(1.234, wires=[1])
>>> op.data
(1.234,)
>>> op.wires
Wires([0, 1])
>>> op.control_wires
Wires([0])
>>> op.target_wires
Wires([1])
Control values are lists of booleans, indicating whether or not to control on the
``0==False`` value or the ``1==True`` wire.
>>> op.control_values
[0]
Provided control values are converted to booleans internally, so
any "truthy" or "falsy" objects work.
>>> Controlled(base, ("a", "b", "c"), control_values=["", None, 5]).control_values
[False, False, True]
Representations for an operator are available if the base class defines them.
Sparse matrices are available if the base class defines either a sparse matrix
or only a dense matrix.
>>> np.set_printoptions(precision=4) # easier to read the matrix
>>> qml.matrix(op)
array([[0.8156+0.j , 0. -0.5786j, 0. +0.j , 0. +0.j ],
[0. -0.5786j, 0.8156+0.j , 0. +0.j , 0. +0.j ],
[0. +0.j , 0. +0.j , 1. +0.j , 0. +0.j ],
[0. +0.j , 0. +0.j , 0. +0.j , 1. +0.j ]])
>>> qml.eigvals(op)
array([1. +0.j , 1. +0.j , 0.8156+0.5786j, 0.8156-0.5786j])
>>> print(qml.generator(op, format='observable'))
(-0.5) [Projector0 X1]
>>> op.sparse_matrix()
<4x4 sparse matrix of type '<class 'numpy.complex128'>'
with 6 stored elements in Compressed Sparse Row format>
If the provided base matrix is an :class:`~.operation.Operation`, then the created
object will be of type :class:`~.ops.op_math.ControlledOp`. This class adds some additional
methods and properties to the basic :class:`~.ops.op_math.Controlled` class.
>>> type(op)
<class 'pennylane.ops.op_math.controlled_class.ControlledOp'>
>>> op.parameter_frequencies
[(0.5, 1.0)]
"""
def _flatten(self):
return (self.base,), (self.control_wires, tuple(self.control_values), self.work_wires)
@classmethod
def _unflatten(cls, data, metadata):
return cls(
data[0], control_wires=metadata[0], control_values=metadata[1], work_wires=metadata[2]
)
# pylint: disable=no-self-argument
@operation.classproperty
def __signature__(cls): # pragma: no cover
# this method is defined so inspect.signature returns __init__ signature
# instead of __new__ signature
# See PEP 362
# use __init__ signature instead of __new__ signature
sig = signature(cls.__init__)
# get rid of self from signature
new_parameters = tuple(sig.parameters.values())[1:]
new_sig = sig.replace(parameters=new_parameters)
return new_sig
# pylint: disable=unused-argument
def __new__(cls, base, *_, **__):
"""If base is an ``Operation``, then a ``ControlledOp`` should be used instead."""
if isinstance(base, operation.Operation):
return object.__new__(ControlledOp)
return object.__new__(Controlled)
# pylint: disable=arguments-differ
@classmethod
def _primitive_bind_call(
cls, base, control_wires, control_values=None, work_wires=None, id=None
):
control_wires = Wires(control_wires)
return cls._primitive.bind(
base, *control_wires, control_values=control_values, work_wires=work_wires
)
# pylint: disable=too-many-function-args
def __init__(self, base, control_wires, control_values=None, work_wires=None, id=None):
control_wires = Wires(control_wires)
work_wires = Wires([]) if work_wires is None else Wires(work_wires)
if control_values is None:
control_values = [True] * len(control_wires)
else:
control_values = (
[bool(control_values)]
if isinstance(control_values, int)
else [bool(control_value) for control_value in control_values]
)
if len(control_values) != len(control_wires):
raise ValueError("control_values should be the same length as control_wires")
if len(Wires.shared_wires([base.wires, control_wires])) != 0:
raise ValueError("The control wires must be different from the base operation wires.")
if len(Wires.shared_wires([work_wires, base.wires + control_wires])) != 0:
raise ValueError(
"Work wires must be different the control_wires and base operation wires."
)
self.hyperparameters["control_wires"] = control_wires
self.hyperparameters["control_values"] = control_values
self.hyperparameters["work_wires"] = work_wires
self._name = f"C({base.name})"
super().__init__(base, id)
@property
def hash(self):
# these gates do not consider global phases in their hash
if self.base.name in ("RX", "RY", "RZ", "Rot"):
base_params = str(
[
(
id(d)
if qml.math.is_abstract(d)
else qml.math.round(qml.math.real(d) % (4 * np.pi), 10)
)
for d in self.base.data
]
)
base_hash = hash(
(
str(self.base.name),
tuple(self.base.wires.tolist()),
base_params,
)
)
else:
base_hash = self.base.hash
return hash(
(
"Controlled",
base_hash,
tuple(self.control_wires.tolist()),
tuple(self.control_values),
tuple(self.work_wires.tolist()),
)
)
# pylint: disable=arguments-renamed, invalid-overridden-method
@property
def has_matrix(self):
return self.base.has_matrix
@property
def batch_size(self):
return self.base.batch_size
@property
def ndim_params(self):
return self.base.ndim_params
# Properties on the control values ######################
@property
def control_values(self):
"""Iterable[Bool]. For each control wire, denotes whether to control on ``True`` or
``False``."""
return self.hyperparameters["control_values"]
@property
def _control_int(self):
"""Int. Conversion of ``control_values`` to an integer."""
return sum(2**i for i, val in enumerate(reversed(self.control_values)) if val)
# Properties on the wires ##########################
@property
def control_wires(self):
"""The control wires."""
return self.hyperparameters["control_wires"]
@property
def target_wires(self):
"""The wires of the target operator."""
return self.base.wires
@property
def work_wires(self):
"""Additional wires that can be used in the decomposition. Not modified by the operation."""
return self.hyperparameters["work_wires"]
@property
def wires(self):
return self.control_wires + self.target_wires
[docs] def map_wires(self, wire_map: dict):
new_base = self.base.map_wires(wire_map=wire_map)
new_control_wires = Wires([wire_map.get(wire, wire) for wire in self.control_wires])
new_work_wires = Wires([wire_map.get(wire, wire) for wire in self.work_wires])
return ctrl(
op=new_base,
control=new_control_wires,
control_values=self.control_values,
work_wires=new_work_wires,
)
# Methods ##########################################
def __repr__(self):
params = [f"control_wires={self.control_wires.tolist()}"]
if self.work_wires:
params.append(f"work_wires={self.work_wires.tolist()}")
if self.control_values and not all(self.control_values):
params.append(f"control_values={self.control_values}")
return f"Controlled({self.base}, {', '.join(params)})"
[docs] def label(self, decimals=None, base_label=None, cache=None):
return self.base.label(decimals=decimals, base_label=base_label, cache=cache)
def _compute_matrix_from_base(self):
base_matrix = self.base.matrix()
interface = qmlmath.get_interface(base_matrix)
num_target_states = 2 ** len(self.target_wires)
num_control_states = 2 ** len(self.control_wires)
total_matrix_size = num_control_states * num_target_states
padding_left = self._control_int * num_target_states
padding_right = total_matrix_size - padding_left - num_target_states
left_pad = qmlmath.convert_like(
qmlmath.cast_like(qmlmath.eye(padding_left, like=interface), 1j), base_matrix
)
right_pad = qmlmath.convert_like(
qmlmath.cast_like(qmlmath.eye(padding_right, like=interface), 1j), base_matrix
)
shape = qml.math.shape(base_matrix)
if len(shape) == 3: # stack if batching
return qml.math.stack(
[qml.math.block_diag([left_pad, _U, right_pad]) for _U in base_matrix]
)
return qmlmath.block_diag([left_pad, base_matrix, right_pad])
[docs] def matrix(self, wire_order=None):
if self.compute_matrix is not Operator.compute_matrix:
canonical_matrix = self.compute_matrix(*self.data)
else:
canonical_matrix = self._compute_matrix_from_base()
wire_order = wire_order or self.wires
return qml.math.expand_matrix(canonical_matrix, wires=self.wires, wire_order=wire_order)
# pylint: disable=arguments-differ
[docs] def sparse_matrix(self, wire_order=None, format="csr"):
if wire_order is not None:
raise NotImplementedError("wire_order argument is not yet implemented.")
try:
target_mat = self.base.sparse_matrix()
except operation.SparseMatrixUndefinedError as e:
if self.base.has_matrix:
target_mat = sparse.lil_matrix(self.base.matrix())
else:
raise operation.SparseMatrixUndefinedError from e
num_target_states = 2 ** len(self.target_wires)
num_control_states = 2 ** len(self.control_wires)
total_states = num_target_states * num_control_states
start_ind = self._control_int * num_target_states
end_ind = start_ind + num_target_states
m = sparse.eye(total_states, format="lil", dtype=target_mat.dtype)
m[start_ind:end_ind, start_ind:end_ind] = target_mat
return m.asformat(format=format)
[docs] def eigvals(self):
base_eigvals = self.base.eigvals()
num_target_wires = len(self.target_wires)
num_control_wires = len(self.control_wires)
total = 2 ** (num_target_wires + num_control_wires)
ones = np.ones(total - len(base_eigvals))
return qmlmath.concatenate([ones, base_eigvals])
@property
def has_diagonalizing_gates(self):
return self.base.has_diagonalizing_gates
@property
def has_decomposition(self):
if self.compute_decomposition is not Operator.compute_decomposition:
return True
if not all(self.control_values):
return True
if len(self.control_wires) == 1 and hasattr(self.base, "_controlled"):
return True
if _is_single_qubit_special_unitary(self.base):
return True
if self.base.has_decomposition:
return True
return False
[docs] def decomposition(self):
if self.compute_decomposition is not Operator.compute_decomposition:
return self.compute_decomposition(*self.data, self.wires)
if all(self.control_values):
decomp = _decompose_no_control_values(self)
if decomp is None:
raise qml.operation.DecompositionUndefinedError
return decomp
# We need to add paulis to flip some control wires
d = [qml.X(w) for w, val in zip(self.control_wires, self.control_values) if not val]
decomp = _decompose_no_control_values(self)
if decomp is None:
no_control_values = copy(self).queue()
no_control_values.hyperparameters["control_values"] = [1] * len(self.control_wires)
d.append(no_control_values)
else:
d += decomp
d += [qml.X(w) for w, val in zip(self.control_wires, self.control_values) if not val]
return d
# pylint: disable=arguments-renamed, invalid-overridden-method
@property
def has_generator(self):
return self.base.has_generator
[docs] def generator(self):
sub_gen = self.base.generator()
projectors = (
qml.Projector([val], wires=w) for val, w in zip(self.control_values, self.control_wires)
)
# needs to return a new_opmath instance regardless of whether new_opmath is enabled, because
# it otherwise can't handle ControlledGlobalPhase, see PR #5194
return qml.prod(*projectors, sub_gen)
@property
def has_adjoint(self):
return self.base.has_adjoint
[docs] def adjoint(self):
return ctrl(
self.base.adjoint(),
self.control_wires,
control_values=self.control_values,
work_wires=self.work_wires,
)
[docs] def pow(self, z):
base_pow = self.base.pow(z)
return [
ctrl(
op,
self.control_wires,
control_values=self.control_values,
work_wires=self.work_wires,
)
for op in base_pow
]
[docs] def simplify(self) -> "Operator":
if isinstance(self.base, Controlled):
base = self.base.base.simplify()
return ctrl(
base,
control=self.control_wires + self.base.control_wires,
control_values=self.control_values + self.base.control_values,
work_wires=self.work_wires + self.base.work_wires,
)
simplified_base = self.base.simplify()
if isinstance(simplified_base, qml.Identity):
return simplified_base
return ctrl(
op=simplified_base,
control=self.control_wires,
control_values=self.control_values,
work_wires=self.work_wires,
)
def _is_single_qubit_special_unitary(op):
if not op.has_matrix or len(op.wires) != 1:
return False
mat = op.matrix()
det = mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]
return qmlmath.allclose(det, 1)
def _decompose_pauli_x_based_no_control_values(op: Controlled):
"""Decomposes a PauliX-based operation"""
if isinstance(op.base, qml.PauliX) and len(op.control_wires) == 1:
return [qml.CNOT(wires=op.wires)]
if isinstance(op.base, qml.PauliX) and len(op.control_wires) == 2:
return qml.Toffoli.compute_decomposition(wires=op.wires)
if isinstance(op.base, qml.CNOT) and len(op.control_wires) == 1:
return qml.Toffoli.compute_decomposition(wires=op.wires)
return qml.MultiControlledX.compute_decomposition(
wires=op.wires,
work_wires=op.work_wires,
)
def _decompose_custom_ops(op: Controlled) -> list["operation.Operator"]:
"""Custom handling for decomposing a controlled operation"""
pauli_x_based_ctrl_ops = _get_pauli_x_based_ops()
ops_with_custom_ctrl_ops = _get_special_ops()
custom_key = (type(op.base), len(op.control_wires))
if custom_key in ops_with_custom_ctrl_ops:
custom_op_cls = ops_with_custom_ctrl_ops[custom_key]
return custom_op_cls.compute_decomposition(*op.data, op.wires)
if isinstance(op.base, pauli_x_based_ctrl_ops):
# has some special case handling of its own for further decomposition
return _decompose_pauli_x_based_no_control_values(op)
if isinstance(op.base, qml.GlobalPhase) and len(op.control_wires) == 1:
# use Lemma 5.2 from https://arxiv.org/pdf/quant-ph/9503016
return [qml.PhaseShift(phi=-op.data[0], wires=op.control_wires)]
# A multi-wire controlled PhaseShift should be decomposed first using the decomposition
# of ControlledPhaseShift. This is because the decomposition of PhaseShift contains a
# GlobalPhase that we do not have a handling for.
# TODO: remove this special case when we support ControlledGlobalPhase [sc-44933]
if isinstance(op.base, qml.PhaseShift):
base_decomp = qml.ControlledPhaseShift.compute_decomposition(*op.data, op.wires[-2:])
return [
ctrl(new_op, op.control_wires[:-1], work_wires=op.work_wires) for new_op in base_decomp
]
# TODO: will be removed in the second part of the controlled rework [sc-37951]
if len(op.control_wires) == 1 and hasattr(op.base, "_controlled"):
result = op.base._controlled(op.control_wires[0]) # pylint: disable=protected-access
# disallow decomposing to itself
# pylint: disable=unidiomatic-typecheck
if type(result) != type(op):
return [result]
qml.QueuingManager.remove(result)
return None
def _decompose_no_control_values(op: Controlled) -> Optional[list["operation.Operator"]]:
"""Decompose without considering control values. Returns None if no decomposition."""
decomp = _decompose_custom_ops(op)
if decomp is not None:
return decomp
if _is_single_qubit_special_unitary(op.base):
if len(op.control_wires) >= 2 and qmlmath.get_interface(*op.data) == "numpy":
return ctrl_decomp_bisect(op.base, op.control_wires)
return ctrl_decomp_zyz(op.base, op.control_wires, work_wires=op.work_wires)
if not op.base.has_decomposition:
return None
base_decomp = op.base.decomposition()
if len(base_decomp) == 0 and isinstance(op.base, qml.GlobalPhase) and len(op.control_wires) > 1:
warnings.warn(
"Multi-Controlled-GlobalPhase currently decomposes to nothing, and this will likely "
"produce incorrect results. Consider implementing your circuit with a different set "
"of operations, or use a device that natively supports GlobalPhase.",
UserWarning,
)
return [ctrl(newop, op.control_wires, work_wires=op.work_wires) for newop in base_decomp]
[docs]class ControlledOp(Controlled, operation.Operation):
"""Operation-specific methods and properties for the :class:`~.ops.op_math.Controlled` class.
When an :class:`~.operation.Operation` is provided to the :class:`~.ops.op_math.Controlled`
class, this type is constructed instead. It adds some additional :class:`~.operation.Operation`
specific methods and properties.
When we no longer rely on certain functionality through ``Operation``, we can get rid of this
class.
.. seealso:: :class:`~.Controlled`
"""
def __new__(cls, *_, **__):
# overrides dispatch behaviour of ``Controlled``
return object.__new__(cls)
# pylint: disable=too-many-function-args
def __init__(self, base, control_wires, control_values=None, work_wires=None, id=None):
super().__init__(base, control_wires, control_values, work_wires, id)
# check the grad_recipe validity
if self.grad_recipe is None:
# Make sure grad_recipe is an iterable of correct length instead of None
self.grad_recipe = [None] * self.num_params
@property
def name(self):
return self._name
@property
def grad_method(self):
return self.base.grad_method
@property
def parameter_frequencies(self):
if self.base.num_params == 1:
try:
base_gen = qml.generator(self.base, format="observable")
except operation.GeneratorUndefinedError as e:
raise operation.ParameterFrequenciesUndefinedError(
f"Operation {self.base.name} does not have parameter frequencies defined."
) from e
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore", message=r".+ eigenvalues will be computed numerically\."
)
base_gen_eigvals = qml.eigvals(base_gen, k=2**self.base.num_wires)
# The projectors in the full generator add a eigenvalue of `0` to
# the eigenvalues of the base generator.
gen_eigvals = np.append(base_gen_eigvals, 0)
processed_gen_eigvals = tuple(np.round(gen_eigvals, 8))
return [qml.gradients.eigvals_to_frequencies(processed_gen_eigvals)]
raise operation.ParameterFrequenciesUndefinedError(
f"Operation {self.name} does not have parameter frequencies defined, "
"and parameter frequencies can not be computed via generator for more than one "
"parameter."
)
# Program capture with controlled ops needs to unpack and re-pack the control wires to support dynamic wires
# See capture module for more information on primitives
# If None, jax isn't installed so the class never got a primitive.
if Controlled._primitive is not None: # pylint: disable=protected-access
@Controlled._primitive.def_impl # pylint: disable=protected-access
def _(base, *control_wires, control_values=None, work_wires=None, id=None):
return type.__call__(
Controlled,
base,
control_wires,
control_values=control_values,
work_wires=work_wires,
id=id,
)
# easier to just keep the same primitive for both versions
# dispatch between the two types happens inside instance creation anyway
ControlledOp._primitive = Controlled._primitive # pylint: disable=protected-access
_modules/pennylane/ops/op_math/controlled
Download Python script
Download Notebook
View on GitHub