Program Listing for File StateVectorLQubit.hpp

Return to documentation for file (pennylane_lightning/core/simulators/lightning_qubit/StateVectorLQubit.hpp)

// Copyright 2018-2023 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

//     http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <complex>
#include <unordered_map>

#include "CPUMemoryModel.hpp"
#include "GateOperation.hpp"
#include "KernelMap.hpp"
#include "KernelType.hpp"
#include "StateVectorBase.hpp"
#include "Threading.hpp"
#include "cpu_kernels/GateImplementationsLM.hpp"

namespace {
using Pennylane::LightningQubit::Util::Threading;
using Pennylane::Util::CPUMemoryModel;
using Pennylane::Util::exp2;
using Pennylane::Util::squaredNorm;
using namespace Pennylane::LightningQubit::Gates;
} // namespace

namespace Pennylane::LightningQubit {
template <class PrecisionT, class Derived>
class StateVectorLQubit : public StateVectorBase<PrecisionT, Derived> {
  public:
    using ComplexT = std::complex<PrecisionT>;
    using MemoryStorageT = Pennylane::Util::MemoryStorageLocation::Undefined;

  protected:
    const Threading threading_;
    const CPUMemoryModel memory_model_;

  private:
    using BaseType = StateVectorBase<PrecisionT, Derived>;
    using GateKernelMap = std::unordered_map<GateOperation, KernelType>;
    using GeneratorKernelMap =
        std::unordered_map<GeneratorOperation, KernelType>;
    using MatrixKernelMap = std::unordered_map<MatrixOperation, KernelType>;
    using SparseMatrixKernelMap =
        std::unordered_map<SparseMatrixOperation, KernelType>;
    using ControlledGateKernelMap =
        std::unordered_map<ControlledGateOperation, KernelType>;
    using ControlledGeneratorKernelMap =
        std::unordered_map<ControlledGeneratorOperation, KernelType>;
    using ControlledMatrixKernelMap =
        std::unordered_map<ControlledMatrixOperation, KernelType>;
    using ControlledSparseMatrixKernelMap =
        std::unordered_map<ControlledSparseMatrixOperation, KernelType>;

    GateKernelMap kernel_for_gates_;
    GeneratorKernelMap kernel_for_generators_;
    MatrixKernelMap kernel_for_matrices_;
    SparseMatrixKernelMap kernel_for_sparse_matrices_;
    ControlledGateKernelMap kernel_for_controlled_gates_;
    ControlledGeneratorKernelMap kernel_for_controlled_generators_;
    ControlledMatrixKernelMap kernel_for_controlled_matrices_;
    ControlledSparseMatrixKernelMap kernel_for_controlled_sparse_matrices_;

    void setKernels(std::size_t num_qubits, Threading threading,
                    CPUMemoryModel memory_model) {
        using KernelMap::OperationKernelMap;
        kernel_for_gates_ =
            OperationKernelMap<GateOperation>::getInstance().getKernelMap(
                num_qubits, threading, memory_model);
        kernel_for_generators_ =
            OperationKernelMap<GeneratorOperation>::getInstance().getKernelMap(
                num_qubits, threading, memory_model);
        kernel_for_matrices_ =
            OperationKernelMap<MatrixOperation>::getInstance().getKernelMap(
                num_qubits, threading, memory_model);
        kernel_for_sparse_matrices_ =
            OperationKernelMap<SparseMatrixOperation>::getInstance()
                .getKernelMap(num_qubits, threading, memory_model);
        kernel_for_controlled_gates_ =
            OperationKernelMap<ControlledGateOperation>::getInstance()
                .getKernelMap(num_qubits, threading, memory_model);
        kernel_for_controlled_generators_ =
            OperationKernelMap<ControlledGeneratorOperation>::getInstance()
                .getKernelMap(num_qubits, threading, memory_model);
        kernel_for_controlled_matrices_ =
            OperationKernelMap<ControlledMatrixOperation>::getInstance()
                .getKernelMap(num_qubits, threading, memory_model);
        kernel_for_controlled_sparse_matrices_ =
            OperationKernelMap<ControlledSparseMatrixOperation>::getInstance()
                .getKernelMap(num_qubits, threading, memory_model);
    }

    [[nodiscard]] inline auto getKernelForGate(GateOperation gate_op) const
        -> KernelType {
        return kernel_for_gates_.at(gate_op);
    }

    [[nodiscard]] inline auto
    getKernelForControlledGate(ControlledGateOperation gate_op) const
        -> KernelType {
        return kernel_for_controlled_gates_.at(gate_op);
    }

    [[nodiscard]] inline auto
    getKernelForGenerator(GeneratorOperation gen_op) const -> KernelType {
        return kernel_for_generators_.at(gen_op);
    }

    [[nodiscard]] inline auto
    getKernelForControlledGenerator(ControlledGeneratorOperation gen_op) const
        -> KernelType {
        return kernel_for_controlled_generators_.at(gen_op);
    }

    [[nodiscard]] inline auto getKernelForMatrix(MatrixOperation mat_op) const
        -> KernelType {
        return kernel_for_matrices_.at(mat_op);
    }

    [[nodiscard]] inline auto
    getKernelForControlledMatrix(ControlledMatrixOperation mat_op) const
        -> KernelType {
        return kernel_for_controlled_matrices_.at(mat_op);
    }

    [[nodiscard]] inline auto
    getKernelForSparseMatrix(SparseMatrixOperation mat_op) const -> KernelType {
        return kernel_for_sparse_matrices_.at(mat_op);
    }

    [[nodiscard]] inline auto getKernelForControlledSparseMatrix(
        ControlledSparseMatrixOperation mat_op) const -> KernelType {
        return kernel_for_controlled_sparse_matrices_.at(mat_op);
    }

    [[nodiscard]] inline auto
    getGateKernelMap() const & -> const GateKernelMap & {
        return kernel_for_gates_;
    }

    [[nodiscard]] inline auto getGateKernelMap() && -> GateKernelMap {
        return kernel_for_gates_;
    }

    [[nodiscard]] inline auto
    getControlledGateKernelMap() const & -> const ControlledGateKernelMap & {
        return kernel_for_controlled_gates_;
    }

    [[nodiscard]] inline auto
    getControlledGateKernelMap() && -> ControlledGateKernelMap {
        return kernel_for_controlled_gates_;
    }

    [[nodiscard]] inline auto
    getGeneratorKernelMap() const & -> const GeneratorKernelMap & {
        return kernel_for_generators_;
    }

    [[nodiscard]] inline auto getGeneratorKernelMap() && -> GeneratorKernelMap {
        return kernel_for_generators_;
    }

    [[nodiscard]] inline auto getControlledGeneratorKernelMap() const & -> const
        ControlledGeneratorKernelMap & {
        return kernel_for_controlled_generators_;
    }

    [[nodiscard]] inline auto
    getControlledGeneratorKernelMap() && -> ControlledGeneratorKernelMap {
        return kernel_for_controlled_generators_;
    }

    [[nodiscard]] inline auto
    getMatrixKernelMap() const & -> const MatrixKernelMap & {
        return kernel_for_matrices_;
    }

    [[nodiscard]] inline auto getMatrixKernelMap() && -> MatrixKernelMap {
        return kernel_for_matrices_;
    }

    [[nodiscard]] inline auto getControlledMatrixKernelMap() const & -> const
        ControlledMatrixKernelMap & {
        return kernel_for_controlled_matrices_;
    }

    [[nodiscard]] inline auto
    getControlledMatrixKernelMap() && -> ControlledMatrixKernelMap {
        return kernel_for_controlled_matrices_;
    }

  protected:
    explicit StateVectorLQubit(std::size_t num_qubits, Threading threading,
                               CPUMemoryModel memory_model)
        : BaseType(num_qubits), threading_{threading},
          memory_model_{memory_model} {
        setKernels(num_qubits, threading, memory_model);
    }

  public:
    [[nodiscard]] inline CPUMemoryModel memoryModel() const {
        return memory_model_;
    }

    [[nodiscard]] inline Threading threading() const { return threading_; }

    [[nodiscard]] auto getSupportedKernels() const & -> std::tuple<
        const GateKernelMap &, const GeneratorKernelMap &,
        const MatrixKernelMap &, const ControlledGateKernelMap &,
        const ControlledGeneratorKernelMap &,
        const ControlledMatrixKernelMap &> {
        return {
            getGateKernelMap(),
            getGeneratorKernelMap(),
            getMatrixKernelMap(),
            getControlledGateKernelMap(),
            getControlledGeneratorKernelMap(),
            getControlledMatrixKernelMap(),
        };
    }

    [[nodiscard]] auto getSupportedKernels() && -> std::tuple<
        GateKernelMap &&, GeneratorKernelMap &&, MatrixKernelMap &&,
        ControlledGateKernelMap &&, ControlledGeneratorKernelMap &&,
        ControlledMatrixKernelMap &&> {
        return {
            getGateKernelMap(),
            getGeneratorKernelMap(),
            getMatrixKernelMap(),
            getControlledGateKernelMap(),
            getControlledGeneratorKernelMap(),
            getControlledMatrixKernelMap(),
        };
    }

    void applyOperation(Pennylane::Gates::KernelType kernel,
                        const std::string &opName,
                        const std::vector<std::size_t> &wires,
                        bool inverse = false,
                        const std::vector<PrecisionT> &params = {}) {
        auto *arr = BaseType::getData();
        DynamicDispatcher<PrecisionT>::getInstance().applyOperation(
            kernel, arr, BaseType::getNumQubits(), opName, wires, inverse,
            params);
    }

    void applyOperation(const std::string &opName,
                        const std::vector<std::size_t> &wires,
                        bool inverse = false,
                        const std::vector<PrecisionT> &params = {}) {
        if (opName == "Identity") {
            return;
        }
        auto *arr = BaseType::getData();
        auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        const auto gate_op = dispatcher.strToGateOp(opName);
        dispatcher.applyOperation(getKernelForGate(gate_op), arr,
                                  BaseType::getNumQubits(), gate_op, wires,
                                  inverse, params);
    }

    void applyOperation(const std::string &opName,
                        const std::vector<std::size_t> &controlled_wires,
                        const std::vector<bool> &controlled_values,
                        const std::vector<std::size_t> &wires,
                        bool inverse = false,
                        const std::vector<PrecisionT> &params = {}) {
        PL_ABORT_IF_NOT(
            areVecsDisjoint<std::size_t>(controlled_wires, wires),
            "`controlled_wires` and `target wires` must be disjoint.");

        PL_ABORT_IF_NOT(controlled_wires.size() == controlled_values.size(),
                        "`controlled_wires` must have the same size as "
                        "`controlled_values`.");
        auto *arr = BaseType::getData();
        const auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        const auto gate_op = dispatcher.strToControlledGateOp(opName);
        const auto kernel = getKernelForControlledGate(gate_op);
        dispatcher.applyControlledGate(
            kernel, arr, BaseType::getNumQubits(), opName, controlled_wires,
            controlled_values, wires, inverse, params);
    }
    template <typename Alloc>
    void applyOperation(const std::string &opName,
                        const std::vector<std::size_t> &wires, bool inverse,
                        const std::vector<PrecisionT> &params,
                        const std::vector<ComplexT, Alloc> &matrix) {
        auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        if (dispatcher.hasGateOp(opName)) {
            applyOperation(opName, wires, inverse, params);
        } else {
            applyMatrix(matrix, wires, inverse);
        }
    }

    template <typename Alloc>
    void applyOperation(const std::string &opName,
                        const std::vector<std::size_t> &controlled_wires,
                        const std::vector<bool> &controlled_values,
                        const std::vector<std::size_t> &wires, bool inverse,
                        const std::vector<PrecisionT> &params,
                        const std::vector<ComplexT, Alloc> &matrix) {
        PL_ABORT_IF_NOT(
            areVecsDisjoint<std::size_t>(controlled_wires, wires),
            "`controlled_wires` and `target wires` must be disjoint.");

        PL_ABORT_IF_NOT(controlled_wires.size() == controlled_values.size(),
                        "`controlled_wires` must have the same size as "
                        "`controlled_values`.");
        if (!controlled_wires.empty()) {
            applyOperation(opName, controlled_wires, controlled_values, wires,
                           inverse, params);
            return;
        }
        auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        if (dispatcher.hasGateOp(opName)) {
            applyOperation(opName, wires, inverse, params);
        } else {
            applyMatrix(matrix, wires, inverse);
        }
    }

    void applyPauliRot(const std::vector<std::size_t> &wires,
                       const bool inverse,
                       const std::vector<PrecisionT> &params,
                       const std::string &word) {
        PL_ABORT_IF_NOT(wires.size() == word.size(),
                        "wires and word have incompatible dimensions.");
        GateImplementationsLM::applyPauliRot<PrecisionT>(
            BaseType::getData(), BaseType::getNumQubits(), wires, inverse,
            params[0], word);
    }

    [[nodiscard]] inline auto applyGenerator(
        Pennylane::Gates::KernelType kernel, const std::string &opName,
        const std::vector<std::size_t> &wires, bool adj = false) -> PrecisionT {
        auto *arr = BaseType::getData();
        return DynamicDispatcher<PrecisionT>::getInstance().applyGenerator(
            kernel, arr, BaseType::getNumQubits(), opName, wires, adj);
    }

    [[nodiscard]] auto applyGenerator(const std::string &opName,
                                      const std::vector<std::size_t> &wires,
                                      bool adj = false) -> PrecisionT {
        auto *arr = BaseType::getData();
        const auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        const auto gen_op = dispatcher.strToGeneratorOp(opName);
        return dispatcher.applyGenerator(getKernelForGenerator(gen_op), arr,
                                         BaseType::getNumQubits(), opName,
                                         wires, adj);
    }

    [[nodiscard]] auto
    applyGenerator(const std::string &opName,
                   const std::vector<std::size_t> &controlled_wires,
                   const std::vector<bool> &controlled_values,
                   const std::vector<std::size_t> &wires, bool adj = false)
        -> PrecisionT {
        auto *arr = BaseType::getData();
        const auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        const auto generator_op = dispatcher.strToControlledGeneratorOp(opName);
        const auto kernel = getKernelForControlledGenerator(generator_op);
        return dispatcher.applyControlledGenerator(
            kernel, arr, BaseType::getNumQubits(), opName, controlled_wires,
            controlled_values, wires, adj);
    }

    inline void
    applyControlledMatrix(const ComplexT *matrix,
                          const std::vector<std::size_t> &controlled_wires,
                          const std::vector<bool> &controlled_values,
                          const std::vector<std::size_t> &wires,
                          bool inverse = false) {
        const auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        auto *arr = BaseType::getData();
        PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");
        PL_ABORT_IF_NOT(controlled_wires.size() == controlled_values.size(),
                        "`controlled_wires` must have the same size as "
                        "`controlled_values`.");
        const auto kernel = [n_wires = wires.size(), this]() {
            switch (n_wires) {
            case 1:
                return getKernelForControlledMatrix(
                    ControlledMatrixOperation::NCSingleQubitOp);
            case 2:
                return getKernelForControlledMatrix(
                    ControlledMatrixOperation::NCTwoQubitOp);
            default:
                return getKernelForControlledMatrix(
                    ControlledMatrixOperation::NCMultiQubitOp);
            }
        }();
        dispatcher.applyControlledMatrix(kernel, arr, BaseType::getNumQubits(),
                                         matrix, controlled_wires,
                                         controlled_values, wires, inverse);
    }

    inline void
    applyControlledMatrix(const std::vector<ComplexT> matrix,
                          const std::vector<std::size_t> &controlled_wires,
                          const std::vector<bool> &controlled_values,
                          const std::vector<std::size_t> &wires,
                          bool inverse = false) {
        PL_ABORT_IF_NOT(
            areVecsDisjoint<std::size_t>(controlled_wires, wires),
            "`controlled_wires` and `target wires` must be disjoint.");
        applyControlledMatrix(matrix.data(), controlled_wires,
                              controlled_values, wires, inverse);
    }

    inline void applyMatrix(Pennylane::Gates::KernelType kernel,
                            const ComplexT *matrix,
                            const std::vector<std::size_t> &wires,
                            bool inverse = false) {
        const auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        auto *arr = BaseType::getData();

        PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");

        dispatcher.applyMatrix(kernel, arr, BaseType::getNumQubits(), matrix,
                               wires, inverse);
    }

    inline void applyMatrix(Pennylane::Gates::KernelType kernel,
                            const std::vector<ComplexT> &matrix,
                            const std::vector<std::size_t> &wires,
                            bool inverse = false) {
        PL_ABORT_IF(matrix.size() != exp2(2 * wires.size()),
                    "The size of matrix does not match with the given "
                    "number of wires");

        applyMatrix(kernel, matrix.data(), wires, inverse);
    }

    inline void applyMatrix(const ComplexT *matrix,
                            const std::vector<std::size_t> &wires,
                            bool inverse = false) {
        using Pennylane::Gates::MatrixOperation;

        PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");

        const auto kernel = [n_wires = wires.size(), this]() {
            switch (n_wires) {
            case 1:
                return getKernelForMatrix(MatrixOperation::SingleQubitOp);
            case 2:
                return getKernelForMatrix(MatrixOperation::TwoQubitOp);
            default:
                return getKernelForMatrix(MatrixOperation::MultiQubitOp);
            }
        }();
        applyMatrix(kernel, matrix, wires, inverse);
    }

    template <typename Alloc>
    inline void applyMatrix(const std::vector<ComplexT, Alloc> &matrix,
                            const std::vector<std::size_t> &wires,
                            bool inverse = false) {
        PL_ABORT_IF(matrix.size() != exp2(2 * wires.size()),
                    "The size of matrix does not match with the given "
                    "number of wires");

        applyMatrix(matrix.data(), wires, inverse);
    }

    template <typename IndexT = std::size_t>
    inline void applyControlledSparseMatrix(
        const IndexT *row_map_ptr, const IndexT *col_idx_ptr,
        const ComplexT *values_ptr,
        const std::vector<std::size_t> &controlled_wires,
        const std::vector<bool> &controlled_values,
        const std::vector<std::size_t> &wires, bool inverse = false) {
        const auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        auto *arr = BaseType::getData();
        PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");
        PL_ABORT_IF_NOT(controlled_wires.size() == controlled_values.size(),
                        "`controlled_wires` must have the same size as "
                        "`controlled_values`.");
        PL_ABORT_IF_NOT(
            areVecsDisjoint<std::size_t>(controlled_wires, wires),
            "`controlled_wires` and `target wires` must be disjoint.");
        const auto kernel = [n_wires = wires.size(), this]() {
            return getKernelForControlledSparseMatrix(
                ControlledSparseMatrixOperation::NCSparseMultiQubitOp);
        }();
        dispatcher.applyControlledSparseMatrix(
            kernel, arr, BaseType::getNumQubits(), row_map_ptr, col_idx_ptr,
            values_ptr, controlled_wires, controlled_values, wires, inverse);
    }

    template <typename Alloc, typename IndexT = std::size_t>
    inline void applyControlledSparseMatrix(
        const std::vector<IndexT> &row_map, const std::vector<IndexT> &col_idx,
        const std::vector<ComplexT, Alloc> &values,
        const std::vector<std::size_t> &controlled_wires,
        const std::vector<bool> &controlled_values,
        const std::vector<std::size_t> &wires, bool inverse = false) {
        applyControlledSparseMatrix(row_map.data(), col_idx.data(),
                                    values.data(), controlled_wires,
                                    controlled_values, wires, inverse);
    }

    template <typename IndexT = std::size_t>
    inline void applySparseMatrix(Pennylane::Gates::KernelType kernel,
                                  const IndexT *row_map, const IndexT *col_idx,
                                  const ComplexT *values,
                                  const std::vector<std::size_t> &wires,
                                  bool inverse = false) {
        const auto &dispatcher = DynamicDispatcher<PrecisionT>::getInstance();
        auto *arr = BaseType::getData();

        PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");

        dispatcher.applySparseMatrix(kernel, arr, BaseType::getNumQubits(),
                                     row_map, col_idx, values, wires, inverse);
    }

    template <typename Alloc, typename IndexT = std::size_t>
    inline void applySparseMatrix(Pennylane::Gates::KernelType kernel,
                                  const std::vector<IndexT> &row_map,
                                  const std::vector<IndexT> &col_idx,
                                  const std::vector<ComplexT, Alloc> &values,
                                  const std::vector<std::size_t> &wires,
                                  bool inverse = false) {
        PL_ABORT_IF(row_map.size() - 1 != exp2(wires.size()),
                    "The size of matrix does not match with the given "
                    "number of wires");

        applySparseMatrix(kernel, row_map.data(), col_idx.data(), values.data(),
                          wires, inverse);
    }

    template <typename IndexT = std::size_t>
    inline void applySparseMatrix(const IndexT *row_map, const IndexT *col_idx,
                                  const ComplexT *values,
                                  const std::vector<std::size_t> &wires,
                                  bool inverse = false) {
        using Pennylane::Gates::SparseMatrixOperation;

        PL_ABORT_IF(wires.empty(), "Number of wires must be larger than 0");

        const auto kernel = [n_wires = wires.size(), this]() {
            return getKernelForSparseMatrix(
                SparseMatrixOperation::SparseMultiQubitOp);
        }();
        applySparseMatrix(kernel, row_map, col_idx, values, wires, inverse);
    }

    template <typename Alloc, typename IndexT = std::size_t>
    inline void applySparseMatrix(const std::vector<IndexT> &row_map,
                                  const std::vector<IndexT> &col_idx,
                                  const std::vector<ComplexT, Alloc> &values,
                                  const std::vector<std::size_t> &wires,
                                  bool inverse = false) {
        PL_ABORT_IF(row_map.size() - 1 != exp2(wires.size()),
                    "The size of matrix does not match with the given "
                    "number of wires");

        applySparseMatrix(row_map.data(), col_idx.data(), values.data(), wires,
                          inverse);
    }

    void collapse(const std::size_t wire, const bool branch) {
        auto *arr = BaseType::getData();
        const std::size_t stride =
            pow(2, BaseType::getNumQubits() - (1 + wire));
        const std::size_t vec_size = pow(2, BaseType::getNumQubits());
        const auto section_size = vec_size / stride;
        const auto half_section_size = section_size / 2;

        // zero half the entries
        // the "half" entries depend on the stride
        // *_*_*_*_ for stride 1
        // **__**__ for stride 2
        // ****____ for stride 4
        const std::size_t k = branch ? 0 : 1;
        for (std::size_t idx = 0; idx < half_section_size; idx++) {
            const std::size_t offset = stride * (k + 2 * idx);
            for (std::size_t ids = 0; ids < stride; ids++) {
                arr[offset + ids] = {0., 0.};
            }
        }

        normalize();
    }

    void normalize() {
        auto *arr = BaseType::getData();
        PrecisionT norm = std::sqrt(squaredNorm(arr, BaseType::getLength()));

        // TODO: Waiting the decision from PL core about how to solve the issue
        // https://github.com/PennyLaneAI/pennylane/issues/6504
        PL_ABORT_IF(norm < std::numeric_limits<PrecisionT>::epsilon() * 1e2,
                    "Vector has norm close to zero and cannot be normalized");

        ComplexT inv_norm = 1. / norm;
        for (std::size_t k = 0; k < BaseType::getLength(); k++) {
            arr[k] *= inv_norm;
        }
    }

    void setBasisState(const std::size_t index) {
        auto length = BaseType::getLength();
        PL_ABORT_IF(index > length - 1, "Invalid index");

        auto *arr = BaseType::getData();
        std::fill(arr, arr + length, 0.0);
        arr[index] = {1.0, 0.0};
    }

    void setBasisState(const std::vector<std::size_t> &state,
                       const std::vector<std::size_t> &wires) {
        const auto n_wires = wires.size();
        const auto num_qubits = BaseType::getNumQubits();
        std::size_t index{0U};
        for (std::size_t k = 0; k < n_wires; k++) {
            const auto bit = static_cast<std::size_t>(state[k]);
            index |= bit << (num_qubits - 1 - wires[k]);
        }
        setBasisState(index);
    }

    void resetStateVector() {
        if (BaseType::getLength() > 0) {
            setBasisState(0U);
        }
    }

    void setStateVector(const std::vector<std::size_t> &indices,
                        const std::vector<ComplexT> &values) {
        const auto num_indices = indices.size();
        PL_ABORT_IF(num_indices != values.size(),
                    "Indices and values length must match");

        auto *arr = BaseType::getData();
        const auto length = BaseType::getLength();
        std::fill(arr, arr + length, 0.0);
        for (std::size_t i = 0; i < num_indices; i++) {
            PL_ABORT_IF(i >= length, "Invalid index");
            arr[indices[i]] = values[i];
        }
    }

    void setStateVector(const std::vector<ComplexT> &state,
                        const std::vector<std::size_t> &wires) {
        PL_ABORT_IF_NOT(state.size() == exp2(wires.size()),
                        "Inconsistent state and wires dimensions.")
        setStateVector(state.data(), wires);
    }

    void setStateVector(const ComplexT *state,
                        const std::vector<std::size_t> &wires) {
        const std::size_t num_state = exp2(wires.size());
        const auto total_wire_count = BaseType::getNumQubits();

        std::vector<std::size_t> reversed_sorted_wires(wires);
        std::sort(reversed_sorted_wires.begin(), reversed_sorted_wires.end());
        std::reverse(reversed_sorted_wires.begin(),
                     reversed_sorted_wires.end());
        std::vector<std::size_t> controlled_wires(total_wire_count);
        std::iota(std::begin(controlled_wires), std::end(controlled_wires), 0);
        for (auto wire : reversed_sorted_wires) {
            // Reverse guarantees that we start erasing at the end of the array.
            // Maybe this can be optimized.
            controlled_wires.erase(controlled_wires.begin() + wire);
        }

        const std::vector<bool> controlled_values(controlled_wires.size(),
                                                  false);
        auto core_function = [num_state,
                              &state](ComplexT *arr,
                                      const std::vector<std::size_t> &indices,
                                      const std::size_t offset) {
            for (std::size_t i = 0; i < num_state; i++) {
                const std::size_t index = indices[i] + offset;
                arr[index] = state[i];
            }
        };
        GateImplementationsLM::applyNCN(BaseType::getData(), total_wire_count,
                                        controlled_wires, controlled_values,
                                        wires, core_function);
    }
};
} // namespace Pennylane::LightningQubit