Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions quest/src/comm/comm_routines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
#include "quest/src/core/errors.hpp"
#include "quest/src/core/bitwise.hpp"
#include "quest/src/cpu/cpu_config.hpp"
#include "quest/src/cpu/cpu_subroutines.hpp"
#include "quest/src/gpu/gpu_config.hpp"
#include "quest/src/comm/comm_config.hpp"
#include "quest/src/comm/comm_indices.hpp"

#include <map>
#include <vector>

#if QUEST_COMPILE_MPI
#include <mpi.h>
extern MPI_Comm comm_getMpiComm(); // comm_config.cpp does not leak MPI_Comm
Expand Down Expand Up @@ -827,3 +831,52 @@ vector<string> comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars)
return {};
#endif
}

void comm_exchangeFusedMultiSwap(Qureg qureg, ConstList64 ctrls, ConstList64 ctrlStates, const std::map<int, int>& swapMap) {
assert_commQuregIsDistributed(qureg);

#if QUEST_COMPILE_MPI
int k = swapMap.size();
if (k == 0) return;

// GPU fallback: sync GPU amps to CPU, perform fused swap on CPU, sync back
if (qureg.isGpuAccelerated)
syncQuregFromGpu(qureg);

qindex chunkSize = qureg.numAmpsPerNode >> k;

std::vector<int> prefixTargs;
for (auto const& [s, p] : swapMap) {
prefixTargs.push_back(p);
}

int myPrefixBits = 0;
for (int i = 0; i < k; i++) {
if (util_getRankBitOfQubit(prefixTargs[i], qureg)) {
myPrefixBits |= (1 << i);
}
}

qcomp* sendBuffer = qureg.cpuCommBuffer;
qcomp* recvBuffer = qureg.cpuCommBuffer + chunkSize;

for (int s = 1; s < (1 << k); s++) {
int target_m = myPrefixBits ^ s;
int pairRank = qureg.rank;
for (int i = 0; i < k; i++) {
if ((s >> i) & 1) {
pairRank = flipBit(pairRank, util_getPrefixInd(prefixTargs[i], qureg));
}
}

cpu_statevec_packFusedMultiSwapBuffers(qureg, swapMap, target_m, sendBuffer);
exchangeArrays(sendBuffer, recvBuffer, chunkSize, pairRank);
cpu_statevec_unpackFusedMultiSwapBuffers(qureg, swapMap, target_m, recvBuffer);
}

if (qureg.isGpuAccelerated)
syncQuregToGpu(qureg);
#else
error_commButEnvNotDistributed();
#endif
}
52 changes: 27 additions & 25 deletions quest/src/comm/comm_routines.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,36 @@
* Signatures for communicating and exchanging amplitudes between compute
* nodes, when running in distributed mode, using the C MPI standard.
* Calling these functions when QUEST_COMPILE_MPI=0, or when the passed Quregs
* are not distributed, will throw a runtime internal error.
*
* are not distributed, will throw a runtime internal error.
*
* @author Tyson Jones
*/

#ifndef COMM_ROUTINES_HPP
#define COMM_ROUTINES_HPP

#include "quest/include/types.h"
#include "quest/include/qureg.h"
#include "quest/include/matrices.h"
#include "quest/include/qureg.h"
#include "quest/include/types.h"
#include "quest/src/core/utilities.hpp"

#include <vector>
#include <map>
#include <string>
#include <vector>

using std::vector;



/*
* STATE EXCHANGE METHODS
*/

void comm_exchangeAmpsToBuffers(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps, int pairRank);
void comm_exchangeAmpsToBuffers(Qureg qureg, qindex sendInd, qindex recvInd,
qindex numAmps, int pairRank);

void comm_exchangeAmpsToBuffers(Qureg qureg, int pairRank);

void comm_exchangeSubBuffers(Qureg qureg, qindex numAmpsAndRecvInd, int pairRank);
void comm_exchangeSubBuffers(Qureg qureg, qindex numAmpsAndRecvInd,
int pairRank);

void comm_asynchSendSubBuffer(Qureg qureg, qindex numElems, int pairRank);

Expand All @@ -39,46 +41,46 @@ void comm_combineAmpsIntoBuffer(Qureg receiver, Qureg sender);

void comm_combineElemsIntoBuffer(Qureg receiver, FullStateDiagMatr sender);


void comm_exchangeFusedMultiSwap(Qureg qureg, ConstList64 ctrls,
ConstList64 ctrlStates,
const std::map<int, int> &swapMap);

/*
* MISC COMMUNICATION METHODS
*/

void comm_broadcastAmp(int sendRank, qcomp* sendAmp);

void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps);
void comm_broadcastAmp(int sendRank, qcomp *sendAmp);

void comm_broadcastIntsFromRoot(int* arr, qindex length);
void comm_sendAmpsToRoot(int sendRank, qcomp *send, qcomp *recv,
qindex numAmps);

void comm_broadcastUnsignedsFromRoot(unsigned* arr, qindex length);

void comm_combineSubArrays(qcomp* recv, vector<qindex> globalRecvInds, vector<qindex> localSendInds, vector<qindex> numAmpsPerRank);
void comm_broadcastIntsFromRoot(int *arr, qindex length);

void comm_broadcastUnsignedsFromRoot(unsigned *arr, qindex length);

void comm_combineSubArrays(qcomp *recv, vector<qindex> globalRecvInds,
vector<qindex> localSendInds,
vector<qindex> numAmpsPerRank);

/*
* REDUCTION METHODS
*/

void comm_reduceAmp(qcomp* localAmp);
void comm_reduceAmp(qcomp *localAmp);

void comm_reduceReal(qreal* localReal);
void comm_reduceReal(qreal *localReal);

void comm_reduceReals(qreal* localReals, qindex numLocalReals);
void comm_reduceReals(qreal *localReals, qindex numLocalReals);

bool comm_isTrueOnAllNodes(bool val);

bool comm_isTrueOnRootNode(bool val);



/*
* GATHER METHODS
*/

vector<std::string> comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars);


vector<std::string> comm_gatherStringsToRoot(char *localChars,
int maxNumLocalChars);

#endif // COMM_ROUTINES_HPP
Loading