Source code for pennylane.transforms.dynamic_one_shot
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Contains the batch dimension transform."""importitertools# pylint: disable=import-outside-toplevelfromcollectionsimportCounterfromcollections.abcimportSequenceimportnumpyasnpimportpennylaneasqmlfrompennylane.measurementsimport(CountsMP,ExpectationMP,MeasurementValue,MidMeasureMP,ProbabilityMP,SampleMP,VarianceMP,)frompennylane.tapeimportQuantumScript,QuantumScriptBatchfrompennylane.typingimportPostprocessingFn,TensorLikefrom.coreimporttransformfill_in_value=np.iinfo(np.int32).mindefis_mcm(operation):"""Returns True if the operation is a mid-circuit measurement and False otherwise."""mcm=isinstance(operation,MidMeasureMP)returnmcmor"MidCircuitMeasure"instr(type(operation))defnull_postprocessing(results):"""A postprocessing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]
[docs]@transformdefdynamic_one_shot(tape:QuantumScript,**kwargs)->tuple[QuantumScriptBatch,PostprocessingFn]:"""Transform a QNode to into several one-shot tapes to support dynamic circuit execution. This transform enables the ``"one-shot"`` mid-circuit measurement method. The ``"one-shot"`` method prompts the device to perform a series of one-shot executions, where in each execution, the ``qml.measure`` operation applies a probabilistic mid-circuit measurement to the circuit. This is in contrast with ``qml.defer_measurement``, which instead introduces an extra wire for each mid-circuit measurement. The ``"one-shot"`` method is favourable in the few-shots and several-mid-circuit-measurements limit, whereas ``qml.defer_measurements`` is favourable in the opposite limit. Args: tape (QNode or QuantumScript or Callable): a quantum circuit. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumScript], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. This circuit will provide the results of a dynamic execution. **Example** Most devices that support mid-circuit measurements will include this transform in its preprocessing automatically when applicable. When this is the case, any user-applied ``dynamic_one_shot`` transforms will be ignored. The recommended way to use dynamic one shot is to specify ``mcm_method="one-shot"`` in the ``qml.qnode`` decorator. .. code-block:: python dev = qml.device("default.qubit", shots=100) params = np.pi / 4 * np.ones(2) @qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots") def func(x, y): qml.RX(x, wires=0) m0 = qml.measure(0) qml.cond(m0, qml.RY)(y, wires=1) return qml.expval(op=m0) """ifnotany(is_mcm(o)forointape.operations):return(tape,),null_postprocessingformintape.measurements:ifnotisinstance(m,(CountsMP,ExpectationMP,ProbabilityMP,SampleMP,VarianceMP)):raiseTypeError(f"Native mid-circuit measurement mode does not support {type(m).__name__} ""measurements.")_=kwargs.get("device",None)ifnottape.shots:raiseqml.QuantumFunctionError("dynamic_one_shot is only supported with finite shots.")samples_present=any(isinstance(mp,SampleMP)formpintape.measurements)postselect_present=any(op.postselectisnotNoneforopintape.operationsifis_mcm(op))ifpostselect_presentandsamples_presentandtape.batch_sizeisnotNone:raiseValueError("Returning qml.sample is not supported when postselecting mid-circuit ""measurements with broadcasting")if(batch_size:=tape.batch_size)isnotNone:tapes,broadcast_fn=qml.transforms.broadcast_expand(tape)else:tapes=[tape]broadcast_fn=Noneaux_tapes=[init_auxiliary_tape(t)fortintapes]postselect_mode=kwargs.get("postselect_mode",None)defreshape_data(array):returnqml.math.squeeze(qml.math.vstack(array))defprocessing_fn(results,has_partitioned_shots=None,batched_results=None):ifbatched_resultsisNoneandbatch_sizeisnotNone:# If broadcasting, recursively process the results for each batch. For each batch# there are tape.shots.total_shots results. The length of the first axis of final_results# will be batch_size.final_results=[]forresultinresults:final_results.append(processing_fn((result,),batched_results=False))returnbroadcast_fn(final_results)ifhas_partitioned_shotsisNoneandtape.shots.has_partitioned_shots:# If using shot vectors, recursively process the results for each shot bin. The length# of the first axis of final_results will be the length of the shot vector.results=list(results[0])final_results=[]forsintape.shots:final_results.append(processing_fn(results[0:s],has_partitioned_shots=False,batched_results=False))delresults[0:s]returntuple(final_results)ifnottape.shots.has_partitioned_shots:results=results[0]is_scalar=notisinstance(results[0],Sequence)ifis_scalar:results=[reshape_data(tuple(results))]else:results=[reshape_data(tuple(res[i]forresinresults))fori,_inenumerate(results[0])]returnparse_native_mid_circuit_measurements(tape,aux_tapes,results,postselect_mode=postselect_mode)returnaux_tapes,processing_fn
defget_legacy_capabilities(dev):"""Gets the capabilities dictionary of a device."""assertisinstance(dev,qml.devices.LegacyDeviceFacade)returndev.target_device.capabilities()def_supports_one_shot(dev:"qml.devices.Device"):"""Checks whether a device supports one-shot."""ifisinstance(dev,qml.devices.LegacyDevice):returnget_legacy_capabilities(dev).get("supports_mid_measure",False)returndev.namein("default.qubit","lightning.qubit")or(dev.capabilitiesisnotNoneand"one-shot"indev.capabilities.supported_mcm_methods)@dynamic_one_shot.custom_qnode_transformdef_dynamic_one_shot_qnode(self,qnode,targs,tkwargs):"""Custom qnode transform for ``dynamic_one_shot``."""iftkwargs.get("device",None):raiseValueError("Cannot provide a 'device' value directly to the dynamic_one_shot decorator ""when transforming a QNode.")ifqnode.deviceisnotNone:ifnot_supports_one_shot(qnode.device):raiseTypeError(f"Device {qnode.device.name} does not support mid-circuit measurements and/or ""one-shot execution mode natively, and hence it does not support the ""dynamic_one_shot transform. 'default.qubit' and 'lightning.qubit' currently ""support mid-circuit measurements and the dynamic_one_shot transform.")tkwargs.setdefault("device",qnode.device)returnself.default_qnode_transform(qnode,targs,tkwargs)definit_auxiliary_tape(circuit:qml.tape.QuantumScript):"""Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. Measurements are replaced by SampleMP measurements on wires and observables found in the original measurements. Args: circuit (QuantumTape): The original QuantumScript Returns: QuantumScript: A copy of the circuit with modified measurements """new_measurements=[]formincircuit.measurements:ifm.mvisNone:ifisinstance(m,VarianceMP):new_measurements.append(SampleMP(obs=m.obs))else:new_measurements.append(m)foropincircuit.operations:if"MidCircuitMeasure"instr(type(op)):# pragma: no covernew_measurements.append(qml.sample(op.out_classical_tracers[0]))elifisinstance(op,MidMeasureMP):new_measurements.append(qml.sample(MeasurementValue([op],lambdares:res)))returnqml.tape.QuantumScript(circuit.operations,new_measurements,shots=[1]*circuit.shots.total_shots,trainable_params=circuit.trainable_params,)# pylint: disable=too-many-branches,too-many-statementsdefparse_native_mid_circuit_measurements(circuit:qml.tape.QuantumScript,aux_tapes:qml.tape.QuantumScript,results:TensorLike,postselect_mode=None,):"""Combines, gathers and normalizes the results of native mid-circuit measurement runs. Args: circuit (QuantumTape): The original ``QuantumScript``. aux_tapes (List[QuantumTape]): List of auxiliary ``QuantumScript`` objects. results (TensorLike): Array of measurement results. Returns: tuple(TensorLike): The results of the simulation. """defmeasurement_with_no_shots(measurement):return(np.nan*np.ones_like(measurement.eigvals())ifisinstance(measurement,ProbabilityMP)elsenp.nan)interface=qml.math.get_deep_interface(results)interface="numpy"ifinterface=="builtins"elseinterfaceinterface="tensorflow"ifinterface=="tf"elseinterfaceactive_qjit=qml.compiler.active()all_mcms=[opforopinaux_tapes[0].operationsifis_mcm(op)]n_mcms=len(all_mcms)mcm_samples=qml.math.hstack(tuple(qml.math.reshape(res,(-1,1))forresinresults[-n_mcms:]))mcm_samples=qml.math.array(mcm_samples,like=interface)# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1has_postselect=qml.math.array([[op.postselectisnotNoneforopinall_mcms]],like=interface,dtype=mcm_samples.dtype,)postselect=qml.math.array([[0ifop.postselectisNoneelseop.postselectforopinall_mcms]],like=interface,dtype=mcm_samples.dtype,)is_valid=qml.math.all(mcm_samples*has_postselect==postselect,axis=1)has_valid=qml.math.any(is_valid)mid_meas=[opforopincircuit.operationsifis_mcm(op)]mcm_samples=[mcm_samples[:,i:i+1]foriinrange(n_mcms)]mcm_samples=dict((k,v)fork,vinzip(mid_meas,mcm_samples))normalized_meas=[]m_count=0formincircuit.measurements:ifnotisinstance(m,(CountsMP,ExpectationMP,ProbabilityMP,SampleMP,VarianceMP)):raiseTypeError(f"Native mid-circuit measurement mode does not support {type(m).__name__} measurements.")ifinterface!="jax"andm.mvisnotNoneandnothas_valid:meas=measurement_with_no_shots(m)elifm.mvisnotNoneandactive_qjit:meas=gather_mcm_qjit(m,mcm_samples,is_valid,postselect_mode=postselect_mode)# pragma: no coverelifm.mvisnotNone:meas=gather_mcm(m,mcm_samples,is_valid,postselect_mode=postselect_mode)elifinterface!="jax"andnothas_valid:meas=measurement_with_no_shots(m)m_count+=1else:result=results[m_count]ifnotisinstance(m,CountsMP):# We don't need to cast to arrays when using qml.counts. qml.math.array is not viable# as it assumes all elements of the input are of builtin python types and not belonging# to any particular interfaceresult=qml.math.array(result,like=interface)ifactive_qjit:# pragma: no cover# `result` contains (bases, counts) need to return (basis, sum(counts)) where `is_valid`# Any row of `result[0]` contains basis, so we return `result[0][0]`# We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samplesifisinstance(m,CountsMP):normalized_meas.append((result[0][0],qml.math.sum(result[1]*qml.math.reshape(is_valid,(-1,1)),axis=0),))m_count+=1continueresult=qml.math.squeeze(result)meas=gather_non_mcm(m,result,is_valid,postselect_mode=postselect_mode)m_count+=1ifisinstance(m,SampleMP):meas=qml.math.squeeze(meas)normalized_meas.append(meas)returntuple(normalized_meas)iflen(normalized_meas)>1elsenormalized_meas[0]defgather_mcm_qjit(measurement,samples,is_valid,postselect_mode=None):# pragma: no cover"""Process MCM measurements when the Catalyst compiler is active. Args: measurement (MeasurementProcess): measurement samples (dict): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """found,meas=False,Nonefork,measinsamples.items():ifmeasurement.mvisk.out_classical_tracers[0]:found=Truebreakifnotfound:raiseLookupError("MCM not found")meas=qml.math.squeeze(meas)ifisinstance(measurement,(CountsMP,ProbabilityMP)):interface=qml.math.get_interface(is_valid)sum_valid=qml.math.sum(is_valid)count_1=qml.math.sum(meas*is_valid)ifisinstance(measurement,CountsMP):returnqml.math.array([0,1],like=interface),qml.math.array([sum_valid-count_1,count_1],like=interface)ifisinstance(measurement,ProbabilityMP):counts=qml.math.array([sum_valid-count_1,count_1],like=interface)returncounts/sum_validreturngather_non_mcm(measurement,meas,is_valid,postselect_mode=postselect_mode)defgather_non_mcm(measurement,samples,is_valid,postselect_mode=None):"""Combines, gathers and normalizes several measurements with trivial measurement values. Args: measurement (MeasurementProcess): measurement samples (TensorLike): Post-processed measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """ifisinstance(measurement,CountsMP):tmp=Counter()ifmeasurement.all_outcomes:ifisinstance(measurement.mv,Sequence):values=[list(m.branches.values())forminmeasurement.mv]values=list(itertools.product(*values))tmp=Counter({"".join(map(str,v)):0forvinvalues})else:values=[list(measurement.mv.branches.values())]values=list(itertools.product(*values))tmp=Counter({float(*v):0forvinvalues})fori,dinenumerate(samples):tmp.update({kifisinstance(k,str)elsefloat(k):v*is_valid[i]fork,vind.items()})ifnotmeasurement.all_outcomes:tmp=Counter({k:vfork,vintmp.items()ifv>0})returndict(sorted(tmp.items()))ifisinstance(measurement,SampleMP):ifpostselect_mode=="pad-invalid-samples"andsamples.ndim==2:is_valid=qml.math.reshape(is_valid,(-1,1))ifpostselect_mode=="pad-invalid-samples":returnqml.math.where(is_valid,samples,fill_in_value)ifqml.math.shape(samples)==():# single shot casesamples=qml.math.reshape(samples,(-1,1))returnsamples[is_valid]if(interface:=qml.math.get_interface(is_valid))=="tensorflow":# Tensorflow requires arrays that are used for arithmetic with each other to have the# same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to# index other tf.Tensors (is_valid is used to index valid samples).is_valid=qml.math.cast_like(is_valid,samples)ifisinstance(measurement,ExpectationMP):returnqml.math.sum(samples*is_valid)/qml.math.sum(is_valid)ifisinstance(measurement,ProbabilityMP):returnqml.math.sum(samples*qml.math.reshape(is_valid,(-1,1)),axis=0)/qml.math.sum(is_valid)# VarianceMPexpval=qml.math.sum(samples*is_valid)/qml.math.sum(is_valid)ifinterface=="tensorflow":# Casting needed for tensorflowsamples=qml.math.cast_like(samples,expval)is_valid=qml.math.cast_like(is_valid,expval)returnqml.math.sum((samples-expval)**2*is_valid)/qml.math.sum(is_valid)defgather_mcm(measurement,samples,is_valid,postselect_mode=None):"""Combines, gathers and normalizes several measurements with non-trivial measurement values. Args: measurement (MeasurementProcess): measurement samples (List[dict]): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """interface=qml.math.get_deep_interface(is_valid)mv=measurement.mv# The following block handles measurement value lists, like ``qml.counts(op=[mcm0, mcm1, mcm2])``.ifisinstance(measurement,(CountsMP,ProbabilityMP,SampleMP))andisinstance(mv,Sequence):mcm_samples=[m.concretize(samples)forminmv]mcm_samples=qml.math.concatenate(mcm_samples,axis=1)ifisinstance(measurement,ProbabilityMP):values=[list(m.branches.values())forminmv]values=list(itertools.product(*values))values=[qml.math.array([v],like=interface,dtype=mcm_samples.dtype)forvinvalues]# Need to use boolean functions explicitly as Tensorflow does not allow integer math# on boolean arrayscounts=[qml.math.count_nonzero(qml.math.logical_and(qml.math.all(mcm_samples==v,axis=1),is_valid))forvinvalues]counts=qml.math.array(counts,like=interface)returncounts/qml.math.sum(counts)ifisinstance(measurement,CountsMP):mcm_samples=[{"".join(str(int(v))forvintuple(s)):1}forsinmcm_samples]returngather_non_mcm(measurement,mcm_samples,is_valid,postselect_mode=postselect_mode)mcm_samples=qml.math.ravel(qml.math.array(mv.concretize(samples),like=interface))ifisinstance(measurement,ProbabilityMP):# Need to use boolean functions explicitly as Tensorflow does not allow integer math# on boolean arrayscounts=[qml.math.count_nonzero(qml.math.logical_and((mcm_samples==v),is_valid))forvinlist(mv.branches.values())]counts=qml.math.array(counts,like=interface)returncounts/qml.math.sum(counts)ifisinstance(measurement,CountsMP):mcm_samples=[{float(s):1}forsinmcm_samples]returngather_non_mcm(measurement,mcm_samples,is_valid,postselect_mode=postselect_mode)