Source code for pennylane.devices.qutrit_mixed.sampling

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code relevant for sampling a qutrit mixed state.
"""
import functools
from typing import Callable

import numpy as np

import pennylane as qml
from pennylane import math
from pennylane.measurements import (
    CountsMP,
    ExpectationMP,
    SampleMeasurement,
    SampleMP,
    Shots,
    VarianceMP,
)
from pennylane.ops import Sum
from pennylane.typing import TensorLike

from .apply_operation import apply_operation
from .measure import measure
from .utils import QUDIT_DIM, get_num_wires


def _apply_diagonalizing_gates(
    mp: SampleMeasurement, state: np.ndarray, is_state_batched: bool = False
):
    """Applies diagonalizing gates when necessary"""
    if mp.obs:
        for op in mp.diagonalizing_gates():
            state = apply_operation(op, state, is_state_batched=is_state_batched)

    return state


def _process_samples(
    mp,
    samples,
    wire_order,
):
    """Processes samples like SampleMP.process_samples, but fixed for qutrits"""
    wire_map = dict(zip(wire_order, range(len(wire_order))))
    mapped_wires = [wire_map[w] for w in mp.wires]

    if mapped_wires:
        # if wires are provided, then we only return samples from those wires
        samples = samples[..., mapped_wires]

    num_wires = samples.shape[-1]  # wires is the last dimension

    if mp.obs is None:
        # if no observable was provided then return the raw samples
        return samples

    # Replace the basis state in the computational basis with the correct eigenvalue.
    # Extract only the columns of the basis samples required based on ``wires``.
    powers_of_three = QUDIT_DIM ** qml.math.arange(num_wires)[::-1]
    indices = qml.math.array(samples @ powers_of_three)
    return mp.eigvals()[indices]


def _process_counts_samples(processed_sample, mp_has_obs):
    """Processes a set of samples and counts the results."""
    observables, counts = math.unique(processed_sample, return_counts=True, axis=0)
    if not mp_has_obs:
        observables = ["".join(observable.astype("str")) for observable in observables]
    return dict(zip(observables, counts))


def _process_expval_samples(processed_sample):
    """Processes a set of samples and returns the expectation value of an observable."""
    eigvals, counts = math.unique(processed_sample, return_counts=True)
    probs = counts / math.sum(counts)
    return math.dot(probs, eigvals)


def _process_variance_samples(processed_sample):
    """Processes a set of samples and returns the variance of an observable."""
    eigvals, counts = math.unique(processed_sample, return_counts=True)
    probs = counts / math.sum(counts)
    return math.dot(probs, (eigvals**2)) - math.dot(probs, eigvals) ** 2


# pylint:disable = too-many-arguments
def _measure_with_samples_diagonalizing_gates(
    mp: SampleMeasurement,
    state: np.ndarray,
    shots: Shots,
    is_state_batched: bool = False,
    rng=None,
    prng_key=None,
    readout_errors: list[Callable] = None,
) -> TensorLike:
    """Returns the samples of the measurement process performed on the given state,
    by rotating the state into the measurement basis using the diagonalizing gates
    given by the measurement process.

    Args:
        mp (~.measurements.SampleMeasurement): The sample measurement to perform
        state (np.ndarray[complex]): The state vector to sample from
        shots (~.measurements.Shots): The number of samples to take
        is_state_batched (bool): whether the state is batched or not
        rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
            seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
            If no value is provided, a default RNG will be used.
        prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
            the key to the JAX pseudo random number generator. Only for simulation using JAX.
        readout_errors (List[Callable]): List of channels to apply to each wire being measured
        to simulate readout errors.

    Returns:
        TensorLike[Any]: Sample measurement results
    """
    # apply diagonalizing gates
    state = _apply_diagonalizing_gates(mp, state, is_state_batched)

    total_indices = get_num_wires(state, is_state_batched)
    wires = qml.wires.Wires(range(total_indices))

    def _process_single_shot(samples):
        samples_processed = _process_samples(mp, samples, wires)
        if isinstance(mp, SampleMP):
            return math.squeeze(samples_processed)
        if isinstance(mp, CountsMP):
            process_func = functools.partial(_process_counts_samples, mp_has_obs=mp.obs is not None)
        elif isinstance(mp, ExpectationMP):
            process_func = _process_expval_samples
        elif isinstance(mp, VarianceMP):
            process_func = _process_variance_samples
        else:
            raise NotImplementedError

        if is_state_batched:
            ret = []
            for processed_sample in samples_processed:
                ret.append(process_func(processed_sample))
            return math.squeeze(ret)
        return process_func(samples_processed)

    # if there is a shot vector, build a list containing results for each shot entry
    if shots.has_partitioned_shots:
        processed_samples = []
        for s in shots:
            # Like default.qubit currently calling sample_state for each shot entry,
            # but it may be better to call sample_state just once with total_shots,
            # then use the shot_range keyword argument
            samples = sample_state(
                state,
                shots=s,
                is_state_batched=is_state_batched,
                wires=wires,
                rng=rng,
                prng_key=prng_key,
                readout_errors=readout_errors,
            )
            processed_samples.append(_process_single_shot(samples))

        return tuple(processed_samples)

    samples = sample_state(
        state,
        shots=shots.total_shots,
        is_state_batched=is_state_batched,
        wires=wires,
        rng=rng,
        prng_key=prng_key,
        readout_errors=readout_errors,
    )

    return _process_single_shot(samples)


def _measure_sum_with_samples(
    mp: SampleMeasurement,
    state: np.ndarray,
    shots: Shots,
    is_state_batched: bool = False,
    rng=None,
    prng_key=None,
    readout_errors: list[Callable] = None,
):
    """Compute expectation values of Sum Observables"""

    def _sum_for_single_shot(s):
        results = []
        for term in mp.obs:
            results.append(
                measure_with_samples(
                    ExpectationMP(term),
                    state,
                    s,
                    is_state_batched=is_state_batched,
                    rng=rng,
                    prng_key=prng_key,
                    readout_errors=readout_errors,
                )
            )

        return sum(results)

    if shots.has_partitioned_shots:
        return tuple(_sum_for_single_shot(type(shots)(s)) for s in shots)

    return _sum_for_single_shot(shots)


def _sample_state_jax(
    state,
    shots: int,
    prng_key,
    is_state_batched: bool = False,
    wires=None,
    readout_errors: list[Callable] = None,
) -> np.ndarray:
    """Returns a series of samples of a state for the JAX interface based on the PRNG.

    Args:
        state (array[complex]): A state vector to be sampled
        shots (int): The number of samples to take
        prng_key (jax.random.PRNGKey): A``jax.random.PRNGKey``. This is
            the key to the JAX pseudo random number generator.
        is_state_batched (bool): whether the state is batched or not
        wires (Sequence[int]): The wires to sample
        readout_errors (List[Callable]): List of channels to apply to each wire being measured
        to simulate readout errors.

    Returns:
        ndarray[int]: Sample values of the shape (shots, num_wires)
    """
    # pylint: disable=import-outside-toplevel

    total_indices = get_num_wires(state, is_state_batched)
    state_wires = qml.wires.Wires(range(total_indices))

    wires_to_sample = wires or state_wires
    num_wires = len(wires_to_sample)

    with qml.queuing.QueuingManager.stop_recording():
        probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors)

    state_len = len(state)

    return _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len)


def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len):
    """
    Sample from a probability distribution for a qutrit system using JAX.

    This function generates samples based on the given probability distribution
    for a qutrit system with a specified number of wires. It can handle both
    batched and non-batched probability distributions. This function uses JAX
    for potential GPU acceleration and improved performance.

    Args:
        probs (jnp.ndarray): Probability distribution to sample from. For non-batched
            input, this should be a 1D array of length QUDIT_DIM**num_wires. For
            batched input, this should be a 2D array where each row is a separate
            probability distribution.
        shots (int): Number of samples to generate.
        num_wires (int): Number of wires in the qutrit system.
        is_state_batched (bool): Whether the input probabilities are batched.
        prng_key (jax.random.PRNGKey): JAX PRNG key for random number generation.
        state_len (int): Length of the state (relevant for batched inputs).

    Returns:
        jnp.ndarray: An array of samples. For non-batched input, the shape is
        (shots, num_wires). For batched input, the shape is
        (batch_size, shots, num_wires).

    Example:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> probs = jnp.array([0.2, 0.3, 0.5])  # For a single-wire qutrit system
        >>> shots = 1000
        >>> num_wires = 1
        >>> is_state_batched = False
        >>> prng_key = jax.random.PRNGKey(42)
        >>> state_len = 1
        >>> samples = _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len)
        >>> samples.shape
        (1000, 1)

    Note:
        This function requires JAX to be installed. It internally imports JAX
        and its numpy module (jnp).
    """
    # pylint: disable=import-outside-toplevel
    import jax
    import jax.numpy as jnp

    key = prng_key

    basis_states = np.arange(QUDIT_DIM**num_wires)
    if is_state_batched:
        # Produce separate keys for each of the probabilities along the broadcasted axis
        keys = []
        for _ in range(state_len):
            key, subkey = jax.random.split(key)
            keys.append(subkey)
        samples = jnp.array(
            [
                jax.random.choice(_key, basis_states, shape=(shots,), p=prob)
                for _key, prob in zip(keys, probs)
            ]
        )
    else:
        samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

    res = np.zeros(samples.shape + (num_wires,), dtype=np.int64)
    for i in range(num_wires):
        res[..., -(i + 1)] = (samples // (QUDIT_DIM**i)) % QUDIT_DIM
    return res


[docs]def sample_state( state, shots: int, is_state_batched: bool = False, wires=None, rng=None, prng_key=None, readout_errors: list[Callable] = None, ) -> np.ndarray: """Returns a series of computational basis samples of a state. Args: state (array[complex]): A state vector to be sampled shots (int): The number of samples to take is_state_batched (bool): whether the state is batched or not wires (Sequence[int]): The wires to sample rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. If no value is provided, a default RNG will be used prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. readout_errors (List[Callable]): List of channels to apply to each wire being measured to simulate readout errors. Returns: ndarray[int]: Sample values of the shape (shots, num_wires) """ if prng_key is not None: return _sample_state_jax( state, shots, prng_key, is_state_batched=is_state_batched, wires=wires, readout_errors=readout_errors, ) total_indices = get_num_wires(state, is_state_batched) state_wires = qml.wires.Wires(range(total_indices)) wires_to_sample = wires or state_wires num_wires = len(wires_to_sample) with qml.queuing.QueuingManager.stop_recording(): probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors) return sample_probs(probs, shots, num_wires, is_state_batched, rng)
[docs]def sample_probs(probs, shots, num_wires, is_state_batched, rng): """ Sample from a probability distribution for a qutrit system. This function generates samples based on the given probability distribution for a qutrit system with a specified number of wires. It can handle both batched and non-batched probability distributions. Args: probs (ndarray): Probability distribution to sample from. For non-batched input, this should be a 1D array of length QUDIT_DIM**num_wires. For batched input, this should be a 2D array where each row is a separate probability distribution. shots (int): Number of samples to generate. num_wires (int): Number of wires in the qutrit system. is_state_batched (bool): Whether the input probabilities are batched. rng (Optional[Generator]): Random number generator to use. If None, a new generator will be created. Returns: ndarray: An array of samples. For non-batched input, the shape is (shots, num_wires). For batched input, the shape is (batch_size, shots, num_wires). Example: >>> probs = np.array([0.2, 0.3, 0.5]) # For a single-wire qutrit system >>> shots = 1000 >>> num_wires = 1 >>> is_state_batched = False >>> rng = np.random.default_rng(42) >>> samples = sample_probs(probs, shots, num_wires, is_state_batched, rng) >>> samples.shape (1000, 1) """ rng = np.random.default_rng(rng) basis_states = np.arange(QUDIT_DIM**num_wires) if is_state_batched: # rng.choice doesn't support broadcasting samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs]) else: samples = rng.choice(basis_states, shots, p=probs) res = np.zeros(samples.shape + (num_wires,), dtype=np.int64) for i in range(num_wires): res[..., -(i + 1)] = (samples // (QUDIT_DIM**i)) % QUDIT_DIM return res
[docs]def measure_with_samples( mp: SampleMeasurement, state: np.ndarray, shots: Shots, is_state_batched: bool = False, rng=None, prng_key=None, readout_errors: list[Callable] = None, ) -> TensorLike: """Returns the samples of the measurement process performed on the given state. This function assumes that the user-defined wire labels in the measurement process have already been mapped to integer wires used in the device. Args: mp (SampleMeasurement): The sample measurement to perform state (np.ndarray[complex]): The state vector to sample from shots (Shots): The number of samples to take is_state_batched (bool): whether the state is batched or not rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. If no value is provided, a default RNG will be used. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. readout_errors (List[Callable]): List of channels to apply to each wire being measured to simulate readout errors. Returns: TensorLike[Any]: Sample measurement results """ if isinstance(mp, ExpectationMP) and isinstance(mp.obs, Sum): measure_fn = _measure_sum_with_samples else: # measure with the usual method (rotate into the measurement basis) measure_fn = _measure_with_samples_diagonalizing_gates return measure_fn( mp, state, shots, is_state_batched=is_state_batched, rng=rng, prng_key=prng_key, readout_errors=readout_errors, )