Source code for pennylane.devices.qubit.apply_operation
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to apply an operation to a state vector."""
# pylint: disable=unused-argument, too-many-arguments
from functools import singledispatch
from string import ascii_letters as alphabet
import numpy as np
import pennylane as qml
from pennylane import math
from pennylane.measurements import MidMeasureMP
from pennylane.ops import Conditional
SQRT2INV = 1 / math.sqrt(2)
EINSUM_OP_WIRECOUNT_PERF_THRESHOLD = 3
EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD = 13
def _get_slice(index, axis, num_axes):
"""Allows slicing along an arbitrary axis of an array or tensor.
Args:
index (int): the index to access
axis (int): the axis to slice into
num_axes (int): total number of axes
Returns:
tuple[slice or int]: a tuple that can be used to slice into an array or tensor
**Example:**
Accessing the 2 index along axis 1 of a 3-axis array:
>>> sl = _get_slice(2, 1, 3)
>>> sl
(slice(None, None, None), 2, slice(None, None, None))
>>> a = np.arange(27).reshape((3, 3, 3))
>>> a[sl]
array([[ 6, 7, 8],
[15, 16, 17],
[24, 25, 26]])
"""
idx = [slice(None)] * num_axes
idx[axis] = index
return tuple(idx)
def apply_operation_einsum(op: qml.operation.Operator, state, is_state_batched: bool = False):
"""Apply ``Operator`` to ``state`` using ``einsum``. This is more efficent at lower qubit
numbers.
Args:
op (Operator): Operator to apply to the quantum state
state (array[complex]): Input quantum state
is_state_batched (bool): Boolean representing whether the state is batched or not
Returns:
array[complex]: output_state
"""
# We use this implicit casting strategy as autograd raises ComplexWarnings
# when backpropagating if casting explicitly. Some type of casting is needed
# to prevent ComplexWarnings with backpropagation with other interfaces
if qml.math.get_interface(state) == "tensorflow":
mat = qml.math.cast_like(op.matrix(), state)
else:
mat = op.matrix() + 0j
total_indices = len(state.shape) - is_state_batched
num_indices = len(op.wires)
state_indices = alphabet[:total_indices]
affected_indices = "".join(alphabet[i] for i in op.wires)
new_indices = alphabet[total_indices : total_indices + num_indices]
new_state_indices = state_indices
for old, new in zip(affected_indices, new_indices):
new_state_indices = new_state_indices.replace(old, new)
einsum_indices = (
f"...{new_indices}{affected_indices},...{state_indices}->...{new_state_indices}"
)
new_mat_shape = [2] * (num_indices * 2)
dim = 2**num_indices
batch_size = math.get_batch_size(mat, (dim, dim), dim**2)
if batch_size is not None:
# Add broadcasting dimension to shape
new_mat_shape = [batch_size] + new_mat_shape
if op.batch_size is None:
op._batch_size = batch_size # pylint:disable=protected-access
reshaped_mat = math.reshape(mat, new_mat_shape)
return math.einsum(einsum_indices, reshaped_mat, state)
def apply_operation_tensordot(op: qml.operation.Operator, state, is_state_batched: bool = False):
"""Apply ``Operator`` to ``state`` using ``math.tensordot``. This is more efficent at higher qubit
numbers.
Args:
op (Operator): Operator to apply to the quantum state
state (array[complex]): Input quantum state
is_state_batched (bool): Boolean representing whether the state is batched or not
Returns:
array[complex]: output_state
"""
# We use this implicit casting strategy as autograd raises ComplexWarnings
# when backpropagating if casting explicitly. Some type of casting is needed
# to prevent ComplexWarnings with backpropagation with other interfaces
if qml.math.get_interface(state) == "tensorflow":
mat = qml.math.cast_like(op.matrix(), state)
else:
mat = op.matrix() + 0j
total_indices = len(state.shape) - is_state_batched
num_indices = len(op.wires)
new_mat_shape = [2] * (num_indices * 2)
dim = 2**num_indices
batch_size = math.get_batch_size(mat, (dim, dim), dim**2)
if is_mat_batched := batch_size is not None:
# Add broadcasting dimension to shape
new_mat_shape = [batch_size] + new_mat_shape
if op.batch_size is None:
op._batch_size = batch_size # pylint:disable=protected-access
reshaped_mat = math.reshape(mat, new_mat_shape)
mat_axes = list(range(-num_indices, 0))
state_axes = [i + is_state_batched for i in op.wires]
axes = (mat_axes, state_axes)
tdot = math.tensordot(reshaped_mat, state, axes=axes)
# tensordot causes the axes given in `wires` to end up in the first positions
# of the resulting tensor. This corresponds to a (partial) transpose of
# the correct output state
# We'll need to invert this permutation to put the indices in the correct place
unused_idxs = [i for i in range(total_indices) if i not in op.wires]
perm = list(op.wires) + unused_idxs
if is_mat_batched:
perm = [0] + [i + 1 for i in perm]
if is_state_batched:
perm.insert(num_indices, -1)
inv_perm = math.argsort(perm)
return math.transpose(tdot, inv_perm)
[docs]@singledispatch
def apply_operation(
op: qml.operation.Operator,
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
"""Apply and operator to a given state.
Args:
op (Operator): The operation to apply to ``state``
state (TensorLike): The starting state.
is_state_batched (bool): Boolean representing whether the state is batched or not
debugger (_Debugger): The debugger to use
**execution_kwargs (Optional[dict]): Optional keyword arguments needed for applying
some operations described below.
Keyword Arguments:
mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value
interface (str): The machine learning interface of the state
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. ``None`` by default.
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
If None, a ``numpy.random.default_rng`` will be used for sampling.
tape_shots (Shots): the shots object of the tape
Returns:
ndarray: output state
.. warning::
``apply_operation`` is an internal function, and thus subject to change without a deprecation cycle.
.. warning::
``apply_operation`` applies no validation to its inputs.
This function assumes that the wires of the operator correspond to indices
of the state. See :func:`~.map_wires` to convert operations to integer wire labels.
The shape of state should be ``[2]*num_wires``.
This is a ``functools.singledispatch`` function, so additional specialized kernels
for specific operations can be registered like:
.. code-block:: python
@apply_operation.register
def _(op: type_op, state):
# custom op application method here
**Example:**
>>> state = np.zeros((2,2))
>>> state[0][0] = 1
>>> state
tensor([[1., 0.],
[0., 0.]], requires_grad=True)
>>> apply_operation(qml.X(0), state)
tensor([[0., 0.],
[1., 0.]], requires_grad=True)
"""
return _apply_operation_default(op, state, is_state_batched, debugger)
def _apply_operation_default(op, state, is_state_batched, debugger):
"""The default behaviour of apply_operation, accessed through the standard dispatch
of apply_operation, as well as conditionally in other dispatches."""
if (
len(op.wires) < EINSUM_OP_WIRECOUNT_PERF_THRESHOLD
and math.ndim(state) < EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD
) or (op.batch_size and is_state_batched):
return apply_operation_einsum(op, state, is_state_batched=is_state_batched)
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
@apply_operation.register
def apply_conditional(
op: Conditional,
state,
is_state_batched: bool = False,
debugger=None,
**execution_kwargs,
):
"""Applies a conditional operation.
Args:
op (Operator): The operation to apply to ``state``
state (TensorLike): The starting state.
is_state_batched (bool): Boolean representing whether the state is batched or not
debugger (_Debugger): The debugger to use
mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value
interface (str): The machine learning interface of the state
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
If None, a ``numpy.random.default_rng`` will be used for sampling.
Returns:
ndarray: output state
"""
mid_measurements = execution_kwargs.get("mid_measurements", None)
rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
interface = qml.math.get_deep_interface(state)
if interface == "jax":
# pylint: disable=import-outside-toplevel
from jax.lax import cond
return cond(
op.meas_val.concretize(mid_measurements),
lambda x: apply_operation(
op.base,
x,
is_state_batched=is_state_batched,
debugger=debugger,
mid_measurements=mid_measurements,
rng=rng,
prng_key=prng_key,
),
lambda x: x,
state,
)
if op.meas_val.concretize(mid_measurements):
return apply_operation(
op.base,
state,
is_state_batched=is_state_batched,
debugger=debugger,
mid_measurements=mid_measurements,
rng=rng,
prng_key=prng_key,
)
return state
@apply_operation.register
def apply_mid_measure(
op: MidMeasureMP, state, is_state_batched: bool = False, debugger=None, **execution_kwargs
):
"""Applies a native mid-circuit measurement.
Args:
op (Operator): The operation to apply to ``state``
state (TensorLike): The starting state.
is_state_batched (bool): Boolean representing whether the state is batched or not
debugger (_Debugger): The debugger to use
mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. ``None`` by default.
rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
the key to the JAX pseudo random number generator. Only for simulation using JAX.
If None, a ``numpy.random.default_rng`` will be used for sampling.
Returns:
ndarray: output state
"""
mid_measurements = execution_kwargs.get("mid_measurements", None)
rng = execution_kwargs.get("rng", None)
prng_key = execution_kwargs.get("prng_key", None)
postselect_mode = execution_kwargs.get("postselect_mode", None)
if is_state_batched:
raise ValueError("MidMeasureMP cannot be applied to batched states.")
wire = op.wires
interface = qml.math.get_deep_interface(state)
if postselect_mode == "fill-shots" and op.postselect is not None:
sample = op.postselect
else:
axis = wire.toarray()[0]
slices = [slice(None)] * qml.math.ndim(state)
slices[axis] = 0
prob0 = qml.math.real(qml.math.norm(state[tuple(slices)])) ** 2
if prng_key is not None:
# pylint: disable=import-outside-toplevel
from jax.random import binomial
def binomial_fn(n, p):
return binomial(prng_key, n, p).astype(int)
else:
binomial_fn = np.random.binomial if rng is None else rng.binomial
sample = binomial_fn(1, 1 - prob0)
mid_measurements[op] = sample
# Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.Projector([sample], wire),...)
# to select the sample branch enables jax.jit and prevents it from using Python callbacks
matrix = qml.math.array([[(sample + 1) % 2, 0.0], [0.0, (sample) % 2]], like=interface)
state = apply_operation(
qml.QubitUnitary(matrix, wire),
state,
is_state_batched=is_state_batched,
debugger=debugger,
)
state = state / qml.math.norm(state)
# Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.X(wire), ...)
# to reset enables jax.jit and prevents it from using Python callbacks
element = op.reset and sample == 1
matrix = qml.math.array(
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]],
like=interface,
dtype=float,
)
state = apply_operation(
qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger
)
return state
@apply_operation.register
def apply_identity(op: qml.Identity, state, is_state_batched: bool = False, debugger=None, **_):
"""Applies a :class:`~.Identity` operation by just returning the input state."""
return state
@apply_operation.register
def apply_global_phase(
op: qml.GlobalPhase, state, is_state_batched: bool = False, debugger=None, **_
):
"""Applies a :class:`~.GlobalPhase` operation by multiplying the state by ``exp(1j * op.data[0])``"""
return qml.math.exp(-1j * qml.math.cast(op.data[0], complex)) * state
@apply_operation.register
def apply_paulix(op: qml.X, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply :class:`pennylane.PauliX` operator to the quantum state"""
axis = op.wires[0] + is_state_batched
return math.roll(state, 1, axis)
@apply_operation.register
def apply_pauliz(op: qml.Z, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply pauliz to state."""
axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)
if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)
# must be first state and then -1 because it breaks otherwise
state1 = math.multiply(state[sl_1], -1)
return math.stack([state[sl_0], state1], axis=axis)
@apply_operation.register
def apply_phaseshift(op: qml.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply PhaseShift to state."""
n_dim = math.ndim(state)
if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
axis = op.wires[0] + is_state_batched
sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)
params = math.cast(op.parameters[0], dtype=complex)
state0 = state[sl_0]
state1 = state[sl_1]
if op.batch_size is not None and len(params) > 1:
interface = math.get_interface(state)
if interface == "torch":
params = math.array(params, like=interface)
if is_state_batched:
params = math.reshape(params, (-1,) + (1,) * (n_dim - 2))
else:
axis = axis + 1
params = math.reshape(params, (-1,) + (1,) * (n_dim - 1))
state0 = math.expand_dims(state0, 0) + math.zeros_like(params)
state1 = math.expand_dims(state1, 0)
state1 = math.multiply(math.cast(state1, dtype=complex), math.exp(1.0j * params))
state = math.stack([state0, state1], axis=axis)
if not is_state_batched and op.batch_size == 1:
state = math.stack([state], axis=0)
return state
@apply_operation.register
def apply_T(op: qml.T, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply T to state."""
axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)
if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)
state1 = math.multiply(math.cast(state[sl_1], dtype=complex), math.exp(0.25j * np.pi))
return math.stack([state[sl_0], state1], axis=axis)
@apply_operation.register
def apply_S(op: qml.S, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply S to state."""
axis = op.wires[0] + is_state_batched
n_dim = math.ndim(state)
if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
sl_0 = _get_slice(0, axis, n_dim)
sl_1 = _get_slice(1, axis, n_dim)
state1 = math.multiply(math.cast(state[sl_1], dtype=complex), 1j)
return math.stack([state[sl_0], state1], axis=axis)
@apply_operation.register
def apply_cnot(op: qml.CNOT, state, is_state_batched: bool = False, debugger=None, **_):
"""Apply cnot gate to state."""
target_axes = (op.wires[1] - 1 if op.wires[1] > op.wires[0] else op.wires[1]) + is_state_batched
control_axes = op.wires[0] + is_state_batched
n_dim = math.ndim(state)
if n_dim >= 9 and math.get_interface(state) == "tensorflow":
return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
sl_0 = _get_slice(0, control_axes, n_dim)
sl_1 = _get_slice(1, control_axes, n_dim)
state_x = math.roll(state[sl_1], 1, target_axes)
return math.stack([state[sl_0], state_x], axis=control_axes)
@apply_operation.register
def apply_multicontrolledx(
op: qml.MultiControlledX,
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
r"""Apply MultiControlledX to a state with the default einsum/tensordot choice
for 8 operation wires or less. Otherwise, apply a custom kernel based on
composing transpositions, rolling of control axes and the CNOT logic above."""
if len(op.wires) < 9:
return _apply_operation_default(op, state, is_state_batched, debugger)
ctrl_wires = [w + is_state_batched for w in op.control_wires]
# apply x on all control wires with control value 0
roll_axes = [w for val, w in zip(op.control_values, ctrl_wires) if val is False]
for ax in roll_axes:
state = math.roll(state, 1, ax)
orig_shape = math.shape(state)
# Move the axes into the order [(batch), other, target, controls]
transpose_axes = (
np.array(
[
w - is_state_batched
for w in range(len(orig_shape))
if w - is_state_batched not in op.wires
]
+ [op.wires[-1]]
+ op.wires[:-1].tolist()
)
+ is_state_batched
)
state = math.transpose(state, transpose_axes)
# Reshape the state into 3-dimensional array with axes [batch+other, target, controls]
state = math.reshape(state, (-1, 2, 2 ** (len(op.wires) - 1)))
# The part of the state to which we want to apply PauliX is now in the last entry along the
# third axis. Extract it, apply the PauliX along the target axis (1), and append a dummy axis
state_x = math.roll(state[:, :, -1], 1, 1)[:, :, np.newaxis]
# Stack the transformed part of the state with the unmodified rest of the state
state = math.concatenate([state[:, :, :-1], state_x], axis=2)
# Reshape into original shape and undo the transposition
state = math.transpose(math.reshape(state, orig_shape), np.argsort(transpose_axes))
# revert x on all "wrong" controls
for ax in roll_axes:
state = math.roll(state, 1, ax)
return state
@apply_operation.register
def apply_grover(
op: qml.GroverOperator,
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
"""Apply GroverOperator either via a custom matrix-free method (more than 8 operation
wires) or via standard matrix based methods (else)."""
if len(op.wires) < 9:
return _apply_operation_default(op, state, is_state_batched, debugger)
return _apply_grover_without_matrix(state, op.wires, is_state_batched)
def _apply_grover_without_matrix(state, op_wires, is_state_batched):
r"""Apply GroverOperator to state. This method uses that this operator
is :math:`2*P-\mathbb{I}`, where :math:`P` is the projector onto the
all-plus state. This allows us to compute the new state by replacing summing
over all axes on which the operation acts, and "filling in" the all-plus state
in the resulting lower-dimensional state via a Kronecker product.
"""
num_wires = len(op_wires)
# 2 * Squared normalization of the all-plus state on the op wires
# (squared, because we skipped the normalization when summing, 2* because of the def of Grover)
prefactor = 2 ** (1 - num_wires)
# The axes to sum over in order to obtain <+|\psi>, where <+| only acts on the op wires.
sum_axes = [w + is_state_batched for w in op_wires]
collapsed = math.sum(state, axis=tuple(sum_axes))
if num_wires == (len(qml.math.shape(state)) - is_state_batched):
# If the operation acts on all wires, we can skip the tensor product with all-ones state
new_shape = (-1,) + (1,) * num_wires if is_state_batched else (1,) * num_wires
return prefactor * math.reshape(collapsed, new_shape) - state
# [todo]: Once Tensorflow support expand_dims with multiple axes in the second argument,
# use the following line instead of the two above.
# return prefactor * math.expand_dims(collapsed, sum_axes) - state
all_plus = math.cast_like(math.full([2] * num_wires, prefactor), state)
# After the Kronecker product (realized with tensordot with axes=0), we need to move
# the new axes to the summed-away axes' positions. Finally, subtract the original state.
source = list(range(math.ndim(collapsed), math.ndim(state)))
# Probably it will be better to use math.full or math.tile to create the outer product
# here computed with math.tensordot. However, Tensorflow and Torch do not have full support
return math.moveaxis(math.tensordot(collapsed, all_plus, axes=0), source, sum_axes) - state
@apply_operation.register
def apply_snapshot(
op: qml.Snapshot, state, is_state_batched: bool = False, debugger=None, **execution_kwargs
):
"""Take a snapshot of the state."""
if debugger is not None and debugger.active:
measurement = op.hyperparameters["measurement"]
shots = execution_kwargs.get("tape_shots")
if isinstance(measurement, qml.measurements.StateMP) or not shots:
snapshot = qml.devices.qubit.measure(measurement, state, is_state_batched)
else:
snapshot = qml.devices.qubit.measure_with_samples(
[measurement],
state,
shots,
is_state_batched,
execution_kwargs.get("rng"),
execution_kwargs.get("prng_key"),
)[0]
if op.tag:
debugger.snapshots[op.tag] = snapshot
else:
debugger.snapshots[len(debugger.snapshots)] = snapshot
return state
# pylint:disable = no-value-for-parameter, import-outside-toplevel
@apply_operation.register
def apply_parametrized_evolution(
op: qml.pulse.ParametrizedEvolution,
state,
is_state_batched: bool = False,
debugger=None,
**_,
):
"""Apply ParametrizedEvolution by evolving the state rather than the operator matrix
if we are operating on more than half of the subsystem"""
# shape(state) is static (not a tracer), we can use an if statement
num_wires = len(qml.math.shape(state)) - is_state_batched
state = qml.math.cast(state, complex)
if (
2 * len(op.wires) <= num_wires
or op.hyperparameters["complementary"]
or (is_state_batched and op.hyperparameters["return_intermediate"])
):
# the subsystem operated on is half as big as the total system, or less
# or we want complementary time evolution
# or both the state and the operation have a batch dimension
# --> evolve matrix
return _apply_operation_default(op, state, is_state_batched, debugger)
# otherwise --> evolve state
return _evolve_state_vector_under_parametrized_evolution(op, state, num_wires, is_state_batched)
def _evolve_state_vector_under_parametrized_evolution(
operation: qml.pulse.ParametrizedEvolution, state, num_wires, is_state_batched
):
"""Uses an odeint solver to compute the evolution of the input ``state`` under the given
``ParametrizedEvolution`` operation.
Args:
state (array[complex]): input state
operation (ParametrizedEvolution): operation to apply on the state
Raises:
ValueError: If the parameters and time windows of the ``ParametrizedEvolution`` are
not defined.
Returns:
TensorLike[complex]: output state
"""
try:
import jax
from jax.experimental.ode import odeint
from pennylane.pulse.parametrized_hamiltonian_pytree import ParametrizedHamiltonianPytree
except ImportError as e: # pragma: no cover
raise ImportError(
"Module jax is required for the ``ParametrizedEvolution`` class. "
"You can install jax via: pip install jax"
) from e
if operation.data is None or operation.t is None:
raise ValueError(
"The parameters and the time window are required to execute a ParametrizedEvolution "
"You can update these values by calling the ParametrizedEvolution class: EV(params, t)."
)
if is_state_batched:
batch_dim = state.shape[0]
state = qml.math.moveaxis(state.reshape((batch_dim, 2**num_wires)), 1, 0)
out_shape = [2] * num_wires + [batch_dim] # this shape is before moving the batch_dim back
else:
state = state.flatten()
out_shape = [2] * num_wires
with jax.ensure_compile_time_eval():
H_jax = ParametrizedHamiltonianPytree.from_hamiltonian( # pragma: no cover
operation.H,
dense=operation.dense,
wire_order=list(np.arange(num_wires)),
)
def fun(y, t):
"""dy/dt = -i H(t) y"""
return (-1j * H_jax(operation.data, t=t)) @ y
result = odeint(fun, state, operation.t, **operation.odeint_kwargs)
if operation.hyperparameters["return_intermediate"]:
return qml.math.reshape(result, [-1] + out_shape)
result = qml.math.reshape(result[-1], out_shape)
if is_state_batched:
return qml.math.moveaxis(result, -1, 0)
return result
_modules/pennylane/devices/qubit/apply_operation
Download Python script
Download Notebook
View on GitHub