qml.defer_measurements

defer_measurements(tape, reduce_postselected=True, allow_postselect=True, num_wires=None)[source]

Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations.

This transform uses the deferred measurement principle and applies to qubit-based quantum functions.

Support for mid-circuit measurements is device-dependent. If a device doesn’t support mid-circuit measurements natively, then the QNode will apply this transform.

Note

The transform uses the ctrl() transform to implement operations controlled on mid-circuit measurement outcomes. The set of operations that can be controlled as such depends on the set of operations supported by the chosen device.

Note

Devices that inherit from QubitDevice must be initialized with an additional wire for each mid-circuit measurement after which the measured wire is reused or reset for defer_measurements to transform the quantum tape correctly.

Note

This transform does not change the list of terminal measurements returned by the quantum function.

Note

When applying the transform on a quantum function that contains the Snapshot instruction, state information corresponding to simulating the transformed circuit will be obtained. No post-measurement states are considered.

Warning

state() is not supported with the defer_measurements transform. Additionally, probs(), sample() and counts() can only be used with defer_measurements if wires or an observable are explicitly specified.

Warning

defer_measurements does not support using custom wire labels if any measured wires are reused or reset.

Parameters
  • tape (QNode or QuantumTape or Callable) – a quantum circuit.

  • reduce_postselected (bool) – Whether to use postselection information to reduce the number of operations and control wires in the output tape. Active by default. This is currently ignored if program capture is enabled.

  • allow_postselect (bool) – Whether postselection is allowed. In order to perform postselection with defer_measurements, the device must support the Projector operation. Defaults to True. This is currently ignored if program capture is enabled.

  • num_wires (int) – Optional argument to specify the total number of circuit wires. This is only used if program capture is enabled.

Returns

The

transformed circuit as described in qml.transform.

Return type

qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]

Raises
  • ValueError – If custom wire labels are used with qubit reuse or reset

  • ValueError – If any measurements with no wires or observable are present

  • ValueError – If continuous variable operations or measurements are present

  • ValueError – If using the transform with any device other than default.qubit and postselection is used

Example

Suppose we have a quantum function with mid-circuit measurements and conditional operations:

def qfunc(par):
    qml.RY(0.123, wires=0)
    qml.Hadamard(wires=1)
    m_0 = qml.measure(1)
    qml.cond(m_0, qml.RY)(par, wires=0)
    return qml.expval(qml.Z(0))

The defer_measurements transform allows executing such quantum functions without having to perform mid-circuit measurements:

>>> dev = qml.device('default.qubit', wires=2)
>>> transformed_qfunc = qml.defer_measurements(qfunc)
>>> qnode = qml.QNode(transformed_qfunc, dev)
>>> par = np.array(np.pi/2, requires_grad=True)
>>> qnode(par)
tensor(0.43487747, requires_grad=True)

We can also differentiate parameters passed to conditional operations:

>>> qml.grad(qnode)(par)
tensor(-0.49622252, requires_grad=True)

Reusing and resetting measured wires will work as expected with the defer_measurements transform:

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def func(x, y):
    qml.RY(x, wires=0)
    qml.CNOT(wires=[0, 1])
    m_0 = qml.measure(1, reset=True)

    qml.cond(m_0, qml.RY)(y, wires=0)
    qml.RX(np.pi/4, wires=1)
    return qml.probs(wires=[0, 1])

Executing this QNode:

>>> pars = np.array([0.643, 0.246], requires_grad=True)
>>> func(*pars)
tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True)

By default, defer_measurements makes use of postselection information of mid-circuit measurements in the circuit in order to reduce the number of controlled operations and control wires. We can explicitly switch this feature off and compare the created circuits with and without this optimization. Consider the following circuit:

@qml.qnode(qml.device("default.qubit"))
def node(x):
    qml.RX(x, 0)
    qml.RX(x, 1)
    qml.RX(x, 2)

    mcm0 = qml.measure(0, postselect=0, reset=False)
    mcm1 = qml.measure(1, postselect=None, reset=True)
    mcm2 = qml.measure(2, postselect=1, reset=False)
    qml.cond(mcm0+mcm1+mcm2==1, qml.RX)(0.5, 3)
    return qml.expval(qml.Z(0) @ qml.Z(3))

Without the optimization, we find three gates controlled on the three measured qubits. They correspond to the combinations of controls that satisfy the condition mcm0+mcm1+mcm2==1.

>>> print(qml.draw(qml.defer_measurements(node, reduce_postselected=False))(0.6))
0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────────────────────────────────┤ ╭<Z@Z>
1: ──RX(0.60)─────────│──╭●─╭X───────────────────────────────────────┤ │
2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|─╭○────────╭○────────╭●────────┤ │
3: ───────────────────│──│──│──────────├RX(0.50)─├RX(0.50)─├RX(0.50)─┤ ╰<Z@Z>
4: ───────────────────╰X─│──│──────────├○────────├●────────├○────────┤
5: ──────────────────────╰X─╰●─────────╰●────────╰○────────╰○────────┤

If we do not explicitly deactivate the optimization, we obtain a much simpler circuit:

>>> print(qml.draw(qml.defer_measurements(node))(0.6))
0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────┤ ╭<Z@Z>
1: ──RX(0.60)─────────│──╭●─╭X───────────┤ │
2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|───┤ │
3: ───────────────────│──│──│──╭RX(0.50)─┤ ╰<Z@Z>
4: ───────────────────╰X─│──│──│─────────┤
5: ──────────────────────╰X─╰●─╰○────────┤

There is only one controlled gate with only one control wire.

qml.defer_measurements can be applied to callables when program capture is enabled. To do so, the num_wires argument must be provided, which should be an integer corresponding to the total number of available wires. For m mid-circuit measurements, range(num_wires - m, num_wires) will be the range of wires used to map mid-circuit measurements to CNOT gates.

Warning

While the transform includes validation to avoid overlap between wires of the original circuit and mid-circuit measurement target wires, if any wires of the original ciruit are traced, i.e. dependent on dynamic arguments to the transformed workflow, the validation may not catch overlaps. Consider the following example:

from functools import partial
import jax

qml.capture.enable()

@qml.capture.expand_plxpr_transforms
@partial(qml.defer_measurements, num_wires=1)
def f(n):
    qml.measure(n)
>>> jax.make_jaxpr(f)(0)
{ lambda ; a:i64[]. let _:AbstractOperator() = CNOT[n_wires=2] a 0 in () }

The circuit gets transformed without issue because the concrete value of the measured wire is unknown. However, execution with n = 0 would raise an error, as the CNOT wires would be (0, 0).

Thus, users must by cautious when transforming a circuit. For ``n`` total wires and ``c`` circuit wires, the number of mid-circuit measurements allowed is ``n - c``.

Using defer_measurements with program capture enabled introduces new features and restrictions:

New features

  • Arbitrary classical processing of mid-circuit measurement values is now possible. With program capture disabled, only limited classical processing, as detailed in the documentation for measure(). With program capture enabled, any unary or binary jax.numpy functions that can be applied to scalars can be used with mid-circuit measurements.

  • Using mid-circuit measurements as gate parameters is now possible. This feature currently has the following restrictions: * Mid-circuit measurement values cannot be used for multiple parameters of the same gate. * Mid-circuit measurement values cannot be used as wires.

    from functools import partial
    import jax
    import jax.numpy as jnp
    
    qml.capture.enable()
    
    @qml.capture.expand_plxpr_transforms
    @partial(qml.defer_measurements, num_wires=10)
    def f():
        m0 = qml.measure(0)
    
        phi = jnp.sin(jnp.pi * m0)
        qml.RX(phi, 0)
        return qml.expval(qml.PauliZ(0))
    
    >>> jax.make_jaxpr(f)()
    { lambda ; . let
        _:AbstractOperator() = CNOT[n_wires=2] 0 9
        a:f64[] = mul 0.0 3.141592653589793
        b:f64[] = sin a
        c:AbstractOperator() = RX[n_wires=1] b 0
        _:AbstractOperator() = Controlled[
          control_values=(False,)
          work_wires=Wires([])
        ] c 9
        d:f64[] = mul 1.0 3.141592653589793
        e:f64[] = sin d
        f:AbstractOperator() = RX[n_wires=1] e 0
        _:AbstractOperator() = Controlled[
          control_values=(True,)
          work_wires=Wires([])
        ] f 9
        g:AbstractOperator() = PauliZ[n_wires=1] 0
        h:AbstractMeasurement(n_wires=None) = expval_obs g
      in (h,) }
    

The above dummy example showcases how the transform is applied when the aforementioned features are used.

What doesn’t work

  • mid-circuit measurement values cannot be used in the condition for a while_loop().

  • measure() cannot be used inside the body of loop primitives (while_loop(), for_loop()).

  • If a branch of cond() uses mid-circuit measurements as its predicate, then all other branches must also use mid-circuit measurement values as predicates.

  • For an n-parameter gate, mid-circuit measurement values can only be used for 1 of the n parameters.

  • measure() can only be used in the bodies of branches of cond() if none of the branches use mid-circuit measurements as predicates

  • measure() cannot be used inside the body of functions being transformed with adjoint() or ctrl().