diff --git a/examples/extended/user_owned_mpi.c b/examples/extended/user_owned_mpi.c new file mode 100644 index 000000000..4e3c766f4 --- /dev/null +++ b/examples/extended/user_owned_mpi.c @@ -0,0 +1,49 @@ +/** @file + * + * An example of using QuEST's experimental + * initCustomMpiQuESTEnv() function, to + * initialise QuEST in an environment where + * MPI is owned and controlled by the user. + * + * @author Oliver Brown + * @author Tyson Jones (doc) + */ + +#include "quest.h" +#include + + +// This example requires linking with MPI, which the CMake +// build only enables when QUEST_ENABLE_SUBCOMM is ON, which +// results in quest.h defining QUEST_COMPILE_SUBCOMM. To +// enable this example to always be compilable (like during +// our CI), we guard against when QUEST_ENABLE_SUBCOMM is OFF. +#if ! QUEST_COMPILE_SUBCOMM +int main(void) +{ + printf("Example skipped since MPI is not linked.\n"); + return 0; +} +#else + + +#include + +int main(void) +{ + const int USE_DISTRIB = 1; + const bool USER_MPI = 1; + const int USE_OPENMP = 1; + const int USE_GPU = 0; + + MPI_Init(NULL, NULL); + initCustomMpiQuESTEnv(USE_DISTRIB, USER_MPI, USE_GPU, USE_OPENMP); + reportQuESTEnv(); + finalizeQuESTEnv(); + MPI_Finalize(); + + return 0; +} + + +#endif // QUEST_COMPILE_SUBCOMM diff --git a/examples/extended/user_owned_mpi.cpp b/examples/extended/user_owned_mpi.cpp new file mode 100644 index 000000000..54345d576 --- /dev/null +++ b/examples/extended/user_owned_mpi.cpp @@ -0,0 +1,49 @@ +/** @file + * + * An example of using QuEST's experimental + * initCustomMpiQuESTEnv() function to + * initialise QuEST in an environment where + * MPI is owned and controlled by the user. + * + * @author Oliver Brown + * @author Tyson Jones (doc) + */ + +#include "quest.h" +#include + + +// This example requires linking with MPI, which the CMake +// build only enables when QUEST_ENABLE_SUBCOMM is ON, which +// results in quest.h defining QUEST_COMPILE_SUBCOMM. To +// enable this example to always be compilable (like during +// our CI), we guard against when QUEST_ENABLE_SUBCOMM is OFF. +#if ! QUEST_COMPILE_SUBCOMM +int main(void) +{ + std::printf("Example skipped since MPI is not linked.\n"); + return 0; +} +#else + + +#include + +int main(void) +{ + const int USE_DISTRIB = 1; + const bool USER_MPI = 1; + const int USE_OPENMP = 1; + const int USE_GPU = 0; + + MPI_Init(NULL, NULL); + initCustomMpiQuESTEnv(USE_DISTRIB, USER_MPI, USE_GPU, USE_OPENMP); + reportQuESTEnv(); + finalizeQuESTEnv(); + MPI_Finalize(); + + return 0; +} + + +#endif // QUEST_COMPILE_SUBCOMM diff --git a/examples/extended/user_owned_submpi.c b/examples/extended/user_owned_submpi.c new file mode 100644 index 000000000..6f2ea6290 --- /dev/null +++ b/examples/extended/user_owned_submpi.c @@ -0,0 +1,84 @@ +/** @file + * + * An example of using QuEST's experimental + * initCustomMpiCommQuESTEnv() function to + * dedicate only some user-owned MPI processes + * to QuEST, and dedicate the remainder to + * other tasks. + * + * @author Oliver Brown + * @author Tyson Jones (doc) + */ + +#include "quest.h" +#include + + +// This example requires linking with MPI, which the CMake +// build only enables when QUEST_ENABLE_SUBCOMM is ON, which +// results in quest.h defining QUEST_COMPILE_SUBCOMM. To +// enable this example to always be compilable (like during +// our CI), we guard against when QUEST_ENABLE_SUBCOMM is OFF. +#if ! QUEST_COMPILE_SUBCOMM +int main() +{ + printf("Example skipped since MPI is not linked.\n"); + return 0; +} +#else + + +#include + +int main (void) +{ + int nprocs, quest_nprocs, world_rank, quest_rank; + MPI_Comm comm_split, comm_quantum, comm_classical; + + MPI_Init(NULL, NULL); + + MPI_Comm_size(MPI_COMM_WORLD, &nprocs); + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + + const int I_AM_QUANTUM = world_rank % 2; + + printf("[%d] Hello from rank %d of %d in MPI_COMM_WORLD.\n", world_rank, world_rank, nprocs); + + MPI_Comm_split(MPI_COMM_WORLD, I_AM_QUANTUM, world_rank, &comm_split); + + if (I_AM_QUANTUM) { + MPI_Comm_dup(comm_split, &comm_quantum); + MPI_Comm_size(comm_quantum, &quest_nprocs); + MPI_Comm_rank(comm_quantum, &quest_rank); + printf("[%d] Hello from rank %d of %d in comm_quantum.\n", world_rank, quest_rank, quest_nprocs); + } else { + MPI_Comm_dup(comm_split, &comm_classical); + quest_rank = -1; + quest_nprocs = -1; + } + + // only procs in quantum comm initialise QuEST + if (I_AM_QUANTUM) { + printf("[%d] Initialising QuEST.\n", world_rank); + initCustomMpiCommQuESTEnv(comm_quantum, -1, -1); // -1 = auto-deployments + + reportQuESTEnv(); + + printf("[%d] Finalising QuEST.\n", world_rank); + finalizeQuESTEnv(); + } + + MPI_Comm_free(&comm_split); + if (I_AM_QUANTUM) { + MPI_Comm_free(&comm_quantum); + } else { + MPI_Comm_free(&comm_classical); + } + + MPI_Finalize(); + + return 0; +} + + +#endif // QUEST_COMPILE_SUBCOMM diff --git a/examples/extended/user_owned_submpi.cpp b/examples/extended/user_owned_submpi.cpp new file mode 100644 index 000000000..ea82a4f9d --- /dev/null +++ b/examples/extended/user_owned_submpi.cpp @@ -0,0 +1,84 @@ +/** @file + * + * An example of using QuEST's experimental + * initCustomMpiCommQuESTEnv() function to + * dedicate only some user-owned MPI processes + * to QuEST, and dedicate the remainder to + * other tasks. + * + * @author Oliver Brown + * @author Tyson Jones (doc) + */ + +#include "quest.h" +#include + + +// This example requires linking with MPI, which the CMake +// build only enables when QUEST_ENABLE_SUBCOMM is ON, which +// results in quest.h defining QUEST_COMPILE_SUBCOMM. To +// enable this example to always be compilable (like during +// our CI), we guard against when QUEST_ENABLE_SUBCOMM is OFF. +#if ! QUEST_COMPILE_SUBCOMM +int main() +{ + std::printf("Example skipped since MPI is not linked.\n"); + return 0; +} +#else + + +#include + +int main (void) +{ + int nprocs, quest_nprocs, world_rank, quest_rank; + MPI_Comm comm_split, comm_quantum, comm_classical; + + MPI_Init(NULL, NULL); + + MPI_Comm_size(MPI_COMM_WORLD, &nprocs); + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + + const int I_AM_QUANTUM = world_rank % 2; + + std::printf("[%d] Hello from rank %d of %d in MPI_COMM_WORLD.\n", world_rank, world_rank, nprocs); + + MPI_Comm_split(MPI_COMM_WORLD, I_AM_QUANTUM, world_rank, &comm_split); + + if (I_AM_QUANTUM) { + MPI_Comm_dup(comm_split, &comm_quantum); + MPI_Comm_size(comm_quantum, &quest_nprocs); + MPI_Comm_rank(comm_quantum, &quest_rank); + std::printf("[%d] Hello from rank %d of %d in comm_quantum.\n", world_rank, quest_rank, quest_nprocs); + } else { + MPI_Comm_dup(comm_split, &comm_classical); + quest_rank = -1; + quest_nprocs = -1; + } + + // only procs in quantum comm initialise QuEST + if (I_AM_QUANTUM) { + std::printf("[%d] Initialising QuEST.\n", world_rank); + initCustomMpiCommQuESTEnv(comm_quantum, modeflag::USE_AUTO, modeflag::USE_AUTO); + + reportQuESTEnv(); + + std::printf("[%d] Finalising QuEST.\n", world_rank); + finalizeQuESTEnv(); + } + + MPI_Comm_free(&comm_split); + if (I_AM_QUANTUM) { + MPI_Comm_free(&comm_quantum); + } else { + MPI_Comm_free(&comm_classical); + } + + MPI_Finalize(); + + return 0; +} + + +#endif // QUEST_COMPILE_SUBCOMM diff --git a/quest/include/environment.h b/quest/include/environment.h index 608829912..a584192d7 100644 --- a/quest/include/environment.h +++ b/quest/include/environment.h @@ -35,16 +35,16 @@ extern "C" { typedef struct { // deployment modes which can be runtime disabled - int isMultithreaded; - int isGpuAccelerated; - int isDistributed; - bool userOwnsMpi; + bool isMultithreaded; + bool isGpuAccelerated; + bool isDistributed; + bool isMpiUserOwned; // deployment modes which cannot be directly changed after compilation - int isCuQuantumEnabled; + bool isCuQuantumEnabled; // deployment configurations which can be changed via environment variables - int isGpuSharingEnabled; + bool isGpuSharingEnabled; // distributed configuration int rank; diff --git a/quest/src/api/environment.cpp b/quest/src/api/environment.cpp index 1cc2f6862..ab18ced91 100644 --- a/quest/src/api/environment.cpp +++ b/quest/src/api/environment.cpp @@ -48,7 +48,7 @@ using std::string; */ -static QuESTEnv* globalEnvPtr = nullptr; +static QuESTEnv* global_envPtr = nullptr; @@ -62,7 +62,7 @@ static QuESTEnv* globalEnvPtr = nullptr; */ -static bool hasEnvBeenFinalized = false; +static bool global_hasEnvBeenFinalized = false; @@ -74,9 +74,12 @@ static bool hasEnvBeenFinalized = false; void validateAndInitCustomQuESTEnv(int useDistrib, bool userOwnsMpi, int useGpuAccel, int useMultithread, const char* caller) { // ensure that we are never re-initialising QuEST (even after finalize) because - // this leads to undefined behaviour in distributed mode, as per the MPI - validate_envNeverInit(globalEnvPtr != nullptr, hasEnvBeenFinalized, caller); + // this leads to undefined behaviour in distributed mode, as per the MPI std, + // regardless of whether the user owns MPI + validate_envNeverInit(global_envPtr != nullptr, global_hasEnvBeenFinalized, caller); + // load env-vars before validating deployment mode, because some env vars can + // affect validation (such as QUEST_PERMIT_NODES_TO_SHARE_GPU) envvars_validateAndLoadEnvVars(caller); validateconfig_setEpsilonToDefault(); @@ -86,14 +89,19 @@ void validateAndInitCustomQuESTEnv(int useDistrib, bool userOwnsMpi, int useGpuA // by mpirun believe they are each the main rank. This seems unavoidable. validate_newEnvDeploymentMode(useDistrib, useGpuAccel, useMultithread, caller); - // overwrite deployments left as modeflag::USE_AUTO + // overwrite deployments (left as modeflag::USE_AUTO=-1) with 0,1 (a bool), + // which crucially, resolves useDistrib, permitting its consultation below autodep_chooseQuESTEnvDeployment(useDistrib, useGpuAccel, useMultithread); + // ensure that current state of MPI is valid + validate_mpiInitStatus(useDistrib, userOwnsMpi, caller); + // optionally initialise MPI; necessary before completing validation, // and before any GPU initialisation and validation, since we will // perform that specifically upon the MPI-process-bound GPU(s). Further, // we can make sure validation errors are reported only by the root node. - comm_init(useDistrib, userOwnsMpi); + if (useDistrib) + comm_init(userOwnsMpi); validate_newEnvDistributedBetweenPower2Nodes(caller); @@ -134,30 +142,25 @@ void validateAndInitCustomQuESTEnv(int useDistrib, bool userOwnsMpi, int useGpuA rand_setSeedsToDefault(); // allocate space for the global QuESTEnv singleton (overwriting nullptr, unless malloc fails) - globalEnvPtr = (QuESTEnv*) malloc(sizeof(QuESTEnv)); + global_envPtr = (QuESTEnv*) malloc(sizeof(QuESTEnv)); // pedantically check that teeny tiny malloc just succeeded - if (globalEnvPtr == nullptr) + if (global_envPtr == nullptr) error_allocOfQuESTEnvFailed(); - // bind deployment info to global instance - globalEnvPtr->isMultithreaded = useMultithread; - globalEnvPtr->isGpuAccelerated = useGpuAccel; - globalEnvPtr->isDistributed = useDistrib; - globalEnvPtr->userOwnsMpi = userOwnsMpi; - globalEnvPtr->isCuQuantumEnabled = useCuQuantum; - globalEnvPtr->isGpuSharingEnabled = permitGpuSharing; + // bind deployment info to global instance (autocasting int to bool) + global_envPtr->isMultithreaded = useMultithread; + global_envPtr->isGpuAccelerated = useGpuAccel; + global_envPtr->isDistributed = useDistrib; + global_envPtr->isMpiUserOwned = userOwnsMpi; + global_envPtr->isCuQuantumEnabled = useCuQuantum; + global_envPtr->isGpuSharingEnabled = permitGpuSharing; // bind distributed info - globalEnvPtr->rank = (useDistrib)? comm_getRank() : 0; - globalEnvPtr->numNodes = (useDistrib)? comm_getNumNodes() : 1; + global_envPtr->rank = (useDistrib)? comm_getRank() : 0; + global_envPtr->numNodes = (useDistrib)? comm_getNumNodes() : 1; } -void updateQuESTEnvDistInfo() { - globalEnvPtr->rank = (globalEnvPtr->isDistributed)? comm_getRank() : 0; - globalEnvPtr->numNodes = (globalEnvPtr->isDistributed)? comm_getNumNodes() : 1; - return; -} /* @@ -193,7 +196,7 @@ void printCompilationInfo() { print_table( "compilation", { {"isMpiCompiled", comm_isMpiCompiled()}, - {"isMpiSubCommunicatorCompiled", comm_isMpiSubCommunicatorCompiled()}, + {"isMpiSubCommCompiled", comm_isMpiSubCommCompiled()}, {"isGpuCompiled", gpu_isGpuCompiled()}, {"isOmpCompiled", cpu_isOpenmpCompiled()}, {"isCuQuantumCompiled", gpu_isCuQuantumCompiled()}, @@ -205,12 +208,12 @@ void printDeploymentInfo() { print_table( "deployment", { - {"isMpiEnabled", globalEnvPtr->isDistributed}, - {"doesUserOwnMpi", globalEnvPtr->userOwnsMpi}, - {"isGpuEnabled", globalEnvPtr->isGpuAccelerated}, - {"isOmpEnabled", globalEnvPtr->isMultithreaded}, - {"isCuQuantumEnabled", globalEnvPtr->isCuQuantumEnabled}, - {"isGpuSharingEnabled", globalEnvPtr->isGpuSharingEnabled}, + {"isMpiEnabled", global_envPtr->isDistributed}, + {"isMpiUserOwned", global_envPtr->isMpiUserOwned}, + {"isGpuEnabled", global_envPtr->isGpuAccelerated}, + {"isOmpEnabled", global_envPtr->isMultithreaded}, + {"isCuQuantumEnabled", global_envPtr->isCuQuantumEnabled}, + {"isGpuSharingEnabled", global_envPtr->isGpuSharingEnabled}, }); } @@ -270,7 +273,7 @@ void printDistributionInfo() { print_table( "distribution", { {"isMpiGpuAware", (comm_isMpiCompiled())? printer_toStr(comm_isMpiGpuAware()) : na}, - {"numMpiNodes", printer_toStr(globalEnvPtr->numNodes)}, + {"numMpiNodes", printer_toStr(global_envPtr->numNodes)}, }); } @@ -280,7 +283,7 @@ void printQuregSizeLimits(bool isDensMatr) { using namespace printer_substrings; // for brevity - int numNodes = globalEnvPtr->numNodes; + int numNodes = global_envPtr->numNodes; // by default, CPU limits are unknown (because memory query might fail) string maxQbForCpu = un; @@ -292,7 +295,7 @@ void printQuregSizeLimits(bool isDensMatr) { maxQbForCpu = printer_toStr(mem_getMaxNumQuregQubitsWhichCanFitInMemory(isDensMatr, 1, cpuMem)); // and the max MPI sizes are only relevant when env is distributed - if (globalEnvPtr->isDistributed) + if (global_envPtr->isDistributed) maxQbForMpiCpu = printer_toStr(mem_getMaxNumQuregQubitsWhichCanFitInMemory(isDensMatr, numNodes, cpuMem)); // when MPI irrelevant, change their status from "unknown" to "N/A" @@ -307,12 +310,12 @@ void printQuregSizeLimits(bool isDensMatr) { string maxQbForMpiGpu = na; // max GPU registers only relevant if env is GPU-accelerated - if (globalEnvPtr->isGpuAccelerated) { + if (global_envPtr->isGpuAccelerated) { qindex gpuMem = gpu_getCurrentAvailableMemoryInBytes(); maxQbForGpu = printer_toStr(mem_getMaxNumQuregQubitsWhichCanFitInMemory(isDensMatr, 1, gpuMem)); // and the max MPI sizes are further only relevant when env is distributed - if (globalEnvPtr->isDistributed) + if (global_envPtr->isDistributed) maxQbForMpiGpu = printer_toStr(mem_getMaxNumQuregQubitsWhichCanFitInMemory(isDensMatr, numNodes, gpuMem)); } @@ -349,7 +352,7 @@ void printQuregAutoDeployments(bool isDensMatr) { // test to theoretically max #qubits, surpassing max that can fit in RAM and GPUs, because // auto-deploy will still try to deploy there to (then subsequent validation will fail) - int maxQubits = mem_getMaxNumQuregQubitsBeforeGlobalMemSizeofOverflow(isDensMatr, globalEnvPtr->numNodes); + int maxQubits = mem_getMaxNumQuregQubitsBeforeGlobalMemSizeofOverflow(isDensMatr, global_envPtr->numNodes); for (int numQubits=1; numQubitsisGpuAccelerated) + if (global_envPtr->isGpuAccelerated) gpu_clearCache(); // syncs first - if (globalEnvPtr->isGpuAccelerated && gpu_isCuQuantumCompiled()) + if (global_envPtr->isGpuAccelerated && gpu_isCuQuantumCompiled()) gpu_finalizeCuQuantum(); - if (globalEnvPtr->isDistributed) { + if (global_envPtr->isDistributed) { comm_sync(); - comm_end(globalEnvPtr->userOwnsMpi); + comm_end(global_envPtr->isMpiUserOwned); } // free global env's heap memory and flag it as unallocated - free(globalEnvPtr); - globalEnvPtr = nullptr; + free(global_envPtr); + global_envPtr = nullptr; // flag that the environment was finalised, to ensure it is never re-initialised - hasEnvBeenFinalized = true; + global_hasEnvBeenFinalized = true; } void syncQuESTEnv() { validate_envIsInit(__func__); - if (globalEnvPtr->isGpuAccelerated) + if (global_envPtr->isGpuAccelerated) gpu_sync(); - if (globalEnvPtr->isDistributed) { + if (global_envPtr->isDistributed) comm_sync(); - #if QUEST_COMPILE_SUBCOMM - updateQuESTEnvDistInfo(); - #endif - } } @@ -509,19 +508,17 @@ void reportQuESTEnv() { void getQuESTEnvironmentString(char str[200]) { validate_envIsInit(__func__); - QuESTEnv env = getQuESTEnv(); - int numThreads = cpu_isOpenmpCompiled()? cpu_getAvailableNumThreads() : 1; - int cuQuantum = env.isGpuAccelerated && gpu_isCuQuantumCompiled(); - int gpuDirect = env.isGpuAccelerated && gpu_isDirectGpuCommPossible(); + int cuQuantum = global_envPtr->isGpuAccelerated && gpu_isCuQuantumCompiled(); + int gpuDirect = global_envPtr->isGpuAccelerated && gpu_isDirectGpuCommPossible(); snprintf(str, 200, "CUDA=%d OpenMP=%d MPI=%d userOwnsMPI=%d threads=%d ranks=%d cuQuantum=%d gpuDirect=%d", - env.isGpuAccelerated, - env.isMultithreaded, - env.isDistributed, - env.userOwnsMpi, + global_envPtr->isGpuAccelerated, + global_envPtr->isMultithreaded, + global_envPtr->isDistributed, + global_envPtr->isMpiUserOwned, numThreads, - env.numNodes, + global_envPtr->numNodes, cuQuantum, gpuDirect); } diff --git a/quest/src/api/subcommunicator.cpp b/quest/src/api/subcommunicator.cpp index e248f0dba..74c05293a 100644 --- a/quest/src/api/subcommunicator.cpp +++ b/quest/src/api/subcommunicator.cpp @@ -2,30 +2,48 @@ #include "quest/include/environment.h" #include "quest/include/subcommunicator.h" +#include "quest/src/core/validation.hpp" #include "quest/src/comm/comm_config.hpp" -#include "quest/src/core/errors.hpp" #if QUEST_COMPILE_MPI && QUEST_COMPILE_SUBCOMM -#include #include + +// TODO: +// We must resolve this communicator function which contains an MPI type +// and ergo should not be leaked outside comm_config.cpp. For now, we cheat! +extern bool comm_setMpiComm(MPI_Comm newComm); + + +// TODO: +// We must resolve this inner function of QuEST initialisation, but which is +// private to api/environment.cpp, and so cannot be exposed in the user-facing +// include/environment.hpp. Grr! For now, we here just cheekily extern it c: +extern void validateAndInitCustomQuESTEnv( + int useDistrib, bool userOwnsMpi, int useGpuAccel, int useMultithread, const char* caller); + + void initCustomMpiCommQuESTEnv(MPI_Comm userQuestComm, int useGpuAccel, int useMultithread) { + // useDistrib and userOwnsMpi are implied by the user of this initialiser const int useDistrib = 1; const bool userOwnsMpi = true; - // set mpiCommQuest to user provided communicator - if (comm_isInit()) { - comm_setMpiComm(userQuestComm); - } else { - error_commNotInit(); - } + // pre-validate that we are able to set the MPI communicator + validate_mpiInitStatus(useDistrib, userOwnsMpi, __func__); + validate_mpiSubCommIsNonNull(userQuestComm != MPI_COMM_NULL, __func__); - // initialise QuEST around that communicator - initCustomMpiQuESTEnv(useDistrib, userOwnsMpi, useGpuAccel, useMultithread); + // avoid re-setting the MPI comm (to avoid an internal error), which happens + // if a user illegally re-calls this function, which will be subsequently + // caught by the validation in validateAndInitCustomQuESTEnv() below + if (!comm_isMpiCommSet()) { + bool success = comm_setMpiComm(userQuestComm); + validate_mpiSubCommSetSucceeded(success, __func__); + } - return; + // perform remaining validation (some is harmlessly repeated) and init QuEST env + validateAndInitCustomQuESTEnv(useDistrib, userOwnsMpi, useGpuAccel, useMultithread, __func__); } #endif diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index 9da8f34e1..c69e72919 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -21,7 +21,7 @@ #if QUEST_COMPILE_MPI #include - static MPI_Comm mpiCommQuest = MPI_COMM_NULL; + static MPI_Comm global_mpiComm = MPI_COMM_NULL; #endif @@ -30,6 +30,7 @@ * WARN ABOUT CUDA-AWARENESS */ + #if QUEST_COMPILE_MPI && QUEST_COMPILE_CUDA // this check is OpenMPI specific @@ -54,6 +55,7 @@ /* * MPI ENVIRONMENT MANAGEMENT + * * all of which is safely callable in non-distributed mode */ @@ -62,7 +64,7 @@ bool comm_isMpiCompiled() { return (bool) QUEST_COMPILE_MPI; } -bool comm_isMpiSubCommunicatorCompiled() { +bool comm_isMpiSubCommCompiled() { return (bool) QUEST_COMPILE_SUBCOMM; } @@ -103,59 +105,27 @@ bool comm_isInit() { } -void comm_init(int useDistrib, bool userOwnsMpi) { +void comm_init(bool userOwnsMpi) { #if QUEST_COMPILE_MPI - // error if user owns MPI but has not initialised - if (userOwnsMpi && !comm_isInit()) { + // re-assert prior user-validations for robustness + if (userOwnsMpi && !comm_isInit()) error_commNotInit(); - } + if (!userOwnsMpi && comm_isInit()) + error_commAlreadyInit(); - // Overall mpiCommQuest should be set in the following ways - // however only useDistrib = 1 and userOwnsMpi = false - // and useDistrib = 0 and userOwnsMpi = true - // require action here - // - // | useDistrib | userOwnsMpi | mpiCommQuest | - // | ---------- | ----------- | -------------- | - // | 0 | false | MPI_COMM_NULL | - // | ---------- | ----------- | -------------- | - // | 1 | false | MPI_COMM_WORLD | - // | ---------- | ----------- | -------------- | - // | 0 | true | MPI_COMM_SELF | - // | ---------- | ----------- | -------------- | - // | | | MPI_COMM_WORLD | - // | 1 | true | or | - // | | | userQuestComm | - // | ---------- | ----------- | -------------- | - + // init MPI only when it's not the user's responsibility + if (!userOwnsMpi) + MPI_Init(NULL, NULL); - if (useDistrib && !userOwnsMpi) { - // error if attempting re-initialisation - if (comm_isInit()) { - error_commAlreadyInit(); - } else { - MPI_Init(NULL, NULL); - // The user wants MPI and is leaving it to QuEST - MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); - } - } else if (!useDistrib && userOwnsMpi) { - // The user has initialised MPI but wants QuEST to ignore it - MPI_Comm_dup(MPI_COMM_SELF, &mpiCommQuest); - } else if (useDistrib && userOwnsMpi) { - // if mpiCommQuEST is still MPI_COMM_NULL the user is not - // providing their own MPI_Comm and we should set mpiCommQuest - // to MPI_COMM_WORLD - if (mpiCommQuest == MPI_COMM_NULL) - MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); - } + // choose communicator only when the user hasn't + if (global_mpiComm == MPI_COMM_NULL) + MPI_Comm_dup(MPI_COMM_WORLD, &global_mpiComm); #endif - return; } - void comm_end(bool userOwnsMpi) { #if QUEST_COMPILE_MPI @@ -163,8 +133,15 @@ void comm_end(bool userOwnsMpi) { if (!comm_isInit()) return; - MPI_Barrier(mpiCommQuest); - MPI_Comm_free(&mpiCommQuest); + // gracefully handle when the communicator is still NULL, because comm_end() may be + // triggered by "bad MPI init" validation, during which, the communicator may not yet + // have been set. We choose NOT to divert to MPI_COMM_WORLD, which is likely just to + // stall at MPI_Barrier, and instead let the user's communicator live on; then crash! + if (global_mpiComm == MPI_COMM_NULL) + return; + + MPI_Barrier(global_mpiComm); + MPI_Comm_free(&global_mpiComm); // QuEST must finalise MPI if the user does not own it if (!userOwnsMpi) @@ -183,8 +160,18 @@ int comm_getRank() { if (!comm_isInit()) return ROOT_RANK; + // Consult the (potentially sub-) communicator for rank; if it is still + // NULL, as can only validly happen during failed QuESTEnv init validation + // (which triggers root-only error printing and ergo this function), we + // fall back to every process believing it is root and so attempting to + // print. This safely avoids consulting a potentially bugged MPI communicator + // and losing the message. We once tried to fallback to MPI_COMM_WORLD here, + // to avoid duplicate output, but it is not worth the risk of msg loss! + if (global_mpiComm == MPI_COMM_NULL) + return ROOT_RANK; + int rank; - MPI_Comm_rank(mpiCommQuest, &rank); + MPI_Comm_rank(global_mpiComm, &rank); return rank; #else @@ -213,7 +200,7 @@ int comm_getNumNodes() { return 1; int numNodes; - MPI_Comm_size(mpiCommQuest, &numNodes); + MPI_Comm_size(global_mpiComm, &numNodes); return numNodes; #else @@ -231,31 +218,58 @@ void comm_sync() { if (!comm_isInit()) return; - MPI_Barrier(mpiCommQuest); + // gracefully handle when the communicator is still NULL, because comm_sync() is + // triggered by "bad MPI init" validation (during the error message printing) + // during which, the communicator may not yet have been overriden + if (global_mpiComm == MPI_COMM_NULL) + return; + + MPI_Barrier(global_mpiComm); #endif } + + +/* + * MPI COMMUNICATOR MANAGEMENT + * + * some of which requires exposing MPI_Comm in external-facing signatures. + * In lieu of leaking these into comm_config.hpp, callers must extern them. + */ + +bool comm_isMpiCommSet() { #if QUEST_COMPILE_MPI - MPI_Comm comm_getMpiComm() { - return mpiCommQuest; - } - - #if QUEST_COMPILE_SUBCOMM - void comm_setMpiComm(MPI_Comm newComm) { - - // error if mpiCommQuEST is already set! - if (mpiCommQuest != MPI_COMM_NULL) { - MPI_Barrier(mpiCommQuest); - MPI_Comm_free(&mpiCommQuest); - error_commDoubleSetMpiComm(); - } - - int mpi_err = MPI_Comm_dup(newComm, &mpiCommQuest); - if (mpi_err != MPI_SUCCESS) { - error_commInvalidMpiComm(); - } - - return; - } - #endif + + // once comm_init() or comm_setMpiComm() overwrite + // the communicator, is can never return to NULL + return (global_mpiComm != MPI_COMM_NULL); +# else + return false; #endif +} + +#if QUEST_COMPILE_MPI + +MPI_Comm comm_getMpiComm() { + + if (global_mpiComm == MPI_COMM_NULL) + error_commMpiCommIsNull(); + + return global_mpiComm; +} + +bool comm_setMpiComm(MPI_Comm newComm) { + + // this is called prior to QuEST initialisation, + // and merely seeks to overwrite global_mpiComm + + if (global_mpiComm != MPI_COMM_NULL) + error_commAlreadyHasSetMpiComm(); + if (newComm == MPI_COMM_NULL) + error_commMpiCommIsNull(); + + auto status = MPI_Comm_dup(newComm, &global_mpiComm); + return status == MPI_SUCCESS; +} + +#endif // QUEST_COMPILE_MPI diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index b2d038cd5..8441dbc23 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -10,19 +10,13 @@ #ifndef COMM_CONFIG_HPP #define COMM_CONFIG_HPP -#include "quest/include/config.h" - -#if QUEST_COMPILE_MPI - #include -#endif - constexpr int ROOT_RANK = 0; bool comm_isMpiCompiled(); -bool comm_isMpiSubCommunicatorCompiled(); +bool comm_isMpiSubCommCompiled(); bool comm_isMpiGpuAware(); -void comm_init(int useDistrib, bool userOwnsMpi); +void comm_init(bool userOwnsMpi); void comm_end(bool userOwnsMpi); void comm_sync(); @@ -33,11 +27,10 @@ bool comm_isInit(); bool comm_isRootNode(); bool comm_isRootNode(int rank); -#if QUEST_COMPILE_MPI - MPI_Comm comm_getMpiComm(); - #if QUEST_COMPILE_SUBCOMM - void comm_setMpiComm(MPI_Comm newComm); - #endif -#endif +bool comm_isMpiCommSet(); + +// Signatures containing MPI types which callers must extern: +// extern MPI_Comm comm_getMpiComm() +// extern bool comm_setMpiComm(MPI_Comm newComm) #endif // COMM_CONFIG_HPP diff --git a/quest/src/comm/comm_routines.cpp b/quest/src/comm/comm_routines.cpp index 0bc90563b..cf6956454 100644 --- a/quest/src/comm/comm_routines.cpp +++ b/quest/src/comm/comm_routines.cpp @@ -6,7 +6,7 @@ * * @author Tyson Jones * @author Jakub Adamski (sped-up large comm by asynch messages) - * @author Oliver Brown (patched max-message inference, consulted on AR and MPICH support) + * @author Oliver Brown (added custom communicators, patched max-message inference, consulted on AR and MPICH support) * @author Ania (Anna) Brown (developed QuEST v1 logic) */ @@ -24,6 +24,7 @@ #if QUEST_COMPILE_MPI #include + extern MPI_Comm comm_getMpiComm(); // comm_config.cpp does not leak MPI_Comm #endif #include @@ -149,8 +150,7 @@ int getMaxNumMessages() { // messages. Beware the max is obtained via a void pointer and might be unset... void* tagUpperBoundPtr; int isAttribSet; - MPI_Comm mpiCommQuest = comm_getMpiComm(); - MPI_Comm_get_attr(mpiCommQuest, MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); + MPI_Comm_get_attr(comm_getMpiComm(), MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); // if something went wrong with obtaining the tag bound, return the safe minimum if (!isAttribSet) @@ -217,7 +217,7 @@ std::array dividePayloadIntoMessages(qindex numAmps) { void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); + MPI_Comm mpiComm = comm_getMpiComm(); // each message is asynchronously dispatched with a final wait, as per arxiv.org/abs/2308.07402 @@ -229,8 +229,8 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { // so that messages are permitted to arrive out-of-order (supporting UCX adaptive-routing) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m]); - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m+1]); + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiComm, &requests[2*m]); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiComm, &requests[2*m+1]); } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -251,7 +251,7 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); + MPI_Comm mpiComm = comm_getMpiComm(); // we will not track nor wait for the asynch send; instead, the caller will later comm_sync() MPI_Request nullReq = MPI_REQUEST_NULL; @@ -262,7 +262,7 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { // asynchronously send the uniquely-tagged messages for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &nullReq); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiComm, &nullReq); } #else @@ -274,7 +274,7 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { void receiveArray(qcomp* dest, qindex numElems, int pairRank) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); + MPI_Comm mpiComm = comm_getMpiComm(); // expect the data in multiple messages auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(numElems); @@ -285,7 +285,7 @@ void receiveArray(qcomp* dest, qindex numElems, int pairRank) { // listen to receive each uniquely-tagged message asynchronously (as per arxiv.org/abs/2308.07402) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[m]); + MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiComm, &requests[m]); } // receivers wait for all messages to be received (while sender asynch proceeds) @@ -310,8 +310,7 @@ void globallyCombineNonUniformSubArrays( ) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - + auto mpiComm = comm_getMpiComm(); int myRank = comm_getRank(); int numNodes = comm_getNumNodes(); @@ -345,14 +344,14 @@ void globallyCombineNonUniformSubArrays( for (int m=0; m 0) { qindex recvInd = globalRecvIndPerRank[sendRank] + (numBigMsgs * bigMsgSize); requests.push_back(MPI_REQUEST_NULL); - MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, mpiCommQuest, &requests.back()); + MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, mpiComm, &requests.back()); } } @@ -648,9 +647,7 @@ void comm_exchangeAmpsToBuffers(Qureg qureg, int pairRank) { void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - - MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, mpiCommQuest); + MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, comm_getMpiComm()); #else error_commButEnvNotDistributed(); @@ -661,7 +658,7 @@ void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); + MPI_Comm mpiComm = comm_getMpiComm(); // only the sender and root nodes need to continue int recvRank = ROOT_RANK; @@ -678,8 +675,8 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) for (qindex m=0; m(m); (myRank == sendRank)? - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, mpiCommQuest, &requests[m]): // sender - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, mpiCommQuest, &requests[m]); // root + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, mpiComm, &requests[m]): // sender + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, mpiComm, &requests[m]); // root } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -692,13 +689,10 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) void comm_broadcastIntsFromRoot(int* arr, qindex length) { - #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_INT, sendRank, mpiCommQuest); + MPI_Bcast(arr, length, MPI_INT, sendRank, comm_getMpiComm()); #else error_commButEnvNotDistributed(); @@ -709,10 +703,8 @@ void comm_broadcastIntsFromRoot(int* arr, qindex length) { void comm_broadcastUnsignedsFromRoot(unsigned* arr, qindex length) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, mpiCommQuest); + MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, comm_getMpiComm()); #else error_commButEnvNotDistributed(); @@ -739,9 +731,7 @@ void comm_combineSubArrays(qcomp* recv, vector recvInds, vector void comm_reduceAmp(qcomp* localAmp) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - - MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, mpiCommQuest); + MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, comm_getMpiComm()); #else error_commButEnvNotDistributed(); @@ -752,9 +742,7 @@ void comm_reduceAmp(qcomp* localAmp) { void comm_reduceReal(qreal* localReal) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - - MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, mpiCommQuest); + MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, comm_getMpiComm()); #else error_commButEnvNotDistributed(); @@ -765,9 +753,7 @@ void comm_reduceReal(qreal* localReal) { void comm_reduceReals(qreal* localReals, qindex numLocalReals) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - - MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, mpiCommQuest); + MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, comm_getMpiComm()); #else error_commButEnvNotDistributed(); @@ -778,12 +764,10 @@ void comm_reduceReals(qreal* localReals, qindex numLocalReals) { bool comm_isTrueOnAllNodes(bool val) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - // perform global AND and broadcast result back to all nodes int local = (int) val; int global; - MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, mpiCommQuest); + MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, comm_getMpiComm()); return (bool) global; #else @@ -819,8 +803,6 @@ bool comm_isTrueOnRootNode(bool val) { vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) { #if QUEST_COMPILE_MPI - MPI_Comm mpiCommQuest = comm_getMpiComm(); - // no need to validate array sizes and memory alloc successes; // these are trivial O(#nodes)-size arrays containing <20 chars int numNodes = comm_getNumNodes(); @@ -831,7 +813,7 @@ vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) // all nodes send root all their local chars int recvRank = ROOT_RANK; MPI_Gather(localChars, maxNumLocalChars, MPI_CHAR, allChars.data(), - maxNumLocalChars, MPI_CHAR, recvRank, mpiCommQuest); + maxNumLocalChars, MPI_CHAR, recvRank, comm_getMpiComm()); // divide allChars into stings, delimited by each node's terminal char vector out(numNodes); diff --git a/quest/src/core/errors.cpp b/quest/src/core/errors.cpp index 7b624a2f7..862136a9c 100644 --- a/quest/src/core/errors.cpp +++ b/quest/src/core/errors.cpp @@ -156,11 +156,6 @@ void error_commAlreadyInit() { raiseInternalError("The MPI communication environment was attemptedly re-initialised despite the QuEST environment already existing."); } -void error_commInvalidMpiComm() { - - raiseInternalError("The supplied MPI communicator was MPI_COMM_NULL, or duplication failed."); -} - void error_commButEnvNotDistributed() { raiseInternalError("A function attempted to invoke communication despite QuEST being compiled in non-distributed mode."); @@ -186,9 +181,14 @@ void error_commNumMessagesExceedTagMax() { raiseInternalError("A function attempted to communicate via more messages than permitted (since there would be more uniquely-tagged messages than the tag upperbound)."); } -void error_commDoubleSetMpiComm() { +void error_commAlreadyHasSetMpiComm() { - raiseInternalError("An attempt was made to set mpiCommQuest after it had already been set, as indicated by mpiCommQuest != MPI_COMM_NULL."); + raiseInternalError("An attempt was made to set the QuEST MPI communicator after it had already been set (and changed from MPI_COMM_NULL)."); +} + +void error_commMpiCommIsNull() { + + raiseInternalError("The MPI communicator was queried (or set) but was unexpectedly MPI_COMM_NULL (or set to be)."); } void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps) { diff --git a/quest/src/core/errors.hpp b/quest/src/core/errors.hpp index f276c06ad..33cc0661d 100644 --- a/quest/src/core/errors.hpp +++ b/quest/src/core/errors.hpp @@ -81,8 +81,6 @@ void error_commNotInit(); void error_commAlreadyInit(); -void error_commInvalidMpiComm(); - void error_commButEnvNotDistributed(); void error_commOutOfBounds(); @@ -93,7 +91,9 @@ void error_commGivenInconsistentNumSubArraysANodes(); void error_commNumMessagesExceedTagMax(); -void error_commDoubleSetMpiComm(); +void error_commAlreadyHasSetMpiComm(); + +void error_commMpiCommIsNull(); void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps); diff --git a/quest/src/core/validation.cpp b/quest/src/core/validation.cpp index 959acb61e..e1df0af76 100644 --- a/quest/src/core/validation.cpp +++ b/quest/src/core/validation.cpp @@ -107,6 +107,21 @@ namespace report { string CUQUANTUM_DEPLOYED_ON_GPU_WITHOUT_MEM_POOLS = "Cannot use cuQuantum since your GPU does not support memory pools. Recompile with cuQuantum disabled to fall-back to using Thrust and custom kernels."; + string USER_OWNED_MPI_WAS_NOT_INIT = + "User owns MPI but did not prior initialise MPI before initialising QuEST."; + + string USER_GIVEN_MPI_COMMUNICATOR_IS_NULL = + "The provided MPI communicator was null (MPI_COMM_NULL)."; + + string USER_GIVEN_MPI_COMMUNICATOR_FAILED_TO_SET = + "The provided MPI communicator could not be used; MPI_Comm_dup() was not successful."; + + string QUEST_OWNED_MPI_WAS_PRE_INIT = + "MPI was already initialised prior to QuESTEnv initialisation, but the user did not declare MPI ownership."; + + string QUEST_IS_NON_DISTRIBUTED_BUT_MPI_WAS_INIT = + "QuESTEnv was initialised to be non-distributed but MPI was externally initialised - this is presently unsupported due to a (very minor) technical limitation. If you need this facility, please raise a Github issue!"; + /* * EXISTING QUESTENV @@ -1155,10 +1170,10 @@ void default_inputErrorHandler(const char* func, const char* msg) { // will then attempt to instantly abort all nodes, losing the error message. comm_sync(); - // finalise MPI before error-exit to avoid scaring user with giant MPI error message + // finalise MPI before error-exit to avoid scaring user with giant MPI error message; // we always "take ownership" of MPI here since we're about to kill the whole program if (comm_isInit()) - comm_end(0); + comm_end(/*userOwnsMpi=*/false); // simply exit, interrupting any other process (potentially leaking) exit(EXIT_FAILURE); @@ -1482,6 +1497,60 @@ void validate_gpuIsCuQuantumCompatible(const char* caller) { assertAllNodesAgreeThat(hasMemPools, report::CUQUANTUM_DEPLOYED_ON_GPU_WITHOUT_MEM_POOLS, caller); } +void validate_mpiInitStatus(bool useDistrib, bool userOwnsMpi, const char* caller) { + + if (!global_isValidationEnabled) + return; + + // Validation prior to this function confirms init(Custom*)QuESTEnv is only ever called + // once, but we must additionally confirm the user has interacted with MPI legally + + bool isMpiInit = comm_isInit(); + + // (A) If the user does not declare ownership of MPI, they are forbidden to initialise it + if (!userOwnsMpi) + assertThat(!isMpiInit, report::QUEST_OWNED_MPI_WAS_PRE_INIT, caller); + + // (B) If QuEST is instructed not to use distribution, we must demand the user is not + // using MPI, because we internally consult comm_isInit() to detect QuEST distribution + // in many functions, and that will give a false positive when the user inits MPI directly. + if (!useDistrib) + assertThat(!isMpiInit, report::QUEST_IS_NON_DISTRIBUTED_BUT_MPI_WAS_INIT, caller); + + // TODO: we can relax above, permitting the user to play with MPI directly while + // disabling it for QuEST, by replacing internal comm_isInit() with e.g. env_isDistributed() + + // (C) If QuEST will use MPI owned by the user, the user must have pre-initialised it + if (useDistrib && userOwnsMpi) + assertThat(isMpiInit, report::USER_OWNED_MPI_WAS_NOT_INIT, caller); + + // Confirmation that all 8 scenarios are handled: + // useDistrib=0, userOwnsMpi=0, isMpiInit=0 (legal: nobody wants MPI) + // (A) useDistrib=0, userOwnsMpi=0, isMpiInit=1 (illegal: user lied about ownership) + // useDistrib=0, userOwnsMpi=1, isMpiInit=0 (legal: user owns MPI but does nothing!) + // (B) useDistrib=0, userOwnsMpi=1, isMpiInit=1 (illegal: comm_isInit() limitation as above) + // useDistrib=1, userOwnsMpi=0, isMpiInit=0 (legal: QuEST will init MPI) + // (A) useDistrib=1, userOwnsMpi=0, isMpiInit=1 (illegal: user lied about ownership) + // (C) useDistrib=1, userOwnsMpi=1, isMpiInit=0 (illegal: user has reponsibility to pre-init) + // useDistrib=1, userOwnsMpi=1, isMpiInit=1 (legal: user fulfilled responsibility to pre-init) +} + +void validate_mpiSubCommIsNonNull(bool isNonNull, const char* caller) { + + if (!global_isValidationEnabled) + return; + + assertThat(isNonNull, report::USER_GIVEN_MPI_COMMUNICATOR_IS_NULL, caller); +} + +void validate_mpiSubCommSetSucceeded(bool success, const char* caller) { + + if (!global_isValidationEnabled) + return; + + assertThat(success, report::USER_GIVEN_MPI_COMMUNICATOR_FAILED_TO_SET, caller); +} + /* diff --git a/quest/src/core/validation.hpp b/quest/src/core/validation.hpp index 66fb8f546..787316326 100644 --- a/quest/src/core/validation.hpp +++ b/quest/src/core/validation.hpp @@ -77,6 +77,12 @@ void validate_newEnvNodesEachHaveUniqueGpu(const char* caller); void validate_gpuIsCuQuantumCompatible(const char* caller); +void validate_mpiInitStatus(bool useDistrib, bool userOwnsMpi, const char* caller); + +void validate_mpiSubCommIsNonNull(bool isNonNull, const char* caller); + +void validate_mpiSubCommSetSucceeded(bool success, const char* caller); + /* diff --git a/tests/unit/environment.cpp b/tests/unit/environment.cpp index 344ac5864..9ecf8e376 100644 --- a/tests/unit/environment.cpp +++ b/tests/unit/environment.cpp @@ -158,13 +158,6 @@ TEST_CASE( "getQuESTEnv", TEST_CATEGORY ) { QuESTEnv env = getQuESTEnv(); - REQUIRE( (env.isMultithreaded == 0 || env.isMultithreaded == 1) ); - REQUIRE( (env.isGpuAccelerated == 0 || env.isGpuAccelerated == 1) ); - REQUIRE( (env.isDistributed == 0 || env.isDistributed == 1) ); - REQUIRE( (env.userOwnsMpi == 0 || env.userOwnsMpi == 1) ); - REQUIRE( (env.isCuQuantumEnabled == 0 || env.isCuQuantumEnabled == 1) ); - REQUIRE( (env.isGpuSharingEnabled == 0 || env.isGpuSharingEnabled == 1) ); - REQUIRE( env.rank >= 0 ); REQUIRE( env.numNodes >= 0 );