# Copyright 2018-2024 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""The default.qubit device is PennyLane's standard qubit-based device."""importconcurrent.futuresimportloggingimportwarningsfromdataclassesimportreplacefromfunctoolsimportpartialfromnumbersimportNumberfromtypingimportOptional,Sequence,Unionimportnumpyasnpimportpennylaneasqmlfrompennylane.loggingimportdebug_logger,debug_logger_initfrompennylane.measurementsimportClassicalShadowMP,ShadowExpvalMPfrompennylane.measurements.mid_measureimportMidMeasureMPfrompennylane.ops.op_mathimportConditionalfrompennylane.tapeimportQuantumScript,QuantumScriptBatch,QuantumScriptOrBatchfrompennylane.transformsimportconvert_to_numpy_parametersfrompennylane.transforms.coreimportTransformProgramfrompennylane.typingimportPostprocessingFn,Result,ResultBatch,TensorLikefrom.importDevicefrom.execution_configimportDefaultExecutionConfig,ExecutionConfigfrom.modifiersimportsimulator_tracking,single_tape_supportfrom.preprocessimport(decompose,mid_circuit_measurements,no_sampling,validate_adjoint_trainable_params,validate_device_wires,validate_measurements,validate_multiprocessing_workers,validate_observables,)from.qubit.adjoint_jacobianimportadjoint_jacobian,adjoint_jvp,adjoint_vjpfrom.qubit.samplingimportjax_random_splitfrom.qubit.simulateimportget_final_state,measure_final_state,simulatelogger=logging.getLogger(__name__)logger.addHandler(logging.NullHandler())
[docs]defstopping_condition(op:qml.operation.Operator)->bool:"""Specify whether or not an Operator object is supported by the device."""ifop.name=="QFT"andlen(op.wires)>=6:returnFalseifop.name=="GroverOperator"andlen(op.wires)>=13:returnFalseifop.name=="Snapshot":returnTrueifop.__class__.__name__[:3]=="Pow"andqml.operation.is_trainable(op):returnFalsereturn((isinstance(op,Conditional)andstopping_condition(op.base))orisinstance(op,MidMeasureMP)orop.has_matrixorop.has_sparse_matrix)
[docs]defstopping_condition_shots(op:qml.operation.Operator)->bool:"""Specify whether or not an Operator object is supported by the device with shots."""return((isinstance(op,Conditional)andstopping_condition_shots(op.base))orisinstance(op,MidMeasureMP)orstopping_condition(op))
[docs]defobservable_accepts_sampling(obs:qml.operation.Operator)->bool:"""Verifies whether an observable supports sample measurement"""ifisinstance(obs,qml.ops.CompositeOp):returnall(observable_accepts_sampling(o)foroinobs.operands)ifisinstance(obs,qml.ops.SymbolicOp):returnobservable_accepts_sampling(obs.base)returnobs.has_diagonalizing_gates
[docs]defobservable_accepts_analytic(obs:qml.operation.Operator,is_expval=False)->bool:"""Verifies whether an observable supports analytic measurement"""ifisinstance(obs,qml.ops.CompositeOp):returnall(observable_accepts_analytic(o,is_expval)foroinobs.operands)ifisinstance(obs,qml.ops.SymbolicOp):returnobservable_accepts_analytic(obs.base,is_expval)ifis_expvalandisinstance(obs,(qml.ops.SparseHamiltonian,qml.ops.Hermitian)):returnTruereturnobs.has_diagonalizing_gates
[docs]defaccepted_sample_measurement(m:qml.measurements.MeasurementProcess)->bool:"""Specifies whether a measurement is accepted when sampling."""ifnotisinstance(m,(qml.measurements.SampleMeasurement,qml.measurements.ClassicalShadowMP,qml.measurements.ShadowExpvalMP,),):returnFalseifm.obsisnotNone:returnobservable_accepts_sampling(m.obs)returnTrue
[docs]defaccepted_analytic_measurement(m:qml.measurements.MeasurementProcess)->bool:"""Specifies whether a measurement is accepted when analytic."""ifnotisinstance(m,qml.measurements.StateMeasurement):returnFalseifm.obsisnotNone:returnobservable_accepts_analytic(m.obs,isinstance(m,qml.measurements.ExpectationMP))returnTrue
[docs]defall_state_postprocessing(results,measurements,wire_order):"""Process a state measurement back into the original measurements."""result=tuple(m.process_state(results[0],wire_order=wire_order)forminmeasurements)returnresult[0]iflen(measurements)==1elseresult
@qml.transformdef_conditional_broastcast_expand(tape):"""Apply conditional broadcast expansion to the tape if needed."""# Currently, default.qubit does not support native parameter broadcasting with# shadow operations. We need to expand the tape to include the broadcasted parameters.ifany(isinstance(mp,(ShadowExpvalMP,ClassicalShadowMP))formpintape.measurements):returnqml.transforms.broadcast_expand(tape)return(tape,),null_postprocessing
[docs]@qml.transformdefno_counts(tape):"""Throws an error on counts measurements."""ifany(isinstance(mp,qml.measurements.CountsMP)formpintape.measurements):raiseNotImplementedError("The JAX-JIT interface doesn't support qml.counts.")return(tape,),null_postprocessing
[docs]@qml.transformdefadjoint_state_measurements(tape:QuantumScript,device_vjp=False)->tuple[QuantumScriptBatch,PostprocessingFn]:"""Perform adjoint measurement preprocessing. * Allows a tape with only expectation values through unmodified * Raises an error if non-expectation value measurements exist and any have diagonalizing gates * Turns the circuit into a state measurement + classical postprocesssing for arbitrary measurements Args: tape (QuantumTape): the input circuit """ifall(isinstance(m,qml.measurements.ExpectationMP)formintape.measurements):return(tape,),null_postprocessingifany(len(m.diagonalizing_gates())>0formintape.measurements):raiseqml.DeviceError("adjoint diff supports either all expectation values or only measurements without observables.")params=tape.get_parameters()ifdevice_vjp:forpinparams:if(qml.math.requires_grad(p)andqml.math.get_interface(p)=="tensorflow"andqml.math.get_dtype_name(p)in{"float32","complex64"}):raiseValueError("tensorflow with adjoint differentiation of the state requires float64 or complex128 parameters.")complex_data=[qml.math.cast(p,complex)forpinparams]tape=tape.bind_new_parameters(complex_data,list(range(len(params))))new_mp=qml.measurements.StateMP(wires=tape.wires)state_tape=tape.copy(measurements=[new_mp])return(state_tape,),partial(all_state_postprocessing,measurements=tape.measurements,wire_order=tape.wires)
[docs]defadjoint_ops(op:qml.operation.Operator)->bool:"""Specify whether or not an Operator is supported by adjoint differentiation."""returnnotisinstance(op,(Conditional,MidMeasureMP))and(op.num_params==0ornotqml.operation.is_trainable(op)or(op.num_params==1andop.has_generator))
[docs]defadjoint_observables(obs:qml.operation.Operator)->bool:"""Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit."""returnobs.has_matrix
def_supports_adjoint(circuit,device_wires,device_name):ifcircuitisNone:returnTrueprog=TransformProgram()prog.add_transform(validate_device_wires,device_wires,name=device_name)_add_adjoint_transforms(prog)try:prog((circuit,))except(qml.operation.DecompositionUndefinedError,qml.DeviceError,AttributeError):returnFalsereturnTruedef_add_adjoint_transforms(program:TransformProgram,device_vjp=False)->None:"""Private helper function for ``preprocess`` that adds the transforms specific for adjoint differentiation. Args: program (TransformProgram): where we will add the adjoint differentiation transforms Side Effects: Adds transforms to the input program. """name="adjoint + default.qubit"program.add_transform(no_sampling,name=name)program.add_transform(decompose,stopping_condition=adjoint_ops,name=name,skip_initial_state_prep=False)program.add_transform(validate_observables,adjoint_observables,name=name)program.add_transform(validate_measurements,name=name,)program.add_transform(adjoint_state_measurements,device_vjp=device_vjp)program.add_transform(qml.transforms.broadcast_expand)program.add_transform(validate_adjoint_trainable_params)
[docs]@simulator_tracking@single_tape_supportclassDefaultQubit(Device):"""A PennyLane device written in Python and capable of backpropagation derivatives. Args: wires (int, Iterable[Number, str]): Number of wires present on the device, or iterable that contains unique labels for the wires as numbers (i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``). Default ``None`` if not specified. shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots to use in executions involving this device. seed (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator, jax.random.PRNGKey]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``, or a request to seed from numpy's global random number generator. The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None`` will pull a seed from the OS entropy. If a ``jax.random.PRNGKey`` is passed as the seed, a JAX-specific sampling function using ``jax.random.choice`` and the ``PRNGKey`` will be used for sampling rather than ``numpy.random.default_rng``. max_workers (int): A ``ProcessPoolExecutor`` executes tapes asynchronously using a pool of at most ``max_workers`` processes. If ``max_workers`` is ``None``, only the current process executes tapes. If you experience any issue, say using JAX, TensorFlow, Torch, try setting ``max_workers`` to ``None``. **Example:** .. code-block:: python n_layers = 5 n_wires = 10 num_qscripts = 5 shape = qml.StronglyEntanglingLayers.shape(n_layers=n_layers, n_wires=n_wires) rng = qml.numpy.random.default_rng(seed=42) qscripts = [] for i in range(num_qscripts): params = rng.random(shape) op = qml.StronglyEntanglingLayers(params, wires=range(n_wires)) qs = qml.tape.QuantumScript([op], [qml.expval(qml.Z(0))]) qscripts.append(qs) >>> dev = DefaultQubit() >>> program, execution_config = dev.preprocess() >>> new_batch, post_processing_fn = program(qscripts) >>> results = dev.execute(new_batch, execution_config=execution_config) >>> post_processing_fn(results) [-0.0006888975950537501, 0.025576307134457577, -0.0038567269892757494, 0.1339705146860149, -0.03780669772690448] This device currently supports backpropagation derivatives: >>> from pennylane.devices import ExecutionConfig >>> dev.supports_derivatives(ExecutionConfig(gradient_method="backprop")) True For example, we can use jax to jit computing the derivative: .. code-block:: python import jax @jax.jit def f(x): qs = qml.tape.QuantumScript([qml.RX(x, 0)], [qml.expval(qml.Z(0))]) program, execution_config = dev.preprocess() new_batch, post_processing_fn = program([qs]) results = dev.execute(new_batch, execution_config=execution_config) return post_processing_fn(results) >>> f(jax.numpy.array(1.2)) DeviceArray(0.36235774, dtype=float32) >>> jax.grad(f)(jax.numpy.array(1.2)) DeviceArray(-0.93203914, dtype=float32, weak_type=True) .. details:: :title: Tracking ``DefaultQubit`` tracks: * ``executions``: the number of unique circuits that would be required on quantum hardware * ``shots``: the number of shots * ``resources``: the :class:`~.resource.Resources` for the executed circuit. * ``simulations``: the number of simulations performed. One simulation can cover multiple QPU executions, such as for non-commuting measurements and batched parameters. * ``batches``: The number of times :meth:`~.execute` is called. * ``results``: The results of each call of :meth:`~.execute` * ``derivative_batches``: How many times :meth:`~.compute_derivatives` is called. * ``execute_and_derivative_batches``: How many times :meth:`~.execute_and_compute_derivatives` is called * ``vjp_batches``: How many times :meth:`~.compute_vjp` is called * ``execute_and_vjp_batches``: How many times :meth:`~.execute_and_compute_vjp` is called * ``jvp_batches``: How many times :meth:`~.compute_jvp` is called * ``execute_and_jvp_batches``: How many times :meth:`~.execute_and_compute_jvp` is called * ``derivatives``: How many circuits are submitted to :meth:`~.compute_derivatives` or :meth:`~.execute_and_compute_derivatives`. * ``vjps``: How many circuits are submitted to :meth:`~.compute_vjp` or :meth:`~.execute_and_compute_vjp` * ``jvps``: How many circuits are submitted to :meth:`~.compute_jvp` or :meth:`~.execute_and_compute_jvp` .. details:: :title: Accelerate calculations with multiprocessing Suppose one has a processor with 5 cores or more, these scripts can be executed in parallel as follows >>> dev = DefaultQubit(max_workers=5) >>> program, execution_config = dev.preprocess() >>> new_batch, post_processing_fn = program(qscripts) >>> results = dev.execute(new_batch, execution_config=execution_config) >>> post_processing_fn(results) If you monitor your CPU usage, you should see 5 new Python processes pop up to crunch through those ``QuantumScript``'s. Beware not oversubscribing your machine. This may happen if a single device already uses many cores, if NumPy uses a multi- threaded BLAS library like MKL or OpenBLAS for example. The number of threads per process times the number of processes should not exceed the number of cores on your machine. You can control the number of threads per process with the environment variables: * ``OMP_NUM_THREADS`` * ``MKL_NUM_THREADS`` * ``OPENBLAS_NUM_THREADS`` where the last two are specific to the MKL and OpenBLAS libraries specifically. .. warning:: Multiprocessing may fail depending on your platform and environment (Python shell, script with a protected entry point, Jupyter notebook, etc.) This may be solved changing the so-called start method. The supported start methods are the following: * Windows (win32): spawn (default). * macOS (darwin): spawn (default), fork, forkserver. * Linux (unix): spawn, fork (default), forkserver. which can be changed with ``multiprocessing.set_start_method()``. For example, if multiprocessing fails on macOS in your Jupyter notebook environment, try restarting the session and adding the following at the beginning of the file: .. code-block:: python import multiprocessing multiprocessing.set_start_method("fork") Additional information can be found in the `multiprocessing doc <https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods>`_. """@propertydefname(self):"""The name of the device."""return"default.qubit"
[docs]defget_prng_keys(self,num:int=1):"""Get ``num`` new keys with ``jax.random.split``. A user may provide a ``jax.random.PRNGKey`` as a random seed. It will be used by the device when executing circuits with finite shots. The JAX RNG is notably different than the NumPy RNG as highlighted in the `JAX documentation <https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html>`_. JAX does not keep track of a global seed or key, but needs one anytime it draws from a random number distribution. Generating randomness therefore requires changing the key every time, which is done by "splitting" the key. For example, when executing ``n`` circuits, the ``PRNGkey`` is split ``n`` times into 2 new keys using ``jax.random.split`` to simulate a non-deterministic behaviour. The device seed is modified in-place using the first key, and the second key is fed to the circuit, and hence can be discarded after returning the results. This same key may be split further down the stack if necessary so that no one key is ever reused. """ifnum<1:raiseValueError("Argument num must be a positive integer.")ifnum>1:return[self.get_prng_keys()[0]for_inrange(num)]self._prng_key,*keys=jax_random_split(self._prng_key)returnkeys
[docs]defreset_prng_key(self):"""Reset the RNG key to its initial value."""self._prng_key=self._prng_seed
_state_cache:Optional[dict]=None""" A cache to store the "pre-rotated state" for reuse between the forward pass call to ``execute`` and subsequent calls to ``compute_vjp``. ``None`` indicates that no caching is required. """_device_options=("max_workers","rng","prng_key")""" tuple of string names for all the device options. """# pylint:disable = too-many-arguments@debug_logger_initdef__init__(self,wires=None,shots=None,seed="global",max_workers=None,)->None:super().__init__(wires=wires,shots=shots)self._max_workers=max_workersseed=np.random.randint(0,high=10000000)ifseed=="global"elseseedifqml.math.get_interface(seed)=="jax":self._prng_seed=seedself._prng_key=seedself._rng=np.random.default_rng(None)else:self._prng_seed=Noneself._prng_key=Noneself._rng=np.random.default_rng(seed)self._debugger=None
[docs]@debug_loggerdefsupports_derivatives(self,execution_config:Optional[ExecutionConfig]=None,circuit:Optional[QuantumScript]=None,)->bool:"""Check whether or not derivatives are available for a given configuration and circuit. ``DefaultQubit`` supports backpropagation derivatives with analytic results, as well as adjoint differentiation. Args: execution_config (ExecutionConfig): The configuration of the desired derivative calculation circuit (QuantumTape): An optional circuit to check derivatives support for. Returns: Bool: Whether or not a derivative can be calculated provided the given information """ifexecution_configisNone:returnTrueno_max_workers=(execution_config.device_options.get("max_workers",self._max_workers)isNone)ifexecution_config.gradient_methodin{"backprop","best"}andno_max_workers:ifcircuitisNone:returnTruereturnnotcircuit.shotsandnotany(isinstance(m.obs,qml.SparseHamiltonian)formincircuit.measurements)ifexecution_config.gradient_methodin{"adjoint","best"}:return_supports_adjoint(circuit,device_wires=self.wires,device_name=self.name)returnFalse
[docs]@debug_loggerdefpreprocess(self,execution_config:ExecutionConfig=DefaultExecutionConfig,)->tuple[TransformProgram,ExecutionConfig]:"""This function defines the device transform program to be applied and an updated device configuration. Args: execution_config (Union[ExecutionConfig, Sequence[ExecutionConfig]]): A data structure describing the parameters needed to fully describe the execution. Returns: TransformProgram, ExecutionConfig: A transform program that when called returns QuantumTapes that the device can natively execute as well as a postprocessing function to be called after execution, and a configuration with unset specifications filled in. This device supports any qubit operations that provide a matrix """config=self._setup_execution_config(execution_config)transform_program=TransformProgram()ifconfig.interface==qml.math.Interface.JAX_JIT:transform_program.add_transform(no_counts)transform_program.add_transform(validate_device_wires,self.wires,name=self.name)transform_program.add_transform(mid_circuit_measurements,device=self,mcm_config=config.mcm_config)transform_program.add_transform(decompose,stopping_condition=stopping_condition,stopping_condition_shots=stopping_condition_shots,name=self.name,)transform_program.add_transform(validate_measurements,analytic_measurements=accepted_analytic_measurement,sample_measurements=accepted_sample_measurement,name=self.name,)transform_program.add_transform(_conditional_broastcast_expand)ifconfig.mcm_config.mcm_method=="tree-traversal":transform_program.add_transform(qml.transforms.broadcast_expand)# Validate multi processingmax_workers=config.device_options.get("max_workers",self._max_workers)ifmax_workers:transform_program.add_transform(validate_multiprocessing_workers,max_workers,self)ifconfig.gradient_method=="backprop":transform_program.add_transform(no_sampling,name="backprop + default.qubit")ifconfig.gradient_method=="adjoint":_add_adjoint_transforms(transform_program,device_vjp=config.use_device_jacobian_product)returntransform_program,config
def_setup_execution_config(self,execution_config:ExecutionConfig)->ExecutionConfig:"""This is a private helper for ``preprocess`` that sets up the execution config. Args: execution_config (ExecutionConfig) Returns: ExecutionConfig: a preprocessed execution config """updated_values={}# uncomment once compilation overhead with jitting improved# TODO: [sc-82874]# jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}# updated_values["convert_to_numpy"] = (# execution_config.interface not in jax_interfaces# or execution_config.gradient_method == "adjoint"# # need numpy to use caching, and need caching higher order derivatives# or execution_config.derivative_order > 1# )# If PRNGKey is present, we can't use a pure_callback, because that would cause leaked tracers# we assume that if someone provides a PRNGkey, they want to jit end-to-endjax_interfaces={qml.math.Interface.JAX,qml.math.Interface.JAX_JIT}updated_values["convert_to_numpy"]=not(self._prng_keyisnotNoneandexecution_config.interfaceinjax_interfacesandexecution_config.gradient_method!="adjoint"# need numpy to use caching, and need caching higher order derivativesandexecution_config.derivative_order==1)foroptioninexecution_config.device_options:ifoptionnotinself._device_options:raiseqml.DeviceError(f"device option {option} not present on {self}")gradient_method=execution_config.gradient_methodifexecution_config.gradient_method=="best":no_max_workers=(execution_config.device_options.get("max_workers",self._max_workers)isNone)gradient_method="backprop"ifno_max_workerselse"adjoint"updated_values["gradient_method"]=gradient_methodifexecution_config.use_device_gradientisNone:updated_values["use_device_gradient"]=gradient_methodin{"adjoint","backprop",}ifexecution_config.use_device_jacobian_productisNone:updated_values["use_device_jacobian_product"]=gradient_method=="adjoint"ifexecution_config.grad_on_executionisNone:updated_values["grad_on_execution"]=gradient_method=="adjoint"updated_values["device_options"]=dict(execution_config.device_options)# copyforoptioninself._device_options:ifoptionnotinupdated_values["device_options"]:updated_values["device_options"][option]=getattr(self,f"_{option}")returnreplace(execution_config,**updated_values)
[docs]@debug_loggerdefexecute(self,circuits:QuantumScriptOrBatch,execution_config:ExecutionConfig=DefaultExecutionConfig,)->Union[Result,ResultBatch]:self.reset_prng_key()max_workers=execution_config.device_options.get("max_workers",self._max_workers)self._state_cache={}ifexecution_config.use_device_jacobian_productelseNoneinterface=(execution_config.interfaceifexecution_config.gradient_methodin{"backprop",None}elseNone)prng_keys=[self.get_prng_keys()[0]for_inrange(len(circuits))]if(notexecution_config.convert_to_numpyandexecution_config.interface==qml.math.Interface.JAX_JITandlen(circuits)>10):warnings.warn(("Jitting executions with many circuits may have substantial classical overhead."" To disable end-to-end jitting, please specify a integer seed instead of a PRNGKey."),UserWarning,)ifmax_workersisNone:returntuple(_simulate_wrapper(c,{"rng":self._rng,"debugger":self._debugger,"interface":interface,"state_cache":self._state_cache,"prng_key":_key,"mcm_method":execution_config.mcm_config.mcm_method,"postselect_mode":execution_config.mcm_config.postselect_mode,},)forc,_keyinzip(circuits,prng_keys))vanilla_circuits=convert_to_numpy_parameters(circuits)[0]seeds=self._rng.integers(2**31-1,size=len(vanilla_circuits))simulate_kwargs=[{"rng":_rng,"prng_key":_key,"mcm_method":execution_config.mcm_config.mcm_method,"postselect_mode":execution_config.mcm_config.postselect_mode,}for_rng,_keyinzip(seeds,prng_keys)]withconcurrent.futures.ProcessPoolExecutor(max_workers=max_workers)asexecutor:exec_map=executor.map(_simulate_wrapper,vanilla_circuits,simulate_kwargs)results=tuple(exec_map)# reset _rng to mimic serial behaviourself._rng=np.random.default_rng(self._rng.integers(2**31-1))returnresults
[docs]@debug_loggerdefcompute_derivatives(self,circuits:QuantumScriptOrBatch,execution_config:ExecutionConfig=DefaultExecutionConfig,):max_workers=execution_config.device_options.get("max_workers",self._max_workers)ifmax_workersisNone:returntuple(adjoint_jacobian(circuit)forcircuitincircuits)vanilla_circuits=convert_to_numpy_parameters(circuits)[0]withconcurrent.futures.ProcessPoolExecutor(max_workers=max_workers)asexecutor:exec_map=executor.map(adjoint_jacobian,vanilla_circuits)res=tuple(exec_map)# reset _rng to mimic serial behaviourself._rng=np.random.default_rng(self._rng.integers(2**31-1))returnres
[docs]@debug_loggerdefsupports_jvp(self,execution_config:Optional[ExecutionConfig]=None,circuit:Optional[QuantumScript]=None,)->bool:"""Whether or not this device defines a custom jacobian vector product. ``DefaultQubit`` supports backpropagation derivatives with analytic results, as well as adjoint differentiation. Args: execution_config (ExecutionConfig): The configuration of the desired derivative calculation circuit (QuantumTape): An optional circuit to check derivatives support for. Returns: bool: Whether or not a derivative can be calculated provided the given information """returnself.supports_derivatives(execution_config,circuit)
[docs]@debug_loggerdefcompute_jvp(self,circuits:QuantumScriptOrBatch,tangents:tuple[Number,...],execution_config:ExecutionConfig=DefaultExecutionConfig,):max_workers=execution_config.device_options.get("max_workers",self._max_workers)ifmax_workersisNone:returntuple(adjoint_jvp(circuit,tans)forcircuit,tansinzip(circuits,tangents))vanilla_circuits=convert_to_numpy_parameters(circuits)[0]withconcurrent.futures.ProcessPoolExecutor(max_workers=max_workers)asexecutor:res=tuple(executor.map(adjoint_jvp,vanilla_circuits,tangents))# reset _rng to mimic serial behaviourself._rng=np.random.default_rng(self._rng.integers(2**31-1))returnres
[docs]@debug_loggerdefsupports_vjp(self,execution_config:Optional[ExecutionConfig]=None,circuit:Optional[QuantumScript]=None,)->bool:"""Whether or not this device defines a custom vector jacobian product. ``DefaultQubit`` supports backpropagation derivatives with analytic results, as well as adjoint differentiation. Args: execution_config (ExecutionConfig): A description of the hyperparameters for the desired computation. circuit (None, QuantumTape): A specific circuit to check differentation for. Returns: bool: Whether or not a derivative can be calculated provided the given information """returnself.supports_derivatives(execution_config,circuit)
[docs]@debug_loggerdefcompute_vjp(self,circuits:QuantumScriptOrBatch,cotangents:tuple[Number,...],execution_config:ExecutionConfig=DefaultExecutionConfig,):r"""The vector jacobian product used in reverse-mode differentiation. ``DefaultQubit`` uses the adjoint differentiation method to compute the VJP. Args: circuits (Union[QuantumTape, Sequence[QuantumTape]]): the circuit or batch of circuits cotangents (Tuple[Number, Tuple[Number]]): Gradient-output vector. Must have shape matching the output shape of the corresponding circuit. If the circuit has a single output, `cotangents` may be a single number, not an iterable of numbers. execution_config (ExecutionConfig): a datastructure with all additional information required for execution Returns: tensor-like: A numeric result of computing the vector jacobian product **Definition of vjp:** If we have a function with jacobian: .. math:: \vec{y} = f(\vec{x}) \qquad J_{i,j} = \frac{\partial y_i}{\partial x_j} The vector jacobian product is the inner product of the derivatives of the output ``y`` with the Jacobian matrix. The derivatives of the output vector are sometimes called the **cotangents**. .. math:: \text{d}x_i = \Sigma_{i} \text{d}y_i J_{i,j} **Shape of cotangents:** The value provided to ``cotangents`` should match the output of :meth:`~.execute`. For computing the full Jacobian, the cotangents can be batched to vectorize the computation. In this case, the cotangents can have the following shapes. ``batch_size`` below refers to the number of entries in the Jacobian: * For a state measurement, the cotangents must have shape ``(batch_size, 2 ** n_wires)`` * For ``n`` expectation values, the cotangents must have shape ``(n, batch_size)``. If ``n = 1``, then the shape must be ``(batch_size,)``. """max_workers=execution_config.device_options.get("max_workers",self._max_workers)ifmax_workersisNone:def_state(circuit):return(Noneifself._state_cacheisNoneelseself._state_cache.get(circuit.hash,None))returntuple(adjoint_vjp(circuit,cots,state=_state(circuit))forcircuit,cotsinzip(circuits,cotangents))vanilla_circuits=convert_to_numpy_parameters(circuits)[0]withconcurrent.futures.ProcessPoolExecutor(max_workers=max_workers)asexecutor:res=tuple(executor.map(adjoint_vjp,vanilla_circuits,cotangents))# reset _rng to mimic serial behaviourself._rng=np.random.default_rng(self._rng.integers(2**31-1))returnres
[docs]@debug_loggerdefeval_jaxpr(self,jaxpr:"jax.core.Jaxpr",consts:list[TensorLike],*args,execution_config=None)->list[TensorLike]:from.qubit.dq_interpreterimportDefaultQubitInterpreterifself.wiresisNone:raiseqml.DeviceError("Device wires are required for jaxpr execution.")ifself.shots.has_partitioned_shots:raiseqml.DeviceError("Shot vectors are unsupported with jaxpr execution.")ifself._prng_keyisnotNone:key=self.get_prng_keys()[0]else:importjaxkey=jax.random.PRNGKey(self._rng.integers(100000))interpreter=DefaultQubitInterpreter(num_wires=len(self.wires),shots=self.shots.total_shots,key=key,execution_config=execution_config,)returninterpreter.eval(jaxpr,consts,*args)
[docs]@debug_loggerdefjaxpr_jvp(self,jaxpr,args:Sequence[TensorLike],tangents:Sequence[TensorLike],execution_config=None,)->tuple[Sequence[TensorLike],Sequence[TensorLike]]:gradient_method=getattr(execution_config,"gradient_method","backprop")ifgradient_method=="backprop":returnself._backprop_jvp(jaxpr,args,tangents,execution_config=execution_config)ifgradient_method=="adjoint":from.qubit.jaxpr_adjointimportexecute_and_jvpreturnexecute_and_jvp(jaxpr,args,tangents,num_wires=len(self.wires))raiseNotImplementedError(f"DefaultQubit does not support gradient_method={gradient_method}")