1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-07-11 19:49:14 +00:00

[cluster] keep track of node counts cluster-wide.

This generalizes exchange of signals between the ranks using a non-blocking all-reduce. It is now used for the stop signal and the node count, but should be easily generalizable (TB hits, and ponder still missing). It avoids having long-lived outstanding non-blocking collectives (removes an early posted Ibarrier). A bit too short a test, but not worse than before:

Score of new-r4-1t vs old-r4-1t: 459 - 401 - 1505  [0.512] 2365
Elo difference: 8.52 +/- 8.43
This commit is contained in:
Joost VandeVondele 2018-12-18 00:00:06 +01:00 committed by Stéphane Nicolet
parent 2f882309d5
commit 87f0fa55a0
7 changed files with 123 additions and 65 deletions

View file

@ -359,7 +359,7 @@ endif
### 3.10 MPI ### 3.10 MPI
ifneq (,$(findstring mpi, $(CXX))) ifneq (,$(findstring mpi, $(CXX)))
mpi = yes mpi = yes
CXXFLAGS += -DUSE_MPI -Wno-cast-qual CXXFLAGS += -DUSE_MPI -Wno-cast-qual -fexceptions
DEPENDFLAGS += -DUSE_MPI DEPENDFLAGS += -DUSE_MPI
endif endif

View file

@ -37,13 +37,21 @@ namespace Cluster {
static int world_rank = MPI_PROC_NULL; static int world_rank = MPI_PROC_NULL;
static int world_size = 0; static int world_size = 0;
static bool stop_signal = false;
static MPI_Request reqStop = MPI_REQUEST_NULL; static MPI_Request reqSignals = MPI_REQUEST_NULL;
static uint64_t signalsCallCounter = 0;
enum Signals : int { SIG_NODES = 0, SIG_STOP = 1, SIG_NB = 2};
static uint64_t signalsSend[SIG_NB] = {};
static uint64_t signalsRecv[SIG_NB] = {};
static uint64_t nodesSearchedOthers = 0;
static uint64_t stopSignalsPosted = 0;
static MPI_Comm InputComm = MPI_COMM_NULL; static MPI_Comm InputComm = MPI_COMM_NULL;
static MPI_Comm TTComm = MPI_COMM_NULL; static MPI_Comm TTComm = MPI_COMM_NULL;
static MPI_Comm MoveComm = MPI_COMM_NULL; static MPI_Comm MoveComm = MPI_COMM_NULL;
static MPI_Comm StopComm = MPI_COMM_NULL; static MPI_Comm signalsComm = MPI_COMM_NULL;
static std::vector<KeyedTTEntry> TTBuff; static std::vector<KeyedTTEntry> TTBuff;
@ -75,14 +83,34 @@ void init() {
MPI_Comm_dup(MPI_COMM_WORLD, &InputComm); MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
MPI_Comm_dup(MPI_COMM_WORLD, &TTComm); MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm); MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
MPI_Comm_dup(MPI_COMM_WORLD, &StopComm); MPI_Comm_dup(MPI_COMM_WORLD, &signalsComm);
} }
void finalize() { void finalize() {
// free data tyes and communicators
MPI_Type_free(&MIDatatype);
MPI_Comm_free(&InputComm);
MPI_Comm_free(&TTComm);
MPI_Comm_free(&MoveComm);
MPI_Comm_free(&signalsComm);
MPI_Finalize(); MPI_Finalize();
} }
int size() {
return world_size;
}
int rank() {
return world_rank;
}
bool getline(std::istream& input, std::string& str) { bool getline(std::istream& input, std::string& str) {
int size; int size;
@ -124,47 +152,62 @@ bool getline(std::istream& input, std::string& str) {
return state; return state;
} }
void sync_start() { void signals_send() {
stop_signal = false; signalsSend[SIG_NODES] = Threads.nodes_searched();
signalsSend[SIG_STOP] = Threads.stop;
// Start listening to stop signal MPI_Iallreduce(signalsSend, signalsRecv, SIG_NB, MPI_UINT64_T,
if (!is_root()) MPI_SUM, signalsComm, &reqSignals);
MPI_Ibarrier(StopComm, &reqStop); ++signalsCallCounter;
} }
void sync_stop() { void signals_process() {
if (is_root()) nodesSearchedOthers = signalsRecv[SIG_NODES] - signalsSend[SIG_NODES];
{ stopSignalsPosted = signalsRecv[SIG_STOP];
if (!stop_signal && Threads.stop) if (signalsRecv[SIG_STOP] > 0)
{
// Signal the cluster about stopping
stop_signal = true;
MPI_Ibarrier(StopComm, &reqStop);
MPI_Wait(&reqStop, MPI_STATUS_IGNORE);
}
}
else
{
int flagStop;
// Check if we've received any stop signal
MPI_Test(&reqStop, &flagStop, MPI_STATUS_IGNORE);
if (flagStop)
Threads.stop = true; Threads.stop = true;
}
void signals_sync() {
while(stopSignalsPosted < uint64_t(size()))
signals_poll();
// finalize outstanding messages of the signal loops. We might have issued one call less than needed on some ranks.
uint64_t globalCounter;
MPI_Allreduce(&signalsCallCounter, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); // MoveComm needed
if (signalsCallCounter < globalCounter)
signals_send();
assert(signalsCallCounter == globalCounter);
MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
signals_process();
}
void signals_init() {
stopSignalsPosted = nodesSearchedOthers = 0;
signalsSend[SIG_NODES] = signalsRecv[SIG_NODES] = 0;
signalsSend[SIG_STOP] = signalsRecv[SIG_STOP] = 0;
}
void signals_poll() {
int flag;
MPI_Test(&reqSignals, &flag, MPI_STATUS_IGNORE);
if (flag)
{
signals_process();
signals_send();
} }
} }
int size() {
return world_size;
}
int rank() {
return world_rank;
}
void save(Thread* thread, TTEntry* tte, void save(Thread* thread, TTEntry* tte,
Key k, Value v, Bound b, Depth d, Move m, Value ev) { Key k, Value v, Bound b, Depth d, Move m, Value ev) {
@ -270,10 +313,23 @@ void pick_moves(MoveInfo& mi) {
MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm); MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm);
} }
void sum(uint64_t& val) { uint64_t nodes_searched() {
const uint64_t send = val; return nodesSearchedOthers + Threads.nodes_searched();
MPI_Reduce(&send, &val, 1, MPI_UINT64_T, MPI_SUM, 0, MoveComm); }
}
#else
#include "cluster.h"
#include "thread.h"
namespace Cluster {
uint64_t nodes_searched() {
return Threads.nodes_searched();
} }
} }

View file

@ -74,9 +74,10 @@ int rank();
inline bool is_root() { return rank() == 0; } inline bool is_root() { return rank() == 0; }
void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev); void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev);
void pick_moves(MoveInfo& mi); void pick_moves(MoveInfo& mi);
void sum(uint64_t& val); uint64_t nodes_searched();
void sync_start(); void signals_init();
void sync_stop(); void signals_poll();
void signals_sync();
#else #else
@ -88,9 +89,10 @@ constexpr int rank() { return 0; }
constexpr bool is_root() { return true; } constexpr bool is_root() { return true; }
inline void save(Thread*, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, b, d, m, ev); } inline void save(Thread*, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, b, d, m, ev); }
inline void pick_moves(MoveInfo&) { } inline void pick_moves(MoveInfo&) { }
inline void sum(uint64_t& ) { } uint64_t nodes_searched();
inline void sync_start() { } inline void signals_init() { }
inline void sync_stop() { } inline void signals_poll() { }
inline void signals_sync() { }
#endif /* USE_MPI */ #endif /* USE_MPI */

View file

@ -234,14 +234,14 @@ void MainThread::search() {
Threads.stopOnPonderhit = true; Threads.stopOnPonderhit = true;
while (!Threads.stop && (Threads.ponder || Limits.infinite)) while (!Threads.stop && (Threads.ponder || Limits.infinite))
{ } // Busy wait for a stop or a ponder reset { Cluster::signals_poll(); } // Busy wait for a stop or a ponder reset
// Stop the threads if not already stopped (also raise the stop if // Stop the threads if not already stopped (also raise the stop if
// "ponderhit" just reset Threads.ponder). // "ponderhit" just reset Threads.ponder).
Threads.stop = true; Threads.stop = true;
// Finish any outstanding barriers. // Signal and synchronize all other ranks
Cluster::sync_stop(); Cluster::signals_sync();
// Wait until all threads have finished // Wait until all threads have finished
for (Thread* th : Threads) for (Thread* th : Threads)
@ -251,7 +251,7 @@ void MainThread::search() {
// When playing in 'nodes as time' mode, subtract the searched nodes from // When playing in 'nodes as time' mode, subtract the searched nodes from
// the available ones before exiting. // the available ones before exiting.
if (Limits.npmsec) if (Limits.npmsec)
Time.availableNodes += Limits.inc[us] - Threads.nodes_searched(); Time.availableNodes += Limits.inc[us] - Cluster::nodes_searched();
// Check if there are threads with a better score than main thread // Check if there are threads with a better score than main thread
Thread* bestThread = this; Thread* bestThread = this;
@ -370,7 +370,7 @@ void Thread::search() {
// Iterative deepening loop until requested to stop or the target depth is reached // Iterative deepening loop until requested to stop or the target depth is reached
while ( (rootDepth += ONE_PLY) < DEPTH_MAX while ( (rootDepth += ONE_PLY) < DEPTH_MAX
&& !Threads.stop && !Threads.stop
&& !(Limits.depth && mainThread && rootDepth / ONE_PLY > Limits.depth)) && !(Limits.depth && mainThread && Cluster::is_root() && rootDepth / ONE_PLY > Limits.depth))
{ {
// Distribute search depths across the helper threads // Distribute search depths across the helper threads
if (idx + Cluster::rank() > 0) if (idx + Cluster::rank() > 0)
@ -384,6 +384,7 @@ void Thread::search() {
if (mainThread) if (mainThread)
mainThread->bestMoveChanges *= 0.517, failedLow = false; mainThread->bestMoveChanges *= 0.517, failedLow = false;
// Save the last iteration's scores before first PV line is searched and // Save the last iteration's scores before first PV line is searched and
// all the move scores except the (new) PV are set to -VALUE_INFINITE. // all the move scores except the (new) PV are set to -VALUE_INFINITE.
for (RootMove& rm : rootMoves) for (RootMove& rm : rootMoves)
@ -1609,16 +1610,16 @@ void MainThread::check_time() {
dbg_print(); dbg_print();
} }
// poll on MPI signals
Cluster::signals_poll();
// We should not stop pondering until told so by the GUI // We should not stop pondering until told so by the GUI
if (Threads.ponder) if (Threads.ponder)
return; return;
// Check if root has reached a stop barrier
Cluster::sync_stop();
if ( (Limits.use_time_management() && elapsed > Time.maximum() - 10) if ( (Limits.use_time_management() && elapsed > Time.maximum() - 10)
|| (Limits.movetime && elapsed >= Limits.movetime) || (Limits.movetime && elapsed >= Limits.movetime)
|| (Limits.nodes && Threads.nodes_searched() >= (uint64_t)Limits.nodes)) || (Limits.nodes && Cluster::nodes_searched() >= (uint64_t)Limits.nodes))
Threads.stop = true; Threads.stop = true;
} }
@ -1633,7 +1634,7 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) {
const RootMoves& rootMoves = pos.this_thread()->rootMoves; const RootMoves& rootMoves = pos.this_thread()->rootMoves;
size_t pvIdx = pos.this_thread()->pvIdx; size_t pvIdx = pos.this_thread()->pvIdx;
size_t multiPV = std::min((size_t)Options["MultiPV"], rootMoves.size()); size_t multiPV = std::min((size_t)Options["MultiPV"], rootMoves.size());
uint64_t nodesSearched = Threads.nodes_searched(); uint64_t nodesSearched = Cluster::nodes_searched();
uint64_t tbHits = Threads.tb_hits() + (TB::RootInTB ? rootMoves.size() : 0); uint64_t tbHits = Threads.tb_hits() + (TB::RootInTB ? rootMoves.size() : 0);
for (size_t i = 0; i < multiPV; ++i) for (size_t i = 0; i < multiPV; ++i)
@ -1661,9 +1662,8 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) {
if (!tb && i == pvIdx) if (!tb && i == pvIdx)
ss << (v >= beta ? " lowerbound" : v <= alpha ? " upperbound" : ""); ss << (v >= beta ? " lowerbound" : v <= alpha ? " upperbound" : "");
// TODO fix approximate node calculation. ss << " nodes " << nodesSearched
ss << " nodes " << nodesSearched * Cluster::size() << " nps " << nodesSearched * 1000 / elapsed;
<< " nps " << nodesSearched * Cluster::size() * 1000 / elapsed;
if (elapsed > 1000) // Earlier makes little sense if (elapsed > 1000) // Earlier makes little sense
ss << " hashfull " << TT.hashfull(); ss << " hashfull " << TT.hashfull();

View file

@ -163,7 +163,6 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states,
main()->wait_for_search_finished(); main()->wait_for_search_finished();
stopOnPonderhit = stop = false; stopOnPonderhit = stop = false;
Cluster::sync_start();
ponder = ponderMode; ponder = ponderMode;
Search::Limits = limits; Search::Limits = limits;
@ -201,5 +200,7 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states,
setupStates->back() = tmp; setupStates->back() = tmp;
Cluster::signals_init();
main()->start_searching(); main()->start_searching();
} }

View file

@ -23,7 +23,7 @@
#include "misc.h" #include "misc.h"
#include "search.h" #include "search.h"
#include "thread.h" #include "cluster.h"
/// The TimeManagement class computes the optimal time to think depending on /// The TimeManagement class computes the optimal time to think depending on
/// the maximum available time, the game move number and other parameters. /// the maximum available time, the game move number and other parameters.
@ -34,7 +34,7 @@ public:
TimePoint optimum() const { return optimumTime; } TimePoint optimum() const { return optimumTime; }
TimePoint maximum() const { return maximumTime; } TimePoint maximum() const { return maximumTime; }
TimePoint elapsed() const { return Search::Limits.npmsec ? TimePoint elapsed() const { return Search::Limits.npmsec ?
TimePoint(Threads.nodes_searched()) : now() - startTime; } TimePoint(Cluster::nodes_searched()) : now() - startTime; }
int64_t availableNodes; // When in 'nodes as time' mode int64_t availableNodes; // When in 'nodes as time' mode

View file

@ -162,7 +162,7 @@ namespace {
cerr << "\nPosition: " << cnt++ << '/' << num << endl; cerr << "\nPosition: " << cnt++ << '/' << num << endl;
go(pos, is, states); go(pos, is, states);
Threads.main()->wait_for_search_finished(); Threads.main()->wait_for_search_finished();
nodes += Threads.nodes_searched(); nodes += Cluster::nodes_searched();
} }
else if (token == "setoption") setoption(is); else if (token == "setoption") setoption(is);
else if (token == "position") position(pos, is, states); else if (token == "position") position(pos, is, states);
@ -173,7 +173,6 @@ namespace {
dbg_print(); // Just before exiting dbg_print(); // Just before exiting
Cluster::sum(nodes);
if (Cluster::is_root()) if (Cluster::is_root())
cerr << "\n===========================" cerr << "\n==========================="
<< "\nTotal time (ms) : " << elapsed << "\nTotal time (ms) : " << elapsed