Source code for pennylane.devices.default_qutrit_mixed

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The default.qutrit.mixed device is PennyLane's standard qutrit simulator for mixed-state
computations."""
import logging
import warnings
from collections.abc import Callable, Sequence
from dataclasses import replace
from functools import partial
from typing import Optional, Union

import numpy as np

import pennylane as qml
from pennylane.logging import debug_logger, debug_logger_init
from pennylane.ops import _qutrit__channel__ops__ as channels
from pennylane.tape import QuantumScript, QuantumScriptOrBatch
from pennylane.transforms.core import TransformProgram
from pennylane.typing import Result, ResultBatch

from . import Device
from .default_qutrit import DefaultQutrit
from .execution_config import DefaultExecutionConfig, ExecutionConfig
from .modifiers import simulator_tracking, single_tape_support
from .preprocess import (
    decompose,
    no_sampling,
    null_postprocessing,
    validate_device_wires,
    validate_measurements,
    validate_observables,
)
from .qutrit_mixed.simulate import simulate

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


observables = {
    "THermitian",
    "GellMann",
}


[docs]def observable_stopping_condition(obs: qml.operation.Operator) -> bool: """Specifies whether an observable is accepted by DefaultQutritMixed.""" if obs.name in {"Prod", "Sum"}: return all(observable_stopping_condition(observable) for observable in obs.operands) if obs.name == "LinearCombination": return all(observable_stopping_condition(observable) for observable in obs.terms()[1]) if obs.name == "SProd": return observable_stopping_condition(obs.base) return obs.name in observables
[docs]def stopping_condition(op: qml.operation.Operator) -> bool: """Specify whether an Operator object is supported by the device.""" expected_set = DefaultQutrit.operations | {"Snapshot"} | channels return op.name in expected_set
[docs]def stopping_condition_shots(op: qml.operation.Operator) -> bool: """Specify whether an Operator object is supported by the device with shots.""" return stopping_condition(op)
[docs]def accepted_sample_measurement(m: qml.measurements.MeasurementProcess) -> bool: """Specifies whether a measurement is accepted when sampling.""" return isinstance(m, qml.measurements.SampleMeasurement)
[docs]@qml.transform def warn_readout_error_state( tape: qml.tape.QuantumTape, ) -> tuple[Sequence[qml.tape.QuantumTape], Callable]: """If a measurement in the QNode is an analytic state or density_matrix, and a readout error parameter is defined, warn that readout error will not be applied. Args: tape (QuantumTape, .QNode, Callable): a quantum circuit. Returns: qnode (pennylane.QNode) or quantum function (callable) or tuple[List[.QuantumTape], function]: The unaltered input circuit. """ if not tape.shots: for m in tape.measurements: if isinstance(m, qml.measurements.StateMP): warnings.warn(f"Measurement {m} is not affected by readout error.") return (tape,), null_postprocessing
[docs]def get_readout_errors(readout_relaxation_probs, readout_misclassification_probs): r"""Get the list of readout errors that should be applied to each measured wire. Args: readout_relaxation_probs (List[float]): Inputs for :class:`~.QutritAmplitudeDamping` channel of the form :math:`[\gamma_{10}, \gamma_{20}, \gamma_{21}]`. This error models amplitude damping associated with longer readout and varying relaxation times of transmon-based qudits. readout_misclassification_probs (List[float]): Inputs for :class:`~.TritFlip` channel of the form :math:`[p_{01}, p_{02}, p_{12}]`. This error models misclassification events in readout. Returns: readout_errors (List[Callable]): List of readout error channels that should be applied to each measured wire. """ measure_funcs = [] if readout_relaxation_probs is not None: try: with qml.queuing.QueuingManager.stop_recording(): qml.QutritAmplitudeDamping(*readout_relaxation_probs, wires=0) except Exception as e: raise qml.DeviceError("Applying damping readout error results in error:\n" + str(e)) measure_funcs.append(partial(qml.QutritAmplitudeDamping, *readout_relaxation_probs)) if readout_misclassification_probs is not None: try: with qml.queuing.QueuingManager.stop_recording(): qml.TritFlip(*readout_misclassification_probs, wires=0) except Exception as e: raise qml.DeviceError("Applying trit flip readout error results in error:\n" + str(e)) measure_funcs.append(partial(qml.TritFlip, *readout_misclassification_probs)) return None if len(measure_funcs) == 0 else measure_funcs
[docs]@simulator_tracking @single_tape_support class DefaultQutritMixed(Device): r"""A PennyLane Python-based device for mixed-state qutrit simulation. Args: wires (int, Iterable[Number, str]): Number of wires present on the device, or iterable that contains unique labels for the wires as numbers (i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``). Default ``None`` if not specified. shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots to use in executions involving this device. seed (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator, jax.random.PRNGKey]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``, or a request to seed from numpy's global random number generator. The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None`` will pull a seed from the OS entropy. If a ``jax.random.PRNGKey`` is passed as the seed, a JAX-specific sampling function using ``jax.random.choice`` and the ``PRNGKey`` will be used for sampling rather than ``numpy.random.default_rng``. readout_relaxation_probs (List[float]): Input probabilities for relaxation errors implemented with the :class:`~.QutritAmplitudeDamping` channel. The input defines the channel's parameters :math:`[\gamma_{10}, \gamma_{20}, \gamma_{21}]`. readout_misclassification_probs (List[float]): Input probabilities for state readout misclassification events implemented with the :class:`~.TritFlip` channel. The input defines the channel's parameters :math:`[p_{01}, p_{02}, p_{12}]`. **Example:** .. code-block:: python n_wires = 5 num_qscripts = 5 qscripts = [] for i in range(num_qscripts): unitary = scipy.stats.unitary_group(dim=3**n_wires, seed=(42 + i)).rvs() op = qml.QutritUnitary(unitary, wires=range(n_wires)) qs = qml.tape.QuantumScript([op], [qml.expval(qml.GellMann(0, 3))]) qscripts.append(qs) >>> dev = DefaultQutritMixed() >>> program, execution_config = dev.preprocess() >>> new_batch, post_processing_fn = program(qscripts) >>> results = dev.execute(new_batch, execution_config=execution_config) >>> post_processing_fn(results) [0.08015701503959313, 0.04521414211599359, -0.0215232130089687, 0.062120285032425865, -0.0635052317625] This device currently supports backpropagation derivatives: >>> from pennylane.devices import ExecutionConfig >>> dev.supports_derivatives(ExecutionConfig(gradient_method="backprop")) True For example, we can use jax to jit computing the derivative: .. code-block:: python import jax @jax.jit def f(x): qs = qml.tape.QuantumScript([qml.TRX(x, 0)], [qml.expval(qml.GellMann(0, 3))]) program, execution_config = dev.preprocess() new_batch, post_processing_fn = program([qs]) results = dev.execute(new_batch, execution_config=execution_config) return post_processing_fn(results)[0] >>> f(jax.numpy.array(1.2)) DeviceArray(0.36235774, dtype=float32) >>> jax.grad(f)(jax.numpy.array(1.2)) DeviceArray(-0.93203914, dtype=float32, weak_type=True) .. details:: :title: Readout Error ``DefaultQutritMixed`` includes readout error support. Two input arguments control the parameters of error channels applied to each measured wire of the state after it has been diagonalized for measurement: * ``readout_relaxation_probs``: Input parameters of a :class:`~.QutritAmplitudeDamping` channel. This error models state relaxation error that occurs during readout of transmon-based qutrits. The motivation for this readout error is described in [`1 <https://arxiv.org/abs/2003.03307>`_] (Sec II.A). * ``readout_misclassification_probs``: Input parameters of a :class:`~.TritFlip` channel. This error models misclassification events in readout. An example of this readout error can be seen in [`2 <https://arxiv.org/abs/2309.11303>`_] (Fig 1a). In the case that both parameters are defined, relaxation error is applied first then misclassification error is applied. .. note:: The readout errors will be applied to the state after it has been diagonalized for each measurement. This may give different results depending on how the observable is defined. This is because diagonalizing gates for the same observable may return eigenvalues in different orders. For example, measuring :class:`~.THermitian` with a non-diagonal GellMann matrix will result in a different measurement result then measuring the equivalent :class:`~.GellMann` observable, as the THermitian eigenvalues are returned in increasing order when explicitly diagonalized (i.e., ``[-1, 0, 1]``), while non-diagonal GellManns provided in PennyLane have their eigenvalues hardcoded (i.e., ``[1, -1, 0]``). .. details:: :title: Tracking ``DefaultQutritMixed`` tracks: * ``executions``: the number of unique circuits that would be required on quantum hardware * ``shots``: the number of shots * ``resources``: the :class:`~.resource.Resources` for the executed circuit. * ``simulations``: the number of simulations performed. One simulation can cover multiple QPU executions, such as for non-commuting measurements and batched parameters. * ``batches``: The number of times :meth:`~.execute` is called. * ``results``: The results of each call of :meth:`~.execute` """ _device_options = ("rng", "prng_key") # tuple of string names for all the device options. @property def name(self): """The name of the device.""" return "default.qutrit.mixed" @debug_logger_init def __init__( # pylint: disable=too-many-arguments self, wires=None, shots=None, seed="global", readout_relaxation_probs=None, readout_misclassification_probs=None, ) -> None: super().__init__(wires=wires, shots=shots) seed = np.random.randint(0, high=10000000) if seed == "global" else seed if qml.math.get_interface(seed) == "jax": self._prng_key = seed self._rng = np.random.default_rng(None) else: self._prng_key = None self._rng = np.random.default_rng(seed) self._debugger = None self.readout_errors = get_readout_errors( readout_relaxation_probs, readout_misclassification_probs )
[docs] @debug_logger def supports_derivatives( self, execution_config: Optional[ExecutionConfig] = None, circuit: Optional[QuantumScript] = None, ) -> bool: """Check whether or not derivatives are available for a given configuration and circuit. ``DefaultQutritMixed`` supports backpropagation derivatives with analytic results. Args: execution_config (ExecutionConfig): The configuration of the desired derivative calculation. circuit (QuantumTape): An optional circuit to check derivatives support for. Returns: bool: Whether or not a derivative can be calculated provided the given information. """ if execution_config is None or execution_config.gradient_method in {"backprop", "best"}: return circuit is None or not circuit.shots return False
def _setup_execution_config(self, execution_config: ExecutionConfig) -> ExecutionConfig: """This is a private helper for ``preprocess`` that sets up the execution config. Args: execution_config (ExecutionConfig): an unprocessed execution config. Returns: ExecutionConfig: a preprocessed execution config. """ updated_values = {} for option in execution_config.device_options: if option not in self._device_options: raise qml.DeviceError(f"device option {option} not present on {self}") if execution_config.gradient_method == "best": updated_values["gradient_method"] = "backprop" updated_values["use_device_gradient"] = False updated_values["grad_on_execution"] = False updated_values["device_options"] = dict(execution_config.device_options) # copy for option in self._device_options: if option not in updated_values["device_options"]: updated_values["device_options"][option] = getattr(self, f"_{option}") return replace(execution_config, **updated_values)
[docs] @debug_logger def preprocess( self, execution_config: ExecutionConfig = DefaultExecutionConfig, ) -> tuple[TransformProgram, ExecutionConfig]: """This function defines the device transform program to be applied and an updated device configuration. Args: execution_config (Union[ExecutionConfig, Sequence[ExecutionConfig]]): A data structure describing the parameters needed to fully describe the execution. Returns: TransformProgram, ExecutionConfig: A transform program that when called returns ``QuantumTape`` objects that the device can natively execute, as well as a postprocessing function to be called after execution, and a configuration with unset specifications filled in. This device: * Supports any qutrit operations that provide a matrix * Supports any qutrit channel that provides Kraus matrices """ config = self._setup_execution_config(execution_config) transform_program = TransformProgram() transform_program.add_transform(validate_device_wires, self.wires, name=self.name) transform_program.add_transform( decompose, stopping_condition=stopping_condition, stopping_condition_shots=stopping_condition_shots, name=self.name, ) transform_program.add_transform( validate_measurements, sample_measurements=accepted_sample_measurement, name=self.name ) transform_program.add_transform( validate_observables, stopping_condition=observable_stopping_condition, name=self.name ) if config.gradient_method == "backprop": transform_program.add_transform(no_sampling, name="backprop + default.qutrit") if self.readout_errors is not None: transform_program.add_transform(warn_readout_error_state) return transform_program, config
[docs] @debug_logger def execute( self, circuits: QuantumScriptOrBatch, execution_config: ExecutionConfig = DefaultExecutionConfig, ) -> Union[Result, ResultBatch]: interface = ( execution_config.interface if execution_config.gradient_method in {"best", "backprop", None} else None ) return tuple( simulate( c, rng=self._rng, prng_key=self._prng_key, debugger=self._debugger, interface=interface, readout_errors=self.readout_errors, ) for c in circuits )