Source code for pennylane.measurements.mid_measure

# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the qml.measure measurement.
"""
import uuid
from typing import Generic, TypeVar, Optional
import warnings

import pennylane as qml
import pennylane.numpy as np
from pennylane.wires import Wires

from .measurements import MeasurementProcess, MidMeasure


[docs]def measure(wires: Wires, reset: Optional[bool] = False): r"""Perform a mid-circuit measurement in the computational basis on the supplied qubit. Measurement outcomes can be obtained and used to conditionally apply operations. If a device doesn't support mid-circuit measurements natively, then the QNode will apply the :func:`defer_measurements` transform. **Example:** .. code-block:: python3 dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(x, y): qml.RY(x, wires=0) qml.CNOT(wires=[0, 1]) m_0 = qml.measure(1) qml.cond(m_0, qml.RY)(y, wires=0) return qml.probs(wires=[0]) Executing this QNode: >>> pars = np.array([0.643, 0.246], requires_grad=True) >>> func(*pars) tensor([0.90165331, 0.09834669], requires_grad=True) Wires can be reused after measurement. Moreover, measured wires can be reset to the :math:`|0 \rangle` by setting ``reset=True``. .. code-block:: python3 dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(): qml.PauliX(1) m_0 = qml.measure(1, reset=True) return qml.probs(wires=[1]) Executing this QNode: >>> func() tensor([1., 0.], requires_grad=True) Mid circuit measurements can be manipulated using the following arithmetic operators: ``+``, ``-``, ``*``, ``/``, ``~`` (not), ``&`` (and), ``|`` (or), ``==``, ``<=``, ``>=``, ``<``, ``>`` with other mid-circuit measurements or scalars. .. Note :: Python ``not``, ``and``, ``or``, do not work since these do not have dunder methods. Instead use ``~``, ``&``, ``|``. Args: wires (Wires): The wire of the qubit the measurement process applies to. reset (Optional[bool]): Whether to reset the wire to the :math:`|0 \rangle` state after measurement. Returns: MidMeasureMP: measurement process instance Raises: QuantumFunctionError: if multiple wires were specified """ wire = Wires(wires) if len(wire) > 1: raise qml.QuantumFunctionError( "Only a single qubit can be measured in the middle of the circuit" ) # Create a UUID and a map between MP and MV to support serialization measurement_id = str(uuid.uuid4())[:8] mp = MidMeasureMP(wires=wire, reset=reset, id=measurement_id) return MeasurementValue([mp], processing_fn=lambda v: v)
T = TypeVar("T")
[docs]class MidMeasureMP(MeasurementProcess): """Mid-circuit measurement. This class additionally stores information about unknown measurement outcomes in the qubit model. Measurements on a single qubit in the computational basis are assumed. Please refer to :func:`measure` for detailed documentation. Args: wires (.Wires): The wires the measurement process applies to. This can only be specified if an observable was not provided. reset (bool): Whether to reset the wire after measurement. id (str): Custom label given to a measurement instance. """ def __init__( self, wires: Optional[Wires] = None, reset: Optional[bool] = False, id: Optional[str] = None ): super().__init__(wires=Wires(wires), id=id) self.reset = reset @property def return_type(self): return MidMeasure @property def samples_computational_basis(self): return False @property def _queue_category(self): return "_ops" @property def hash(self): """int: Returns an integer hash uniquely representing the measurement process""" fingerprint = ( self.__class__.__name__, tuple(self.wires.tolist()), self.id, ) return hash(fingerprint)
[docs]class MeasurementValue(Generic[T]): """A class representing unknown measurement outcomes in the qubit model. Measurements on a single qubit in the computational basis are assumed. Args: measurements (list[.MidMeasureMP]): The measurement(s) that this object depends on. processing_fn (callable): A lazily transformation applied to the measurement values. """ def __init__(self, measurements, processing_fn): self.measurements = measurements self.processing_fn = processing_fn def _items(self): """A generator representing all the possible outcomes of the MeasurementValue.""" for i in range(2 ** len(self.measurements)): branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements))) yield branch, self.processing_fn(*branch) @property def branches(self): """A dictionary representing all possible outcomes of the MeasurementValue.""" ret_dict = {} for i in range(2 ** len(self.measurements)): branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements))) ret_dict[branch] = self.processing_fn(*branch) return ret_dict def _transform_bin_op(self, base_bin, other): """Helper function for defining dunder binary operations.""" if isinstance(other, MeasurementValue): # pylint: disable=protected-access return self._merge(other)._apply(lambda t: base_bin(t[0], t[1])) # if `other` is not a MeasurementValue then apply it to each branch return self._apply(lambda v: base_bin(v, other)) def __invert__(self): """Return a copy of the measurement value with an inverted control value.""" return self._apply(lambda v: not v) def __eq__(self, other): return self._transform_bin_op(lambda a, b: a == b, other) def __ne__(self, other): return self._transform_bin_op(lambda a, b: a != b, other) def __add__(self, other): return self._transform_bin_op(lambda a, b: a + b, other) def __radd__(self, other): return self._apply(lambda v: other + v) def __sub__(self, other): return self._transform_bin_op(lambda a, b: a - b, other) def __rsub__(self, other): return self._apply(lambda v: other - v) def __mul__(self, other): return self._transform_bin_op(lambda a, b: a * b, other) def __rmul__(self, other): return self._apply(lambda v: other * v) def __truediv__(self, other): return self._transform_bin_op(lambda a, b: a / b, other) def __rtruediv__(self, other): return self._apply(lambda v: other / v) def __lt__(self, other): return self._transform_bin_op(lambda a, b: a < b, other) def __le__(self, other): return self._transform_bin_op(lambda a, b: a <= b, other) def __gt__(self, other): return self._transform_bin_op(lambda a, b: a > b, other) def __ge__(self, other): return self._transform_bin_op(lambda a, b: a >= b, other) def __and__(self, other): return self._transform_bin_op(lambda a, b: a and b, other) def __or__(self, other): return self._transform_bin_op(lambda a, b: a or b, other) def _apply(self, fn): """Apply a post computation to this measurement""" return MeasurementValue(self.measurements, lambda *x: fn(self.processing_fn(*x))) def _merge(self, other: "MeasurementValue"): """Merge two measurement values""" with warnings.catch_warnings(): # Using a filter because the new behaviour of MP hash will be valid here warnings.filterwarnings( "ignore", message="The behaviour of measurement process hashing", category=UserWarning, ) # create a new merged list with no duplicates and in lexical ordering merged_measurements = list(set(self.measurements).union(set(other.measurements))) merged_measurements.sort(key=lambda m: m.id) # create a new function that selects the correct indices for each sub function def merged_fn(*x): with warnings.catch_warnings(): # Using a filter because the new behaviour of MP equality will be valid here warnings.filterwarnings( "ignore", message="The behaviour of measurement process equality", category=UserWarning, ) sub_args_1 = ( x[i] for i in [merged_measurements.index(m) for m in self.measurements] ) sub_args_2 = ( x[i] for i in [merged_measurements.index(m) for m in other.measurements] ) out_1 = self.processing_fn(*sub_args_1) out_2 = other.processing_fn(*sub_args_2) return out_1, out_2 return MeasurementValue(merged_measurements, merged_fn) def __getitem__(self, i): branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements))) return self.processing_fn(*branch) def __str__(self): lines = [] for i in range(2 ** (len(self.measurements))): branch = tuple(int(b) for b in np.binary_repr(i, width=len(self.measurements))) id_branch_mapping = [ f"{self.measurements[j].id}={branch[j]}" for j in range(len(branch)) ] lines.append( "if " + ",".join(id_branch_mapping) + " => " + str(self.processing_fn(*branch)) ) return "\n".join(lines)