Source code for pennylane.devices.default_qutrit

# Copyright 2018-2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
The default.qutrit device is PennyLane's standard qutrit-based device.

It implements the :class:`~pennylane.devices._legacy_device.Device` methods as well as some built-in
:mod:`qutrit operations <pennylane.ops.qutrit>`, and provides simple pure state
simulation of qutrit-based quantum computing.
"""
import functools
import logging

import numpy as np

import pennylane as qml  # pylint: disable=unused-import
from pennylane.devices.default_qubit_legacy import _get_slice
from pennylane.logging import debug_logger, debug_logger_init
from pennylane.wires import WireError

from .._version import __version__
from ._qutrit_device import QutritDevice

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

# tolerance for numerical errors
tolerance = 1e-10

OMEGA = qml.math.exp(2 * np.pi * 1j / 3)


# pylint: disable=too-many-arguments
[docs]class DefaultQutrit(QutritDevice): """Default qutrit device for PennyLane. .. warning:: The API of ``DefaultQutrit`` will be updated soon to follow a new device interface described in :class:`pennylane.devices.Device`. This change will not alter device behaviour for most workflows, but may have implications for plugin developers and users who directly interact with device methods. Please consult :class:`pennylane.devices.Device` and the implementation in :class:`pennylane.devices.DefaultQubit` for more information on what the new interface will look like and be prepared to make updates in a coming release. If you have any feedback on these changes, please create an `issue <https://github.com/PennyLaneAI/pennylane/issues>`_ or post in our `discussion forum <https://discuss.pennylane.ai/>`_. Args: wires (int, Iterable[Number, str]): Number of subsystems represented by the device, or iterable that contains unique labels for the subsystems as numbers (i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``). Default 1 if not specified. shots (None, int): How many times the circuit should be evaluated (or sampled) to estimate the expectation values. Defaults to ``None`` if not specified, which means that the device returns analytical results. """ name = "Default qutrit PennyLane plugin" short_name = "default.qutrit" pennylane_requires = __version__ version = __version__ author = "Mudit Pandey, UBC Quantum Software and Algorithms Research Group, and Xanadu" # TODO: Update list of operations and observables once more are added operations = { "Identity", "QutritUnitary", "ControlledQutritUnitary", "TShift", "Adjoint(TShift)", "TClock", "Adjoint(TClock)", "TAdd", "Adjoint(TAdd)", "TSWAP", "THadamard", "Adjoint(THadamard)", "TRX", "TRY", "TRZ", "QutritBasisState", } # Identity is supported as an observable for qml.state() to work correctly. However, any # measurement types that rely on eigenvalue decomposition will not work with qml.Identity observables = {"THermitian", "GellMann", "Identity", "Prod"} # Static methods to use qml.math to allow for backprop differentiation _reshape = staticmethod(qml.math.reshape) _flatten = staticmethod(qml.math.flatten) _transpose = staticmethod(qml.math.transpose) _dot = staticmethod(qml.math.dot) _stack = staticmethod(qml.math.stack) _conj = staticmethod(qml.math.conj) _roll = staticmethod(qml.math.roll) _cast = staticmethod(qml.math.cast) _tensordot = staticmethod(qml.math.tensordot) _real = staticmethod(qml.math.real) _imag = staticmethod(qml.math.imag) @staticmethod def _reduce_sum(array, axes): return qml.math.sum(array, tuple(axes)) @staticmethod def _asarray(array, dtype=None): # Support float if not hasattr(array, "__len__"): return np.asarray(array, dtype=dtype) res = qml.math.cast(array, dtype=dtype) return res @debug_logger_init def __init__( self, wires, *, r_dtype=np.float64, c_dtype=np.complex128, shots=None, analytic=None, ): super().__init__(wires, shots, r_dtype=r_dtype, c_dtype=c_dtype, analytic=analytic) # TODO: add support for snapshots # self._debugger = None # Create the initial state. Internally, we store the # state as an array of dimension [3]*wires. self._state = self._create_basis_state(0) self._pre_rotated_state = self._state # TODO: Add operations self._apply_ops = { # All operations that can be applied on the `default.qutrit` device by directly # manipulating the internal state array will be included in this dictionary "TShift": self._apply_tshift, "TClock": self._apply_tclock, "TAdd": self._apply_tadd, "TSWAP": self._apply_tswap, }
[docs] @functools.lru_cache() def map_wires(self, wires): # temporarily overwrite this method to bypass # wire map that produces Wires objects try: mapped_wires = [self.wire_map[w] for w in wires] except KeyError as e: raise WireError( f"Did not find some of the wires {wires.labels} on device with wires {self.wires.labels}." ) from e return mapped_wires
[docs] def define_wire_map(self, wires): # temporarily overwrite this method to bypass # wire map that produces Wires objects consecutive_wires = range(self.num_wires) wire_map = zip(wires, consecutive_wires) return dict(wire_map)
[docs] @debug_logger def apply(self, operations, rotations=None, **kwargs): # pylint: disable=arguments-differ rotations = rotations or [] # apply the circuit operations # Operations are enumerated so that the order of operations can eventually be used # for correctly applying basis state / state vector / snapshot operations which will # be added later. for i, operation in enumerate(operations): # pylint: disable=unused-variable if i > 0 and isinstance(operation, qml.QutritBasisState): raise qml.DeviceError( f"Operation {operation.name} cannot be used after other operations have already been applied " f"on a {self.short_name} device." ) if isinstance(operation, qml.QutritBasisState): self._apply_basis_state(operation.parameters[0], operation.wires) else: self._state = self._apply_operation(self._state, operation) # store the pre-rotated state self._pre_rotated_state = self._state # apply the circuit rotations for operation in rotations: self._state = self._apply_operation(self._state, operation)
def _apply_basis_state(self, state, wires): """Initialize the state vector in a specified computational basis state. Args: state (array[int]): computational basis state of shape ``(wires,)`` consisting of 0s, 1s and 2s. wires (Wires): wires that the provided computational state should be initialized on Note: This function does not support broadcasted inputs yet. """ # translate to wire labels used by device device_wires = self.map_wires(wires) # length of basis state parameter n_basis_state = len(state) if not set(state.tolist()).issubset({0, 1, 2}): raise ValueError("QutritBasisState parameter must consist of 0, 1 or 2 integers.") if n_basis_state != len(device_wires): raise ValueError("QutritBasisState parameter and wires must be of equal length.") # get computational basis state number basis_states = 3 ** (self.num_wires - 1 - np.array(device_wires)) basis_states = qml.math.convert_like(basis_states, state) num = int(qml.math.dot(state, basis_states)) self._state = self._create_basis_state(num) def _apply_operation(self, state, operation): """Applies operations to the input state. Args: state (array[complex]): input state operation (~.Operation): operation to apply on the device Returns: array[complex]: output state """ if operation.name == "Identity": return state wires = operation.wires if operation.name in self._apply_ops: # pylint: disable=no-else-return axes = self.wires.indices(wires) return self._apply_ops[operation.name](state, axes) elif ( isinstance(operation, qml.ops.Adjoint) # pylint: disable=no-member and operation.base.name in self._apply_ops ): axes = self.wires.indices(wires) return self._apply_ops[operation.base.name](state, axes, inverse=True) matrix = self._asarray(self._get_unitary_matrix(operation), dtype=self.C_DTYPE) return self._apply_unitary(state, matrix, wires) def _apply_tshift(self, state, axes, inverse=False): """Applies a ternary Shift gate by rolling 1 unit along the axis specified in ``axes``. Rolling by 1 unit along the axis means that the :math:`|0 \rangle` state with index ``0`` is shifted to the :math:`|1 \rangle` state with index ``1``. Likewise, since rolling beyond the last index loops back to the first, :math:`|2 \rangle` is transformed to :math:`|0 \rangle`. Args: state (array[complex]): input state axes (List[int]): target axes to apply transformation inverse (bool): whether to apply the inverse operation Returns: array[complex]: output state """ shift = -1 if inverse else 1 return self._roll(state, shift, axes[0]) def _apply_tclock(self, state, axes, inverse=False): """Applies a ternary Clock gate by adding appropriate phases to the 1 and 2 indices along the axis specified in ``axes`` Args: state (array[complex]): input state axes (List[int]): target axes to apply transformation inverse (bool): whether to apply the inverse operation Returns: array[complex]: output state """ partial_state = self._apply_phase(state, axes, 1, OMEGA, inverse) return self._apply_phase(partial_state, axes, 2, OMEGA**2, inverse) def _apply_tadd(self, state, axes, inverse=False): """Applies a controlled ternary add gate by slicing along the first axis specified in ``axes`` and applying a TShift transformation along the second axis. The ternary add gate acts on the computational basis states like :math:`\text{TAdd}\vert i, j\rangle \rightarrow \vert i, i+j \rangle`, where addition is taken modulo 3. By slicing along the first axis, we are able to select all of the amplitudes with corresponding :math:`|1\rangle` and :math:`|2\rangle` for the control qutrit. This means we just need to apply a :class:`~.TShift` gate when slicing along index 1, and a :class:`~.TShift` adjoint gate when slicing along index 2 Args: state (array[complex]): input state axes (List[int]): target axes to apply transformation Returns: array[complex]: output state """ slices = [_get_slice(i, axes[0], self.num_wires) for i in range(3)] # We will be slicing into the state according to state[slices[1]] and state[slices[2]], # giving us all of the amplitudes with a |1> and |2> for the control qutrit. The resulting # array has lost an axis relative to state and we need to be careful about the axis we # roll. If axes[1] is larger than axes[0], then we need to shift the target axis down by # one, otherwise we can leave as-is. For example: a state has [0, 1, 2, 3], control=1, # target=3. Then, state[slices[1]] has 3 axes and target=3 now corresponds to the second axis. target_axes = [axes[1] - 1] if axes[1] > axes[0] else [axes[1]] state_1 = self._apply_tshift(state[slices[1]], axes=target_axes, inverse=inverse) state_2 = self._apply_tshift(state[slices[2]], axes=target_axes, inverse=not inverse) return self._stack([state[slices[0]], state_1, state_2], axis=axes[0]) def _apply_tswap(self, state, axes, **kwargs): # pylint: disable=unused-argument """Applies a ternary SWAP gate by performing a partial transposition along the specified axes. The ternary SWAP gate acts on the computational basis states like :math:`\vert i, j\rangle \rightarrow \vert j, i \rangle`. Args: state (array[complex]): input state axes (List[int]): target axes to apply transformation Returns: array[complex]: output state """ all_axes = list(range(len(state.shape))) all_axes[axes[0]] = axes[1] all_axes[axes[1]] = axes[0] return self._transpose(state, all_axes) def _apply_phase( self, state, axes, index, phase, inverse=False ): # pylint: disable=too-many-arguments """Applies a phase onto the specified index along the axis specified in ``axes``. Args: state (array[complex]): input state axes (List[int]): target axes to apply transformation index (int): target index of axis to apply phase to phase (float): phase to apply inverse (bool): whether to apply the inverse phase Returns: array[complex]: output state """ num_wires = len(state.shape) slices = [_get_slice(i, axes[0], num_wires) for i in range(3)] phase = self._conj(phase) if inverse else phase state_slices = [ self._const_mul(phase if i == index else 1, state[slices[i]]) for i in range(3) ] return self._stack(state_slices, axis=axes[0]) def _get_unitary_matrix(self, unitary): # pylint: disable=no-self-use """Return the matrix representing a unitary operation. Args: unitary (~.Operation): a PennyLane unitary operation Returns: array[complex]: Returns a 2D matrix representation of the unitary in the computational basis. """ return unitary.matrix()
[docs] @classmethod def capabilities(cls): capabilities = super().capabilities().copy() capabilities.update( model="qutrit", supports_inverse_operations=True, supports_analytic_computation=True, returns_state=True, passthru_devices={ "autograd": "default.qutrit", "tf": "default.qutrit", "torch": "default.qutrit", "jax": "default.qutrit", }, ) return capabilities
def _create_basis_state(self, index): """Return a computational basis state over all wires. Args: index (int): integer representing the computational basis state Returns: array[complex]: complex array of shape ``[3]*self.num_wires`` representing the statevector of the basis state """ state = np.zeros(3**self.num_wires, dtype=np.complex128) state[index] = 1 state = self._asarray(state, dtype=self.C_DTYPE) return self._reshape(state, [3] * self.num_wires) @property def state(self): return self._flatten(self._pre_rotated_state)
[docs] @debug_logger def density_matrix(self, wires): """Returns the reduced density matrix of a given set of wires. Args: wires (Wires): wires of the reduced system. Returns: array[complex]: complex tensor of shape ``(3 ** len(wires), 3 ** len(wires))`` representing the reduced density matrix. """ dim = self.num_wires state = self._pre_rotated_state # Return the full density matrix by using numpy tensor product if wires == self.wires: density_matrix = self._tensordot(state, self._conj(state), axes=0) density_matrix = self._reshape(density_matrix, (3 ** len(wires), 3 ** len(wires))) return density_matrix complete_system = list(range(0, dim)) traced_system = [x for x in complete_system if x not in wires.labels] # Return the reduced density matrix by using numpy tensor product density_matrix = self._tensordot( state, self._conj(state), axes=(traced_system, traced_system) ) density_matrix = self._reshape(density_matrix, (3 ** len(wires), 3 ** len(wires))) return density_matrix
def _apply_unitary(self, state, mat, wires): r"""Apply multiplication of a matrix to subsystems of the quantum state. Args: state (array[complex]): input state mat (array): matrix to multiply wires (Wires): target wires Returns: array[complex]: output state """ # translate to wire labels used by device device_wires = self.map_wires(wires) mat = self._cast(self._reshape(mat, [3] * len(device_wires) * 2), dtype=self.C_DTYPE) axes = (list(range(len(device_wires), 2 * len(device_wires))), device_wires) tdot = self._tensordot(mat, state, axes=axes) # tensordot causes the axes given in `wires` to end up in the first positions # of the resulting tensor. This corresponds to a (partial) transpose of # the correct output state # We'll need to invert this permutation to put the indices in the correct place unused_idxs = [idx for idx in range(self.num_wires) if idx not in device_wires] perm = list(device_wires) + unused_idxs inv_perm = np.argsort(perm) # argsort gives inverse permutation return self._transpose(tdot, inv_perm)
[docs] @debug_logger def reset(self): """Reset the device""" super().reset() # init the state vector to |00..0> self._state = self._create_basis_state(0) self._pre_rotated_state = self._state
[docs] @debug_logger def analytic_probability(self, wires=None): if self._state is None: return None flat_state = self._flatten(self._state) real_state = self._real(flat_state) imag_state = self._imag(flat_state) prob = self.marginal_prob(real_state**2 + imag_state**2, wires) return prob