Source code for pennylane.devices.qubit.apply_operation

# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to apply an operation to a state vector."""
# pylint: disable=unused-argument

from functools import singledispatch
from string import ascii_letters as alphabet

import numpy as np
import scipy as sp

import pennylane as qml
from pennylane import math, ops
from pennylane.measurements import MidMeasureMP
from pennylane.operation import Operator
from pennylane.ops import Conditional

SQRT2INV = 1 / math.sqrt(2)

EINSUM_OP_WIRECOUNT_PERF_THRESHOLD = 3
EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD = 13


def _get_slice(index, axis, num_axes):
    """Allows slicing along an arbitrary axis of an array or tensor.

    Args:
        index (int): the index to access
        axis (int): the axis to slice into
        num_axes (int): total number of axes

    Returns:
        tuple[slice or int]: a tuple that can be used to slice into an array or tensor

    **Example:**

    Accessing the 2 index along axis 1 of a 3-axis array:

    >>> sl = _get_slice(2, 1, 3)
    >>> sl
    (slice(None, None, None), 2, slice(None, None, None))
    >>> a = np.arange(27).reshape((3, 3, 3))
    >>> a[sl]
    array([[ 6,  7,  8],
           [15, 16, 17],
           [24, 25, 26]])
    """
    idx = [slice(None)] * num_axes
    idx[axis] = index
    return tuple(idx)


def apply_operation_einsum(op: Operator, state, is_state_batched: bool = False):
    """Apply ``Operator`` to ``state`` using ``einsum``. This is more efficent at lower qubit
    numbers.

    Args:
        op (Operator): Operator to apply to the quantum state
        state (array[complex]): Input quantum state
        is_state_batched (bool): Boolean representing whether the state is batched or not

    Returns:
        array[complex]: output_state
    """
    # We use this implicit casting strategy as autograd raises ComplexWarnings
    # when backpropagating if casting explicitly. Some type of casting is needed
    # to prevent ComplexWarnings with backpropagation with other interfaces
    if (
        math.get_interface(state) == "tensorflow"
    ):  # pragma: no cover (TensorFlow tests were disabled during deprecation)
        mat = math.cast_like(op.matrix(), state)
    else:
        mat = op.matrix() + 0j

    total_indices = len(state.shape) - is_state_batched
    num_indices = len(op.wires)

    state_indices = alphabet[:total_indices]
    affected_indices = "".join(alphabet[i] for i in op.wires)

    new_indices = alphabet[total_indices : total_indices + num_indices]

    new_state_indices = state_indices
    for old, new in zip(affected_indices, new_indices):
        new_state_indices = new_state_indices.replace(old, new)

    einsum_indices = (
        f"...{new_indices}{affected_indices},...{state_indices}->...{new_state_indices}"
    )

    new_mat_shape = [2] * (num_indices * 2)
    dim = 2**num_indices
    batch_size = math.get_batch_size(mat, (dim, dim), dim**2)
    if batch_size is not None:
        # Add broadcasting dimension to shape
        new_mat_shape = [batch_size] + new_mat_shape
        if op.batch_size is None:
            op._batch_size = batch_size  # pylint:disable=protected-access
    reshaped_mat = math.reshape(mat, new_mat_shape)

    return math.einsum(einsum_indices, reshaped_mat, state)


def apply_operation_tensordot(op: Operator, state, is_state_batched: bool = False):
    """Apply ``Operator`` to ``state`` using ``math.tensordot``. This is more efficent at higher qubit
    numbers.

    Args:
        op (Operator): Operator to apply to the quantum state
        state (array[complex]): Input quantum state
        is_state_batched (bool): Boolean representing whether the state is batched or not

    Returns:
        array[complex]: output_state
    """
    # We use this implicit casting strategy as autograd raises ComplexWarnings
    # when backpropagating if casting explicitly. Some type of casting is needed
    # to prevent ComplexWarnings with backpropagation with other interfaces
    if (
        math.get_interface(state) == "tensorflow"
    ):  # pragma: no cover (TensorFlow tests were disabled during deprecation)
        mat = math.cast_like(op.matrix(), state)
    else:
        mat = op.matrix() + 0j

    total_indices = len(state.shape) - is_state_batched
    num_indices = len(op.wires)

    new_mat_shape = [2] * (num_indices * 2)
    dim = 2**num_indices
    batch_size = math.get_batch_size(mat, (dim, dim), dim**2)
    if is_mat_batched := batch_size is not None:
        # Add broadcasting dimension to shape
        new_mat_shape = [batch_size] + new_mat_shape
        if op.batch_size is None:
            op._batch_size = batch_size  # pylint:disable=protected-access
    reshaped_mat = math.reshape(mat, new_mat_shape)

    mat_axes = list(range(-num_indices, 0))
    state_axes = [i + is_state_batched for i in op.wires]
    axes = (mat_axes, state_axes)

    tdot = math.tensordot(reshaped_mat, state, axes=axes)

    # tensordot causes the axes given in `wires` to end up in the first positions
    # of the resulting tensor. This corresponds to a (partial) transpose of
    # the correct output state
    # We'll need to invert this permutation to put the indices in the correct place
    unused_idxs = [i for i in range(total_indices) if i not in op.wires]
    perm = list(op.wires) + unused_idxs
    if is_mat_batched:
        perm = [0] + [i + 1 for i in perm]
    if is_state_batched:
        perm.insert(num_indices, -1)

    inv_perm = math.argsort(perm)
    return math.transpose(tdot, inv_perm)


[docs] @singledispatch def apply_operation( op: Operator, state, is_state_batched: bool = False, debugger=None, **_, ): """Apply and operator to a given state. Args: op (Operator): The operation to apply to ``state`` state (TensorLike): The starting state. is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use **execution_kwargs (Optional[dict]): Optional keyword arguments needed for applying some operations described below. Keyword Arguments: mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value interface (str): The machine learning interface of the state postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to keep the same number of shots. ``None`` by default. rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, a ``numpy.random.default_rng`` will be used for sampling. tape_shots (Shots): the shots object of the tape Returns: ndarray: output state .. warning:: ``apply_operation`` is an internal function, and thus subject to change without a deprecation cycle. .. warning:: ``apply_operation`` applies no validation to its inputs. This function assumes that the wires of the operator correspond to indices of the state. See :func:`~.map_wires` to convert operations to integer wire labels. The shape of state should be ``[2]*num_wires``. This is a ``functools.singledispatch`` function, so additional specialized kernels for specific operations can be registered like: .. code-block:: python @apply_operation.register def _(op: type_op, state): # custom op application method here **Example:** >>> state = np.zeros((2,2)) >>> state[0][0] = 1 >>> state tensor([[1., 0.], [0., 0.]], requires_grad=True) >>> apply_operation(qml.X(0), state) tensor([[0., 0.], [1., 0.]], requires_grad=True) """ return _apply_operation_default(op, state, is_state_batched, debugger)
def apply_operation_csr_matrix(op, state, is_state_batched: bool = False): """The csr_matrix specialized version apply operation.""" # State is numpy array, should have been stored in tensor version # remember the initial shape and recover in the end if sp.sparse.issparse(state): raise TypeError("State should not be sparse in default qubit pipeline") original_shape = math.shape(state) num_wires = len(original_shape) - int(is_state_batched) full_state = math.reshape(state, [-1, 2**num_wires]) # expected: [batch_size, 2**num_wires] state_opT = full_state @ op.sparse_matrix(wire_order=range(num_wires)).T state_reshaped = math.reshape(state_opT, original_shape) return state_reshaped def _apply_operation_default(op, state, is_state_batched, debugger): """The default behaviour of apply_operation, accessed through the standard dispatch of apply_operation, as well as conditionally in other dispatches.""" if op.has_sparse_matrix and not op.has_matrix: return apply_operation_csr_matrix(op, state, is_state_batched=is_state_batched) if ( len(op.wires) < EINSUM_OP_WIRECOUNT_PERF_THRESHOLD and math.ndim(state) < EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD ) or (op.batch_size and is_state_batched): return apply_operation_einsum(op, state, is_state_batched=is_state_batched) return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) @apply_operation.register def apply_conditional( op: Conditional, state, is_state_batched: bool = False, debugger=None, **execution_kwargs, ): """Applies a conditional operation. Args: op (Operator): The operation to apply to ``state`` state (TensorLike): The starting state. is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value interface (str): The machine learning interface of the state rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state """ mid_measurements = execution_kwargs.get("mid_measurements", None) rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) interface = math.get_deep_interface(state) if interface == "jax": # pylint: disable=import-outside-toplevel from jax.lax import cond return cond( op.meas_val.concretize(mid_measurements), lambda x: apply_operation( op.base, x, is_state_batched=is_state_batched, debugger=debugger, mid_measurements=mid_measurements, rng=rng, prng_key=prng_key, ), lambda x: x, state, ) if op.meas_val.concretize(mid_measurements): return apply_operation( op.base, state, is_state_batched=is_state_batched, debugger=debugger, mid_measurements=mid_measurements, rng=rng, prng_key=prng_key, ) return state @apply_operation.register def apply_mid_measure( op: MidMeasureMP, state, is_state_batched: bool = False, debugger=None, **execution_kwargs ): """Applies a native mid-circuit measurement. Args: op (Operator): The operation to apply to ``state`` state (TensorLike): The starting state. is_state_batched (bool): Boolean representing whether the state is batched or not debugger (_Debugger): The debugger to use mid_measurements (dict, None): Mid-circuit measurement dictionary mutated to record the sampled value postselect_mode (str): Configuration for handling shots with mid-circuit measurement postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to keep the same number of shots. ``None`` by default. rng (Optional[numpy.random._generator.Generator]): A NumPy random number generator. prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the JAX pseudo random number generator. Only for simulation using JAX. If None, a ``numpy.random.default_rng`` will be used for sampling. Returns: ndarray: output state """ mid_measurements = execution_kwargs.get("mid_measurements", None) rng = execution_kwargs.get("rng", None) prng_key = execution_kwargs.get("prng_key", None) if is_state_batched: raise ValueError("MidMeasureMP cannot be applied to batched states.") wire = op.wires interface = math.get_deep_interface(state) axis = wire.toarray()[0] slices = [slice(None)] * math.ndim(state) slices[axis] = 0 prob0 = math.real(math.norm(state[tuple(slices)])) ** 2 if prng_key is not None: # pylint: disable=import-outside-toplevel from jax.random import binomial def binomial_fn(n, p): return binomial(prng_key, n, p).astype(int) else: binomial_fn = np.random.binomial if rng is None else rng.binomial sample = binomial_fn(1, 1 - prob0) assert mid_measurements is not None mid_measurements[op] = sample # Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.Projector([sample], wire),...) # to select the sample branch enables jax.jit and prevents it from using Python callbacks matrix = math.array([[(sample + 1) % 2, 0.0], [0.0, (sample) % 2]], like=interface) state = apply_operation( ops.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger, ) state = state / math.norm(state) # Using apply_operation(qml.QubitUnitary,...) instead of apply_operation(qml.X(wire), ...) # to reset enables jax.jit and prevents it from using Python callbacks element = op.reset and sample == 1 matrix = math.array( [[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface, dtype=float, ) state = apply_operation( ops.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger ) return state @apply_operation.register def apply_identity(op: ops.Identity, state, is_state_batched: bool = False, debugger=None, **_): """Applies a :class:`~.Identity` operation by just returning the input state.""" return state @apply_operation.register def apply_global_phase( op: ops.GlobalPhase, state, is_state_batched: bool = False, debugger=None, **_ ): """Applies a :class:`~.GlobalPhase` operation by multiplying the state by ``exp(1j * op.data[0])``""" return math.exp(-1j * math.cast(op.data[0], complex)) * state @apply_operation.register def apply_paulix(op: ops.X, state, is_state_batched: bool = False, debugger=None, **_): """Apply :class:`pennylane.PauliX` operator to the quantum state""" axis = op.wires[0] + is_state_batched return math.roll(state, 1, axis) @apply_operation.register def apply_pauliz(op: ops.Z, state, is_state_batched: bool = False, debugger=None, **_): """Apply pauliz to state.""" axis = op.wires[0] + is_state_batched n_dim = math.ndim(state) if ( n_dim >= 9 and math.get_interface(state) == "tensorflow" ): # pragma: no cover (TensorFlow tests were disabled during deprecation) return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) sl_0 = _get_slice(0, axis, n_dim) sl_1 = _get_slice(1, axis, n_dim) # must be first state and then -1 because it breaks otherwise state1 = math.multiply(state[sl_1], -1) return math.stack([state[sl_0], state1], axis=axis) @apply_operation.register def apply_phaseshift(op: ops.PhaseShift, state, is_state_batched: bool = False, debugger=None, **_): """Apply PhaseShift to state.""" n_dim = math.ndim(state) if ( n_dim >= 9 and math.get_interface(state) == "tensorflow" ): # pragma: no cover (TensorFlow tests were disabled during deprecation) return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) axis = op.wires[0] + is_state_batched sl_0 = _get_slice(0, axis, n_dim) sl_1 = _get_slice(1, axis, n_dim) params = math.cast(op.parameters[0], dtype=complex) state0 = state[sl_0] state1 = state[sl_1] if op.batch_size is not None: interface = math.get_interface(state) if interface == "torch": params = math.array(params, like=interface) if is_state_batched: params = math.reshape(params, (-1,) + (1,) * (n_dim - 2)) else: axis = axis + 1 params = math.reshape(params, (-1,) + (1,) * (n_dim - 1)) state0 = math.expand_dims(state0, 0) + math.zeros_like(params) state1 = math.expand_dims(state1, 0) state1 = math.multiply(math.cast(state1, dtype=complex), math.exp(1.0j * params)) state = math.stack([state0, state1], axis=axis) return state @apply_operation.register def apply_T(op: ops.T, state, is_state_batched: bool = False, debugger=None, **_): """Apply T to state.""" axis = op.wires[0] + is_state_batched n_dim = math.ndim(state) if ( n_dim >= 9 and math.get_interface(state) == "tensorflow" ): # pragma: no cover (TensorFlow tests were disabled during deprecation) return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) sl_0 = _get_slice(0, axis, n_dim) sl_1 = _get_slice(1, axis, n_dim) state1 = math.multiply(math.cast(state[sl_1], dtype=complex), math.exp(0.25j * np.pi)) return math.stack([state[sl_0], state1], axis=axis) @apply_operation.register def apply_S(op: ops.S, state, is_state_batched: bool = False, debugger=None, **_): """Apply S to state.""" axis = op.wires[0] + is_state_batched n_dim = math.ndim(state) if ( n_dim >= 9 and math.get_interface(state) == "tensorflow" ): # pragma: no cover (TensorFlow tests were disabled during deprecation) return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) sl_0 = _get_slice(0, axis, n_dim) sl_1 = _get_slice(1, axis, n_dim) state1 = math.multiply(math.cast(state[sl_1], dtype=complex), 1j) return math.stack([state[sl_0], state1], axis=axis) @apply_operation.register def apply_cnot(op: ops.CNOT, state, is_state_batched: bool = False, debugger=None, **_): """Apply cnot gate to state.""" target_axes = (op.wires[1] - 1 if op.wires[1] > op.wires[0] else op.wires[1]) + is_state_batched control_axes = op.wires[0] + is_state_batched n_dim = math.ndim(state) if ( n_dim >= 9 and math.get_interface(state) == "tensorflow" ): # pragma: no cover (TensorFlow tests were disabled during deprecation) return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) sl_0 = _get_slice(0, control_axes, n_dim) sl_1 = _get_slice(1, control_axes, n_dim) state_x = math.roll(state[sl_1], 1, target_axes) return math.stack([state[sl_0], state_x], axis=control_axes) @apply_operation.register def apply_multicontrolledx( op: ops.MultiControlledX, state, is_state_batched: bool = False, debugger=None, **_, ): r"""Apply MultiControlledX to a state with the default einsum/tensordot choice for 8 operation wires or less. Otherwise, apply a custom kernel based on composing transpositions, rolling of control axes and the CNOT logic above.""" if len(op.wires) < 9: return _apply_operation_default(op, state, is_state_batched, debugger) ctrl_wires = [w + is_state_batched for w in op.control_wires] # apply x on all control wires with control value 0 roll_axes = [w for val, w in zip(op.control_values, ctrl_wires) if val is False] for ax in roll_axes: state = math.roll(state, 1, ax) orig_shape = math.shape(state) # Move the axes into the order [(batch), other, target, controls] transpose_axes = ( np.array( [ w - is_state_batched for w in range(len(orig_shape)) if w - is_state_batched not in op.wires ] + [op.wires[-1]] + op.wires[:-1].tolist() ) + is_state_batched ) state = math.transpose(state, transpose_axes) # Reshape the state into 3-dimensional array with axes [batch+other, target, controls] state = math.reshape(state, (-1, 2, 2 ** (len(op.wires) - 1))) # The part of the state to which we want to apply PauliX is now in the last entry along the # third axis. Extract it, apply the PauliX along the target axis (1), and append a dummy axis state_x = math.roll(state[:, :, -1], 1, 1)[:, :, np.newaxis] # Stack the transformed part of the state with the unmodified rest of the state state = math.concatenate([state[:, :, :-1], state_x], axis=2) # Reshape into original shape and undo the transposition state = math.transpose(math.reshape(state, orig_shape), np.argsort(transpose_axes)) # revert x on all "wrong" controls for ax in roll_axes: state = math.roll(state, 1, ax) return state @apply_operation.register def apply_grover( op: qml.GroverOperator, state, is_state_batched: bool = False, debugger=None, **_, ): """Apply GroverOperator either via a custom matrix-free method (more than 8 operation wires) or via standard matrix based methods (else).""" if len(op.wires) < 9: return _apply_operation_default(op, state, is_state_batched, debugger) return _apply_grover_without_matrix(state, op.wires, is_state_batched) def _apply_grover_without_matrix(state, op_wires, is_state_batched): r"""Apply GroverOperator to state. This method uses that this operator is :math:`2*P-\mathbb{I}`, where :math:`P` is the projector onto the all-plus state. This allows us to compute the new state by replacing summing over all axes on which the operation acts, and "filling in" the all-plus state in the resulting lower-dimensional state via a Kronecker product. """ num_wires = len(op_wires) # 2 * Squared normalization of the all-plus state on the op wires # (squared, because we skipped the normalization when summing, 2* because of the def of Grover) prefactor = 2 ** (1 - num_wires) # The axes to sum over in order to obtain <+|\psi>, where <+| only acts on the op wires. sum_axes = [w + is_state_batched for w in op_wires] collapsed = math.sum(state, axis=tuple(sum_axes)) if num_wires == (len(math.shape(state)) - is_state_batched): # If the operation acts on all wires, we can skip the tensor product with all-ones state new_shape = (-1,) + (1,) * num_wires if is_state_batched else (1,) * num_wires return prefactor * math.reshape(collapsed, new_shape) - state # [todo]: Once Tensorflow support expand_dims with multiple axes in the second argument, # use the following line instead of the two above. # return prefactor * math.expand_dims(collapsed, sum_axes) - state all_plus = math.cast_like(math.full([2] * num_wires, prefactor), state) # After the Kronecker product (realized with tensordot with axes=0), we need to move # the new axes to the summed-away axes' positions. Finally, subtract the original state. source = list(range(math.ndim(collapsed), math.ndim(state))) # Probably it will be better to use math.full or math.tile to create the outer product # here computed with math.tensordot. However, Tensorflow and Torch do not have full support return math.moveaxis(math.tensordot(collapsed, all_plus, axes=0), source, sum_axes) - state @apply_operation.register def apply_snapshot( op: ops.Snapshot, state, is_state_batched: bool = False, debugger=None, **execution_kwargs ): """Take a snapshot of the state.""" if debugger is None or not debugger.active: return state measurement = op.hyperparameters["measurement"] if op.hyperparameters["shots"] == "workflow": shots = execution_kwargs.get("tape_shots") else: shots = op.hyperparameters["shots"] if shots: snapshot = qml.devices.qubit.measure_with_samples( [measurement], state, shots, is_state_batched, execution_kwargs.get("rng"), execution_kwargs.get("prng_key"), )[0] else: snapshot = qml.devices.qubit.measure(measurement, state, is_state_batched) if op.tag is None: debugger.snapshots[len(debugger.snapshots)] = snapshot elif op.tag not in debugger.snapshots: debugger.snapshots[op.tag] = snapshot elif isinstance(debugger.snapshots[op.tag], list): debugger.snapshots[op.tag].append(snapshot) else: debugger.snapshots[op.tag] = [debugger.snapshots[op.tag], snapshot] return state # pylint:disable=import-outside-toplevel @apply_operation.register def apply_parametrized_evolution( op: qml.pulse.ParametrizedEvolution, state, is_state_batched: bool = False, debugger=None, **_, ): """Apply ParametrizedEvolution by evolving the state rather than the operator matrix if we are operating on more than half of the subsystem""" # shape(state) is static (not a tracer), we can use an if statement num_wires = len(math.shape(state)) - is_state_batched state = math.cast(state, complex) if ( 2 * len(op.wires) <= num_wires or op.hyperparameters["complementary"] or (is_state_batched and op.hyperparameters["return_intermediate"]) ): # the subsystem operated on is half as big as the total system, or less # or we want complementary time evolution # or both the state and the operation have a batch dimension # --> evolve matrix return _apply_operation_default(op, state, is_state_batched, debugger) # otherwise --> evolve state return _evolve_state_vector_under_parametrized_evolution(op, state, num_wires, is_state_batched) def _evolve_state_vector_under_parametrized_evolution( operation: qml.pulse.ParametrizedEvolution, state, num_wires, is_state_batched ): """Uses an odeint solver to compute the evolution of the input ``state`` under the given ``ParametrizedEvolution`` operation. Args: state (array[complex]): input state operation (ParametrizedEvolution): operation to apply on the state Raises: ValueError: If the parameters and time windows of the ``ParametrizedEvolution`` are not defined. Returns: TensorLike[complex]: output state """ try: import jax from jax.experimental.ode import odeint from pennylane.pulse.parametrized_hamiltonian_pytree import ParametrizedHamiltonianPytree except ImportError as e: # pragma: no cover raise ImportError( "Module jax is required for the ``ParametrizedEvolution`` class. " "You can install jax via: pip install jax~=0.6.0" ) from e if operation.data is None or operation.t is None: raise ValueError( "The parameters and the time window are required to execute a ParametrizedEvolution " "You can update these values by calling the ParametrizedEvolution class: EV(params, t)." ) if is_state_batched: batch_dim = state.shape[0] state = math.moveaxis(state.reshape((batch_dim, 2**num_wires)), 1, 0) out_shape = [2] * num_wires + [batch_dim] # this shape is before moving the batch_dim back else: state = state.flatten() out_shape = [2] * num_wires with jax.ensure_compile_time_eval(): H_jax = ParametrizedHamiltonianPytree.from_hamiltonian( # pragma: no cover operation.H, dense=operation.dense, wire_order=list(np.arange(num_wires)), ) def fun(y, t): """dy/dt = -i H(t) y""" return (-1j * H_jax(operation.data, t=t)) @ y result = odeint(fun, state, operation.t, **operation.odeint_kwargs) if operation.hyperparameters["return_intermediate"]: return math.reshape(result, [-1] + out_shape) result = math.reshape(result[-1], out_shape) if is_state_batched: return math.moveaxis(result, -1, 0) return result