mirror of
https://github.com/sockspls/badfish
synced 2025-04-30 08:43:09 +00:00
[Cluster] Use a sendrecv ring instead of allgather
Using point to point instead of a collective improves performance, and might be more flexible for future improvements. Also corrects the condition for the number elements required to fill the send buffer. The actual Elo gains depends a bit on the setup used for testing. 8mpi x 32t yields 141 - 102 - 957 ~ 11 Elo 8mpi x 1t yields 70 +- 9 Elo.
This commit is contained in:
parent
5e7777e9d0
commit
bf17a410ec
3 changed files with 36 additions and 31 deletions
|
@ -65,9 +65,9 @@ static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
|
||||||
// The receive buffer is used to gather information from all ranks.
|
// The receive buffer is used to gather information from all ranks.
|
||||||
// THe TTCacheCounter tracks the number of local elements that are ready to be sent.
|
// THe TTCacheCounter tracks the number of local elements that are ready to be sent.
|
||||||
static MPI_Comm TTComm = MPI_COMM_NULL;
|
static MPI_Comm TTComm = MPI_COMM_NULL;
|
||||||
static MPI_Request reqGather = MPI_REQUEST_NULL;
|
static std::array<std::vector<KeyedTTEntry>, 2> TTSendRecvBuffs;
|
||||||
static uint64_t gathersPosted = 0;
|
static std::array<MPI_Request, 2> reqsTTSendRecv = {MPI_REQUEST_NULL, MPI_REQUEST_NULL};
|
||||||
static std::vector<KeyedTTEntry> TTRecvBuff;
|
static uint64_t sendRecvPosted = 0;
|
||||||
static std::atomic<uint64_t> TTCacheCounter = {};
|
static std::atomic<uint64_t> TTCacheCounter = {};
|
||||||
|
|
||||||
/// Initialize MPI and associated data types. Note that the MPI library must be configured
|
/// Initialize MPI and associated data types. Note that the MPI library must be configured
|
||||||
|
@ -126,11 +126,13 @@ int rank() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The receive buffer depends on the number of MPI ranks and threads, resize as needed
|
/// The receive buffer depends on the number of MPI ranks and threads, resize as needed
|
||||||
void ttRecvBuff_resize(size_t nThreads) {
|
void ttSendRecvBuff_resize(size_t nThreads) {
|
||||||
|
|
||||||
TTRecvBuff.resize(TTCacheSize * world_size * nThreads);
|
|
||||||
std::fill(TTRecvBuff.begin(), TTRecvBuff.end(), KeyedTTEntry());
|
|
||||||
|
|
||||||
|
for (int i : {0, 1})
|
||||||
|
{
|
||||||
|
TTSendRecvBuffs[i].resize(TTCacheSize * world_size * nThreads);
|
||||||
|
std::fill(TTSendRecvBuffs[i].begin(), TTSendRecvBuffs[i].end(), KeyedTTEntry());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// As input is only received by the root (rank 0) of the cluster, this input must be relayed
|
/// As input is only received by the root (rank 0) of the cluster, this input must be relayed
|
||||||
|
@ -208,6 +210,17 @@ void signals_process() {
|
||||||
Threads.stop = true;
|
Threads.stop = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void sendrecv_post() {
|
||||||
|
|
||||||
|
++sendRecvPosted;
|
||||||
|
MPI_Irecv(TTSendRecvBuffs[sendRecvPosted % 2].data(),
|
||||||
|
TTSendRecvBuffs[sendRecvPosted % 2].size() * sizeof(KeyedTTEntry), MPI_BYTE,
|
||||||
|
(rank() + size() - 1) % size(), 42, TTComm, &reqsTTSendRecv[0]);
|
||||||
|
MPI_Isend(TTSendRecvBuffs[(sendRecvPosted + 1) % 2].data(),
|
||||||
|
TTSendRecvBuffs[(sendRecvPosted + 1) % 2].size() * sizeof(KeyedTTEntry), MPI_BYTE,
|
||||||
|
(rank() + 1 ) % size(), 42, TTComm, &reqsTTSendRecv[1]);
|
||||||
|
}
|
||||||
|
|
||||||
/// During search, most message passing is asynchronous, but at the end of
|
/// During search, most message passing is asynchronous, but at the end of
|
||||||
/// search it makes sense to bring them to a common, finalized state.
|
/// search it makes sense to bring them to a common, finalized state.
|
||||||
void signals_sync() {
|
void signals_sync() {
|
||||||
|
@ -226,22 +239,17 @@ void signals_sync() {
|
||||||
}
|
}
|
||||||
assert(signalsCallCounter == globalCounter);
|
assert(signalsCallCounter == globalCounter);
|
||||||
MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
|
MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
|
||||||
|
|
||||||
signals_process();
|
signals_process();
|
||||||
|
|
||||||
// Finalize outstanding messages in the gather loop
|
// Finalize outstanding messages in the sendRecv loop
|
||||||
MPI_Allreduce(&gathersPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm);
|
MPI_Allreduce(&sendRecvPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm);
|
||||||
if (gathersPosted < globalCounter)
|
while (sendRecvPosted < globalCounter)
|
||||||
{
|
{
|
||||||
size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
|
MPI_Waitall(reqsTTSendRecv.size(), reqsTTSendRecv.data(), MPI_STATUSES_IGNORE);
|
||||||
MPI_Wait(&reqGather, MPI_STATUS_IGNORE);
|
sendrecv_post();
|
||||||
MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
|
|
||||||
TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
|
|
||||||
TTComm, &reqGather);
|
|
||||||
++gathersPosted;
|
|
||||||
}
|
}
|
||||||
assert(gathersPosted == globalCounter);
|
assert(sendRecvPosted == globalCounter);
|
||||||
MPI_Wait(&reqGather, MPI_STATUS_IGNORE);
|
MPI_Waitall(reqsTTSendRecv.size(), reqsTTSendRecv.data(), MPI_STATUSES_IGNORE);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -279,7 +287,7 @@ void cluster_info(Depth depth) {
|
||||||
|
|
||||||
sync_cout << "info depth " << depth / ONE_PLY << " cluster "
|
sync_cout << "info depth " << depth / ONE_PLY << " cluster "
|
||||||
<< " signals " << signalsCallCounter << " sps " << signalsCallCounter * 1000 / elapsed
|
<< " signals " << signalsCallCounter << " sps " << signalsCallCounter * 1000 / elapsed
|
||||||
<< " gathers " << gathersPosted << " gpps " << TTRecvBuff.size() * gathersPosted * 1000 / elapsed
|
<< " sendRecvs " << sendRecvPosted << " srpps " << TTSendRecvBuffs[0].size() * sendRecvPosted * 1000 / elapsed
|
||||||
<< " TTSaves " << TTSaves << " TTSavesps " << TTSaves * 1000 / elapsed
|
<< " TTSaves " << TTSaves << " TTSavesps " << TTSaves * 1000 / elapsed
|
||||||
<< sync_endl;
|
<< sync_endl;
|
||||||
}
|
}
|
||||||
|
@ -312,11 +320,11 @@ void save(Thread* thread, TTEntry* tte,
|
||||||
|
|
||||||
// Communicate on main search thread, as soon the threads combined have collected
|
// Communicate on main search thread, as soon the threads combined have collected
|
||||||
// sufficient data to fill the send buffers.
|
// sufficient data to fill the send buffers.
|
||||||
if (thread == Threads.main() && TTCacheCounter > size() * recvBuffPerRankSize)
|
if (thread == Threads.main() && TTCacheCounter > recvBuffPerRankSize)
|
||||||
{
|
{
|
||||||
// Test communication status
|
// Test communication status
|
||||||
int flag;
|
int flag;
|
||||||
MPI_Test(&reqGather, &flag, MPI_STATUS_IGNORE);
|
MPI_Testall(reqsTTSendRecv.size(), reqsTTSendRecv.data(), &flag, MPI_STATUSES_IGNORE);
|
||||||
|
|
||||||
// Current communication is complete
|
// Current communication is complete
|
||||||
if (flag)
|
if (flag)
|
||||||
|
@ -333,7 +341,7 @@ void save(Thread* thread, TTEntry* tte,
|
||||||
std::lock_guard<Mutex> lk(th->ttCache.mutex);
|
std::lock_guard<Mutex> lk(th->ttCache.mutex);
|
||||||
|
|
||||||
for (auto&& e : th->ttCache.buffer)
|
for (auto&& e : th->ttCache.buffer)
|
||||||
TTRecvBuff[i++] = e;
|
TTSendRecvBuffs[sendRecvPosted % 2][i++] = e;
|
||||||
|
|
||||||
// Reset thread's send buffer
|
// Reset thread's send buffer
|
||||||
th->ttCache.buffer = {};
|
th->ttCache.buffer = {};
|
||||||
|
@ -344,7 +352,7 @@ void save(Thread* thread, TTEntry* tte,
|
||||||
else // process data received from the corresponding rank.
|
else // process data received from the corresponding rank.
|
||||||
for (size_t i = irank * recvBuffPerRankSize; i < (irank + 1) * recvBuffPerRankSize; ++i)
|
for (size_t i = irank * recvBuffPerRankSize; i < (irank + 1) * recvBuffPerRankSize; ++i)
|
||||||
{
|
{
|
||||||
auto&& e = TTRecvBuff[i];
|
auto&& e = TTSendRecvBuffs[sendRecvPosted % 2][i];
|
||||||
bool found;
|
bool found;
|
||||||
TTEntry* replace_tte;
|
TTEntry* replace_tte;
|
||||||
replace_tte = TT.probe(e.first, found);
|
replace_tte = TT.probe(e.first, found);
|
||||||
|
@ -354,10 +362,7 @@ void save(Thread* thread, TTEntry* tte,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start next communication
|
// Start next communication
|
||||||
MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
|
sendrecv_post();
|
||||||
TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
|
|
||||||
TTComm, &reqGather);
|
|
||||||
++gathersPosted;
|
|
||||||
|
|
||||||
// Force check of time on the next occasion, the above actions might have taken some time.
|
// Force check of time on the next occasion, the above actions might have taken some time.
|
||||||
static_cast<MainThread*>(thread)->callsCnt = 0;
|
static_cast<MainThread*>(thread)->callsCnt = 0;
|
||||||
|
|
|
@ -91,7 +91,7 @@ int rank();
|
||||||
inline bool is_root() { return rank() == 0; }
|
inline bool is_root() { return rank() == 0; }
|
||||||
void save(Thread* thread, TTEntry* tte, Key k, Value v, bool PvHit, Bound b, Depth d, Move m, Value ev);
|
void save(Thread* thread, TTEntry* tte, Key k, Value v, bool PvHit, Bound b, Depth d, Move m, Value ev);
|
||||||
void pick_moves(MoveInfo& mi, std::string& PVLine);
|
void pick_moves(MoveInfo& mi, std::string& PVLine);
|
||||||
void ttRecvBuff_resize(size_t nThreads);
|
void ttSendRecvBuff_resize(size_t nThreads);
|
||||||
uint64_t nodes_searched();
|
uint64_t nodes_searched();
|
||||||
uint64_t tb_hits();
|
uint64_t tb_hits();
|
||||||
uint64_t TT_saves();
|
uint64_t TT_saves();
|
||||||
|
@ -110,7 +110,7 @@ constexpr int rank() { return 0; }
|
||||||
constexpr bool is_root() { return true; }
|
constexpr bool is_root() { return true; }
|
||||||
inline void save(Thread*, TTEntry* tte, Key k, Value v, bool PvHit, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, PvHit, b, d, m, ev); }
|
inline void save(Thread*, TTEntry* tte, Key k, Value v, bool PvHit, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, PvHit, b, d, m, ev); }
|
||||||
inline void pick_moves(MoveInfo&, std::string&) { }
|
inline void pick_moves(MoveInfo&, std::string&) { }
|
||||||
inline void ttRecvBuff_resize(size_t) { }
|
inline void ttSendRecvBuff_resize(size_t) { }
|
||||||
uint64_t nodes_searched();
|
uint64_t nodes_searched();
|
||||||
uint64_t tb_hits();
|
uint64_t tb_hits();
|
||||||
uint64_t TT_saves();
|
uint64_t TT_saves();
|
||||||
|
|
|
@ -141,7 +141,7 @@ void ThreadPool::set(size_t requested) {
|
||||||
TT.resize(Options["Hash"]);
|
TT.resize(Options["Hash"]);
|
||||||
|
|
||||||
// Adjust cluster buffers
|
// Adjust cluster buffers
|
||||||
Cluster::ttRecvBuff_resize(requested);
|
Cluster::ttSendRecvBuff_resize(requested);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue