Source code for pennylane.transforms.defer_measurements

# Copyright 2018-2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code for the tape transform implementing the deferred measurement principle."""

import pennylane as qml
from pennylane.measurements import CountsMP, MeasurementValue, MidMeasureMP, ProbabilityMP, SampleMP
from pennylane.ops.op_math import ctrl
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.typing import PostprocessingFn
from pennylane.wires import Wires

# pylint: disable=too-many-branches, protected-access, too-many-statements


def _check_tape_validity(tape: QuantumScript):
    """Helper function to check that the tape is valid."""
    cv_types = (qml.operation.CVOperation, qml.operation.CVObservable)
    ops_cv = any(isinstance(op, cv_types) and op.name != "Identity" for op in tape.operations)
    obs_cv = any(
        isinstance(getattr(op, "obs", None), cv_types)
        and not isinstance(getattr(op, "obs", None), qml.Identity)
        for op in tape.measurements
    )
    if ops_cv or obs_cv:
        raise ValueError("Continuous variable operations and observables are not supported.")

    for mp in tape.measurements:
        if isinstance(mp, (CountsMP, ProbabilityMP, SampleMP)) and not (
            mp.obs or mp._wires or mp.mv
        ):
            raise ValueError(
                f"Cannot use {mp.__class__.__name__} as a measurement without specifying wires "
                "when using qml.defer_measurements. Deferred measurements can occur "
                "automatically when using mid-circuit measurements on a device that does not "
                "support them."
            )

        if mp.__class__.__name__ == "StateMP":
            raise ValueError(
                "Cannot use StateMP as a measurement when using qml.defer_measurements. "
                "Deferred measurements can occur automatically when using mid-circuit "
                "measurements on a device that does not support them."
            )

    samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements)
    postselect_present = any(
        op.postselect is not None for op in tape.operations if isinstance(op, MidMeasureMP)
    )
    if postselect_present and samples_present and tape.batch_size is not None:
        raise ValueError(
            "Returning qml.sample is not supported when postselecting mid-circuit "
            "measurements with broadcasting"
        )


def _collect_mid_measure_info(tape: QuantumScript):
    """Helper function to collect information related to mid-circuit measurements in the tape."""

    # Find wires that are reused after measurement
    measured_wires = []
    reused_measurement_wires = set()
    any_repeated_measurements = False
    is_postselecting = False

    for op in tape:
        if isinstance(op, MidMeasureMP):
            if op.postselect is not None:
                is_postselecting = True
            if op.reset:
                reused_measurement_wires.add(op.wires[0])

            if op.wires[0] in measured_wires:
                any_repeated_measurements = True
            measured_wires.append(op.wires[0])

        else:
            reused_measurement_wires = reused_measurement_wires.union(
                set(measured_wires).intersection(op.wires.toset())
            )

    return measured_wires, reused_measurement_wires, any_repeated_measurements, is_postselecting


def null_postprocessing(results):
    """A postprocessing function returned by a transform that only converts the batch of results
    into a result for a single ``QuantumTape``.
    """
    return results[0]


[docs]@transform def defer_measurements( tape: QuantumScript, reduce_postselected: bool = True, **kwargs ) -> tuple[QuantumScriptBatch, PostprocessingFn]: """Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. This transform uses the `deferred measurement principle <https://en.wikipedia.org/wiki/Deferred_Measurement_Principle>`_ and applies to qubit-based quantum functions. Support for mid-circuit measurements is device-dependent. If a device doesn't support mid-circuit measurements natively, then the QNode will apply this transform. .. note:: The transform uses the :func:`~.ctrl` transform to implement operations controlled on mid-circuit measurement outcomes. The set of operations that can be controlled as such depends on the set of operations supported by the chosen device. .. note:: Devices that inherit from :class:`~pennylane.QubitDevice` **must** be initialized with an additional wire for each mid-circuit measurement after which the measured wire is reused or reset for ``defer_measurements`` to transform the quantum tape correctly. .. note:: This transform does not change the list of terminal measurements returned by the quantum function. .. note:: When applying the transform on a quantum function that contains the :class:`~.Snapshot` instruction, state information corresponding to simulating the transformed circuit will be obtained. No post-measurement states are considered. .. warning:: :func:`~.pennylane.state` is not supported with the ``defer_measurements`` transform. Additionally, :func:`~.pennylane.probs`, :func:`~.pennylane.sample` and :func:`~.pennylane.counts` can only be used with ``defer_measurements`` if wires or an observable are explicitly specified. .. warning:: ``defer_measurements`` does not support using custom wire labels if any measured wires are reused or reset. Args: tape (QNode or QuantumTape or Callable): a quantum circuit. reduce_postselected (bool): Whether or not to use postselection information to reduce the number of operations and control wires in the output tape. Active by default. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. Raises: ValueError: If custom wire labels are used with qubit reuse or reset ValueError: If any measurements with no wires or observable are present ValueError: If continuous variable operations or measurements are present ValueError: If using the transform with any device other than :class:`default.qubit <~pennylane.devices.DefaultQubit>` and postselection is used **Example** Suppose we have a quantum function with mid-circuit measurements and conditional operations: .. code-block:: python3 def qfunc(par): qml.RY(0.123, wires=0) qml.Hadamard(wires=1) m_0 = qml.measure(1) qml.cond(m_0, qml.RY)(par, wires=0) return qml.expval(qml.Z(0)) The ``defer_measurements`` transform allows executing such quantum functions without having to perform mid-circuit measurements: >>> dev = qml.device('default.qubit', wires=2) >>> transformed_qfunc = qml.defer_measurements(qfunc) >>> qnode = qml.QNode(transformed_qfunc, dev) >>> par = np.array(np.pi/2, requires_grad=True) >>> qnode(par) tensor(0.43487747, requires_grad=True) We can also differentiate parameters passed to conditional operations: >>> qml.grad(qnode)(par) tensor(-0.49622252, requires_grad=True) Reusing and reseting measured wires will work as expected with the ``defer_measurements`` transform: .. code-block:: python3 dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(x, y): qml.RY(x, wires=0) qml.CNOT(wires=[0, 1]) m_0 = qml.measure(1, reset=True) qml.cond(m_0, qml.RY)(y, wires=0) qml.RX(np.pi/4, wires=1) return qml.probs(wires=[0, 1]) Executing this QNode: >>> pars = np.array([0.643, 0.246], requires_grad=True) >>> func(*pars) tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True) .. details:: :title: Usage Details By default, ``defer_measurements`` makes use of postselection information of mid-circuit measurements in the circuit in order to reduce the number of controlled operations and control wires. We can explicitly switch this feature off and compare the created circuits with and without this optimization. Consider the following circuit: .. code-block:: python3 @qml.qnode(qml.device("default.qubit")) def node(x): qml.RX(x, 0) qml.RX(x, 1) qml.RX(x, 2) mcm0 = qml.measure(0, postselect=0, reset=False) mcm1 = qml.measure(1, postselect=None, reset=True) mcm2 = qml.measure(2, postselect=1, reset=False) qml.cond(mcm0+mcm1+mcm2==1, qml.RX)(0.5, 3) return qml.expval(qml.Z(0) @ qml.Z(3)) Without the optimization, we find three gates controlled on the three measured qubits. They correspond to the combinations of controls that satisfy the condition ``mcm0+mcm1+mcm2==1``. >>> print(qml.draw(qml.defer_measurements(node, reduce_postselected=False))(0.6)) 0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────────────────────────────────┤ ╭<Z@Z> 1: ──RX(0.60)─────────│──╭●─╭X───────────────────────────────────────┤ │ 2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|─╭○────────╭○────────╭●────────┤ │ 3: ───────────────────│──│──│──────────├RX(0.50)─├RX(0.50)─├RX(0.50)─┤ ╰<Z@Z> 4: ───────────────────╰X─│──│──────────├○────────├●────────├○────────┤ 5: ──────────────────────╰X─╰●─────────╰●────────╰○────────╰○────────┤ If we do not explicitly deactivate the optimization, we obtain a much simpler circuit: >>> print(qml.draw(qml.defer_measurements(node))(0.6)) 0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────┤ ╭<Z@Z> 1: ──RX(0.60)─────────│──╭●─╭X───────────┤ │ 2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|───┤ │ 3: ───────────────────│──│──│──╭RX(0.50)─┤ ╰<Z@Z> 4: ───────────────────╰X─│──│──│─────────┤ 5: ──────────────────────╰X─╰●─╰○────────┤ There is only one controlled gate with only one control wire. """ if not any(isinstance(o, MidMeasureMP) for o in tape.operations): return (tape,), null_postprocessing _check_tape_validity(tape) device = kwargs.get("device", None) new_operations = [] # Find wires that are reused after measurement ( measured_wires, reused_measurement_wires, any_repeated_measurements, is_postselecting, ) = _collect_mid_measure_info(tape) if is_postselecting and device is not None and not isinstance(device, qml.devices.DefaultQubit): raise ValueError(f"Postselection is not supported on the {device} device.") if len(reused_measurement_wires) > 0 and not all(isinstance(w, int) for w in tape.wires): raise ValueError( "qml.defer_measurements does not support custom wire labels with qubit reuse/reset." ) # Apply controlled operations to store measurement outcomes and replace # classically controlled operations control_wires = {} cur_wire = ( max(tape.wires) + 1 if reused_measurement_wires or any_repeated_measurements else None ) for op in tape.operations: if isinstance(op, MidMeasureMP): _ = measured_wires.pop(0) if op.postselect is not None: with QueuingManager.stop_recording(): new_operations.append(qml.Projector([op.postselect], wires=op.wires[0])) # Store measurement outcome in new wire if wire gets reused if op.wires[0] in reused_measurement_wires or op.wires[0] in measured_wires: control_wires[op.id] = cur_wire with QueuingManager.stop_recording(): new_operations.append(qml.CNOT([op.wires[0], cur_wire])) if op.reset: with QueuingManager.stop_recording(): # No need to manually reset if postselecting on |0> if op.postselect is None: new_operations.append(qml.CNOT([cur_wire, op.wires[0]])) elif op.postselect == 1: # We know that the measured wire will be in the |1> state if # postselected |1>. So we can just apply a PauliX instead of # a CNOT to reset new_operations.append(qml.X(op.wires[0])) cur_wire += 1 else: control_wires[op.id] = op.wires[0] elif op.__class__.__name__ == "Conditional": with QueuingManager.stop_recording(): new_operations.extend(_add_control_gate(op, control_wires, reduce_postselected)) else: new_operations.append(op) new_measurements = [] for mp in tape.measurements: if mp.mv is not None: # Update measurement value wires. We can't use `qml.map_wires` because the same # wire can map to different control wires when multiple mid-circuit measurements # are made on the same wire. This mapping is determined by the id of the # MidMeasureMPs. Thus, we need to manually map wires for each MidMeasureMP. if isinstance(mp.mv, MeasurementValue): new_ms = [ qml.map_wires(m, {m.wires[0]: control_wires[m.id]}) for m in mp.mv.measurements ] new_m = MeasurementValue(new_ms, mp.mv.processing_fn) else: new_m = [] for val in mp.mv: new_ms = [ qml.map_wires(m, {m.wires[0]: control_wires[m.id]}) for m in val.measurements ] new_m.append(MeasurementValue(new_ms, val.processing_fn)) with QueuingManager.stop_recording(): new_mp = ( type(mp)(obs=new_m) if not isinstance(mp, CountsMP) else CountsMP(obs=new_m, all_outcomes=mp.all_outcomes) ) else: new_mp = mp new_measurements.append(new_mp) new_tape = type(tape)(new_operations, new_measurements, shots=tape.shots) if is_postselecting and new_tape.batch_size is not None: # Split tapes if broadcasting with postselection return qml.transforms.broadcast_expand(new_tape) return [new_tape], null_postprocessing
@defer_measurements.custom_qnode_transform def _defer_measurements_qnode(self, qnode, targs, tkwargs): """Custom qnode transform for ``defer_measurements``.""" if tkwargs.get("device", None): raise ValueError( "Cannot provide a 'device' value directly to the defer_measurements decorator " "when transforming a QNode." ) tkwargs.setdefault("device", qnode.device) return self.default_qnode_transform(qnode, targs, tkwargs) def _add_control_gate(op, control_wires, reduce_postselected): """Helper function to add control gates""" if reduce_postselected: control = [control_wires[m.id] for m in op.meas_val.measurements if m.postselect is None] items = op.meas_val._postselected_items() else: control = [control_wires[m.id] for m in op.meas_val.measurements] items = op.meas_val._items() new_ops = [] for branch, value in items: if value: # Empty sampling branches can occur when using _postselected_items new_op = ( op.base if branch == () else ctrl(op.base, control=Wires(control), control_values=branch) ) new_ops.append(new_op) return new_ops