Program Listing for File MPIManager.hpp

Return to documentation for file (pennylane_lightning/core/utils/MPIManager.hpp)

// Copyright 2022-2023 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

//     http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <bit>
#include <complex>
#include <cstring>
#include <mpi.h>
#include <stdexcept>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <vector>

#include "Error.hpp"

namespace {
using namespace Pennylane::Util;
} // namespace

namespace Pennylane::Util {
// LCOV_EXCL_START
inline void errhandler(int errcode, const char *str) {
    char msg[MPI_MAX_ERROR_STRING];
    int resultlen;
    MPI_Error_string(errcode, msg, &resultlen);
    fprintf(stderr, "%s: %s\n", str, msg);
    MPI_Abort(MPI_COMM_WORLD, 1);
}
// LCOV_EXCL_STOP

#define PL_MPI_IS_SUCCESS(fn)                                                  \
    {                                                                          \
        int errcode;                                                           \
        errcode = (fn);                                                        \
        if (errcode != MPI_SUCCESS)                                            \
            errhandler(errcode, #fn);                                          \
    }

template <typename T> auto cppTypeToString() -> const std::string {
    const std::string typestr = std::type_index(typeid(T)).name();
    return typestr;
}

class MPIManager {
  private:
    bool isExternalComm_;
    std::size_t rank_;
    std::size_t size_per_node_;
    std::size_t size_;
    MPI_Comm communicator_;

    std::string vendor_;
    std::size_t version_;
    std::size_t subversion_;

    std::unordered_map<std::string, MPI_Op> cpp_mpi_op_map = {
        {"op_null", MPI_OP_NULL}, {"max", MPI_MAX},
        {"min", MPI_MIN},         {"sum", MPI_SUM},
        {"prod", MPI_PROD},       {"land", MPI_LAND},
        {"band", MPI_BAND},       {"lor", MPI_LOR},
        {"bor", MPI_BOR},         {"lxor", MPI_LXOR},
        {"bxor", MPI_BXOR},       {"minloc", MPI_MINLOC},
        {"maxloc", MPI_MAXLOC},   {"replace", MPI_REPLACE},
    };

    std::unordered_map<std::string, MPI_Datatype> cpp_mpi_type_map = {
        {cppTypeToString<char>(), MPI_CHAR},
        {cppTypeToString<signed char>(), MPI_SIGNED_CHAR},
        {cppTypeToString<unsigned char>(), MPI_UNSIGNED_CHAR},
        {cppTypeToString<wchar_t>(), MPI_WCHAR},
        {cppTypeToString<short>(), MPI_SHORT},
        {cppTypeToString<unsigned short>(), MPI_UNSIGNED_SHORT},
        {cppTypeToString<int>(), MPI_INT},
        {cppTypeToString<unsigned int>(), MPI_UNSIGNED},
        {cppTypeToString<long>(), MPI_LONG},
        {cppTypeToString<unsigned long>(), MPI_UNSIGNED_LONG},
        {cppTypeToString<long long>(), MPI_LONG_LONG_INT},
        {cppTypeToString<float>(), MPI_FLOAT},
        {cppTypeToString<double>(), MPI_DOUBLE},
        {cppTypeToString<long double>(), MPI_LONG_DOUBLE},
        {cppTypeToString<int8_t>(), MPI_INT8_T},
        {cppTypeToString<int16_t>(), MPI_INT16_T},
        {cppTypeToString<int32_t>(), MPI_INT32_T},
        {cppTypeToString<int64_t>(), MPI_INT64_T},
        {cppTypeToString<uint8_t>(), MPI_UINT8_T},
        {cppTypeToString<uint16_t>(), MPI_UINT16_T},
        {cppTypeToString<uint32_t>(), MPI_UINT32_T},
        {cppTypeToString<uint64_t>(), MPI_UINT64_T},
        {cppTypeToString<bool>(), MPI_C_BOOL},
        {cppTypeToString<std::complex<float>>(), MPI_C_FLOAT_COMPLEX},
        {cppTypeToString<std::complex<double>>(), MPI_C_DOUBLE_COMPLEX},
        {cppTypeToString<std::complex<long double>>(),
         MPI_C_LONG_DOUBLE_COMPLEX}};

    void setVendor() {
        char version[MPI_MAX_LIBRARY_VERSION_STRING];
        int resultlen;

        PL_MPI_IS_SUCCESS(MPI_Get_library_version(version, &resultlen));

        std::string version_str = version;

        if (version_str.find("Open MPI") != std::string::npos) {
            vendor_ = "Open MPI";
        } else if (version_str.find("MPICH") != std::string::npos) {
            vendor_ = "MPICH";
        } else {
            PL_ABORT("Unsupported MPI implementation.\n");
        }
    }

    void setVersion() {
        int version_int, subversion_int;
        PL_MPI_IS_SUCCESS(MPI_Get_version(&version_int, &subversion_int));
        version_ = static_cast<std::size_t>(version_int);
        subversion_ = static_cast<std::size_t>(subversion_int);
    }

    void setNumProcsPerNode() {
        MPI_Comm node_comm;
        int size_per_node_int;
        PL_MPI_IS_SUCCESS(
            MPI_Comm_split_type(this->getComm(), MPI_COMM_TYPE_SHARED,
                                this->getRank(), MPI_INFO_NULL, &node_comm));
        PL_MPI_IS_SUCCESS(MPI_Comm_size(node_comm, &size_per_node_int));
        size_per_node_ = static_cast<std::size_t>(size_per_node_int);
        int compare;
        PL_MPI_IS_SUCCESS(
            MPI_Comm_compare(MPI_COMM_WORLD, node_comm, &compare));
        if (compare != MPI_IDENT)
            PL_MPI_IS_SUCCESS(MPI_Comm_free(&node_comm));
        this->Barrier();
    }

    void check_mpi_config() {
        // check if number of processes is power of two.
        PL_ABORT_IF(std::has_single_bit(
                        static_cast<unsigned int>(this->getSize())) != true,
                    "Processes number is not power of two.");
        PL_ABORT_IF(std::has_single_bit(
                        static_cast<unsigned int>(size_per_node_)) != true,
                    "Number of processes per node is not power of two.");
    }

  public:
    MPIManager(MPI_Comm communicator = MPI_COMM_WORLD)
        : communicator_(communicator) {
        int status = 0;
        MPI_Initialized(&status);
        if (!status) {
            PL_MPI_IS_SUCCESS(MPI_Init(nullptr, nullptr));
        }
        isExternalComm_ = true;
        int rank_int;
        int size_int;
        PL_MPI_IS_SUCCESS(MPI_Comm_rank(communicator_, &rank_int));
        PL_MPI_IS_SUCCESS(MPI_Comm_size(communicator_, &size_int));

        rank_ = static_cast<std::size_t>(rank_int);
        size_ = static_cast<std::size_t>(size_int);

        setVendor();
        setVersion();
        setNumProcsPerNode();
        check_mpi_config();
    }

    MPIManager(int argc, char **argv) {
        int status = 0;
        MPI_Initialized(&status);
        if (!status) {
            PL_MPI_IS_SUCCESS(MPI_Init(&argc, &argv));
        }
        isExternalComm_ = false;
        communicator_ = MPI_COMM_WORLD;
        int rank_int;
        int size_int;
        PL_MPI_IS_SUCCESS(MPI_Comm_rank(communicator_, &rank_int));
        PL_MPI_IS_SUCCESS(MPI_Comm_size(communicator_, &size_int));

        rank_ = static_cast<std::size_t>(rank_int);
        size_ = static_cast<std::size_t>(size_int);

        setVendor();
        setVersion();
        setNumProcsPerNode();
        check_mpi_config();
    }

    MPIManager(const MPIManager &other) {
        int status = 0;
        MPI_Initialized(&status);
        if (!status) {
            PL_MPI_IS_SUCCESS(MPI_Init(nullptr, nullptr));
        }
        isExternalComm_ = true;
        rank_ = other.rank_;
        size_ = other.size_;
        MPI_Comm_dup(
            other.communicator_,
            &communicator_); // Avoid freeing other.communicator_ in ~MPIManager
        vendor_ = other.vendor_;
        version_ = other.version_;
        subversion_ = other.subversion_;
        size_per_node_ = other.size_per_node_;
    }

    // LCOV_EXCL_START
    ~MPIManager() {
        if (!isExternalComm_) {
            int initflag;
            int finflag;
            PL_MPI_IS_SUCCESS(MPI_Initialized(&initflag));
            PL_MPI_IS_SUCCESS(MPI_Finalized(&finflag));
            if (initflag && !finflag) {
                PL_MPI_IS_SUCCESS(MPI_Finalize());
            }
        } else {
            int compare;
            PL_MPI_IS_SUCCESS(
                MPI_Comm_compare(MPI_COMM_WORLD, communicator_, &compare));
            if (compare != MPI_IDENT)
                PL_MPI_IS_SUCCESS(MPI_Comm_free(&communicator_));
        }
    }
    // LCOV_EXCL_STOP

    template <typename T> auto getMPIDatatype() const -> MPI_Datatype {
        const auto cpp_mpi_type_map = get_cpp_mpi_type_map();
        auto it = cpp_mpi_type_map.find(cppTypeToString<T>());
        if (it != cpp_mpi_type_map.end()) {
            return it->second;
        } else {
            PL_ABORT("Type not supported for MPIManager");
        }
    }

    // General MPI operations
    auto getRank() const -> std::size_t { return rank_; }

    auto getSize() const -> std::size_t { return size_; }

    auto getSizeNode() const -> std::size_t { return size_per_node_; }

    MPI_Comm getComm() const { return communicator_; }

    double getTime() const { return MPI_Wtime(); }

    auto getVendor() const -> const std::string & { return vendor_; }

    auto getVersion() const -> std::tuple<std::size_t, std::size_t> {
        return {version_, subversion_};
    }

    virtual auto get_cpp_mpi_type_map() const
        -> const std::unordered_map<std::string, MPI_Datatype> & {
        return cpp_mpi_type_map;
    }

    auto getMPIOpType(const std::string &op_str) const -> MPI_Op {
        auto it = cpp_mpi_op_map.find(op_str);
        if (it != cpp_mpi_op_map.end()) {
            return it->second;
        } else {
            PL_ABORT("Op not supported");
        }
    }

    template <typename T>
    void Allgather(T &sendBuf, std::vector<T> &recvBuf,
                   std::size_t sendCount = 1) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        PL_ABORT_IF(recvBuf.size() != this->getSize(),
                    "Incompatible size of sendBuf and recvBuf.");

        int sendCountInt = static_cast<int>(sendCount);
        PL_MPI_IS_SUCCESS(MPI_Allgather(&sendBuf, sendCountInt, datatype,
                                        recvBuf.data(), sendCountInt, datatype,
                                        this->getComm()));
    }

    template <typename T> auto allgather(T &sendBuf) -> std::vector<T> {
        MPI_Datatype datatype = getMPIDatatype<T>();
        std::vector<T> recvBuf(this->getSize());
        PL_MPI_IS_SUCCESS(MPI_Allgather(&sendBuf, 1, datatype, recvBuf.data(),
                                        1, datatype, this->getComm()));
        return recvBuf;
    }

    template <typename T>
    void Allgather(std::vector<T> &sendBuf, std::vector<T> &recvBuf) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        PL_ABORT_IF(recvBuf.size() != sendBuf.size() * this->getSize(),
                    "Incompatible size of sendBuf and recvBuf.");
        PL_MPI_IS_SUCCESS(MPI_Allgather(
            sendBuf.data(), sendBuf.size(), datatype, recvBuf.data(),
            sendBuf.size(), datatype, this->getComm()));
    }

    template <typename T>
    auto allgather(std::vector<T> &sendBuf) -> std::vector<T> {
        MPI_Datatype datatype = getMPIDatatype<T>();
        std::vector<T> recvBuf(sendBuf.size() * this->getSize());
        PL_MPI_IS_SUCCESS(MPI_Allgather(
            sendBuf.data(), sendBuf.size(), datatype, recvBuf.data(),
            sendBuf.size(), datatype, this->getComm()));
        return recvBuf;
    }

    template <typename T>
    void Allreduce(T &sendBuf, T &recvBuf, const std::string &op_str) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);
        PL_MPI_IS_SUCCESS(MPI_Allreduce(&sendBuf, &recvBuf, 1, datatype, op,
                                        this->getComm()));
    }

    template <typename T>
    auto allreduce(const T &sendBuf, const std::string &op_str) const -> T {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);
        T recvBuf;
        PL_MPI_IS_SUCCESS(MPI_Allreduce(&sendBuf, &recvBuf, 1, datatype, op,
                                        this->getComm()));
        return recvBuf;
    }

    template <typename T>
    void Allreduce(std::vector<T> &sendBuf, std::vector<T> &recvBuf,
                   const std::string &op_str) {
        PL_ABORT_IF(recvBuf.size() != sendBuf.size(),
                    "Incompatible size of sendBuf and recvBuf.");
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);
        PL_MPI_IS_SUCCESS(MPI_Allreduce(sendBuf.data(), recvBuf.data(),
                                        sendBuf.size(), datatype, op,
                                        this->getComm()));
    }

    template <typename T>
    auto allreduce(std::vector<T> &sendBuf, const std::string &op_str)
        -> std::vector<T> {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);
        std::vector<T> recvBuf(sendBuf.size());
        PL_MPI_IS_SUCCESS(MPI_Allreduce(sendBuf.data(), recvBuf.data(),
                                        sendBuf.size(), datatype, op,
                                        this->getComm()));
        return recvBuf;
    }

    template <typename T>
    void Reduce(T &sendBuf, T &recvBuf, std::size_t root,
                const std::string &op_str) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);
        PL_MPI_IS_SUCCESS(MPI_Reduce(&sendBuf, &recvBuf, 1, datatype, op, root,
                                     this->getComm()));
    }

    template <typename T>
    void Reduce(std::vector<T> &sendBuf, std::vector<T> &recvBuf,
                std::size_t root, const std::string &op_str) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);
        PL_MPI_IS_SUCCESS(MPI_Reduce(sendBuf.data(), recvBuf.data(),
                                     sendBuf.size(), datatype, op, root,
                                     this->getComm()));
    }

    template <typename T>
    void Reduce(T *sendBuf, T *recvBuf, std::size_t length, std::size_t root,
                const std::string &op_str) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);
        PL_MPI_IS_SUCCESS(MPI_Reduce(sendBuf, recvBuf, length, datatype, op,
                                     root, this->getComm()));
    }

    template <typename T>
    void Gather(T &sendBuf, std::vector<T> &recvBuf, std::size_t root) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        PL_MPI_IS_SUCCESS(MPI_Gather(&sendBuf, 1, datatype, recvBuf.data(), 1,
                                     datatype, root, this->getComm()));
    }

    template <typename T>
    void Gather(std::vector<T> &sendBuf, std::vector<T> &recvBuf,
                std::size_t root) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        PL_MPI_IS_SUCCESS(MPI_Gather(sendBuf.data(), sendBuf.size(), datatype,
                                     recvBuf.data(), sendBuf.size(), datatype,
                                     root, this->getComm()));
    }

    void Barrier() { PL_MPI_IS_SUCCESS(MPI_Barrier(this->getComm())); }

    template <typename T> void Bcast(T &sendBuf, std::size_t root) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        int rootInt = static_cast<int>(root);
        PL_MPI_IS_SUCCESS(
            MPI_Bcast(&sendBuf, 1, datatype, rootInt, this->getComm()));
    }

    template <typename T>
    void Bcast(std::vector<T> &sendBuf, std::size_t root) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        int rootInt = static_cast<int>(root);
        PL_MPI_IS_SUCCESS(MPI_Bcast(sendBuf.data(), sendBuf.size(), datatype,
                                    rootInt, this->getComm()));
    }

    template <typename T>
    void Scatter(T *sendBuf, T *recvBuf, std::size_t dataSize,
                 std::size_t root) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        int rootInt = static_cast<int>(root);
        PL_MPI_IS_SUCCESS(MPI_Scatter(sendBuf, dataSize, datatype, recvBuf,
                                      dataSize, datatype, rootInt,
                                      this->getComm()));
    }

    template <typename T>
    void Scatter(std::vector<T> &sendBuf, std::vector<T> &recvBuf,
                 std::size_t root) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        PL_ABORT_IF(sendBuf.size() != recvBuf.size() * this->getSize(),
                    "Incompatible size of sendBuf and recvBuf.");
        int rootInt = static_cast<int>(root);
        PL_MPI_IS_SUCCESS(MPI_Scatter(sendBuf.data(), recvBuf.size(), datatype,
                                      recvBuf.data(), recvBuf.size(), datatype,
                                      rootInt, this->getComm()));
    }

    template <typename T>
    auto scatter(std::vector<T> &sendBuf, std::size_t root) -> std::vector<T> {
        MPI_Datatype datatype = getMPIDatatype<T>();
        int recvBufSize;
        if (this->getRank() == root) {
            recvBufSize = sendBuf.size() / this->getSize();
        }
        this->Bcast<int>(recvBufSize, root);
        std::vector<T> recvBuf(recvBufSize);
        int rootInt = static_cast<int>(root);
        PL_MPI_IS_SUCCESS(MPI_Scatter(sendBuf.data(), recvBuf.size(), datatype,
                                      recvBuf.data(), recvBuf.size(), datatype,
                                      rootInt, this->getComm()));
        return recvBuf;
    }

    template <typename T> void Send(std::vector<T> &sendBuf, std::size_t dest) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        const int tag = 6789;

        PL_MPI_IS_SUCCESS(MPI_Send(sendBuf.data(), sendBuf.size(), datatype,
                                   static_cast<int>(dest), tag,
                                   this->getComm()));
    }

    template <typename T>
    void Recv(std::vector<T> &recvBuf, std::size_t source) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Status status;
        const int tag = MPI_ANY_TAG;

        PL_MPI_IS_SUCCESS(MPI_Recv(recvBuf.data(), recvBuf.size(), datatype,
                                   static_cast<int>(source), tag,
                                   this->getComm(), &status));
    }

    template <typename T>
    void Sendrecv(T &sendBuf, std::size_t dest, T &recvBuf, std::size_t source,
                  std::size_t tag = 0) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Status status;
        int sendtag = static_cast<int>(tag);
        int recvtag = sendtag;
        int destInt = static_cast<int>(dest);
        int sourceInt = static_cast<int>(source);
        PL_MPI_IS_SUCCESS(MPI_Sendrecv(&sendBuf, 1, datatype, destInt, sendtag,
                                       &recvBuf, 1, datatype, sourceInt,
                                       recvtag, this->getComm(), &status));
    }

    template <typename T>
    void Sendrecv(std::vector<T> &sendBuf, std::size_t dest,
                  std::vector<T> &recvBuf, std::size_t source,
                  std::size_t tag = 0) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Status status;
        int sendtag = static_cast<int>(tag);
        int recvtag = sendtag;
        int destInt = static_cast<int>(dest);
        int sourceInt = static_cast<int>(source);
        PL_MPI_IS_SUCCESS(MPI_Sendrecv(sendBuf.data(), sendBuf.size(), datatype,
                                       destInt, sendtag, recvBuf.data(),
                                       recvBuf.size(), datatype, sourceInt,
                                       recvtag, this->getComm(), &status));
    }

    template <typename T>
    void GatherV(std::vector<T> &sendBuf, std::vector<T> &recvBuf,
                 std::size_t root, std::vector<int> &displacements) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        int rootInt = static_cast<int>(root);

        std::vector<int> recvcount(getSize(), sendBuf.size());

        PL_MPI_IS_SUCCESS(MPI_Gatherv(sendBuf.data(), sendBuf.size(), datatype,
                                      recvBuf.data(), recvcount.data(),
                                      displacements.data(), datatype, rootInt,
                                      this->getComm()));
    }

    template <typename T>
    void GatherV(std::vector<T> &sendBuf, std::vector<T> &recvBuf,
                 std::vector<int> &recvCounts, std::size_t root,
                 std::vector<int> &displacements) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        int rootInt = static_cast<int>(root);

        PL_MPI_IS_SUCCESS(MPI_Gatherv(sendBuf.data(), sendBuf.size(), datatype,
                                      recvBuf.data(), recvCounts.data(),
                                      displacements.data(), datatype, rootInt,
                                      this->getComm()));
    }

    template <typename T>
    void Scan(T &sendBuf, T &recvBuf, const std::string &op_str) {
        MPI_Datatype datatype = getMPIDatatype<T>();
        MPI_Op op = getMPIOpType(op_str);

        PL_MPI_IS_SUCCESS(
            MPI_Scan(&sendBuf, &recvBuf, 1, datatype, op, this->getComm()));
    }

    auto split(std::size_t color, std::size_t key) -> MPIManager {
        MPI_Comm newcomm;
        int colorInt = static_cast<int>(color);
        int keyInt = static_cast<int>(key);
        PL_MPI_IS_SUCCESS(
            MPI_Comm_split(this->getComm(), colorInt, keyInt, &newcomm));
        return MPIManager(newcomm);
    }
};
} // namespace Pennylane::Util