Source code for pennylane.devices.qutrit_mixed.measure

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code relevant for performing measurements on a qutrit mixed state.
"""

from collections.abc import Callable
from string import ascii_letters as alphabet

from pennylane import math, queuing
from pennylane.measurements import (
    ExpectationMP,
    MeasurementProcess,
    ProbabilityMP,
    StateMeasurement,
    StateMP,
    VarianceMP,
)
from pennylane.ops import Hamiltonian, Sum
from pennylane.typing import TensorLike
from pennylane.wires import Wires

from .apply_operation import apply_operation
from .utils import QUDIT_DIM, get_num_wires, reshape_state_as_matrix


def calculate_expval(
    measurementprocess: ExpectationMP,
    state: TensorLike,
    is_state_batched: bool = False,
    readout_errors: list[Callable] = None,
) -> TensorLike:
    """Measure the expectation value of an observable.

    Args:
        measurementprocess (ExpectationMP): measurement process to apply to the state.
        state (TensorLike): the state to measure.
        is_state_batched (bool): whether the state is batched or not.
        readout_errors (List[Callable]): List of chanels to apply to each wire being measured
        to simulate readout errors.

    Returns:
        TensorLike: expectation value of observable wrt the state.
    """
    probs = calculate_probability(measurementprocess, state, is_state_batched, readout_errors)
    eigvals = math.asarray(measurementprocess.eigvals(), dtype="float64")
    # In case of broadcasting, `probs` has two axes and these are a matrix-vector products
    return math.dot(probs, eigvals)


def calculate_reduced_density_matrix(
    measurementprocess: StateMeasurement,
    state: TensorLike,
    is_state_batched: bool = False,
    _readout_errors: list[Callable] = None,
) -> TensorLike:
    """Get the state or reduced density matrix.

    Args:
        measurementprocess (StateMeasurement): measurement to apply to the state.
        state (TensorLike): state to apply the measurement to.
        is_state_batched (bool): whether the state is batched or not.
        _readout_errors (List[Callable]): List of channels to apply to each wire being measured
        to simulate readout errors. These are not applied on this type of measurement.

    Returns:
        TensorLike: state or reduced density matrix.
    """
    wires = measurementprocess.wires
    if not wires:
        return reshape_state_as_matrix(state, get_num_wires(state, is_state_batched))

    num_obs_wires = len(wires)
    num_state_wires = get_num_wires(state, is_state_batched)
    state_wire_indices_list = list(alphabet[:num_state_wires] * 2)
    final_state_wire_indices_list = [""] * (2 * num_obs_wires)

    for i, wire in enumerate(wires):
        col_index = wire + num_state_wires
        state_wire_indices_list[col_index] = alphabet[col_index]
        final_state_wire_indices_list[i] = alphabet[wire]
        final_state_wire_indices_list[i + num_obs_wires] = alphabet[col_index]

    state_wire_indices = "".join(state_wire_indices_list)
    final_state_wire_indices = "".join(final_state_wire_indices_list)

    state = math.einsum(f"...{state_wire_indices}->...{final_state_wire_indices}", state)

    return reshape_state_as_matrix(state, len(wires))


def calculate_probability(
    measurementprocess: StateMeasurement,
    state: TensorLike,
    is_state_batched: bool = False,
    readout_errors: list[Callable] = None,
) -> TensorLike:
    """Find the probability of measuring states.

    Args:
        measurementprocess (StateMeasurement): measurement to apply to the state.
        state (TensorLike): state to apply the measurement to.
        is_state_batched (bool): whether the state is batched or not.
        readout_errors (List[Callable]): List of channels to apply to each wire being measured
        to simulate readout errors.

    Returns:
        TensorLike: the probability of the state being in each measurable state.
    """
    for op in measurementprocess.diagonalizing_gates():
        state = apply_operation(op, state, is_state_batched=is_state_batched)

    wires = measurementprocess.wires
    num_state_wires = get_num_wires(state, is_state_batched)
    wire_order = Wires(range(num_state_wires))

    if readout_errors is not None:
        with queuing.QueuingManager.stop_recording():
            for wire in wires:
                for m_error in readout_errors:
                    state = apply_operation(m_error(wire), state, is_state_batched=is_state_batched)

    # probs are diagonal elements
    # stacking list since diagonal function axis selection parameter names
    # are not consistent across interfaces
    reshaped_state = reshape_state_as_matrix(state, num_state_wires)
    if is_state_batched:
        probs = math.real(math.stack([math.diagonal(dm) for dm in reshaped_state]))
    else:
        probs = math.real(math.diagonal(reshaped_state))

    # if a probability is very small it may round to negative, undesirable.
    # math.clip with None bounds breaks with tensorflow, using this instead:
    probs = math.where(probs < 0, 0, probs)
    if wires == Wires([]):
        # no need to marginalize
        return probs

    # determine which subsystems are to be summed over
    inactive_wires = Wires.unique_wires([wire_order, wires])

    # translate to wire labels used by device
    wire_map = dict(zip(wire_order, range(len(wire_order))))
    mapped_wires = [wire_map[w] for w in wires]
    inactive_wires = [wire_map[w] for w in inactive_wires]

    # reshape the probability so that each axis corresponds to a wire
    num_device_wires = len(wire_order)
    shape = [QUDIT_DIM] * num_device_wires
    desired_axes = math.argsort(math.argsort(mapped_wires))
    flat_shape = (-1,)
    expected_size = QUDIT_DIM**num_device_wires
    batch_size = math.get_batch_size(probs, (expected_size,), expected_size)
    if batch_size is not None:
        # prob now is reshaped to have self.num_wires+1 axes in the case of broadcasting
        shape.insert(0, batch_size)
        inactive_wires = [idx + 1 for idx in inactive_wires]
        desired_axes = math.insert(desired_axes + 1, 0, 0)
        flat_shape = (batch_size, -1)

    prob = math.reshape(probs, shape)
    # sum over all inactive wires
    prob = math.sum(prob, axis=tuple(inactive_wires))
    # rearrange wires if necessary
    prob = math.transpose(prob, desired_axes)
    # flatten and return probabilities
    return math.reshape(prob, flat_shape)


def calculate_variance(
    measurementprocess: StateMeasurement,
    state: TensorLike,
    is_state_batched: bool = False,
    readout_errors: list[Callable] = None,
) -> TensorLike:
    """Find variance of observable.

    Args:
        measurementprocess (StateMeasurement): measurement to apply to the state.
        state (TensorLike): state to apply the measurement to.
        is_state_batched (bool): whether the state is batched or not.
        readout_errors (List[Callable]): List of operators to apply to each wire being measured
        to simulate readout errors.

    Returns:
        TensorLike: the variance of the observable wrt the state.
    """
    probs = calculate_probability(measurementprocess, state, is_state_batched, readout_errors)
    eigvals = math.asarray(measurementprocess.eigvals(), dtype="float64")
    # In case of broadcasting, `probs` has two axes and these are a matrix-vector products
    return math.dot(probs, (eigvals**2)) - math.dot(probs, eigvals) ** 2


def calculate_expval_sum_of_terms(
    measurementprocess: ExpectationMP,
    state: TensorLike,
    is_state_batched: bool = False,
    readout_errors: list[Callable] = None,
) -> TensorLike:
    """Measure the expectation value of the state when the measured observable is a ``Hamiltonian`` or ``Sum``
    and it must be backpropagation compatible.

    Args:
        measurementprocess (ExpectationMP): measurement process to apply to the state.
        state (TensorLike): the state to measure.
        is_state_batched (bool): whether the state is batched or not.
        readout_errors (List[Callable]): List of channels to apply to each wire being measured
        to simulate readout errors.

    Returns:
        TensorLike: the expectation value of the sum of Hamiltonian observable wrt the state.
    """
    if isinstance(measurementprocess.obs, Sum):
        # Recursively call measure on each term, so that the best measurement method can
        # be used for each term
        return sum(
            measure(
                ExpectationMP(term),
                state,
                is_state_batched=is_state_batched,
                readout_errors=readout_errors,
            )
            for term in measurementprocess.obs
        )
    # else hamiltonian
    return sum(
        c
        * measure(
            ExpectationMP(t),
            state,
            is_state_batched=is_state_batched,
            readout_errors=readout_errors,
        )
        for c, t in zip(*measurementprocess.obs.terms())
    )


# pylint: disable=too-many-return-statements
def get_measurement_function(
    measurementprocess: MeasurementProcess,
) -> Callable[[MeasurementProcess, TensorLike, bool, list[Callable]], TensorLike]:
    """Get the appropriate method for performing a measurement.

    Args:
        measurementprocess (MeasurementProcess): measurement process to apply to the state.
        state (TensorLike): the state to measure.
        is_state_batched (bool): whether the state is batched or not.

    Returns:
        Callable: function that returns the measurement result.
    """
    if isinstance(measurementprocess, StateMeasurement):
        if isinstance(measurementprocess, ExpectationMP):
            if isinstance(measurementprocess.obs, (Hamiltonian, Sum)):
                return calculate_expval_sum_of_terms
            if measurementprocess.obs.has_matrix:
                return calculate_expval
        if measurementprocess.obs is None or measurementprocess.obs.has_diagonalizing_gates:
            if isinstance(measurementprocess, StateMP):
                return calculate_reduced_density_matrix
            if isinstance(measurementprocess, ProbabilityMP):
                return calculate_probability
            if isinstance(measurementprocess, VarianceMP):
                return calculate_variance

    raise NotImplementedError


[docs]def measure( measurementprocess: MeasurementProcess, state: TensorLike, is_state_batched: bool = False, readout_errors: list[Callable] = None, ) -> TensorLike: """Apply a measurement process to a state. Args: measurementprocess (MeasurementProcess): measurement process to apply to the state. state (TensorLike): the state to measure. is_state_batched (bool): whether the state is batched or not. readout_errors (List[Callable]): List of channels to apply to each wire being measured to simulate readout errors. Returns: Tensorlike: the result of the measurement process being applied to the state. """ measurement_function = get_measurement_function(measurementprocess) return measurement_function(measurementprocess, state, is_state_batched, readout_errors)