diff --git a/src/cluster.cpp b/src/cluster.cpp index e7d73e21..ea287676 100644 --- a/src/cluster.cpp +++ b/src/cluster.cpp @@ -65,9 +65,9 @@ static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL; // The receive buffer is used to gather information from all ranks. // THe TTCacheCounter tracks the number of local elements that are ready to be sent. static MPI_Comm TTComm = MPI_COMM_NULL; -static MPI_Request reqGather = MPI_REQUEST_NULL; -static uint64_t gathersPosted = 0; -static std::vector TTRecvBuff; +static std::array, 2> TTSendRecvBuffs; +static std::array reqsTTSendRecv = {MPI_REQUEST_NULL, MPI_REQUEST_NULL}; +static uint64_t sendRecvPosted = 0; static std::atomic TTCacheCounter = {}; /// Initialize MPI and associated data types. Note that the MPI library must be configured @@ -126,11 +126,13 @@ int rank() { } /// The receive buffer depends on the number of MPI ranks and threads, resize as needed -void ttRecvBuff_resize(size_t nThreads) { - - TTRecvBuff.resize(TTCacheSize * world_size * nThreads); - std::fill(TTRecvBuff.begin(), TTRecvBuff.end(), KeyedTTEntry()); +void ttSendRecvBuff_resize(size_t nThreads) { + for (int i : {0, 1}) + { + TTSendRecvBuffs[i].resize(TTCacheSize * world_size * nThreads); + std::fill(TTSendRecvBuffs[i].begin(), TTSendRecvBuffs[i].end(), KeyedTTEntry()); + } } /// As input is only received by the root (rank 0) of the cluster, this input must be relayed @@ -208,6 +210,17 @@ void signals_process() { Threads.stop = true; } +void sendrecv_post() { + + ++sendRecvPosted; + MPI_Irecv(TTSendRecvBuffs[sendRecvPosted % 2].data(), + TTSendRecvBuffs[sendRecvPosted % 2].size() * sizeof(KeyedTTEntry), MPI_BYTE, + (rank() + size() - 1) % size(), 42, TTComm, &reqsTTSendRecv[0]); + MPI_Isend(TTSendRecvBuffs[(sendRecvPosted + 1) % 2].data(), + TTSendRecvBuffs[(sendRecvPosted + 1) % 2].size() * sizeof(KeyedTTEntry), MPI_BYTE, + (rank() + 1 ) % size(), 42, TTComm, &reqsTTSendRecv[1]); +} + /// During search, most message passing is asynchronous, but at the end of /// search it makes sense to bring them to a common, finalized state. void signals_sync() { @@ -226,22 +239,17 @@ void signals_sync() { } assert(signalsCallCounter == globalCounter); MPI_Wait(&reqSignals, MPI_STATUS_IGNORE); - signals_process(); - // Finalize outstanding messages in the gather loop - MPI_Allreduce(&gathersPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); - if (gathersPosted < globalCounter) + // Finalize outstanding messages in the sendRecv loop + MPI_Allreduce(&sendRecvPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); + while (sendRecvPosted < globalCounter) { - size_t recvBuffPerRankSize = Threads.size() * TTCacheSize; - MPI_Wait(&reqGather, MPI_STATUS_IGNORE); - MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE, - TTComm, &reqGather); - ++gathersPosted; + MPI_Waitall(reqsTTSendRecv.size(), reqsTTSendRecv.data(), MPI_STATUSES_IGNORE); + sendrecv_post(); } - assert(gathersPosted == globalCounter); - MPI_Wait(&reqGather, MPI_STATUS_IGNORE); + assert(sendRecvPosted == globalCounter); + MPI_Waitall(reqsTTSendRecv.size(), reqsTTSendRecv.data(), MPI_STATUSES_IGNORE); } @@ -279,7 +287,7 @@ void cluster_info(Depth depth) { sync_cout << "info depth " << depth / ONE_PLY << " cluster " << " signals " << signalsCallCounter << " sps " << signalsCallCounter * 1000 / elapsed - << " gathers " << gathersPosted << " gpps " << TTRecvBuff.size() * gathersPosted * 1000 / elapsed + << " sendRecvs " << sendRecvPosted << " srpps " << TTSendRecvBuffs[0].size() * sendRecvPosted * 1000 / elapsed << " TTSaves " << TTSaves << " TTSavesps " << TTSaves * 1000 / elapsed << sync_endl; } @@ -312,11 +320,11 @@ void save(Thread* thread, TTEntry* tte, // Communicate on main search thread, as soon the threads combined have collected // sufficient data to fill the send buffers. - if (thread == Threads.main() && TTCacheCounter > size() * recvBuffPerRankSize) + if (thread == Threads.main() && TTCacheCounter > recvBuffPerRankSize) { // Test communication status int flag; - MPI_Test(&reqGather, &flag, MPI_STATUS_IGNORE); + MPI_Testall(reqsTTSendRecv.size(), reqsTTSendRecv.data(), &flag, MPI_STATUSES_IGNORE); // Current communication is complete if (flag) @@ -333,7 +341,7 @@ void save(Thread* thread, TTEntry* tte, std::lock_guard lk(th->ttCache.mutex); for (auto&& e : th->ttCache.buffer) - TTRecvBuff[i++] = e; + TTSendRecvBuffs[sendRecvPosted % 2][i++] = e; // Reset thread's send buffer th->ttCache.buffer = {}; @@ -344,7 +352,7 @@ void save(Thread* thread, TTEntry* tte, else // process data received from the corresponding rank. for (size_t i = irank * recvBuffPerRankSize; i < (irank + 1) * recvBuffPerRankSize; ++i) { - auto&& e = TTRecvBuff[i]; + auto&& e = TTSendRecvBuffs[sendRecvPosted % 2][i]; bool found; TTEntry* replace_tte; replace_tte = TT.probe(e.first, found); @@ -354,10 +362,7 @@ void save(Thread* thread, TTEntry* tte, } // Start next communication - MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE, - TTComm, &reqGather); - ++gathersPosted; + sendrecv_post(); // Force check of time on the next occasion, the above actions might have taken some time. static_cast(thread)->callsCnt = 0; diff --git a/src/cluster.h b/src/cluster.h index 38ad1253..0e074554 100644 --- a/src/cluster.h +++ b/src/cluster.h @@ -91,7 +91,7 @@ int rank(); inline bool is_root() { return rank() == 0; } void save(Thread* thread, TTEntry* tte, Key k, Value v, bool PvHit, Bound b, Depth d, Move m, Value ev); void pick_moves(MoveInfo& mi, std::string& PVLine); -void ttRecvBuff_resize(size_t nThreads); +void ttSendRecvBuff_resize(size_t nThreads); uint64_t nodes_searched(); uint64_t tb_hits(); uint64_t TT_saves(); @@ -110,7 +110,7 @@ constexpr int rank() { return 0; } constexpr bool is_root() { return true; } inline void save(Thread*, TTEntry* tte, Key k, Value v, bool PvHit, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, PvHit, b, d, m, ev); } inline void pick_moves(MoveInfo&, std::string&) { } -inline void ttRecvBuff_resize(size_t) { } +inline void ttSendRecvBuff_resize(size_t) { } uint64_t nodes_searched(); uint64_t tb_hits(); uint64_t TT_saves(); diff --git a/src/thread.cpp b/src/thread.cpp index 29267963..69a7752f 100644 --- a/src/thread.cpp +++ b/src/thread.cpp @@ -141,7 +141,7 @@ void ThreadPool::set(size_t requested) { TT.resize(Options["Hash"]); // Adjust cluster buffers - Cluster::ttRecvBuff_resize(requested); + Cluster::ttSendRecvBuff_resize(requested); } }