Source code for pennylane.devices.default_qutrit_mixed
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""The default.qutrit.mixed device is PennyLane's standard qutrit simulator for mixed-statecomputations."""importloggingimportwarningsfromcollections.abcimportCallable,SequencefromdataclassesimportreplacefromfunctoolsimportpartialfromtypingimportOptional,Unionimportnumpyasnpimportpennylaneasqmlfrompennylane.loggingimportdebug_logger,debug_logger_initfrompennylane.opsimport_qutrit__channel__ops__aschannelsfrompennylane.tapeimportQuantumScript,QuantumScriptOrBatchfrompennylane.transforms.coreimportTransformProgramfrompennylane.typingimportResult,ResultBatchfrom.importDevicefrom.default_qutritimportDefaultQutritfrom.execution_configimportDefaultExecutionConfig,ExecutionConfigfrom.modifiersimportsimulator_tracking,single_tape_supportfrom.preprocessimport(decompose,no_sampling,null_postprocessing,validate_device_wires,validate_measurements,validate_observables,)from.qutrit_mixed.simulateimportsimulatelogger=logging.getLogger(__name__)logger.addHandler(logging.NullHandler())observables={"THermitian","GellMann",}
[docs]defobservable_stopping_condition(obs:qml.operation.Operator)->bool:"""Specifies whether an observable is accepted by DefaultQutritMixed."""ifobs.namein{"Prod","Sum"}:returnall(observable_stopping_condition(observable)forobservableinobs.operands)ifobs.name=="LinearCombination":returnall(observable_stopping_condition(observable)forobservableinobs.terms()[1])ifobs.name=="SProd":returnobservable_stopping_condition(obs.base)returnobs.nameinobservables
[docs]defstopping_condition(op:qml.operation.Operator)->bool:"""Specify whether an Operator object is supported by the device."""expected_set=DefaultQutrit.operations|{"Snapshot"}|channelsreturnop.nameinexpected_set
[docs]defstopping_condition_shots(op:qml.operation.Operator)->bool:"""Specify whether an Operator object is supported by the device with shots."""returnstopping_condition(op)
[docs]defaccepted_sample_measurement(m:qml.measurements.MeasurementProcess)->bool:"""Specifies whether a measurement is accepted when sampling."""returnisinstance(m,qml.measurements.SampleMeasurement)
[docs]@qml.transformdefwarn_readout_error_state(tape:qml.tape.QuantumTape,)->tuple[Sequence[qml.tape.QuantumTape],Callable]:"""If a measurement in the QNode is an analytic state or density_matrix, and a readout error parameter is defined, warn that readout error will not be applied. Args: tape (QuantumTape, .QNode, Callable): a quantum circuit. Returns: qnode (pennylane.QNode) or quantum function (callable) or tuple[List[.QuantumTape], function]: The unaltered input circuit. """ifnottape.shots:formintape.measurements:ifisinstance(m,qml.measurements.StateMP):warnings.warn(f"Measurement {m} is not affected by readout error.")return(tape,),null_postprocessing
[docs]defget_readout_errors(readout_relaxation_probs,readout_misclassification_probs):r"""Get the list of readout errors that should be applied to each measured wire. Args: readout_relaxation_probs (List[float]): Inputs for :class:`~.QutritAmplitudeDamping` channel of the form :math:`[\gamma_{10}, \gamma_{20}, \gamma_{21}]`. This error models amplitude damping associated with longer readout and varying relaxation times of transmon-based qudits. readout_misclassification_probs (List[float]): Inputs for :class:`~.TritFlip` channel of the form :math:`[p_{01}, p_{02}, p_{12}]`. This error models misclassification events in readout. Returns: readout_errors (List[Callable]): List of readout error channels that should be applied to each measured wire. """measure_funcs=[]ifreadout_relaxation_probsisnotNone:try:withqml.queuing.QueuingManager.stop_recording():qml.QutritAmplitudeDamping(*readout_relaxation_probs,wires=0)exceptExceptionase:raiseqml.DeviceError("Applying damping readout error results in error:\n"+str(e))measure_funcs.append(partial(qml.QutritAmplitudeDamping,*readout_relaxation_probs))ifreadout_misclassification_probsisnotNone:try:withqml.queuing.QueuingManager.stop_recording():qml.TritFlip(*readout_misclassification_probs,wires=0)exceptExceptionase:raiseqml.DeviceError("Applying trit flip readout error results in error:\n"+str(e))measure_funcs.append(partial(qml.TritFlip,*readout_misclassification_probs))returnNoneiflen(measure_funcs)==0elsemeasure_funcs
[docs]@simulator_tracking@single_tape_supportclassDefaultQutritMixed(Device):r"""A PennyLane Python-based device for mixed-state qutrit simulation. Args: wires (int, Iterable[Number, str]): Number of wires present on the device, or iterable that contains unique labels for the wires as numbers (i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``). Default ``None`` if not specified. shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots to use in executions involving this device. seed (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator, jax.random.PRNGKey]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``, or a request to seed from numpy's global random number generator. The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None`` will pull a seed from the OS entropy. If a ``jax.random.PRNGKey`` is passed as the seed, a JAX-specific sampling function using ``jax.random.choice`` and the ``PRNGKey`` will be used for sampling rather than ``numpy.random.default_rng``. readout_relaxation_probs (List[float]): Input probabilities for relaxation errors implemented with the :class:`~.QutritAmplitudeDamping` channel. The input defines the channel's parameters :math:`[\gamma_{10}, \gamma_{20}, \gamma_{21}]`. readout_misclassification_probs (List[float]): Input probabilities for state readout misclassification events implemented with the :class:`~.TritFlip` channel. The input defines the channel's parameters :math:`[p_{01}, p_{02}, p_{12}]`. **Example:** .. code-block:: python n_wires = 5 num_qscripts = 5 qscripts = [] for i in range(num_qscripts): unitary = scipy.stats.unitary_group(dim=3**n_wires, seed=(42 + i)).rvs() op = qml.QutritUnitary(unitary, wires=range(n_wires)) qs = qml.tape.QuantumScript([op], [qml.expval(qml.GellMann(0, 3))]) qscripts.append(qs) >>> dev = DefaultQutritMixed() >>> program, execution_config = dev.preprocess() >>> new_batch, post_processing_fn = program(qscripts) >>> results = dev.execute(new_batch, execution_config=execution_config) >>> post_processing_fn(results) [0.08015701503959313, 0.04521414211599359, -0.0215232130089687, 0.062120285032425865, -0.0635052317625] This device currently supports backpropagation derivatives: >>> from pennylane.devices import ExecutionConfig >>> dev.supports_derivatives(ExecutionConfig(gradient_method="backprop")) True For example, we can use jax to jit computing the derivative: .. code-block:: python import jax @jax.jit def f(x): qs = qml.tape.QuantumScript([qml.TRX(x, 0)], [qml.expval(qml.GellMann(0, 3))]) program, execution_config = dev.preprocess() new_batch, post_processing_fn = program([qs]) results = dev.execute(new_batch, execution_config=execution_config) return post_processing_fn(results)[0] >>> f(jax.numpy.array(1.2)) DeviceArray(0.36235774, dtype=float32) >>> jax.grad(f)(jax.numpy.array(1.2)) DeviceArray(-0.93203914, dtype=float32, weak_type=True) .. details:: :title: Readout Error ``DefaultQutritMixed`` includes readout error support. Two input arguments control the parameters of error channels applied to each measured wire of the state after it has been diagonalized for measurement: * ``readout_relaxation_probs``: Input parameters of a :class:`~.QutritAmplitudeDamping` channel. This error models state relaxation error that occurs during readout of transmon-based qutrits. The motivation for this readout error is described in [`1 <https://arxiv.org/abs/2003.03307>`_] (Sec II.A). * ``readout_misclassification_probs``: Input parameters of a :class:`~.TritFlip` channel. This error models misclassification events in readout. An example of this readout error can be seen in [`2 <https://arxiv.org/abs/2309.11303>`_] (Fig 1a). In the case that both parameters are defined, relaxation error is applied first then misclassification error is applied. .. note:: The readout errors will be applied to the state after it has been diagonalized for each measurement. This may give different results depending on how the observable is defined. This is because diagonalizing gates for the same observable may return eigenvalues in different orders. For example, measuring :class:`~.THermitian` with a non-diagonal GellMann matrix will result in a different measurement result then measuring the equivalent :class:`~.GellMann` observable, as the THermitian eigenvalues are returned in increasing order when explicitly diagonalized (i.e., ``[-1, 0, 1]``), while non-diagonal GellManns provided in PennyLane have their eigenvalues hardcoded (i.e., ``[1, -1, 0]``). .. details:: :title: Tracking ``DefaultQutritMixed`` tracks: * ``executions``: the number of unique circuits that would be required on quantum hardware * ``shots``: the number of shots * ``resources``: the :class:`~.resource.Resources` for the executed circuit. * ``simulations``: the number of simulations performed. One simulation can cover multiple QPU executions, such as for non-commuting measurements and batched parameters. * ``batches``: The number of times :meth:`~.execute` is called. * ``results``: The results of each call of :meth:`~.execute` """_device_options=("rng","prng_key")# tuple of string names for all the device options.@propertydefname(self):"""The name of the device."""return"default.qutrit.mixed"@debug_logger_initdef__init__(# pylint: disable=too-many-argumentsself,wires=None,shots=None,seed="global",readout_relaxation_probs=None,readout_misclassification_probs=None,)->None:super().__init__(wires=wires,shots=shots)seed=np.random.randint(0,high=10000000)ifseed=="global"elseseedifqml.math.get_interface(seed)=="jax":self._prng_key=seedself._rng=np.random.default_rng(None)else:self._prng_key=Noneself._rng=np.random.default_rng(seed)self._debugger=Noneself.readout_errors=get_readout_errors(readout_relaxation_probs,readout_misclassification_probs)
[docs]@debug_loggerdefsupports_derivatives(self,execution_config:Optional[ExecutionConfig]=None,circuit:Optional[QuantumScript]=None,)->bool:"""Check whether or not derivatives are available for a given configuration and circuit. ``DefaultQutritMixed`` supports backpropagation derivatives with analytic results. Args: execution_config (ExecutionConfig): The configuration of the desired derivative calculation. circuit (QuantumTape): An optional circuit to check derivatives support for. Returns: bool: Whether or not a derivative can be calculated provided the given information. """ifexecution_configisNoneorexecution_config.gradient_methodin{"backprop","best"}:returncircuitisNoneornotcircuit.shotsreturnFalse
def_setup_execution_config(self,execution_config:ExecutionConfig)->ExecutionConfig:"""This is a private helper for ``preprocess`` that sets up the execution config. Args: execution_config (ExecutionConfig): an unprocessed execution config. Returns: ExecutionConfig: a preprocessed execution config. """updated_values={}foroptioninexecution_config.device_options:ifoptionnotinself._device_options:raiseqml.DeviceError(f"device option {option} not present on {self}")ifexecution_config.gradient_method=="best":updated_values["gradient_method"]="backprop"updated_values["use_device_gradient"]=Falseupdated_values["grad_on_execution"]=Falseupdated_values["device_options"]=dict(execution_config.device_options)# copyforoptioninself._device_options:ifoptionnotinupdated_values["device_options"]:updated_values["device_options"][option]=getattr(self,f"_{option}")returnreplace(execution_config,**updated_values)
[docs]@debug_loggerdefpreprocess(self,execution_config:ExecutionConfig=DefaultExecutionConfig,)->tuple[TransformProgram,ExecutionConfig]:"""This function defines the device transform program to be applied and an updated device configuration. Args: execution_config (Union[ExecutionConfig, Sequence[ExecutionConfig]]): A data structure describing the parameters needed to fully describe the execution. Returns: TransformProgram, ExecutionConfig: A transform program that when called returns ``QuantumTape`` objects that the device can natively execute, as well as a postprocessing function to be called after execution, and a configuration with unset specifications filled in. This device: * Supports any qutrit operations that provide a matrix * Supports any qutrit channel that provides Kraus matrices """config=self._setup_execution_config(execution_config)transform_program=TransformProgram()transform_program.add_transform(validate_device_wires,self.wires,name=self.name)transform_program.add_transform(decompose,stopping_condition=stopping_condition,stopping_condition_shots=stopping_condition_shots,name=self.name,)transform_program.add_transform(validate_measurements,sample_measurements=accepted_sample_measurement,name=self.name)transform_program.add_transform(validate_observables,stopping_condition=observable_stopping_condition,name=self.name)ifconfig.gradient_method=="backprop":transform_program.add_transform(no_sampling,name="backprop + default.qutrit")ifself.readout_errorsisnotNone:transform_program.add_transform(warn_readout_error_state)returntransform_program,config