Source code for pennylane.transforms.dynamic_one_shot

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the batch dimension transform.
"""

import itertools
from collections import Counter
from collections.abc import Sequence

import numpy as np

import pennylane as qml
from pennylane.exceptions import QuantumFunctionError
from pennylane.measurements import (
    CountsMP,
    ExpectationMP,
    MeasurementValue,
    MidMeasureMP,
    ProbabilityMP,
    SampleMP,
    VarianceMP,
)
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.typing import PostprocessingFn, TensorLike

from .core import transform

fill_in_value = np.iinfo(np.int32).min


def is_mcm(operation):
    """Returns True if the operation is a mid-circuit measurement and False otherwise."""
    mcm = isinstance(operation, MidMeasureMP)
    return mcm or "MidCircuitMeasure" in str(type(operation))


def null_postprocessing(results):
    """A postprocessing function returned by a transform that only converts the batch of results
    into a result for a single ``QuantumTape``.
    """
    return results[0]


[docs] @transform def dynamic_one_shot(tape: QuantumScript, **kwargs) -> tuple[QuantumScriptBatch, PostprocessingFn]: """Transform a QNode to into several one-shot tapes to support dynamic circuit execution. This transform enables the ``"one-shot"`` mid-circuit measurement method. The ``"one-shot"`` method prompts the device to perform a series of one-shot executions, where in each execution, the ``qml.measure`` operation applies a probabilistic mid-circuit measurement to the circuit. This is in contrast with ``qml.defer_measurement``, which instead introduces an extra wire for each mid-circuit measurement. The ``"one-shot"`` method is favourable in the few-shots and several-mid-circuit-measurements limit, whereas ``qml.defer_measurements`` is favourable in the opposite limit. Args: tape (QNode or QuantumScript or Callable): a quantum circuit. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumScript], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. This circuit will provide the results of a dynamic execution. **Example** Most devices that support mid-circuit measurements will include this transform in its preprocessing automatically when applicable. When this is the case, any user-applied ``dynamic_one_shot`` transforms will be ignored. The recommended way to use dynamic one shot is to specify ``mcm_method="one-shot"`` in the ``qml.qnode`` decorator. .. code-block:: python dev = qml.device("default.qubit") params = np.pi / 4 * np.ones(2) @partial(qml.set_shots, shots=100) @qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots") def func(x, y): qml.RX(x, wires=0) m0 = qml.measure(0) qml.cond(m0, qml.RY)(y, wires=1) return qml.expval(op=m0) """ if not any(is_mcm(o) for o in tape.operations): return (tape,), null_postprocessing for m in tape.measurements: if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)): raise TypeError( f"Native mid-circuit measurement mode does not support {type(m).__name__} " "measurements." ) _ = kwargs.get("device", None) if not tape.shots: raise QuantumFunctionError("dynamic_one_shot is only supported with finite shots.") samples_present = any(isinstance(mp, SampleMP) for mp in tape.measurements) postselect_present = any(op.postselect is not None for op in tape.operations if is_mcm(op)) if postselect_present and samples_present and tape.batch_size is not None: raise ValueError( "Returning qml.sample is not supported when postselecting mid-circuit " "measurements with broadcasting" ) if (batch_size := tape.batch_size) is not None: tapes, broadcast_fn = qml.transforms.broadcast_expand(tape) else: tapes = [tape] broadcast_fn = None aux_tapes = [init_auxiliary_tape(t) for t in tapes] postselect_mode = kwargs.get("postselect_mode", None) def reshape_data(array): return qml.math.squeeze(qml.math.vstack(array)) def processing_fn(results, has_partitioned_shots=None, batched_results=None): if batched_results is None and batch_size is not None: # If broadcasting, recursively process the results for each batch. For each batch # there are tape.shots.total_shots results. The length of the first axis of final_results # will be batch_size. final_results = [] for result in results: final_results.append(processing_fn((result,), batched_results=False)) return broadcast_fn(final_results) if has_partitioned_shots is None and tape.shots.has_partitioned_shots: # If using shot vectors, recursively process the results for each shot bin. The length # of the first axis of final_results will be the length of the shot vector. results = list(results[0]) final_results = [] for s in tape.shots: final_results.append( processing_fn(results[0:s], has_partitioned_shots=False, batched_results=False) ) del results[0:s] return tuple(final_results) if not tape.shots.has_partitioned_shots: results = results[0] is_scalar = not isinstance(results[0], Sequence) if is_scalar: results = [reshape_data(tuple(results))] else: results = [ reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0]) ] return parse_native_mid_circuit_measurements( tape, aux_tapes, results, postselect_mode=postselect_mode ) return aux_tapes, processing_fn
def get_legacy_capabilities(dev): """Gets the capabilities dictionary of a device.""" assert isinstance(dev, qml.devices.LegacyDeviceFacade) return dev.target_device.capabilities() def _supports_one_shot(dev: "qml.devices.Device"): """Checks whether a device supports one-shot.""" if isinstance(dev, qml.devices.LegacyDevice): return get_legacy_capabilities(dev).get("supports_mid_measure", False) return dev.name in ("default.qubit", "lightning.qubit") or ( dev.capabilities is not None and "one-shot" in dev.capabilities.supported_mcm_methods ) @dynamic_one_shot.custom_qnode_transform def _dynamic_one_shot_qnode(self, qnode, targs, tkwargs): """Custom qnode transform for ``dynamic_one_shot``.""" if tkwargs.get("device", None): raise ValueError( "Cannot provide a 'device' value directly to the dynamic_one_shot decorator " "when transforming a QNode." ) if qnode.device is not None: if not _supports_one_shot(qnode.device): raise TypeError( f"Device {qnode.device.name} does not support mid-circuit measurements and/or " "one-shot execution mode natively, and hence it does not support the " "dynamic_one_shot transform. 'default.qubit' and 'lightning.qubit' currently " "support mid-circuit measurements and the dynamic_one_shot transform." ) tkwargs.setdefault("device", qnode.device) return self.default_qnode_transform(qnode, targs, tkwargs) def init_auxiliary_tape(circuit: qml.tape.QuantumScript): """Creates an auxiliary circuit to perform one-shot mid-circuit measurement calculations. Measurements are replaced by SampleMP measurements on wires and observables found in the original measurements. Args: circuit (QuantumTape): The original QuantumScript Returns: QuantumScript: A copy of the circuit with modified measurements """ new_measurements = [] for m in circuit.measurements: if m.mv is None: if isinstance(m, VarianceMP): new_measurements.append(SampleMP(obs=m.obs)) else: new_measurements.append(m) for op in circuit.operations: if "MidCircuitMeasure" in str(type(op)): # pragma: no cover new_measurements.append(qml.sample(op.out_classical_tracers[0])) elif isinstance(op, MidMeasureMP): new_measurements.append(qml.sample(MeasurementValue([op], lambda res: res))) return qml.tape.QuantumScript( circuit.operations, new_measurements, shots=[1] * circuit.shots.total_shots, trainable_params=circuit.trainable_params, ) # pylint: disable=too-many-branches def parse_native_mid_circuit_measurements( circuit: qml.tape.QuantumScript, aux_tapes: qml.tape.QuantumScript, results: TensorLike, postselect_mode=None, ): """Combines, gathers and normalizes the results of native mid-circuit measurement runs. Args: circuit (QuantumTape): The original ``QuantumScript``. aux_tapes (List[QuantumTape]): List of auxiliary ``QuantumScript`` objects. results (TensorLike): Array of measurement results. Returns: tuple(TensorLike): The results of the simulation. """ def measurement_with_no_shots(measurement): return ( np.nan * np.ones_like(measurement.eigvals()) if isinstance(measurement, ProbabilityMP) else np.nan ) interface = qml.math.get_deep_interface(results) interface = "numpy" if interface == "builtins" else interface interface = "tensorflow" if interface == "tf" else interface active_qjit = qml.compiler.active() all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) mcm_samples = qml.math.hstack( tuple(qml.math.reshape(res, (-1, 1)) for res in results[-n_mcms:]) ) mcm_samples = qml.math.array(mcm_samples, like=interface) # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 has_postselect = qml.math.array( [[op.postselect is not None for op in all_mcms]], like=interface, dtype=mcm_samples.dtype, ) postselect = qml.math.array( [[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface, dtype=mcm_samples.dtype, ) is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1) has_valid = qml.math.any(is_valid) mid_meas = [op for op in circuit.operations if is_mcm(op)] mcm_samples = [mcm_samples[:, i : i + 1] for i in range(n_mcms)] mcm_samples = dict((k, v) for k, v in zip(mid_meas, mcm_samples)) normalized_meas = [] m_count = 0 for m in circuit.measurements: if not isinstance(m, (CountsMP, ExpectationMP, ProbabilityMP, SampleMP, VarianceMP)): raise TypeError( f"Native mid-circuit measurement mode does not support {type(m).__name__} measurements." ) if interface != "jax" and m.mv is not None and not has_valid: meas = measurement_with_no_shots(m) elif m.mv is not None and active_qjit: meas = gather_mcm_qjit( m, mcm_samples, is_valid, postselect_mode=postselect_mode ) # pragma: no cover elif m.mv is not None: meas = gather_mcm(m, mcm_samples, is_valid, postselect_mode=postselect_mode) elif interface != "jax" and not has_valid: meas = measurement_with_no_shots(m) m_count += 1 else: result = results[m_count] if not isinstance(m, CountsMP): # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable # as it assumes all elements of the input are of builtin python types and not belonging # to any particular interface result = qml.math.array(result, like=interface) if active_qjit: # pragma: no cover # `result` contains (bases, counts) need to return (basis, sum(counts)) where `is_valid` # Any row of `result[0]` contains basis, so we return `result[0][0]` # We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samples if isinstance(m, CountsMP): normalized_meas.append( ( result[0][0], qml.math.sum(result[1] * qml.math.reshape(is_valid, (-1, 1)), axis=0), ) ) m_count += 1 continue result = qml.math.squeeze(result) meas = gather_non_mcm(m, result, is_valid, postselect_mode=postselect_mode) m_count += 1 if isinstance(m, SampleMP): meas = qml.math.squeeze(meas) normalized_meas.append(meas) return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0] def gather_mcm_qjit(measurement, samples, is_valid, postselect_mode=None): # pragma: no cover """Process MCM measurements when the Catalyst compiler is active. Args: measurement (MeasurementProcess): measurement samples (dict): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """ found, meas = False, None for k, meas in samples.items(): if measurement.mv is k.out_classical_tracers[0]: found = True break if not found: raise LookupError("MCM not found") meas = qml.math.squeeze(meas) if isinstance(measurement, (CountsMP, ProbabilityMP)): interface = qml.math.get_interface(is_valid) sum_valid = qml.math.sum(is_valid) count_1 = qml.math.sum(meas * is_valid) if isinstance(measurement, CountsMP): return qml.math.array([0, 1], like=interface), qml.math.array( [sum_valid - count_1, count_1], like=interface ) if isinstance(measurement, ProbabilityMP): counts = qml.math.array([sum_valid - count_1, count_1], like=interface) return counts / sum_valid return gather_non_mcm(measurement, meas, is_valid, postselect_mode=postselect_mode) def gather_non_mcm(measurement, samples, is_valid, postselect_mode=None): """Combines, gathers and normalizes several measurements with trivial measurement values. Args: measurement (MeasurementProcess): measurement samples (TensorLike): Post-processed measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """ if isinstance(measurement, CountsMP): tmp = Counter() if measurement.all_outcomes: if isinstance(measurement.mv, Sequence): values = [list(m.branches.values()) for m in measurement.mv] values = list(itertools.product(*values)) tmp = Counter({"".join(map(str, v)): 0 for v in values}) else: values = [list(measurement.mv.branches.values())] values = list(itertools.product(*values)) tmp = Counter({float(*v): 0 for v in values}) for i, d in enumerate(samples): tmp.update( {k if isinstance(k, str) else float(k): v * is_valid[i] for k, v in d.items()} ) if not measurement.all_outcomes: tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) if isinstance(measurement, SampleMP): if postselect_mode == "pad-invalid-samples" and samples.ndim == 2: is_valid = qml.math.reshape(is_valid, (-1, 1)) if postselect_mode == "pad-invalid-samples": return qml.math.where(is_valid, samples, fill_in_value) if qml.math.shape(samples) == (): # single shot case samples = qml.math.reshape(samples, (-1, 1)) return samples[is_valid] if (interface := qml.math.get_interface(is_valid)) == "tensorflow": # Tensorflow requires arrays that are used for arithmetic with each other to have the # same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to # index other tf.Tensors (is_valid is used to index valid samples). is_valid = qml.math.cast_like(is_valid, samples) if isinstance(measurement, ExpectationMP): return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) if isinstance(measurement, ProbabilityMP): return qml.math.sum(samples * qml.math.reshape(is_valid, (-1, 1)), axis=0) / qml.math.sum( is_valid ) # VarianceMP expval = qml.math.sum(samples * is_valid) / qml.math.sum(is_valid) if interface == "tensorflow": # Casting needed for tensorflow samples = qml.math.cast_like(samples, expval) is_valid = qml.math.cast_like(is_valid, expval) return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid) def gather_mcm(measurement, samples, is_valid, postselect_mode=None): """Combines, gathers and normalizes several measurements with non-trivial measurement values. Args: measurement (MeasurementProcess): measurement samples (List[dict]): Mid-circuit measurement samples is_valid (TensorLike): Boolean array with the same shape as ``samples`` where the value at each index specifies whether or not the respective sample is valid. Returns: TensorLike: The combined measurement outcome """ interface = qml.math.get_deep_interface(is_valid) mv = measurement.mv # The following block handles measurement value lists, like ``qml.counts(op=[mcm0, mcm1, mcm2])``. if isinstance(measurement, (CountsMP, ProbabilityMP, SampleMP)) and isinstance(mv, Sequence): mcm_samples = [m.concretize(samples) for m in mv] mcm_samples = qml.math.concatenate(mcm_samples, axis=1) if isinstance(measurement, ProbabilityMP): values = [list(m.branches.values()) for m in mv] values = list(itertools.product(*values)) values = [qml.math.array([v], like=interface, dtype=mcm_samples.dtype) for v in values] # Need to use boolean functions explicitly as Tensorflow does not allow integer math # on boolean arrays counts = [ qml.math.count_nonzero( qml.math.logical_and(qml.math.all(mcm_samples == v, axis=1), is_valid) ) for v in values ] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode) mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface)) if isinstance(measurement, ProbabilityMP): # Need to use boolean functions explicitly as Tensorflow does not allow integer math # on boolean arrays counts = [ qml.math.count_nonzero(qml.math.logical_and((mcm_samples == v), is_valid)) for v in list(mv.branches.values()) ] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): mcm_samples = [{float(s): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode)