Source code for pennylane.devices.qubit_mixed.sampling

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Submodule for sampling a qubit mixed state.
"""
# pylint: disable=too-many-positional-arguments
from typing import Callable, Union

import numpy as np

import pennylane as qml
from pennylane import math
from pennylane.devices.qubit.sampling import _group_measurements, jax_random_split, sample_probs
from pennylane.measurements import CountsMP, ExpectationMP, SampleMeasurement, Shots
from pennylane.measurements.classical_shadow import ClassicalShadowMP, ShadowExpvalMP
from pennylane.ops import LinearCombination, Sum
from pennylane.typing import TensorLike

from .apply_operation import _get_num_wires, apply_operation
from .measure import measure


def _apply_diagonalizing_gates(
    mps: list[SampleMeasurement], state: np.ndarray, is_state_batched: bool = False
):
    """
    !Note: `mps` is supposed only have qubit-wise commuting measurements
    """
    if len(mps) == 1:
        diagonalizing_gates = mps[0].diagonalizing_gates()
    elif all(mp.obs for mp in mps):
        diagonalizing_gates = qml.pauli.diagonalize_qwc_pauli_words([mp.obs for mp in mps])[0]
    else:
        diagonalizing_gates = []

    for op in diagonalizing_gates:
        state = apply_operation(op, state, is_state_batched=is_state_batched)

    return state


# pylint:disable = too-many-arguments
def _measure_with_samples_diagonalizing_gates(
    mps: list[SampleMeasurement],
    state: np.ndarray,
    shots: Shots,
    is_state_batched: bool = False,
    rng=None,
    prng_key=None,
    readout_errors: list[Callable] = None,
) -> TensorLike:
    """Returns the samples of the measurement process performed on the given state,
    by rotating the state into the measurement basis using the diagonalizing gates
    given by the measurement process.

    Args:
        mp (~.measurements.SampleMeasurement): The sample measurement to perform
        state (TensorLike): The density matrix to sample from
        shots (~.measurements.Shots): The number of samples to take
        is_state_batched (bool): whether the state is batched or not
        rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
            seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
            If no value is provided, a default RNG will be used.
        prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
            the key to the JAX pseudo random number generator. Only for simulation using JAX.
        readout_errors (List[Callable]): List of channels to apply to each wire being measured
        to simulate readout errors.

    Returns:
        TensorLike[Any]: Sample measurement results
    """
    # apply diagonalizing gates
    state = _apply_diagonalizing_gates(mps, state, is_state_batched)

    total_indices = _get_num_wires(state, is_state_batched)
    wires = qml.wires.Wires(range(total_indices))

    def _process_single_shot(samples):
        processed = []
        for mp in mps:
            res = mp.process_samples(samples, wires)
            if not isinstance(mp, CountsMP):
                res = math.squeeze(res)

            processed.append(res)

        return tuple(processed)

    prng_key, _ = jax_random_split(prng_key)
    samples = sample_state(
        state,
        shots=shots.total_shots,
        is_state_batched=is_state_batched,
        wires=wires,
        rng=rng,
        prng_key=prng_key,
        readout_errors=readout_errors,
    )
    processed_samples = []
    for lower, upper in shots.bins():
        shot = _process_single_shot(samples[..., lower:upper, :])
        processed_samples.append(shot)

    if shots.has_partitioned_shots:
        return tuple(zip(*processed_samples))

    return processed_samples[0]


def _measure_classical_shadow(
    mp: list[Union[ClassicalShadowMP, ShadowExpvalMP]],
    state: np.ndarray,
    shots: Shots,
    is_state_batched: bool = False,
    rng=None,
    prng_key=None,
    readout_errors=None,
):
    """
    Returns the result of a classical shadow measurement on the given state.

    A classical shadow measurement doesn't fit neatly into the current measurement API
    since different diagonalizing gates are used for each shot. Here it's treated as a
    state measurement with shots instead of a sample measurement.

    Args:
        mp (~.measurements.SampleMeasurement): The sample measurement to perform
        state (np.ndarray[complex]): The state vector to sample from
        shots (~.measurements.Shots): The number of samples to take
        is_state_batched (bool): Whether the state is batched or not
        rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
            seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
            If no value is provided, a default RNG will be used.

    Returns:
        TensorLike[Any]: Sample measurement results
    """
    # pylint: disable=unused-argument

    # the list contains only one element based on how we group measurements
    mp = mp[0]

    wires = qml.wires.Wires(range(_get_num_wires(state, is_state_batched)))

    if shots.has_partitioned_shots:
        return [tuple(process_state_with_shots(mp, state, wires, s, rng=rng) for s in shots)]

    return [process_state_with_shots(mp, state, wires, shots.total_shots, rng=rng)]


def process_state_with_shots(mp, state, wire_order, shots, rng=None):
    """Sample 'shots' classical shadow snapshots from the given density matrix `state`.

    Args:
        mp (ClassicalShadowMP or ShadowExpvalMP): The classical shadow measurement to perform
        state (np.ndarray): A (2^N, 2^N) density matrix for N qubits
        wire_order (qml.wires.Wires): The global wire ordering
        shots (int): Number of classical-shadow snapshots
        rng (None or int or Generator): Random seed for measurement bits

    Returns:
        np.ndarray[int]: shape (2, shots, num_shadow_qubits).
            First row: measurement outcomes (0 or 1).
            Second row: Pauli basis recipe (0=X, 1=Y, 2=Z).
    """
    return mp.process_density_matrix_with_shots(
        state,
        wire_order,
        shots,
        rng=rng,
    )


def _measure_hamiltonian_with_samples(
    mp: list[ExpectationMP],
    state: np.ndarray,
    shots: Shots,
    is_state_batched: bool = False,
    rng=None,
    prng_key=None,
    readout_errors=None,
):
    # the list contains only one element based on how we group measurements
    mp = mp[0]

    # if the measurement process involves a Hamiltonian, measure each
    # of the terms separately and sum
    def _sum_for_single_shot(s, prng_key=None):
        results = measure_with_samples(
            [ExpectationMP(t) for t in mp.obs.terms()[1]],
            state,
            s,
            is_state_batched=is_state_batched,
            rng=rng,
            prng_key=prng_key,
            readout_errors=readout_errors,
        )
        return sum(c * res for c, res in zip(mp.obs.terms()[0], results))

    keys = jax_random_split(prng_key, num=shots.num_copies)
    unsqueezed_results = tuple(
        _sum_for_single_shot(type(shots)(s), key) for s, key in zip(shots, keys)
    )
    return [unsqueezed_results] if shots.has_partitioned_shots else [unsqueezed_results[0]]


def _measure_sum_with_samples(
    mp: list[ExpectationMP],
    state: np.ndarray,
    shots: Shots,
    is_state_batched: bool = False,
    rng=None,
    prng_key=None,
    readout_errors: list[Callable] = None,
):
    """Compute expectation values of Sum Observables"""
    mp = mp[0]

    def _sum_for_single_shot(s, prng_key=None):
        results = measure_with_samples(
            [ExpectationMP(t) for t in mp.obs],
            state,
            s,
            is_state_batched=is_state_batched,
            rng=rng,
            prng_key=prng_key,
            readout_errors=readout_errors,
        )
        return sum(results)

    keys = jax_random_split(prng_key, num=shots.num_copies)
    unsqueezed_results = tuple(
        _sum_for_single_shot(type(shots)(s), key) for s, key in zip(shots, keys)
    )
    return [unsqueezed_results] if shots.has_partitioned_shots else [unsqueezed_results[0]]


[docs] def sample_state( state, shots: int, is_state_batched: bool = False, wires=None, rng=None, prng_key=None, readout_errors: list[Callable] = None, ) -> np.ndarray: """Returns a series of computational basis samples of a state. Args: state (array[complex]): A density matrix to be sampled shots (int): The number of samples to take is_state_batched (bool): whether the state is batched or not wires (Sequence[int]): The wires to sample rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. If no value is provided, a default RNG will be used prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. readout_errors (List[Callable]): List of channels to apply to each wire being measured to simulate readout errors. Returns: ndarray[int]: Sample values of the shape (shots, num_wires) """ total_indices = _get_num_wires(state, is_state_batched) state_wires = qml.wires.Wires(range(total_indices)) wires_to_sample = wires or state_wires num_wires = len(wires_to_sample) with qml.queuing.QueuingManager.stop_recording(): probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors) # After getting the correct probs, there's no difference between mixed states and pure states. # Therefore, we directly re-use the sample_probs from the module qubit. return sample_probs(probs, shots, num_wires, is_state_batched, rng, prng_key=prng_key)
[docs] def measure_with_samples( measurements: list[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], state: np.ndarray, shots: Shots, is_state_batched: bool = False, rng=None, prng_key=None, readout_errors: list[Callable] = None, ) -> TensorLike: """Returns the samples of the measurement process performed on the given state. This function assumes that the user-defined wire labels in the measurement process have already been mapped to integer wires used in the device. Args: mp (SampleMeasurement): The sample measurement to perform state (np.ndarray[complex]): The density matrix to sample from shots (Shots): The number of samples to take is_state_batched (bool): whether the state is batched or not rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. If no value is provided, a default RNG will be used. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. readout_errors (List[Callable]): List of channels to apply to each wire being measured to simulate readout errors. Returns: TensorLike[Any]: Sample measurement results """ groups, indices = _group_measurements(measurements) all_res = [] for group in groups: if isinstance(group[0], ExpectationMP) and isinstance(group[0].obs, LinearCombination): measure_fn = _measure_hamiltonian_with_samples elif isinstance(group[0], ExpectationMP) and isinstance(group[0].obs, Sum): measure_fn = _measure_sum_with_samples elif isinstance(group[0], (ClassicalShadowMP, ShadowExpvalMP)): measure_fn = _measure_classical_shadow else: # measure with the usual method (rotate into the measurement basis) measure_fn = _measure_with_samples_diagonalizing_gates prng_key, key = jax_random_split(prng_key) all_res.extend( measure_fn( group, state, shots, is_state_batched=is_state_batched, rng=rng, prng_key=key, readout_errors=readout_errors, ) ) flat_indices = [_i for i in indices for _i in i] # reorder results sorted_res = tuple( res for _, res in sorted(list(enumerate(all_res)), key=lambda r: flat_indices[r[0]]) ) # put the shot vector axis before the measurement axis if shots.has_partitioned_shots: sorted_res = tuple(zip(*sorted_res)) return sorted_res