Source code for catalyst.api_extensions.quantum_operators
# Copyright 2022-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains public API functions for quantum operators which are not
included in PennyLane, or whose behaviour needs to be adapted for Catalyst.
"""
import copy
from collections.abc import Sized
from typing import Any, Callable, List, Optional, Union
import jax
import pennylane as qml
from jax._src.tree_util import tree_flatten
from jax.core import get_aval
from pennylane import QueuingManager
from pennylane.operation import Operator
from pennylane.ops.op_math.controlled import create_controlled_op
from pennylane.tape import QuantumTape
from catalyst.api_extensions.control_flow import cond
from catalyst.jax_extras import (
ClosedJaxpr,
DynamicJaxprTracer,
_input_type_to_tracers,
convert_constvars_jaxpr,
deduce_avals,
new_inner_tracer,
)
from catalyst.jax_primitives import AbstractQreg, adjoint_p, qmeasure_p
from catalyst.jax_tracer import (
HybridOp,
HybridOpRegion,
QRegPromise,
has_nested_tapes,
trace_quantum_tape,
)
from catalyst.tracing.contexts import EvaluationContext
## API ##
[docs]def measure(
wires, reset: Optional[bool] = False, postselect: Optional[int] = None
) -> DynamicJaxprTracer:
r"""A :func:`qjit` compatible mid-circuit measurement on 1 qubit for PennyLane/Catalyst.
.. important::
The :func:`qml.measure() <pennylane.measure>` function is **not** QJIT
compatible and :func:`catalyst.measure` from Catalyst should be used instead.
Args:
wires (int): The wire the projective measurement applies to.
reset (Optional[bool]): Whether to reset the wire to the :math:`|0\rangle`
state after measurement.
postselect (Optional[int]): Which basis state to postselect after a mid-circuit measurement.
Returns:
A JAX tracer for the mid-circuit measurement.
Raises:
ValueError: Called outside the tape context.
**Example**
.. code-block:: python
dev = qml.device("lightning.qubit", wires=2)
@qjit
@qml.qnode(dev)
def circuit(x: float):
qml.RX(x, wires=0)
m1 = measure(wires=0)
qml.RX(m1 * jnp.pi, wires=1)
m2 = measure(wires=1)
qml.RZ(m2 * jnp.pi / 2, wires=0)
return qml.expval(qml.PauliZ(0)), m2
>>> circuit(0.43)
[array(1.), array(False)]
>>> circuit(0.43)
[array(-1.), array(True)]
**Example with post-selection**
.. code-block:: python
dev = qml.device("lightning.qubit", wires=1)
@qjit
@qml.qnode(dev)
def circuit():
qml.Hadamard(0)
m = measure(0, postselect=1)
return qml.expval(qml.PauliZ(0))
>>> circuit()
-1.0
**Example with reset**
.. code-block:: python
dev = qml.device("lightning.qubit", wires=1)
@qjit
@qml.qnode(dev)
def circuit():
qml.Hadamard(0)
m = measure(0, reset=True)
return qml.expval(qml.PauliZ(0))
>>> circuit()
1.0
"""
EvaluationContext.check_is_tracing("catalyst.measure can only be used from within @qjit.")
EvaluationContext.check_is_quantum_tracing(
"catalyst.measure can only be used from within a qml.qnode."
)
ctx = EvaluationContext.get_main_tracing_context()
wires = list(wires) if isinstance(wires, (list, tuple)) else [wires]
if len(wires) != 1:
raise TypeError(f"Only one element is supported for the 'wires' parameter, got {wires}.")
if isinstance(wires[0], jax.Array) and wires[0].shape not in ((), (1,)):
raise TypeError(
f"Measure is only supported on 1 qubit, got array of shape {wires[0].shape}."
)
# Copy, so wires remain unmodified
in_classical_tracers = wires.copy()
if postselect is not None and postselect not in [0, 1]:
raise TypeError(f"postselect must be '0' or '1', got {postselect}")
in_classical_tracers.append(postselect)
m = new_inner_tracer(ctx.trace, get_aval(True))
MidCircuitMeasure(
in_classical_tracers=in_classical_tracers,
out_classical_tracers=[m],
regions=[],
)
# If reset was requested, reset qubit only if the measurement result was 1
if reset:
@cond(m)
def reset_fn():
qml.PauliX(wires=wires)
reset_fn()
return m
[docs]def adjoint(f: Union[Callable, Operator]) -> Union[Callable, Operator]:
"""A :func:`~.qjit` compatible adjoint transformer for PennyLane/Catalyst.
Returns a quantum function or operator that applies the adjoint of the
provided function or operator.
.. warning::
This function does not support performing the adjoint
of quantum functions that contain mid-circuit measurements.
Args:
f (Callable or Operator): A PennyLane operation or a Python function
containing PennyLane quantum operations.
Returns:
If an Operator is provided, returns an Operator that is the adjoint. If
a function is provided, returns a function with the same call signature
that returns the Adjoint of the provided function.
Raises:
ValueError: invalid parameter values
**Example 1 (basic usage)**
.. code-block:: python
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def workflow(theta, wires):
catalyst.adjoint(qml.RZ)(theta, wires=wires)
catalyst.adjoint(qml.RZ(theta, wires=wires))
def func():
qml.RX(theta, wires=wires)
qml.RY(theta, wires=wires)
catalyst.adjoint(func)()
return qml.probs()
>>> workflow(jnp.pi/2, wires=0)
array([0.5, 0.5])
**Example 2 (with Catalyst control flow)**
.. code-block:: python
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def workflow(theta, n, wires):
def func():
@catalyst.for_loop(0, n, 1)
def loop_fn(i):
qml.RX(theta, wires=wires)
loop_fn()
catalyst.adjoint(func)()
return qml.probs()
>>> workflow(jnp.pi/2, 3, 0)
[1.00000000e+00 7.39557099e-32]
"""
if not EvaluationContext.is_tracing():
return qml.adjoint(f)
def _call_handler(*args, _callee: Callable, **kwargs):
EvaluationContext.check_is_quantum_tracing(
"catalyst.adjoint can only be used from within a qml.qnode."
)
ctx = EvaluationContext.get_main_tracing_context()
with EvaluationContext.frame_tracing_context(ctx) as inner_trace:
in_classical_tracers, _ = tree_flatten((args, kwargs))
wffa, in_avals, _, _ = deduce_avals(_callee, args, kwargs)
arg_classical_tracers = _input_type_to_tracers(inner_trace.new_arg, in_avals)
quantum_tape = QuantumTape()
with QueuingManager.stop_recording(), quantum_tape:
# FIXME: move all full_raise calls into a separate function
res_classical_tracers = [
inner_trace.full_raise(t)
for t in wffa.call_wrapped(*arg_classical_tracers)
if isinstance(t, DynamicJaxprTracer)
]
_check_no_measurements(quantum_tape)
adjoint_region = HybridOpRegion(
inner_trace, quantum_tape, arg_classical_tracers, res_classical_tracers
)
return Adjoint(
in_classical_tracers=in_classical_tracers,
out_classical_tracers=[],
regions=[adjoint_region],
)
if isinstance(f, Callable):
def _callable(*args, **kwargs):
return _call_handler(*args, _callee=f, **kwargs)
return _callable
elif isinstance(f, Operator):
QueuingManager.remove(f)
def _callee():
QueuingManager.append(f)
return _call_handler(_callee=_callee)
else:
raise ValueError(f"Expected a callable or a qml.Operator, not {f}")
[docs]def ctrl(
f: Union[Callable, Operator],
control: List[Any],
control_values: Optional[List[Any]] = None,
work_wires: Optional[List[Any]] = None,
) -> Callable:
"""Create a method that applies a controlled version of the provided op. This function is the
Catalyst version of the ``qml.ctrl`` that supports Catalyst hybrid operations such as loops and
conditionals.
Args:
f (Callable or Operator): A PennyLane operation or a Python function
containing PennyLane quantum operations.
control (Wires): The control wire(s).
control_values (List[bool], optional): The value(s) the control wire(s) should take.
Integers other than 0 or 1 will be treated as ``int(bool(x))``.
work_wires (Any): Any auxiliary wires that can be used in the decomposition
Returns:
(function or :class:`~.operation.Operator`): If an Operator is provided, returns a
Controlled version of the Operator. If a function is provided, returns a function with the
same call signature that creates a controlled version of the provided function.
Raises:
ValueError: invalid parameter values, measurements are among the controlled operations.
**Example**
.. code-block:: python
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def workflow(theta, w, cw):
qml.Hadamard(wires=[0])
qml.Hadamard(wires=[1])
def func(arg):
qml.RX(theta, wires=arg)
@cond(theta > 0.0)
def cond_fn():
qml.RY(theta, wires=w)
catalyst.ctrl(func, control=[cw])(w)
catalyst.ctrl(cond_fn, control=[cw])()
catalyst.ctrl(qml.RZ, control=[cw])(theta, wires=w)
catalyst.ctrl(qml.RY(theta, wires=w), control=[cw])
return qml.probs()
>>> workflow(jnp.pi/4, 1, 0)
array([0.25, 0.25, 0.03661165, 0.46338835])
"""
if not EvaluationContext.is_tracing():
return qml.ctrl(f, control, control_values, work_wires)
if control_values is not None and (
(len(control) if isinstance(control, Sized) else 1)
!= (len(control_values) if isinstance(control_values, Sized) else 1)
):
raise ValueError(
f"Length of the control_values ({len(control_values)}) must be None or equal "
f"to the lenght of control ({len(control)})"
)
def _call_handler(*args, _callee: Callable, **kwargs):
EvaluationContext.check_is_quantum_tracing(
"catalyst.ctrl can only be used from within a qml.qnode."
)
in_classical_tracers, _ = tree_flatten((args, kwargs))
quantum_tape = QuantumTape()
with QueuingManager.stop_recording(), quantum_tape:
res = _callee(*args, **kwargs)
out_classical_tracers, _ = tree_flatten(res)
_check_no_measurements(quantum_tape)
region = HybridOpRegion(None, quantum_tape, [], [])
# Return the operation instance since PL expects this for qml.ctrl(op).
return QCtrl(
control_wires=control,
control_values=control_values,
work_wires=work_wires,
in_classical_tracers=in_classical_tracers,
out_classical_tracers=out_classical_tracers,
regions=[region],
)
if isinstance(f, Callable):
def _callable(*args, **kwargs):
return _call_handler(*args, _callee=f, **kwargs)
return _callable
elif isinstance(f, Operator):
QueuingManager.remove(f)
def _callee():
QueuingManager.append(f)
return _call_handler(_callee=_callee)
else:
raise ValueError(f"Expected a callable or a qml.Operator, not {f}") # pragma: no cover
## IMPL ##
class MidCircuitMeasure(HybridOp):
"""Operation representing a mid-circuit measurement."""
binder = qmeasure_p.bind
def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
op = self
wire = op.in_classical_tracers[0]
qubit = qrp.extract([wire])[0]
postselect = op.in_classical_tracers[1]
qubit2 = op.bind_overwrite_classical_tracers(ctx, trace, qubit, postselect=postselect)
qrp.insert([wire], [qubit2])
return qrp
class Adjoint(HybridOp):
"""PennyLane's adjoint operation"""
binder = adjoint_p.bind
def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
op = self
body_trace = op.regions[0].trace
body_tape = op.regions[0].quantum_tape
res_classical_tracers = op.regions[0].res_classical_tracers
with EvaluationContext.frame_tracing_context(ctx, body_trace):
qreg_in = _input_type_to_tracers(body_trace.new_arg, [AbstractQreg()])[0]
qrp_out = trace_quantum_tape(body_tape, device, qreg_in, ctx, body_trace)
qreg_out = qrp_out.actualize()
body_jaxpr, _, body_consts = ctx.frames[body_trace].to_jaxpr2(
res_classical_tracers + [qreg_out]
)
qreg = qrp.actualize()
args, args_tree = tree_flatten((body_consts, op.in_classical_tracers, [qreg]))
op_results = adjoint_p.bind(
*args,
args_tree=args_tree,
jaxpr=ClosedJaxpr(convert_constvars_jaxpr(body_jaxpr), ()),
)
qrp2 = QRegPromise(op_results[-1])
return qrp2
@property
def wires(self):
"""The list of all static wires."""
assert len(self.regions) == 1, "Adjoint is expected to have one region"
total_wires = sum((op.wires for op in self.regions[0].quantum_tape.operations), [])
return total_wires
# TODO: This class needs to be made interoperable with qml.Controlled since qml.ctrl dispatches
# to this class whenever a qjit context is active.
class QCtrl(HybridOp):
"""Catalyst quantum ctrl operation"""
def __init__(self, *args, control_wires, control_values=None, work_wires=None, **kwargs):
self._control_wires = qml.wires.Wires(control_wires)
self._work_wires = qml.wires.Wires([] if work_wires is None else work_wires)
if control_values is None:
self._control_values = [True] * len(self._control_wires)
elif isinstance(control_values, (int, bool)):
self._control_values = [control_values]
else:
self._control_values = control_values
super().__init__(*args, **kwargs)
def trace_quantum(self, ctx, device, trace, qrp) -> QRegPromise:
raise NotImplementedError("QCtrl does not support JAX quantum tracing") # pragma: no cover
def decomposition(self):
"""Compute quantum decomposition of the gate by recursively scanning the nested tape and
distributing the quantum control operaiton over the tape operations."""
assert len(self.regions) == 1, "Qctrl is expected to have one region"
_check_no_measurements(self.regions[0].quantum_tape)
new_tape = qctrl_distribute(
self.regions[0].quantum_tape,
self._control_wires,
self._control_values,
self._work_wires,
)
return new_tape.operations
@property
def wires(self):
"""The list of all control-wires, work-wires, and active-wires."""
assert len(self.regions) == 1, "Qctrl is expected to have one region"
total_wires = sum(
(op.wires for op in self.regions[0].quantum_tape.operations),
self._control_wires,
)
total_wires += self._work_wires
return total_wires
@property
def control_wires(self):
"""Wires used in quantum conditioning."""
return self._control_wires
@property
def control_values(self):
"""(Boolean) Values upon which to condition on."""
return self._control_values
@property
def work_wires(self):
"""Optional wires that can be used in the expansion of this op."""
return self._work_wires
def map_wires(self, wire_map):
"""Map wires to new wires according to wire_map"""
new_ops = []
for op in self.regions[0].quantum_tape.operations:
new_ops.append(op.map_wires(wire_map))
self.regions[0].quantum_tape = QuantumTape(new_ops, [])
self._control_wires = [wire_map.get(wire, wire) for wire in self._control_wires]
self._work_wires = [wire_map.get(wire, wire) for wire in self._work_wires]
return self
def qctrl_distribute(
tape: QuantumTape,
control_wires: List[Any],
control_values: List[Any],
work_wires: Optional[List[Any]] = None,
) -> QuantumTape:
"""Distribute the quantum control operation, described by ``control_wires`` and
``control_values``, over all the operations on the nested quantum tape.
"""
# Note: The transformation modifies operations in the source quantum tape, so we must not use it
# after we called this function.
assert len(control_wires) > 0, "This transformation expects a non-empty list of control_wires"
assert len(control_wires) == len(control_values), (
f"Length of the control_values ({len(control_values)}) must be equal "
f"to the lenght of control_wires ({len(control_wires)})"
)
ctx = EvaluationContext.get_main_tracing_context()
ops2 = []
for op in tape.operations:
if has_nested_tapes(op):
if isinstance(op, QCtrl):
for region in [region for region in op.regions if region.quantum_tape is not None]:
tape2 = qctrl_distribute(
region.quantum_tape,
control_wires + op.control_wires,
control_values + op.control_values,
work_wires + op.work_wires,
)
ops2.extend(tape2.operations)
else:
for region in [region for region in op.regions if region.quantum_tape is not None]:
with EvaluationContext.frame_tracing_context(ctx, region.trace):
region.quantum_tape = qctrl_distribute(
region.quantum_tape, control_wires, control_values, work_wires
)
ops2.append(op)
else:
ops2.append(
create_controlled_op(
copy.copy(op),
control=control_wires,
control_values=control_values,
work_wires=work_wires,
)
)
return QuantumTape(ops2, tape.measurements)
## PRIVATE ##
def _check_no_measurements(tape: QuantumTape) -> None:
"""Check the nested quantum tape for the absense of quantum measurements of any kind"""
msg = "Quantum measurements are not allowed"
if len(tape.measurements) > 0:
raise ValueError(msg)
for op in tape.operations:
if has_nested_tapes(op):
for r in [r for r in op.regions if r.quantum_tape is not None]:
_check_no_measurements(r.quantum_tape)
else:
if isinstance(op, MidCircuitMeasure):
raise ValueError(msg)
_modules/catalyst/api_extensions/quantum_operators
Download Python script
Download Notebook
View on GitHub