Source code for pennylane.devices.qubit.simulate
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simulate a quantum script."""
import logging
# pylint: disable=protected-access
from collections import Counter
from functools import partial, singledispatch
from typing import Optional
import numpy as np
from numpy.random import default_rng
import pennylane as qml
from pennylane.logging import debug_logger
from pennylane.measurements import (
CountsMP,
ExpectationMP,
MidMeasureMP,
ProbabilityMP,
SampleMP,
VarianceMP,
find_post_processed_mcms,
)
from pennylane.transforms.dynamic_one_shot import gather_mcm
from pennylane.typing import Result
from .apply_operation import apply_operation
from .initialize_state import create_initial_state
from .measure import measure
from .sampling import jax_random_split, measure_with_samples
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
INTERFACE_TO_LIKE = {
# map interfaces known by autoray to themselves
None: None,
"numpy": "numpy",
"autograd": "autograd",
"jax": "jax",
"torch": "torch",
"tensorflow": "tensorflow",
# map non-standard interfaces to those known by autoray
"auto": None,
"scipy": "numpy",
"jax-jit": "jax",
"jax-python": "jax",
"JAX": "jax",
"pytorch": "torch",
"tf": "tensorflow",
"tensorflow-autograph": "tensorflow",
"tf-autograph": "tensorflow",
}
class TreeTraversalStack:
"""This class is used to record various data used during the
depth-first tree-traversal procedure for simulating dynamic circuits."""
counts: list
probs: list
results_0: list
results_1: list
states: list
def __init__(self, max_depth):
self.counts = [None] * max_depth
self.probs = [None] * max_depth
self.results_0 = [None] * max_depth
self.results_1 = [None] * max_depth
self.states = [None] * max_depth
def any_is_empty(self, depth):
"""Return True if any result at ``depth`` is ``None`` and False otherwise."""
return self.results_0[depth] is None or self.results_1[depth] is None
def is_full(self, depth):
"""Return True if the results at ``depth`` are both not ``None`` and False otherwise."""
return self.results_0[depth] is not None and self.results_1[depth] is not None
def prune(self, depth):
"""Reset all stack entries at ``depth`` to ``None``."""
self.counts[depth] = None
self.probs[depth] = None
self.results_0[depth] = None
self.results_1[depth] = None
self.states[depth] = None
class _FlexShots(qml.measurements.Shots):
"""Shots class that allows zero shots."""
# pylint: disable=super-init-not-called
def __init__(self, shots=None):
if isinstance(shots, int):
self.total_shots = shots
self.shot_vector = (qml.measurements.ShotCopies(shots, 1),)
elif isinstance(shots, self.__class__):
return # self already _is_ shots as defined by __new__
else:
self.__all_tuple_init__([s if isinstance(s, tuple) else (s, 1) for s in shots])
self._frozen = True
def _postselection_postprocess(state, is_state_batched, shots, **execution_kwargs):
"""Update state after projector is applied."""
if is_state_batched:
raise ValueError(
"Cannot postselect on circuits with broadcasting. Use the "
"qml.transforms.broadcast_expand transform to split a broadcasted "
"tape into multiple non-broadcasted tapes before executing if "
"postselection is used."
)
rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
postselect_mode = execution_kwargs.get("postselect_mode", None)
# The floor function is being used here so that a norm very close to zero becomes exactly
# equal to zero so that the state can become invalid. This way, execution can continue, and
# bad postselection gives results that are invalid rather than results that look valid but
# are incorrect.
norm = qml.math.norm(state)
if not qml.math.is_abstract(state) and qml.math.allclose(norm, 0.0):
norm = 0.0
if shots:
# Clip the number of shots using a binomial distribution using the probability of
# measuring the postselected state.
if prng_key is not None:
# pylint: disable=import-outside-toplevel
from jax.random import binomial
binomial_fn = partial(binomial, prng_key)
else:
binomial_fn = np.random.binomial if rng is None else rng.binomial
postselected_shots = (
shots
if postselect_mode == "fill-shots" or qml.math.is_abstract(norm)
else [int(binomial_fn(s, float(norm**2))) for s in shots]
)
# _FlexShots is used here since the binomial distribution could result in zero
# valid samples
shots = _FlexShots(postselected_shots)
state = state / norm
return state, shots
@debug_logger
def get_final_state(circuit, debugger=None, **execution_kwargs):
"""
Get the final state that results from executing the given quantum script.
This is an internal function that will be called by the successor to ``default.qubit``.
Args:
circuit (.QuantumScript): The single circuit to simulate. This circuit is assumed to have
non-negative integer wire labels
debugger (._Debugger): The debugger to use
interface (str): The machine learning interface to create the initial state with
mid_measurements (None, dict): Dictionary of mid-circuit measurements
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
If None, a ``numpy.random.default_rng`` will be used for sampling.
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.
Returns:
Tuple[TensorLike, bool]: A tuple containing the final state of the quantum script and
whether the state has a batch dimension.
"""
prng_key = execution_kwargs.pop("prng_key", None)
interface = execution_kwargs.get("interface", None)
prep = None
if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase):
prep = circuit[0]
state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface])
# initial state is batched only if the state preparation (if it exists) is batched
is_state_batched = bool(prep and prep.batch_size is not None)
key = prng_key
for op in circuit.operations[bool(prep) :]:
if isinstance(op, MidMeasureMP):
prng_key, key = jax_random_split(prng_key)
state = apply_operation(
op,
state,
is_state_batched=is_state_batched,
debugger=debugger,
prng_key=key,
tape_shots=circuit.shots,
**execution_kwargs,
)
# Handle postselection on mid-circuit measurements
if isinstance(op, qml.Projector):
prng_key, key = jax_random_split(prng_key)
state, new_shots = _postselection_postprocess(
state, is_state_batched, circuit.shots, prng_key=key, **execution_kwargs
)
circuit._shots = new_shots
# new state is batched if i) the old state is batched, or ii) the new op adds a batch dim
is_state_batched = is_state_batched or (op.batch_size is not None)
for _ in range(circuit.num_wires - len(circuit.op_wires)):
# if any measured wires are not operated on, we pad the state with zeros.
# We know they belong at the end because the circuit is in standard wire-order
state = qml.math.stack([state, qml.math.zeros_like(state)], axis=-1)
return state, is_state_batched
# pylint: disable=too-many-arguments
@debug_logger
def measure_final_state(circuit, state, is_state_batched, **execution_kwargs) -> Result:
"""
Perform the measurements required by the circuit on the provided state.
This is an internal function that will be called by the successor to ``default.qubit``.
Args:
circuit (.QuantumScript): The single circuit to simulate. This circuit is assumed to have
non-negative integer wire labels
state (TensorLike): The state to perform measurement on
is_state_batched (bool): Whether the state has a batch dimension or not.
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
If None, the default ``sample_state`` function and a ``numpy.random.default_rng``
will be used for sampling.
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
Tuple[TensorLike]: The measurement results
"""
rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
mid_measurements = execution_kwargs.get("mid_measurements", None)
# analytic case
if not circuit.shots:
if mid_measurements is not None:
raise TypeError("Native mid-circuit measurements are only supported with finite shots.")
if len(circuit.measurements) == 1:
return measure(circuit.measurements[0], state, is_state_batched=is_state_batched)
return tuple(
measure(mp, state, is_state_batched=is_state_batched) for mp in circuit.measurements
)
# finite-shot case
rng = default_rng(rng)
results = measure_with_samples(
circuit.measurements,
state,
shots=circuit.shots,
is_state_batched=is_state_batched,
rng=rng,
prng_key=prng_key,
mid_measurements=mid_measurements,
)
if len(circuit.measurements) == 1:
if circuit.shots.has_partitioned_shots:
return tuple(res[0] for res in results)
return results[0]
return results
[docs]@debug_logger
def simulate(
circuit: qml.tape.QuantumScript,
debugger=None,
state_cache: Optional[dict] = None,
**execution_kwargs,
) -> Result:
"""Simulate a single quantum script.
This is an internal function that is used by``default.qubit``.
Args:
circuit (QuantumTape): The single circuit to simulate
debugger (_Debugger): The debugger to use
state_cache=None (Optional[dict]): A dictionary mapping the hash of a circuit to
the pre-rotated state. Used to pass the state between forward passes and vjp
calculations.
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. If None, a random key will be
generated. Only for simulation using JAX.
interface (str): The machine learning interface to create the initial state with
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.
mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements.
``"deferred"`` is ignored. If mid-circuit measurements are found in the circuit,
the device will use ``"tree-traversal"`` if specified and the ``"one-shot"`` method
otherwise. For usage details, please refer to the
:doc:`dynamic quantum circuits page </introduction/dynamic_quantum_circuits>`.
Returns:
tuple(TensorLike): The results of the simulation
Note that this function can return measurements for non-commuting observables simultaneously.
This function assumes that all operations provide matrices.
>>> qs = qml.tape.QuantumScript([qml.RX(1.2, wires=0)], [qml.expval(qml.Z(0)), qml.probs(wires=(0,1))])
>>> simulate(qs)
(0.36235775447667357,
tensor([0.68117888, 0. , 0.31882112, 0. ], requires_grad=True))
"""
prng_key = execution_kwargs.pop("prng_key", None)
circuit = circuit.map_to_standard_wires()
has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if has_mcm:
if execution_kwargs.get("mcm_method", None) == "tree-traversal":
return simulate_tree_mcm(circuit, prng_key=prng_key, **execution_kwargs)
results = []
aux_circ = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=[1],
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
if qml.math.get_deep_interface(circuit.data) == "jax" and prng_key is not None:
# pylint: disable=import-outside-toplevel
import jax
def simulate_partial(k):
return simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, prng_key=k, **execution_kwargs
)
results = jax.vmap(simulate_partial, in_axes=(0,))(keys)
results = tuple(zip(*results))
else:
for i in range(circuit.shots.total_shots):
results.append(
simulate_one_shot_native_mcm(
aux_circ, debugger=debugger, prng_key=keys[i], **execution_kwargs
)
)
return tuple(results)
ops_key, meas_key = jax_random_split(prng_key)
state, is_state_batched = get_final_state(
circuit, debugger=debugger, prng_key=ops_key, **execution_kwargs
)
if state_cache is not None:
state_cache[circuit.hash] = state
return measure_final_state(
circuit, state, is_state_batched, prng_key=meas_key, **execution_kwargs
)
# pylint: disable=too-many-branches,too-many-statements
def simulate_tree_mcm(
circuit: qml.tape.QuantumScript,
debugger=None,
**execution_kwargs,
) -> Result:
"""Simulate a single quantum script with native mid-circuit measurements using the tree-traversal algorithm.
The tree-traversal algorithm recursively explores all combinations of mid-circuit measurement
outcomes using a depth-first approach. The depth-first approach requires ``n_mcm`` copies
of the state vector (``n_mcm + 1`` state vectors in total) and records ``n_mcm`` vectors
of mid-circuit measurement samples. It is generally more efficient than ``one-shot`` because it takes all samples
at a leaf at once and stops exploring more branches when a single shot is allocated to a sub-tree.
Args:
circuit (QuantumTape): The single circuit to simulate
rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. If None, a random key will be
generated. Only for simulation using JAX.
debugger (_Debugger): The debugger to use
interface (str): The machine learning interface to create the initial state with
Returns:
tuple(TensorLike): The results of the simulation
"""
PROBS_TOL = 0.0
interface = execution_kwargs.get("interface", None)
postselect_mode = execution_kwargs.get("postselect_mode", None)
##########################
# shot vector processing #
##########################
if circuit.shots.has_partitioned_shots:
prng_key = execution_kwargs.pop("prng_key", None)
keys = jax_random_split(prng_key, num=circuit.shots.num_copies)
results = []
for k, s in zip(keys, circuit.shots):
aux_circuit = qml.tape.QuantumScript(
circuit.operations,
circuit.measurements,
shots=s,
)
results.append(simulate_tree_mcm(aux_circuit, debugger, prng_key=k, **execution_kwargs))
return tuple(results)
#######################
# main implementation #
#######################
# `var` measurements cannot be aggregated on the fly as they require the global `expval`
# variance_transform replaces `var` measurements with `expval` and `expval**2` measurements
[circuit], variance_post_processing = variance_transform(circuit)
finite_shots = bool(circuit.shots)
##################
# Parse MCM info #
##################
# mcms is the list of all mid-circuit measurement operations
# mcms[d] is the parent MCM (node) of a circuit segment (edge) at depth `d`
# The first element is None because there is no parent MCM at depth 0
mcms = tuple([None] + [op for op in circuit.operations if isinstance(op, MidMeasureMP)])
n_mcms = len(mcms) - 1
# We obtain `measured_mcms_indices`, the list of MCMs which require post-processing:
# either as requested by terminal measurements or post-selection
measured_mcms = find_post_processed_mcms(circuit)
measured_mcms_indices = [i for i, mcm in enumerate(mcms[1:]) if mcm in measured_mcms]
# `mcm_samples` is a register of MCMs. It is necessary to correctly keep track of
# correlated MCM values which may be requested by terminal measurements.
mcm_samples = {
k + 1: qml.math.empty((circuit.shots.total_shots,), dtype=bool) if finite_shots else None
for k in measured_mcms_indices
}
#############################
# Initialize tree-traversal #
#############################
# mcm_current[:d+1] is the active branch at depth `d`
# The first entry is always 0 as the first edge does not stem from an MCM.
# For example, if `d = 2` and `mcm_current = [0, 1, 1, 0]` we are on the 11-branch,
# i.e. the first two MCMs had outcome 1. The last entry isn't meaningful until we are
# at depth `d=3`.
mcm_current = qml.math.zeros(n_mcms + 1, dtype=int)
# `mid_measurements` maps the elements of `mcm_current` to their respective MCMs
# This is used by `get_final_state::apply_operation` for `Conditional` operations
mid_measurements = dict(zip(mcms[1:], mcm_current[1:].tolist()))
# Split circuit into segments
circuits = split_circuit_at_mcms(circuit)
circuits[0] = prepend_state_prep(circuits[0], None, interface, circuit.wires)
terminal_measurements = circuits[-1].measurements if finite_shots else circuit.measurements
# Initialize stacks
cumcounts = [0] * (n_mcms + 1)
stack = TreeTraversalStack(n_mcms + 1)
# The goal is to obtain the measurements of the zero-branch and one-branch
# and to combine them into the final result. Exit the loop once the
# zero-branch and one-branch measurements are available.
depth = 0
while stack.any_is_empty(1):
###########################################
# Combine measurements & step up the tree #
###########################################
# Combine two leaves once measurements are available
if stack.is_full(depth):
# Call `combine_measurements` to count-average measurements
measurement_dicts = get_measurement_dicts(terminal_measurements, stack, depth)
measurements = combine_measurements(
terminal_measurements, measurement_dicts, mcm_samples
)
mcm_current[depth:] = 0 # Reset current branch
stack.prune(depth) # Clear stacks
# Go up one level to explore alternate subtree of the same depth
depth -= 1
if mcm_current[depth] == 1:
stack.results_1[depth] = measurements
mcm_current[depth] = 0
else:
stack.results_0[depth] = measurements
mcm_current[depth] = 1
# Update MCM values
mid_measurements.update(
(k, v) for k, v in zip(mcms[depth:], mcm_current[depth:].tolist())
)
continue
################################################
# Determine whether to execute the active edge #
################################################
# Parse shots for the current branch
if finite_shots:
if stack.counts[depth]:
shots = stack.counts[depth][mcm_current[depth]]
else:
shots = circuits[depth].shots.total_shots
skip_subtree = not bool(shots)
else:
shots = None
skip_subtree = (
stack.probs[depth] is not None
and float(stack.probs[depth][mcm_current[depth]]) <= PROBS_TOL
)
# Update active branch dict
invalid_postselect = (
depth > 0
and mcms[depth].postselect is not None
and mcm_current[depth] != mcms[depth].postselect
)
###########################################
# Obtain measurements for the active edge #
###########################################
# If num_shots is zero or postselecting on the wrong branch, update measurements with an empty tuple
if skip_subtree or invalid_postselect:
# Adjust counts if `invalid_postselect`
if invalid_postselect:
if finite_shots:
# Bump downstream cumulative counts before zeroing-out counts
for d in range(depth + 1, n_mcms + 1):
cumcounts[d] += stack.counts[depth][mcm_current[depth]]
stack.counts[depth][mcm_current[depth]] = 0
else:
stack.probs[depth][mcm_current[depth]] = 0
measurements = tuple()
else:
# If num_shots is non-zero, simulate the current depth circuit segment
if depth == 0:
initial_state = stack.states[0]
else:
initial_state = branch_state(stack.states[depth], mcm_current[depth], mcms[depth])
circtmp = qml.tape.QuantumScript(
circuits[depth].operations,
circuits[depth].measurements,
qml.measurements.shots.Shots(shots),
)
circtmp = prepend_state_prep(circtmp, initial_state, interface, circuit.wires)
state, is_state_batched = get_final_state(
circtmp,
debugger=debugger,
mid_measurements=mid_measurements,
**execution_kwargs,
)
measurements = measure_final_state(circtmp, state, is_state_batched, **execution_kwargs)
#####################################
# Update stack & step down the tree #
#####################################
# If not at a leaf, project on the zero-branch and increase depth by one
if depth < n_mcms and (not skip_subtree and not invalid_postselect):
depth += 1
# Update the active branch samples with `update_mcm_samples`
if finite_shots:
if (
mcms[depth]
and mcms[depth].postselect is not None
and postselect_mode == "fill-shots"
):
samples = mcms[depth].postselect * qml.math.ones_like(measurements)
else:
samples = qml.math.atleast_1d(measurements)
stack.counts[depth] = samples_to_counts(samples)
stack.probs[depth] = counts_to_probs(stack.counts[depth])
else:
stack.probs[depth] = dict(zip([False, True], measurements))
samples = None
# Store a copy of the state-vector to project on the one-branch
stack.states[depth] = state
mcm_samples, cumcounts = update_mcm_samples(samples, mcm_samples, depth, cumcounts)
continue
################################################
# Update terminal measurements & step sideways #
################################################
if not skip_subtree and not invalid_postselect:
measurements = insert_mcms(circuit, measurements, mid_measurements)
# If at a zero-branch leaf, update measurements and switch to the one-branch
if mcm_current[depth] == 0:
stack.results_0[depth] = measurements
mcm_current[depth] = True
mid_measurements[mcms[depth]] = True
continue
# If at a one-branch leaf, update measurements
stack.results_1[depth] = measurements
##################################################
# Finalize terminal measurements post-processing #
##################################################
measurement_dicts = get_measurement_dicts(terminal_measurements, stack, depth)
if finite_shots:
terminal_measurements = circuit.measurements
mcm_samples = {mcms[i]: v for i, v in mcm_samples.items()}
mcm_samples = prune_mcm_samples(mcm_samples)
results = combine_measurements(terminal_measurements, measurement_dicts, mcm_samples)
return variance_post_processing((results,))
def split_circuit_at_mcms(circuit):
"""Return a list of circuits segments (one for each mid-circuit measurement in the
original circuit) where the terminal measurements probe the MCM statistics. Only
the last segment retains the original terminal measurements.
Args:
circuit (QuantumTape): The circuit to simulate
Returns:
Sequence[QuantumTape]: Circuit segments.
"""
mcm_gen = ((i, op) for i, op in enumerate(circuit) if isinstance(op, MidMeasureMP))
circuits = []
first = 0
for last, op in mcm_gen:
new_operations = circuit.operations[first:last]
new_measurements = (
[qml.sample(wires=op.wires)] if circuit.shots else [qml.probs(wires=op.wires)]
)
circuits.append(
qml.tape.QuantumScript(new_operations, new_measurements, shots=circuit.shots)
)
first = last + 1
last_circuit_operations = circuit.operations[first:]
last_circuit_measurements = []
for m in circuit.measurements:
if m.mv is None:
last_circuit_measurements.append(m)
circuits.append(
qml.tape.QuantumScript(
last_circuit_operations, last_circuit_measurements, shots=circuit.shots
)
)
return circuits
def prepend_state_prep(circuit, state, interface, wires):
"""Prepend a ``StatePrep`` operation with the prescribed ``wires`` to the circuit.
``get_final_state`` executes a circuit on a subset of wires found in operations
or measurements. This function makes sure that an initial state with the correct size is created
on the first invocation of ``simulate_tree_mcm``. ``wires`` should be the wires attribute
of the original circuit (which included all wires)."""
if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase):
return circuit
state = (
create_initial_state(wires, None, like=INTERFACE_TO_LIKE[interface])
if state is None
else state
)
return qml.tape.QuantumScript(
[qml.StatePrep(qml.math.ravel(state), wires=wires, validate_norm=False)]
+ circuit.operations,
circuit.measurements,
shots=circuit.shots,
)
def insert_mcms(circuit, results, mid_measurements):
"""Inserts terminal measurements of MCMs if the circuit is evaluated in analytic mode."""
if circuit.shots or all(m.mv is None for m in circuit.measurements):
return results
results = list(results)
new_results = []
mid_measurements = {k: qml.math.array([[v]]) for k, v in mid_measurements.items()}
for m in circuit.measurements:
if m.mv is None:
new_results.append(results.pop(0))
else:
new_results.append(gather_mcm(m, mid_measurements, qml.math.array([[True]])))
return new_results
def get_measurement_dicts(measurements, stack, depth):
"""Combine a probs dictionary and two tuples of measurements into a
tuple of dictionaries storing the probs and measurements of both branches."""
# We use `circuits[-1].measurements` since it contains the
# target measurements (this is the only tape segment with
# unmodified measurements)
probs, results_0, results_1 = stack.probs[depth], stack.results_0[depth], stack.results_1[depth]
measurement_dicts = [{} for _ in measurements]
# Special treatment for single measurements
single_measurement = len(measurements) == 1
# Store each measurement in a dictionary `{branch: (prob, measure)}`
for branch, prob in probs.items():
meas = results_0 if branch == 0 else results_1
if single_measurement:
meas = [meas]
for i, m in enumerate(meas):
measurement_dicts[i][branch] = (prob, m)
return measurement_dicts
def branch_state(state, branch, mcm):
"""Collapse the state on a given branch.
Args:
state (TensorLike): The initial state
branch (int): The branch on which the state is collapsed
mcm (MidMeasureMP): Mid-circuit measurement object used to obtain the wires and ``reset``
Returns:
TensorLike: The collapsed state
"""
if isinstance(state, np.ndarray):
# FASTER
state = state.copy()
slices = [slice(None)] * qml.math.ndim(state)
axis = mcm.wires.toarray()[0]
slices[axis] = int(not branch)
state[tuple(slices)] = 0.0
state /= qml.math.norm(state)
else:
# SLOWER
state = apply_operation(qml.Projector([branch], mcm.wires), state)
state = state / qml.math.norm(state)
if mcm.reset and branch == 1:
state = apply_operation(qml.PauliX(mcm.wires), state)
return state
def samples_to_counts(samples):
"""Converts samples to counts.
This function forces integer keys and values which are required by ``simulate_tree_mcm``.
"""
counts_1 = int(qml.math.count_nonzero(samples))
return {0: samples.size - counts_1, 1: counts_1}
def counts_to_probs(counts):
"""Converts counts to probs."""
probs = qml.math.array(list(counts.values()))
probs = probs / qml.math.sum(probs)
return dict(zip(counts.keys(), probs))
def prune_mcm_samples(mcm_samples):
"""Removes invalid mid-measurement samples.
Post-selection on a given mid-circuit measurement leads to ignoring certain branches
of the tree and samples. The corresponding samples in all other mid-circuit measurement
must be deleted accordingly. We need to find which samples are
corresponding to the current branch by looking at all parent nodes.
"""
if not mcm_samples or all(v is None for v in mcm_samples.values()):
return mcm_samples
mask = qml.math.ones(list(mcm_samples.values())[0].shape, dtype=bool)
for mcm, s in mcm_samples.items():
if mcm.postselect is None:
continue
mask = qml.math.logical_and(mask, s == mcm.postselect)
return {k: v[mask] for k, v in mcm_samples.items()}
def update_mcm_samples(samples, mcm_samples, depth, cumcounts):
"""Updates the depth-th mid-measurement samples.
To illustrate how the function works, let's take an example. Suppose there are
``2**20`` shots in total and the computation is midway through the circuit at the
7th MCM, the active branch is ``[0,1,1,0,0,1]``, and at each MCM everything happened to
split the counts 50/50, so there are ``2**14`` samples to update.
These samples are correlated with the parent
branches, so where do they go? They must update the ``2**14`` elements whose parent
sequence corresponds to ``[0,1,1,0,0,1]``. ``cumcounts`` is used for this job and
increased by the size of ``samples`` each time this function is called.
"""
if depth not in mcm_samples or mcm_samples[depth] is None:
return mcm_samples, cumcounts
count1 = qml.math.sum(samples)
count0 = samples.size - count1
mcm_samples[depth][cumcounts[depth] : cumcounts[depth] + count0] = 0
cumcounts[depth] += count0
mcm_samples[depth][cumcounts[depth] : cumcounts[depth] + count1] = 1
cumcounts[depth] += count1
return mcm_samples, cumcounts
@qml.transform
def variance_transform(circuit):
"""Replace variance measurements by expectation value measurements of both the observable and the observable square.
This is necessary since computing the variance requires the global expectation value which is not available from measurements on subtrees.
"""
skip_transform = not any(isinstance(m, VarianceMP) for m in circuit.measurements)
if skip_transform:
return (circuit,), lambda x: x[0]
def variance_post_processing(results):
"""Compute the global variance from expectation value measurements of both the observable and the observable square."""
new_results = list(results[0])
offset = len(circuit.measurements)
for i, (r, m) in enumerate(zip(new_results, circuit.measurements)):
if isinstance(m, VarianceMP):
expval = new_results.pop(offset)
new_results[i] = r - expval**2
return new_results[0] if len(new_results) == 1 else new_results
new_measurements = []
extra_measurements = []
for m in circuit.measurements:
if isinstance(m, VarianceMP):
obs2 = m.mv * m.mv if m.mv is not None else m.obs @ m.obs
new_measurements.append(ExpectationMP(obs=obs2))
extra_measurements.append(ExpectationMP(obs=m.mv if m.mv is not None else m.obs))
else:
new_measurements.append(m)
new_measurements.extend(extra_measurements)
return (
(
qml.tape.QuantumScript(
circuit.operations,
new_measurements,
shots=circuit.shots,
),
),
variance_post_processing,
)
def measurement_with_no_shots(measurement):
"""Returns a NaN scalar or array of the correct size when executing an all-invalid-shot circuit."""
if isinstance(measurement, ProbabilityMP):
return np.nan * qml.math.ones(2 ** len(measurement.wires))
return np.nan
def combine_measurements(terminal_measurements, results, mcm_samples):
"""Returns combined measurement values of various types."""
empty_mcm_samples = False
need_mcm_samples = not all(v is None for v in mcm_samples.values())
need_mcm_samples = need_mcm_samples and any(
circ_meas.mv is not None for circ_meas in terminal_measurements
)
if need_mcm_samples:
empty_mcm_samples = len(next(iter(mcm_samples.values()))) == 0
if empty_mcm_samples and any(len(m) != 0 for m in mcm_samples.values()): # pragma: no cover
raise ValueError("mcm_samples have inconsistent shapes.")
final_measurements = []
for circ_meas in terminal_measurements:
if need_mcm_samples and circ_meas.mv is not None and empty_mcm_samples:
comb_meas = measurement_with_no_shots(circ_meas)
elif need_mcm_samples and circ_meas.mv is not None:
mcm_samples = {k: v.reshape((-1, 1)) for k, v in mcm_samples.items()}
is_valid = qml.math.ones(list(mcm_samples.values())[0].shape[0], dtype=bool)
comb_meas = gather_mcm(circ_meas, mcm_samples, is_valid)
elif not results or not results[0]:
if len(results) > 0:
_ = results.pop(0)
comb_meas = measurement_with_no_shots(circ_meas)
else:
comb_meas = combine_measurements_core(circ_meas, results.pop(0))
if isinstance(circ_meas, SampleMP):
comb_meas = qml.math.squeeze(comb_meas)
final_measurements.append(comb_meas)
return final_measurements[0] if len(final_measurements) == 1 else tuple(final_measurements)
@singledispatch
def combine_measurements_core(original_measurement, measures): # pylint: disable=unused-argument
"""Returns the combined measurement value of a given type."""
raise TypeError(
f"Native mid-circuit measurement mode does not support {type(original_measurement).__name__}"
)
@combine_measurements_core.register
def _(original_measurement: CountsMP, measures): # pylint: disable=unused-argument
"""The counts are accumulated using a ``Counter`` object."""
keys = list(measures.keys())
new_counts = Counter()
for k in keys:
if not measures[k][0]:
continue
new_counts.update(measures[k][1])
return dict(sorted(new_counts.items()))
@combine_measurements_core.register
def _(original_measurement: ExpectationMP, measures): # pylint: disable=unused-argument
"""The expectation value of two branches is a weighted sum of expectation values."""
cum_value = 0
total_counts = 0
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += qml.math.multiply(v[0], v[1])
total_counts += v[0]
return cum_value / total_counts
@combine_measurements_core.register
def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused-argument
"""The combined probability of two branches is a weighted sum of the probabilities. Note the implementation is the same as for ``ExpectationMP``."""
cum_value = 0
total_counts = 0
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += qml.math.multiply(v[0], v[1])
total_counts += v[0]
return cum_value / total_counts
@combine_measurements_core.register
def _(original_measurement: SampleMP, measures): # pylint: disable=unused-argument
"""The combined samples of two branches is obtained by concatenating the sample of each branch."""
new_sample = tuple(
qml.math.atleast_1d(m[1]) for m in measures.values() if m[0] and not m[1] is tuple()
)
return qml.math.squeeze(qml.math.concatenate(new_sample))
@debug_logger
def simulate_one_shot_native_mcm(
circuit: qml.tape.QuantumScript, debugger=None, **execution_kwargs
) -> Result:
"""Simulate a single shot of a single quantum script with native mid-circuit measurements.
Assumes that the circuit has been transformed by `dynamic_one_shot`.
Args:
circuit (QuantumTape): The single circuit to simulate
debugger (_Debugger): The debugger to use
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. If None, a random key will be
generated. Only for simulation using JAX.
interface (str): The machine learning interface to create the initial state with
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.
Returns:
Result: The results of the simulation
"""
prng_key = execution_kwargs.pop("prng_key", None)
ops_key, meas_key = jax_random_split(prng_key)
mid_measurements = {}
state, is_state_batched = get_final_state(
circuit,
debugger=debugger,
mid_measurements=mid_measurements,
prng_key=ops_key,
**execution_kwargs,
)
return measure_final_state(
circuit,
state,
is_state_batched,
prng_key=meas_key,
mid_measurements=mid_measurements,
**execution_kwargs,
)
_modules/pennylane/devices/qubit/simulate
Download Python script
Download Notebook
View on GitHub