Source code for pennylane.devices.qubit_mixed.apply_operation
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to apply operations to a qubit mixed state."""
# pylint: disable=unused-argument
from functools import singledispatch
from string import ascii_letters as alphabet
from typing import Union
import numpy as np
import pennylane as qml
from pennylane import math
from pennylane.devices.qubit.apply_operation import _apply_grover_without_matrix
from pennylane.operation import Channel
from pennylane.ops.qubit.attributes import diagonal_in_z_basis
from .einsum_manpulation import get_einsum_mapping
alphabet_array = math.array(list(alphabet))
TENSORDOT_STATE_NDIM_PERF_THRESHOLD = 9
def _get_slice(index, axis, num_axes):
"""Allows slicing along an arbitrary axis of an array or tensor.
Args:
index (int): the index to access
axis (int): the axis to slice into
num_axes (int): total number of axes
Returns:
tuple[slice or int]: a tuple that can be used to slice into an array or tensor
**Example:**
Accessing the 2 index along axis 1 of a 3-axis array:
>>> sl = _get_slice(2, 1, 3)
>>> sl
(slice(None, None, None), 2, slice(None, None, None))
>>> a = np.arange(27).reshape((3, 3, 3))
>>> a[sl]
array([[ 6, 7, 8],
[15, 16, 17],
[24, 25, 26]])
"""
idx = [slice(None)] * num_axes
idx[axis] = index
return tuple(idx)
def _phase_shift(state, axis, phase_factor=-1, debugger=None, **_):
"""
Applies a phase shift operation to a density matrix along a specified axis.
This function implements a phase shift operation on a mixed quantum state (density matrix).
For a given axis, it applies the phase shift by conjugating the density matrix with the
phase shift operator: ρ -> U ρ U†, where U is the phase shift operator. This implementation
is specific to single-qubit operations without broadcasting.
Args:
state (array-like): The density matrix to transform, with shape (2^n, 2^n) where n is
the number of qubits.
axis (int): The target qubit axis (0-based indexing) where the phase shift is applied.
phase_factor (complex, optional): The complex phase to apply. Common values include:
* -1 for Pauli-Z gate
* 1j for S gate (π/2 phase)
* exp(1j * π/4) for T gate (π/4 phase)
debugger (callable, optional): A debug function for operation verification.
Defaults to None.
**_: Additional unused keyword arguments.
Returns:
array-like: The transformed density matrix with the same shape as the input.
Raises:
ValueError: If axis is invalid for the given density matrix dimension.
ValueError: If the input state is not a valid density matrix (not square or
incorrect dimensions).
Example:
>>> import numpy as np
>>> # Single-qubit case: density matrix for |+⟩⟨+|
>>> plus_state = np.array([[0.5, 0.5],
... [0.5, 0.5]])
>>> # Apply Pauli-Z (phase_factor=-1)
>>> z_applied = _phase_shift(plus_state, axis=0)
>>> print(z_applied)
[[0.5, -0.5],
[-0.5, 0.5]]
>>> # Two-qubit case: density matrix for |0⟩⟨0| ⊗ |+⟩⟨+|
>>> two_qubit_state = np.array([
... [0.5, 0.5, 0, 0],
... [0.5, 0.5, 0, 0],
... [0, 0, 0, 0],
... [0, 0, 0, 0]
... ]).reshape(2,2,2,2)
>>> # Apply phase shift on second qubit (axis=1)
>>> z_on_second = _phase_shift(two_qubit_state, axis=1)
>>> print(z_on_second)
... [[[[ 0.5 0.5]
... [ 0. 0. ]]
... [[-0.5 -0.5]
... [-0. -0. ]]]
... [[[ 0. 0. ]
... [ 0. 0. ]]
... [[-0. -0. ]
... [-0. -0. ]]]]
>>> # Apply phase shift on first qubit (axis=1)
>>> z_on_first = _phase_shift(two_qubit_state, axis=0)
>>> print(z_on_first)
... [[[[ 0.5 0.5]
... [ 0. 0. ]]
... [[ 0.5 0.5]
... [ 0. 0. ]]]
... [[[-0. -0. ]
... [-0. -0. ]]
... [[-0. -0. ]
... [-0. -0. ]]]]
Notes:
- The operation is performed in-place for computational efficiency
- The function assumes the density matrix is in the computational basis
- For an n-qubit system, the axis should be in range [0, n-1]
- The phase shift operator U for single-qubit case is:
U = [[1, 0],
[0, phase_factor]]
"""
n_dim = math.ndim(state)
sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)
state_1 = math.multiply(state[sl_1], phase_factor)
return math.stack([state[sl_0], state_1], axis=axis)
def _get_dagger_symmetric_real_op(op, num_wires):
"""Get the conjugate transpose of an operation by shifting num_wires. Should only be used for real, symmetric operations."""
return qml.map_wires(op, {w: w + num_wires for w in op.wires})
def _get_num_wires(state, is_state_batched):
"""
For density matrix, we need to infer the number of wires from the state.
"""
shape = qml.math.shape(state)
batch_size = shape[0] if is_state_batched else 1
total_dim = math.prod(shape) // batch_size
# total_dim should be 2^(2*num_wires)
# Solve for num_wires: 2*num_wires = log2(total_dim) -> num_wires = log2(total_dim)/2
num_wires = int(math.log2(total_dim) / 2)
return num_wires
def _conjugate_state_with(k, state, axes_left, axes_right):
"""Perform the double tensor product k @ state @ k.conj(), with given, single matrix k.
The `axes_left` and `axes_right` arguments are taken from the ambient variable space
and `axes_right` is assumed to incorporate the tensor product and the transposition
of k.conj() simultaneously."""
return math.tensordot(
math.tensordot(k, state, axes_left),
math.conj(k),
axes_right,
)
def apply_operation_einsum(
op: qml.operation.Operator,
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
r"""Apply a quantum channel specified by a list of Kraus operators to subsystems of the
quantum state. For a unitary gate, there is a single Kraus operator.
Args:
op (Operator): Operator to apply to the quantum state
state (array[complex]): Input quantum state
is_state_batched (bool): Boolean representing whether the state is batched or not
Returns:
array[complex]: output_state
"""
num_ch_wires = len(op.wires)
if isinstance(op, Channel):
kraus = op.kraus_matrices()
else:
kraus = [math.cast_like(op.matrix(), state)]
# Shape kraus operators
kraus_shape = [len(kraus)] + [2] * num_ch_wires * 2
if not isinstance(op, Channel):
mat = op.matrix() + 0j
dim = 2**num_ch_wires
batch_size = math.get_batch_size(mat, (dim, dim), dim**2)
if batch_size is not None:
# Add broadcasting dimension to shape
kraus_shape = [batch_size] + kraus_shape
kraus = math.stack(kraus)
kraus_transpose = math.stack(math.moveaxis(kraus, source=-1, destination=-2))
# Torch throws error if math.conj is used before stack
kraus_dagger = math.conj(kraus_transpose)
kraus = math.cast(math.reshape(kraus, kraus_shape), complex)
kraus_dagger = math.reshape(kraus_dagger, kraus_shape)
#! Check the def of helper func for details
einsum_indices = get_einsum_mapping(op, state, is_state_batched)
# Cast back to the same as state
return math.einsum(einsum_indices, kraus, state, kraus_dagger)
def apply_operation_tensordot(
op: qml.operation.Operator, state, is_state_batched: bool = False, debugger=None, **_
):
"""Apply ``Operator`` to ``state`` using ``math.tensordot``. This is more efficent at higher qubit
numbers.
Args:
op (Operator): Operator to apply to the quantum state
state (array[complex]): Input quantum state
is_state_batched (bool): Boolean representing whether the state is batched or not
Returns:
array[complex]: output_state
"""
channel_wires = op.wires
num_ch_wires = len(channel_wires)
num_wires = _get_num_wires(state, is_state_batched)
#! Note that here we do not take into consideration the len of kraus list
kraus_shape = [2] * num_ch_wires * 2
# This could be pulled into separate function if tensordot is added
if isinstance(op, Channel):
kraus = [math.cast_like(math.reshape(k, kraus_shape), state) for k in op.kraus_matrices()]
else:
# !Note: we don't treat the batched ops inside tensordot calling
# here's for the unified treatment of the ops in the tensordot calling
# i.e. treating the op as a kraus list len 1
mat = op.matrix() + 0j
kraus = [mat]
kraus = [math.reshape(k, kraus_shape) for k in kraus]
kraus = math.array(kraus) # Necessary for Jax
# Small trick: _apply_channel_tensordot, here for the contraction on the right side we
# also directly contract the column indices of the channel instead of rows
# for simplicity. This can also save a step when transposing the Kraus operators.
row_wires_list = [w + is_state_batched for w in channel_wires.tolist()]
col_wires_list = [w + num_wires for w in row_wires_list]
channel_col_ids = list(range(-num_ch_wires, 0))
new_channel_col_ids = [-num_wires + w for w in channel_wires]
axes_left = [channel_col_ids, row_wires_list]
axes_right = [[0] + new_channel_col_ids, [0] + channel_col_ids]
_state = _conjugate_state_with(kraus, state, axes_left, axes_right)
source_left = list(range(num_ch_wires))
dest_left = row_wires_list
source_right = list(range(-num_ch_wires, 0))
dest_right = col_wires_list
result = math.moveaxis(_state, source_left + source_right, dest_left + dest_right)
return result
[docs]
@singledispatch
def apply_operation(
op: qml.operation.Operator,
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
"""Apply an operation to a given state.
Args:
op (Operator): The operation to apply to ``state``
state (TensorLike): The starting state.
is_state_batched (bool): Boolean representing whether the state is batched or not.
debugger (_Debugger): The debugger to use.
Keyword Arguments:
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``.
This is the key to the JAX pseudo random number generator. Only for simulation using JAX.
If None, a ``numpy.random.default_rng`` will be used for sampling.
tape_shots (Shots): The shots object of the tape.
Returns:
ndarray: The output state.
.. warning::
``apply_operation`` is an internal function, and thus subject to change without a deprecation cycle.
.. warning::
``apply_operation`` applies no validation to its inputs.
This function assumes that the wires of the operator correspond to indices
of the state. See :func:`~.map_wires` to convert operations to integer wire labels.
The shape of the state should be ``[2] * (num_wires * 2)`` (the original tensor form) or
``[2**num_wires, 2**num_wires]`` (the expanded matrix form), where ``2`` is
the dimension of the system.
This is a ``functools.singledispatch`` function, so additional specialized kernels
for specific operations can be registered like:
.. code-block:: python
@apply_operation.register
def _(op: type_op, state, is_state_batched=False, **kwargs):
# custom op application method here
**Example:**
>>> state = np.zeros((2, 2, 2, 2))
>>> state[0][0][0][0] = 1
>>> state
array([[[[1., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]],
[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]]])
>>> apply_operation(qml.PauliX(0), state)
array([[[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]],
[[[0., 0.],
[1., 0.]],
[[0., 0.],
[0., 0.]]]])
"""
return _apply_operation_default(op, state, is_state_batched, debugger, **_)
def _apply_operation_default(op, state, is_state_batched, debugger, **_):
"""The default behaviour of apply_operation, accessed through the standard dispatch
of apply_operation, as well as conditionally in other dispatches.
"""
if op in diagonal_in_z_basis:
return apply_diagonal_unitary(op, state, is_state_batched, debugger, **_)
num_op_wires = len(op.wires)
interface = math.get_interface(state)
# Add another layer of condition to rule out batched op (not channel) for tensordot calling
if (op.batch_size is None) and (
(num_op_wires > 2 and interface in {"autograd", "numpy"}) or num_op_wires > 7
):
return apply_operation_tensordot(op, state, is_state_batched, debugger, **_)
return apply_operation_einsum(op, state, is_state_batched, debugger, **_)
@apply_operation.register
def apply_identity(op: qml.Identity, state, is_state_batched: bool = False, debugger=None, **_):
"""Applies a :class:`~.Identity` operation by just returning the input state."""
return state
@apply_operation.register
def apply_global_phase(
op: qml.GlobalPhase, state, is_state_batched: bool = False, debugger=None, **_
):
"""Applies a :class:`~.GlobalPhase` operation by multiplying the state by ``exp(1j * op.data[0])``"""
# Note: the global phase is a scalar, so we can just multiply the
# state by it. For density matrix we suppose that the global phase
# means a phase factor acting on the basis state vectors, which
# implies that in the final density matrix there will be no effect.
return state
@apply_operation.register
def apply_paulix(op: qml.X, state, is_state_batched: bool = False, debugger=None, **_):
"""Applies a :class:`~.PauliX` operation by multiplying the state by the Pauli-X matrix."""
# PauliX is basically a bit flip, so we can just apply the X gate to the state
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
axis_left = op.wires[0] + is_state_batched
axis_right = axis_left + num_wires
return math.roll(math.roll(state, 1, axis_left), 1, axis_right)
@apply_operation.register
def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None, **_):
"""Applies a :class:`~.PauliZ` operation by multiplying the state by the Pauli-Z matrix."""
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
n_dim = math.ndim(state)
if n_dim >= TENSORDOT_STATE_NDIM_PERF_THRESHOLD and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
# First, flip the left side
axis = op.wires[0] + is_state_batched
state = _phase_shift(state, axis)
# Second, flip the right side
axis = op.wires[0] + is_state_batched + num_wires
state = _phase_shift(state, axis)
return state
@apply_operation.register
def apply_T(op: qml.T, state, is_state_batched: bool = False, debugger=None, **_):
"""Applies a :class:`~.T` operation by multiplying the state by the T matrix."""
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
n_dim = math.ndim(state)
if n_dim >= TENSORDOT_STATE_NDIM_PERF_THRESHOLD and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
# First, flip the left side
axis = op.wires[0] + is_state_batched
state = _phase_shift(state, axis, phase_factor=math.exp(0.25j * np.pi))
# Second, flip the right side
axis = op.wires[0] + is_state_batched + num_wires
state = _phase_shift(state, axis, phase_factor=math.exp(-0.25j * np.pi))
return state
@apply_operation.register
def apply_S(op: qml.S, state, is_state_batched: bool = False, debugger=None, **_):
"""Applies a :class:`~.S` operation by multiplying the state by the S matrix."""
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
n_dim = math.ndim(state)
if n_dim >= TENSORDOT_STATE_NDIM_PERF_THRESHOLD and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
# First, flip the left side
axis = op.wires[0] + is_state_batched
state = _phase_shift(state, axis, phase_factor=1.0j)
# Second, flip the right side
axis = op.wires[0] + is_state_batched + num_wires
state = _phase_shift(state, axis, phase_factor=-1.0j)
return state
@apply_operation.register
def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_):
"""Applies a :class:`~.Phaseshift` operation by multiplying the state by the Phaseshift matrix."""
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
n_dim = math.ndim(state)
if n_dim >= TENSORDOT_STATE_NDIM_PERF_THRESHOLD and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
# Common constants always needed
n_dim = math.ndim(state)
num_wires = _get_num_wires(state, is_state_batched)
# Start applying from the left side
axis = op.wires[0] + is_state_batched
# Slice indices of the affected axis
sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)
# Get the phase shift parameter
params = math.cast(op.parameters[0], dtype=complex)
state0 = state[sl_0]
state1 = state[sl_1]
if op.batch_size is not None and len(params) > 1:
interface = math.get_interface(state)
if interface == "torch":
params = math.array(params, like=interface)
if is_state_batched:
# If both op and state are batched, they have to have the same batch size
params = math.reshape(params, (-1,) + (1,) * (n_dim - 2))
else:
# Op is batched, state is not, so we need to expand the state to batched
params = math.reshape(params, (-1,) + (1,) * (n_dim - 1))
state0 = math.expand_dims(state0, 0) + math.zeros_like(params)
state1 = math.expand_dims(state1, 0)
# Update status
is_state_batched = True
axis = axis + 1
n_dim = n_dim + 1
state1 = math.multiply(math.cast(state1, dtype=complex), math.exp(1.0j * params))
state = math.stack([state0, state1], axis=axis)
# Left side finished
# Now start right side
axis += num_wires # Move to the right side (conjugate side)
# Slice indices of the affected axis
sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)
# Get the phase shift parameter, conjugated
state0 = state[sl_0]
state1 = state[sl_1]
# No need for expanding, since on the left side we already did
state1 = math.multiply(math.cast(state1, dtype=complex), math.exp(-1.0j * params))
state = math.stack([state0, state1], axis=axis)
return state
# !TODO: in the future investigate if there's other missing operations
# satisfying this condition.
SYMMETRIC_REAL_OPS = (
qml.CNOT,
qml.MultiControlledX,
qml.Toffoli,
qml.SWAP,
qml.CSWAP,
qml.CZ,
qml.CH,
)
def apply_symmetric_real_op(
op: Union[
qml.CNOT,
qml.MultiControlledX,
qml.Toffoli,
qml.SWAP,
qml.CSWAP,
qml.CZ,
qml.CH,
],
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
r"""Apply real, symmetric operator (e.g. X, CX and related controlled-X variants) to a density matrix state.
This function handles CZ, CH, CNOT, CSWAP, SWAP, Toffoli, and general MultiControlledX operations using the same underlying
implementation, as they share the properties of being real and symmetric. For operations with 8 or fewer wires,
it uses the default einsum contraction. For larger operations, it leverages a custom kernel that
exploits the fact that for real, symmetric operators, the adjoint operation can be implemented
by shifting wires by `num_wires`.
Args:
op (.Operation): CZ, CH, CNOT, CSWAP, SWAP, Toffoli, and general MultiControlledX operation
state (tensor_like): The density matrix state to apply the operation to
is_state_batched (bool): Whether the state has a batch dimension. Rather than checking
matrix dimensions, we use op.batch_size for efficiency
debugger (optional): A debugger instance for operation validation
Returns:
tensor_like: The transformed density matrix state
Note:
This is not a final version. Two possible improvements are:
1. More existing real, symmetric ops to include in this dispatch
2. A more general approach to handle other types of ops but following
similar logic as in this function.
"""
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
if len(op.wires) < TENSORDOT_STATE_NDIM_PERF_THRESHOLD:
return _apply_operation_default(op, state, is_state_batched, debugger)
state = qml.devices.qubit.apply_operation(op, state, is_state_batched, debugger)
op_dagger = _get_dagger_symmetric_real_op(op, num_wires)
state = qml.devices.qubit.apply_operation(op_dagger, state, is_state_batched, debugger)
return state
# NOTE: this loop is for a nice doc rendering for `apply_operation`. With a direct multiple single
# registers over different op class there will be severe rendering issue as discussed in
# https://github.com/PennyLaneAI/pennylane/pull/6684#pullrequestreview-2565634328
for op_class in SYMMETRIC_REAL_OPS:
apply_operation.register(op_class)(apply_symmetric_real_op)
@apply_operation.register
def apply_grover(
op: qml.GroverOperator,
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
"""Apply GroverOperator either via a custom matrix-free method (more than 8 operation
wires) or via standard matrix based methods (else)."""
if len(op.wires) < TENSORDOT_STATE_NDIM_PERF_THRESHOLD:
return _apply_operation_default(op, state, is_state_batched, debugger)
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
state = _apply_grover_without_matrix(state, op.wires, is_state_batched)
state = _apply_grover_without_matrix(state, [w + num_wires for w in op.wires], is_state_batched)
return state
def apply_diagonal_unitary(op, state, is_state_batched: bool = False, debugger=None, **_):
"""_summary_
Args:
op (_type_): _description_
state (_type_): _description_
is_state_batched (bool, optional): _description_. Defaults to False.
debugger (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
channel_wires = op.wires
num_wires = int((len(math.shape(state)) - is_state_batched) / 2)
eigvals = op.eigvals()
eigvals = math.stack(eigvals)
eigvals = math.reshape(eigvals, [2] * len(channel_wires))
eigvals = math.cast_like(eigvals, state)
state_indices = alphabet[: 2 * num_wires + is_state_batched]
row_wires_list = [w + is_state_batched for w in channel_wires.tolist()]
col_wires_list = [w + num_wires for w in row_wires_list]
row_indices = "".join(alphabet_array[row_wires_list].tolist())
col_indices = "".join(alphabet_array[col_wires_list].tolist())
# Basically, we want to do, lambda_a rho_ab lambda_b
einsum_indices = f"{row_indices},{state_indices},{col_indices}->{state_indices}"
return math.einsum(einsum_indices, eigvals, state, math.conj(eigvals))
@apply_operation.register
def apply_snapshot(
op: qml.Snapshot, state, is_state_batched: bool = False, debugger=None, **execution_kwargs
):
"""Take a snapshot of the mixed state
Args:
op (qml.Snapshot): the snapshot operation
state (array): current quantum state
is_state_batched (bool): whether the state is batched
debugger: the debugger instance for storing snapshots
Returns:
array: the unchanged quantum state
"""
if debugger and debugger.active:
measurement = op.hyperparameters.get(
"measurement", None
) # default: None, meaning no measurement, simply copy the state
if op.hyperparameters["shots"] == "workflow":
shots = execution_kwargs.get("tape_shots")
else:
shots = op.hyperparameters["shots"]
if isinstance(measurement, qml.measurements.StateMP) or not shots:
snapshot = qml.devices.qubit_mixed.measure(measurement, state, is_state_batched)
else:
snapshot = qml.devices.qubit_mixed.measure_with_samples(
[measurement],
state,
shots,
is_state_batched,
execution_kwargs.get("rng"),
execution_kwargs.get("prng_key"),
)[0]
# Store snapshot with optional tag
if op.tag:
debugger.snapshots[op.tag] = snapshot
else:
debugger.snapshots[len(debugger.snapshots)] = snapshot
return state
@apply_operation.register
def apply_density_matrix(
op: qml.QubitDensityMatrix,
state,
is_state_batched: bool = False,
debugger=None,
**execution_kwargs,
):
"""
Applies a QubitDensityMatrix operation by initializing or replacing
the quantum state with the provided density matrix.
- If the QubitDensityMatrix covers all wires, we directly return the provided density matrix as the new state.
- If only a subset of the wires is covered, we:
1. Partial trace out those wires from the current state to get the density matrix of the complement wires.
2. Take the tensor product of the complement density matrix and the provided density_matrix.
3. Reshape to the correct final shape and return.
Args:
op (qml.QubitDensityMatrix): The QubitDensityMatrix operation.
state (array-like): The current quantum state.
is_state_batched (bool): Whether the state is batched.
debugger: A debugger instance for diagnostics.
**execution_kwargs: Additional kwargs.
Returns:
array-like: The updated quantum state.
Raises:
ValueError: If the density matrix is invalid.
"""
density_matrix = op.parameters[0]
num_wires = len(op.wires)
expected_dim = 2**num_wires
# Cast density_matrix to the same type and device as state
density_matrix = math.cast_like(density_matrix, state)
# Extract total wires
num_state_wires = _get_num_wires(state, is_state_batched)
all_wires = list(range(num_state_wires))
op_wires = op.wires
complement_wires = [w for w in all_wires if w not in op_wires]
# If the operation covers the full system, just return it
if len(op_wires) == num_state_wires:
# If batched, broadcast
if is_state_batched:
batch_size = math.shape(state)[0]
density_matrix = math.broadcast_to(
density_matrix, (batch_size,) + math.shape(density_matrix)
)
# Reshape to match final shape of state
return math.reshape(density_matrix, math.shape(state))
# Partial system update:
# 1. Partial trace out op_wires from state
# partial_trace reduces the dimension to only the complement wires
sigma = qml.math.partial_trace(state, indices=op_wires)
# sigma now has shape:
# (batch_size, 2^(n - num_wires), 2^(n - num_wires)) where n = total wires
# 2. Take kron(sigma, density_matrix)
sigma_dim = 2 ** len(complement_wires) # dimension of complement subsystem
dm_dim = expected_dim # dimension of the replaced subsystem
if is_state_batched:
batch_size = math.shape(sigma)[0]
sigma_2d = math.reshape(sigma, (batch_size, sigma_dim, sigma_dim))
dm_2d = math.reshape(density_matrix, (dm_dim, dm_dim))
# Initialize new_dm and fill via a loop or vectorized kron if available
new_dm = []
for b in range(batch_size):
new_dm.append(math.kron(sigma_2d[b], dm_2d))
rho = math.stack(new_dm, axis=0)
else:
sigma_2d = math.reshape(sigma, (sigma_dim, sigma_dim))
dm_2d = math.reshape(density_matrix, (dm_dim, dm_dim))
rho = math.kron(sigma_2d, dm_2d)
# rho now has shape (batch_size?, 2^n, 2^n)
# 3. Reshape rho into the full tensor form [2]*(2*n) or [batch_size, 2]*(2*n)
final_shape = ([batch_size] if is_state_batched else []) + [2] * (2 * num_state_wires)
rho = math.reshape(rho, final_shape)
# Return the updated state
return reorder_after_kron(rho, complement_wires, op_wires, is_state_batched)
def reorder_after_kron(rho, complement_wires, op_wires, is_state_batched):
"""
Reorder the wires of `rho` from [complement_wires + op_wires] back to [0,1,...,N-1].
Args:
rho (tensor): The density matrix after kron(sigma, density_matrix).
complement_wires (list[int]): The wires not affected by the QubitDensityMatrix update.
op_wires (Wires): The wires affected by the QubitDensityMatrix.
is_state_batched (bool): Whether the state is batched.
Returns:
tensor: The density matrix with wires in the original order.
"""
# Final order after kron is complement_wires + op_wires (for both left and right sides).
all_wires = complement_wires + list(op_wires)
num_wires = len(all_wires)
batch_offset = 1 if is_state_batched else 0
# The current axis mapping is:
# Left side wires: offset to offset+num_wires-1
# Right side wires: offset+num_wires to offset+2*num_wires-1
#
# We want to reorder these so that the left side wires are [0,...,num_wires-1] and
# the right side wires are [num_wires,...,2*num_wires-1].
# Create a lookup from wire label to its position in the current order.
wire_to_pos = {w: i for i, w in enumerate(all_wires)}
# We'll construct a permutation of axes. `rho` has dimensions:
# [batch?] + [2]*num_wires (left side) + [2]*num_wires (right side)
#
# After transpose, dimension i in the new tensor should correspond to dimension new_axes[i] in the old tensor.
old_ndim = rho.ndim
new_axes = [None] * old_ndim
# If batched, batch dimension remains at axis 0
if is_state_batched:
new_axes[0] = 0
# For the left wires:
# Desired final order: 0,1,...,num_wires-1
# Currently: all_wires in some order
# old axis = batch_offset + wire_to_pos[w]
# new axis = batch_offset + w
for w in range(num_wires):
old_axis = batch_offset + wire_to_pos[w]
new_axes[batch_offset + w] = old_axis
# For the right wires:
# Desired final order: num_wires,...,2*num_wires-1
# Currently: batch_offset+num_wires+wire_to_pos[w]
# new axis: batch_offset+num_wires+w
for w in range(num_wires):
old_axis = batch_offset + num_wires + wire_to_pos[w]
new_axes[batch_offset + num_wires + w] = old_axis
# Apply the transpose
rho = math.transpose(rho, axes=tuple(new_axes))
return rho
_modules/pennylane/devices/qubit_mixed/apply_operation
Download Python script
Download Notebook
View on GitHub