qml.defer_measurements¶
- defer_measurements(tape, reduce_postselected=True, allow_postselect=True, num_wires=None)[source]¶
Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations.
This transform uses the deferred measurement principle and applies to qubit-based quantum functions.
Support for mid-circuit measurements is device-dependent. If a device doesn’t support mid-circuit measurements natively, then the QNode will apply this transform.
Note
The transform uses the
ctrl()
transform to implement operations controlled on mid-circuit measurement outcomes. The set of operations that can be controlled as such depends on the set of operations supported by the chosen device.Note
Devices that inherit from
QubitDevice
must be initialized with an additional wire for each mid-circuit measurement after which the measured wire is reused or reset fordefer_measurements
to transform the quantum tape correctly.Note
This transform does not change the list of terminal measurements returned by the quantum function.
Note
When applying the transform on a quantum function that contains the
Snapshot
instruction, state information corresponding to simulating the transformed circuit will be obtained. No post-measurement states are considered.Warning
state()
is not supported with thedefer_measurements
transform. Additionally,probs()
,sample()
andcounts()
can only be used withdefer_measurements
if wires or an observable are explicitly specified.Warning
defer_measurements
does not support using custom wire labels if any measured wires are reused or reset.- Parameters
tape (QNode or QuantumTape or Callable) – a quantum circuit.
reduce_postselected (bool) – Whether to use postselection information to reduce the number of operations and control wires in the output tape. Active by default. This is currently ignored if program capture is enabled.
allow_postselect (bool) – Whether postselection is allowed. In order to perform postselection with
defer_measurements
, the device must support theProjector
operation. Defaults toTrue
. This is currently ignored if program capture is enabled.num_wires (int) – Optional argument to specify the total number of circuit wires. This is only used if program capture is enabled.
- Returns
- The
transformed circuit as described in
qml.transform
.
- Return type
qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]
- Raises
ValueError – If custom wire labels are used with qubit reuse or reset
ValueError – If any measurements with no wires or observable are present
ValueError – If continuous variable operations or measurements are present
ValueError – If using the transform with any device other than
default.qubit
and postselection is used
Example
Suppose we have a quantum function with mid-circuit measurements and conditional operations:
def qfunc(par): qml.RY(0.123, wires=0) qml.Hadamard(wires=1) m_0 = qml.measure(1) qml.cond(m_0, qml.RY)(par, wires=0) return qml.expval(qml.Z(0))
The
defer_measurements
transform allows executing such quantum functions without having to perform mid-circuit measurements:>>> dev = qml.device('default.qubit', wires=2) >>> transformed_qfunc = qml.defer_measurements(qfunc) >>> qnode = qml.QNode(transformed_qfunc, dev) >>> par = np.array(np.pi/2, requires_grad=True) >>> qnode(par) tensor(0.43487747, requires_grad=True)
We can also differentiate parameters passed to conditional operations:
>>> qml.grad(qnode)(par) tensor(-0.49622252, requires_grad=True)
Reusing and resetting measured wires will work as expected with the
defer_measurements
transform:dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(x, y): qml.RY(x, wires=0) qml.CNOT(wires=[0, 1]) m_0 = qml.measure(1, reset=True) qml.cond(m_0, qml.RY)(y, wires=0) qml.RX(np.pi/4, wires=1) return qml.probs(wires=[0, 1])
Executing this QNode:
>>> pars = np.array([0.643, 0.246], requires_grad=True) >>> func(*pars) tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True)
Usage Details
By default,
defer_measurements
makes use of postselection information of mid-circuit measurements in the circuit in order to reduce the number of controlled operations and control wires. We can explicitly switch this feature off and compare the created circuits with and without this optimization. Consider the following circuit:@qml.qnode(qml.device("default.qubit")) def node(x): qml.RX(x, 0) qml.RX(x, 1) qml.RX(x, 2) mcm0 = qml.measure(0, postselect=0, reset=False) mcm1 = qml.measure(1, postselect=None, reset=True) mcm2 = qml.measure(2, postselect=1, reset=False) qml.cond(mcm0+mcm1+mcm2==1, qml.RX)(0.5, 3) return qml.expval(qml.Z(0) @ qml.Z(3))
Without the optimization, we find three gates controlled on the three measured qubits. They correspond to the combinations of controls that satisfy the condition
mcm0+mcm1+mcm2==1
.>>> print(qml.draw(qml.defer_measurements(node, reduce_postselected=False))(0.6)) 0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────────────────────────────────┤ ╭<Z@Z> 1: ──RX(0.60)─────────│──╭●─╭X───────────────────────────────────────┤ │ 2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|─╭○────────╭○────────╭●────────┤ │ 3: ───────────────────│──│──│──────────├RX(0.50)─├RX(0.50)─├RX(0.50)─┤ ╰<Z@Z> 4: ───────────────────╰X─│──│──────────├○────────├●────────├○────────┤ 5: ──────────────────────╰X─╰●─────────╰●────────╰○────────╰○────────┤
If we do not explicitly deactivate the optimization, we obtain a much simpler circuit:
>>> print(qml.draw(qml.defer_measurements(node))(0.6)) 0: ──RX(0.60)──|0⟩⟨0|─╭●─────────────────┤ ╭<Z@Z> 1: ──RX(0.60)─────────│──╭●─╭X───────────┤ │ 2: ──RX(0.60)─────────│──│──│───|1⟩⟨1|───┤ │ 3: ───────────────────│──│──│──╭RX(0.50)─┤ ╰<Z@Z> 4: ───────────────────╰X─│──│──│─────────┤ 5: ──────────────────────╰X─╰●─╰○────────┤
There is only one controlled gate with only one control wire.
Deferred measurements with program capture
qml.defer_measurements
can be applied to callables when program capture is enabled. To do so, thenum_wires
argument must be provided, which should be an integer corresponding to the total number of available wires. Form
mid-circuit measurements,range(num_wires - m, num_wires)
will be the range of wires used to map mid-circuit measurements toCNOT
gates.Warning
While the transform includes validation to avoid overlap between wires of the original circuit and mid-circuit measurement target wires, if any wires of the original ciruit are traced, i.e. dependent on dynamic arguments to the transformed workflow, the validation may not catch overlaps. Consider the following example:
from functools import partial import jax qml.capture.enable() @qml.capture.expand_plxpr_transforms @partial(qml.defer_measurements, num_wires=1) def f(n): qml.measure(n)
>>> jax.make_jaxpr(f)(0) { lambda ; a:i64[]. let _:AbstractOperator() = CNOT[n_wires=2] a 0 in () }
The circuit gets transformed without issue because the concrete value of the measured wire is unknown. However, execution with n = 0 would raise an error, as the CNOT wires would be (0, 0).
Thus, users must by cautious when transforming a circuit. For ``n`` total wires and ``c`` circuit wires, the number of mid-circuit measurements allowed is ``n - c``.
Using
defer_measurements
with program capture enabled introduces new features and restrictions:New features
Arbitrary classical processing of mid-circuit measurement values is now possible. With program capture disabled, only limited classical processing, as detailed in the documentation for
measure()
. With program capture enabled, any unary or binaryjax.numpy
functions that can be applied to scalars can be used with mid-circuit measurements.Using mid-circuit measurements as gate parameters is now possible. This feature currently has the following restrictions: * Mid-circuit measurement values cannot be used for multiple parameters of the same gate. * Mid-circuit measurement values cannot be used as wires.
from functools import partial import jax import jax.numpy as jnp qml.capture.enable() @qml.capture.expand_plxpr_transforms @partial(qml.defer_measurements, num_wires=10) def f(): m0 = qml.measure(0) phi = jnp.sin(jnp.pi * m0) qml.RX(phi, 0) return qml.expval(qml.PauliZ(0))
>>> jax.make_jaxpr(f)() { lambda ; . let _:AbstractOperator() = CNOT[n_wires=2] 0 9 a:f64[] = mul 0.0 3.141592653589793 b:f64[] = sin a c:AbstractOperator() = RX[n_wires=1] b 0 _:AbstractOperator() = Controlled[ control_values=(False,) work_wires=Wires([]) ] c 9 d:f64[] = mul 1.0 3.141592653589793 e:f64[] = sin d f:AbstractOperator() = RX[n_wires=1] e 0 _:AbstractOperator() = Controlled[ control_values=(True,) work_wires=Wires([]) ] f 9 g:AbstractOperator() = PauliZ[n_wires=1] 0 h:AbstractMeasurement(n_wires=None) = expval_obs g in (h,) }
The above dummy example showcases how the transform is applied when the aforementioned features are used.
What doesn’t work
mid-circuit measurement values cannot be used in the condition for a
while_loop()
.measure()
cannot be used inside the body of loop primitives (while_loop()
,for_loop()
).If a branch of
cond()
uses mid-circuit measurements as its predicate, then all other branches must also use mid-circuit measurement values as predicates.For an
n
-parameter gate, mid-circuit measurement values can only be used for 1 of then
parameters.measure()
can only be used in the bodies of branches ofcond()
if none of the branches use mid-circuit measurements as predicatesmeasure()
cannot be used inside the body of functions being transformed withadjoint()
orctrl()
.