Source code for pennylane.measurements.mid_measure
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the qml.measure measurement.
"""
import uuid
from typing import Generic, TypeVar, Optional
import numpy as np
import pennylane as qml
from pennylane.wires import Wires
from .measurements import MeasurementProcess, MidMeasure
[docs]def measure(wires: Wires, reset: Optional[bool] = False, postselect: Optional[int] = None):
r"""Perform a mid-circuit measurement in the computational basis on the
supplied qubit.
Measurement outcomes can be obtained and used to conditionally apply
operations.
If a device doesn't support mid-circuit measurements natively, then the
QNode will apply the :func:`defer_measurements` transform.
**Example:**
.. code-block:: python3
dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev)
def func(x, y):
qml.RY(x, wires=0)
qml.CNOT(wires=[0, 1])
m_0 = qml.measure(1)
qml.cond(m_0, qml.RY)(y, wires=0)
return qml.probs(wires=[0])
Executing this QNode:
>>> pars = np.array([0.643, 0.246], requires_grad=True)
>>> func(*pars)
tensor([0.90165331, 0.09834669], requires_grad=True)
Wires can be reused after measurement. Moreover, measured wires can be reset
to the :math:`|0 \rangle` by setting ``reset=True``.
.. code-block:: python3
dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev)
def func():
qml.X(1)
m_0 = qml.measure(1, reset=True)
return qml.probs(wires=[1])
Executing this QNode:
>>> func()
tensor([1., 0.], requires_grad=True)
Mid circuit measurements can be manipulated using the following arithmetic operators:
``+``, ``-``, ``*``, ``/``, ``~`` (not), ``&`` (and), ``|`` (or), ``==``, ``<=``,
``>=``, ``<``, ``>`` with other mid-circuit measurements or scalars.
.. Note ::
Python ``not``, ``and``, ``or``, do not work since these do not have dunder methods.
Instead use ``~``, ``&``, ``|``.
Args:
wires (Wires): The wire of the qubit the measurement process applies to.
reset (Optional[bool]): Whether to reset the wire to the :math:`|0 \rangle`
state after measurement.
postselect (Optional[int]): Which basis state to postselect after a mid-circuit
measurement. None by default. If postselection is requested, only the post-measurement
state that is used for postselection will be considered in the remaining circuit.
Returns:
MidMeasureMP: measurement process instance
Raises:
QuantumFunctionError: if multiple wires were specified
.. details::
:title: Postselection
Postselection discards outcomes that do not meet the criteria provided by the ``postselect``
argument. For example, specifying ``postselect=1`` on wire 0 would be equivalent to projecting
the state vector onto the :math:`|1\rangle` state on wire 0:
.. code-block:: python3
dev = qml.device("default.qubit")
@qml.qnode(dev)
def func(x):
qml.RX(x, wires=0)
m0 = qml.measure(0, postselect=1)
qml.cond(m0, qml.X)(wires=1)
return qml.sample(wires=1)
By postselecting on ``1``, we only consider the ``1`` measurement outcome on wire 0. So, the probability of
measuring ``1`` on wire 1 after postselection should also be 1. Executing this QNode with 10 shots:
>>> func(np.pi / 2, shots=10)
array([1, 1, 1, 1, 1, 1, 1])
Note that only 7 samples are returned. This is because samples that do not meet the postselection criteria are
thrown away.
If postselection is requested on a state with zero probability of being measured, the result may contain ``NaN``
or ``Inf`` values:
.. code-block:: python3
dev = qml.device("default.qubit")
@qml.qnode(dev)
def func(x):
qml.RX(x, wires=0)
m0 = qml.measure(0, postselect=1)
qml.cond(m0, qml.X)(wires=1)
return qml.probs(wires=1)
>>> func(0.0)
tensor([nan, nan], requires_grad=True)
In the case of ``qml.sample``, an empty array will be returned:
.. code-block:: python3
dev = qml.device("default.qubit")
@qml.qnode(dev)
def func(x):
qml.RX(x, wires=0)
m0 = qml.measure(0, postselect=1)
qml.cond(m0, qml.X)(wires=1)
return qml.sample(wires=[0, 1])
>>> func(0.0, shots=[10, 10])
(array([], shape=(0, 2), dtype=int64), array([], shape=(0, 2), dtype=int64))
.. note::
Currently, postselection support is only available on ``default.qubit``. Using postselection
on other devices will raise an error.
.. warning::
All measurements are supported when using postselection. However, postselection on a zero probability
state can cause some measurements to break:
* With finite shots, one must be careful when measuring ``qml.probs`` or ``qml.counts``, as these
measurements will raise errors if there are no valid samples after postselection. This will occur
with postselection states that have zero or close to zero probability.
* With analytic execution, ``qml.mutual_info`` will raise errors when using any interfaces except
``jax``, and ``qml.vn_entropy`` will raise an error with the ``tensorflow`` interface when the
postselection state has zero probability.
* When using JIT, ``QNode``'s may have unexpected behaviour when postselection on a zero
probability state is performed. Due to floating point precision, the zero probability may not be
detected, thus letting execution continue as normal without ``NaN`` or ``Inf`` values or empty
samples, leading to unexpected or incorrect results.
"""
wire = Wires(wires)
if len(wire) > 1:
raise qml.QuantumFunctionError(
"Only a single qubit can be measured in the middle of the circuit"
)
# Create a UUID and a map between MP and MV to support serialization
measurement_id = str(uuid.uuid4())[:8]
mp = MidMeasureMP(wires=wire, reset=reset, postselect=postselect, id=measurement_id)
return MeasurementValue([mp], processing_fn=lambda v: v)
T = TypeVar("T")
[docs]class MidMeasureMP(MeasurementProcess):
"""Mid-circuit measurement.
This class additionally stores information about unknown measurement outcomes in the qubit model.
Measurements on a single qubit in the computational basis are assumed.
Please refer to :func:`measure` for detailed documentation.
Args:
wires (.Wires): The wires the measurement process applies to.
This can only be specified if an observable was not provided.
reset (bool): Whether to reset the wire after measurement.
postselect (Optional[int]): Which basis state to postselect after a mid-circuit
measurement. None by default. If postselection is requested, only the post-measurement
state that is used for postselection will be considered in the remaining circuit.
id (str): Custom label given to a measurement instance.
"""
def _flatten(self):
metadata = (("wires", self.raw_wires), ("reset", self.reset), ("id", self.id))
return (None, None), metadata
def __init__(
self,
wires: Optional[Wires] = None,
reset: Optional[bool] = False,
postselect: Optional[int] = None,
id: Optional[str] = None,
):
self.batch_size = None
super().__init__(wires=Wires(wires), id=id)
self.reset = reset
self.postselect = postselect
[docs] def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=unused-argument
r"""How the mid-circuit measurement is represented in diagrams and drawings.
Args:
decimals=None (Int): If ``None``, no parameters are included. Else,
how to round the parameters.
base_label=None (Iterable[str]): overwrite the non-parameter component of the label.
Must be same length as ``obs`` attribute.
cache=None (dict): dictionary that carries information between label calls
in the same drawing
Returns:
str: label to use in drawings
"""
_label = "┤↗"
if self.postselect is not None:
_label += "₁" if self.postselect == 1 else "₀"
_label += "├" if not self.reset else "│ │0⟩"
return _label
@property
def return_type(self):
return MidMeasure
@property
def samples_computational_basis(self):
return False
@property
def _queue_category(self):
return "_ops"
@property
def hash(self):
"""int: Returns an integer hash uniquely representing the measurement process"""
fingerprint = (
self.__class__.__name__,
tuple(self.wires.tolist()),
self.id,
)
return hash(fingerprint)
@property
def data(self):
"""The data of the measurement. Needed to match the Operator API."""
return []
@property
def name(self):
"""The name of the measurement. Needed to match the Operator API."""
return "MidMeasureMP"
[docs]class MeasurementValue(Generic[T]):
"""A class representing unknown measurement outcomes in the qubit model.
Measurements on a single qubit in the computational basis are assumed.
Args:
measurements (list[.MidMeasureMP]): The measurement(s) that this object depends on.
processing_fn (callable): A lazily transformation applied to the measurement values.
"""
name = "MeasurementValue"
def __init__(self, measurements, processing_fn):
self.measurements = measurements
self.processing_fn = processing_fn
def _items(self):
"""A generator representing all the possible outcomes of the MeasurementValue."""
for i in range(2 ** len(self.measurements)):
branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements)))
yield branch, self.processing_fn(*branch)
@property
def wires(self):
"""Returns a list of wires corresponding to the mid-circuit measurements."""
return Wires.all_wires([m.wires for m in self.measurements])
@property
def branches(self):
"""A dictionary representing all possible outcomes of the MeasurementValue."""
ret_dict = {}
for i in range(2 ** len(self.measurements)):
branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements)))
ret_dict[branch] = self.processing_fn(*branch)
return ret_dict
[docs] def map_wires(self, wire_map):
"""Returns a copy of the current ``MeasurementValue`` with the wires of each measurement changed
according to the given wire map.
Args:
wire_map (dict): dictionary containing the old wires as keys and the new wires as values
Returns:
MeasurementValue: new ``MeasurementValue`` instance with measurement wires mapped
"""
mapped_measurements = [m.map_wires(wire_map) for m in self.measurements]
return MeasurementValue(mapped_measurements, self.processing_fn)
def _transform_bin_op(self, base_bin, other):
"""Helper function for defining dunder binary operations."""
if isinstance(other, MeasurementValue):
# pylint: disable=protected-access
return self._merge(other)._apply(lambda t: base_bin(t[0], t[1]))
# if `other` is not a MeasurementValue then apply it to each branch
return self._apply(lambda v: base_bin(v, other))
def __invert__(self):
"""Return a copy of the measurement value with an inverted control
value."""
return self._apply(lambda v: not v)
def __eq__(self, other):
return self._transform_bin_op(lambda a, b: a == b, other)
def __ne__(self, other):
return self._transform_bin_op(lambda a, b: a != b, other)
def __add__(self, other):
return self._transform_bin_op(lambda a, b: a + b, other)
def __radd__(self, other):
return self._apply(lambda v: other + v)
def __sub__(self, other):
return self._transform_bin_op(lambda a, b: a - b, other)
def __rsub__(self, other):
return self._apply(lambda v: other - v)
def __mul__(self, other):
return self._transform_bin_op(lambda a, b: a * b, other)
def __rmul__(self, other):
return self._apply(lambda v: other * v)
def __truediv__(self, other):
return self._transform_bin_op(lambda a, b: a / b, other)
def __rtruediv__(self, other):
return self._apply(lambda v: other / v)
def __lt__(self, other):
return self._transform_bin_op(lambda a, b: a < b, other)
def __le__(self, other):
return self._transform_bin_op(lambda a, b: a <= b, other)
def __gt__(self, other):
return self._transform_bin_op(lambda a, b: a > b, other)
def __ge__(self, other):
return self._transform_bin_op(lambda a, b: a >= b, other)
def __and__(self, other):
return self._transform_bin_op(lambda a, b: a and b, other)
def __or__(self, other):
return self._transform_bin_op(lambda a, b: a or b, other)
def _apply(self, fn):
"""Apply a post computation to this measurement"""
return MeasurementValue(self.measurements, lambda *x: fn(self.processing_fn(*x)))
[docs] def concretize(self, measurements: dict):
"""Returns a concrete value from a dictionary of hashes with concrete values."""
values = tuple(measurements[meas] for meas in self.measurements)
return self.processing_fn(*values)
def _merge(self, other: "MeasurementValue"):
"""Merge two measurement values"""
# create a new merged list with no duplicates and in lexical ordering
merged_measurements = list(set(self.measurements).union(set(other.measurements)))
merged_measurements.sort(key=lambda m: m.id)
# create a new function that selects the correct indices for each sub function
def merged_fn(*x):
sub_args_1 = (x[i] for i in [merged_measurements.index(m) for m in self.measurements])
sub_args_2 = (x[i] for i in [merged_measurements.index(m) for m in other.measurements])
out_1 = self.processing_fn(*sub_args_1)
out_2 = other.processing_fn(*sub_args_2)
return out_1, out_2
return MeasurementValue(merged_measurements, merged_fn)
def __getitem__(self, i):
branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements)))
return self.processing_fn(*branch)
def __str__(self):
lines = []
for i in range(2 ** (len(self.measurements))):
branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements)))
id_branch_mapping = [
f"{self.measurements[j].id}={branch[j]}" for j in range(len(branch))
]
lines.append(
"if " + ",".join(id_branch_mapping) + " => " + str(self.processing_fn(*branch))
)
return "\n".join(lines)
def __repr__(self):
return f"MeasurementValue(wires={self.wires.tolist()})"
_modules/pennylane/measurements/mid_measure
Download Python script
Download Notebook
View on GitHub