1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-07-11 19:49:14 +00:00

[cluster] keep track of node counts cluster-wide.

This generalizes exchange of signals between the ranks using a non-blocking all-reduce. It is now used for the stop signal and the node count, but should be easily generalizable (TB hits, and ponder still missing). It avoids having long-lived outstanding non-blocking collectives (removes an early posted Ibarrier). A bit too short a test, but not worse than before:

Score of new-r4-1t vs old-r4-1t: 459 - 401 - 1505  [0.512] 2365
Elo difference: 8.52 +/- 8.43
This commit is contained in:
Joost VandeVondele 2018-12-18 00:00:06 +01:00 committed by Stéphane Nicolet
parent 2f882309d5
commit 87f0fa55a0
7 changed files with 123 additions and 65 deletions

View file

@ -359,7 +359,7 @@ endif
### 3.10 MPI
ifneq (,$(findstring mpi, $(CXX)))
mpi = yes
CXXFLAGS += -DUSE_MPI -Wno-cast-qual
CXXFLAGS += -DUSE_MPI -Wno-cast-qual -fexceptions
DEPENDFLAGS += -DUSE_MPI
endif

View file

@ -37,13 +37,21 @@ namespace Cluster {
static int world_rank = MPI_PROC_NULL;
static int world_size = 0;
static bool stop_signal = false;
static MPI_Request reqStop = MPI_REQUEST_NULL;
static MPI_Request reqSignals = MPI_REQUEST_NULL;
static uint64_t signalsCallCounter = 0;
enum Signals : int { SIG_NODES = 0, SIG_STOP = 1, SIG_NB = 2};
static uint64_t signalsSend[SIG_NB] = {};
static uint64_t signalsRecv[SIG_NB] = {};
static uint64_t nodesSearchedOthers = 0;
static uint64_t stopSignalsPosted = 0;
static MPI_Comm InputComm = MPI_COMM_NULL;
static MPI_Comm TTComm = MPI_COMM_NULL;
static MPI_Comm MoveComm = MPI_COMM_NULL;
static MPI_Comm StopComm = MPI_COMM_NULL;
static MPI_Comm signalsComm = MPI_COMM_NULL;
static std::vector<KeyedTTEntry> TTBuff;
@ -75,14 +83,34 @@ void init() {
MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
MPI_Comm_dup(MPI_COMM_WORLD, &StopComm);
MPI_Comm_dup(MPI_COMM_WORLD, &signalsComm);
}
void finalize() {
// free data tyes and communicators
MPI_Type_free(&MIDatatype);
MPI_Comm_free(&InputComm);
MPI_Comm_free(&TTComm);
MPI_Comm_free(&MoveComm);
MPI_Comm_free(&signalsComm);
MPI_Finalize();
}
int size() {
return world_size;
}
int rank() {
return world_rank;
}
bool getline(std::istream& input, std::string& str) {
int size;
@ -124,45 +152,60 @@ bool getline(std::istream& input, std::string& str) {
return state;
}
void sync_start() {
void signals_send() {
stop_signal = false;
// Start listening to stop signal
if (!is_root())
MPI_Ibarrier(StopComm, &reqStop);
signalsSend[SIG_NODES] = Threads.nodes_searched();
signalsSend[SIG_STOP] = Threads.stop;
MPI_Iallreduce(signalsSend, signalsRecv, SIG_NB, MPI_UINT64_T,
MPI_SUM, signalsComm, &reqSignals);
++signalsCallCounter;
}
void sync_stop() {
void signals_process() {
if (is_root())
nodesSearchedOthers = signalsRecv[SIG_NODES] - signalsSend[SIG_NODES];
stopSignalsPosted = signalsRecv[SIG_STOP];
if (signalsRecv[SIG_STOP] > 0)
Threads.stop = true;
}
void signals_sync() {
while(stopSignalsPosted < uint64_t(size()))
signals_poll();
// finalize outstanding messages of the signal loops. We might have issued one call less than needed on some ranks.
uint64_t globalCounter;
MPI_Allreduce(&signalsCallCounter, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); // MoveComm needed
if (signalsCallCounter < globalCounter)
signals_send();
assert(signalsCallCounter == globalCounter);
MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
signals_process();
}
void signals_init() {
stopSignalsPosted = nodesSearchedOthers = 0;
signalsSend[SIG_NODES] = signalsRecv[SIG_NODES] = 0;
signalsSend[SIG_STOP] = signalsRecv[SIG_STOP] = 0;
}
void signals_poll() {
int flag;
MPI_Test(&reqSignals, &flag, MPI_STATUS_IGNORE);
if (flag)
{
if (!stop_signal && Threads.stop)
{
// Signal the cluster about stopping
stop_signal = true;
MPI_Ibarrier(StopComm, &reqStop);
MPI_Wait(&reqStop, MPI_STATUS_IGNORE);
}
signals_process();
signals_send();
}
else
{
int flagStop;
// Check if we've received any stop signal
MPI_Test(&reqStop, &flagStop, MPI_STATUS_IGNORE);
if (flagStop)
Threads.stop = true;
}
}
int size() {
return world_size;
}
int rank() {
return world_rank;
}
void save(Thread* thread, TTEntry* tte,
@ -270,10 +313,23 @@ void pick_moves(MoveInfo& mi) {
MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm);
}
void sum(uint64_t& val) {
uint64_t nodes_searched() {
const uint64_t send = val;
MPI_Reduce(&send, &val, 1, MPI_UINT64_T, MPI_SUM, 0, MoveComm);
return nodesSearchedOthers + Threads.nodes_searched();
}
}
#else
#include "cluster.h"
#include "thread.h"
namespace Cluster {
uint64_t nodes_searched() {
return Threads.nodes_searched();
}
}

View file

@ -74,9 +74,10 @@ int rank();
inline bool is_root() { return rank() == 0; }
void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev);
void pick_moves(MoveInfo& mi);
void sum(uint64_t& val);
void sync_start();
void sync_stop();
uint64_t nodes_searched();
void signals_init();
void signals_poll();
void signals_sync();
#else
@ -88,9 +89,10 @@ constexpr int rank() { return 0; }
constexpr bool is_root() { return true; }
inline void save(Thread*, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, b, d, m, ev); }
inline void pick_moves(MoveInfo&) { }
inline void sum(uint64_t& ) { }
inline void sync_start() { }
inline void sync_stop() { }
uint64_t nodes_searched();
inline void signals_init() { }
inline void signals_poll() { }
inline void signals_sync() { }
#endif /* USE_MPI */

View file

@ -234,14 +234,14 @@ void MainThread::search() {
Threads.stopOnPonderhit = true;
while (!Threads.stop && (Threads.ponder || Limits.infinite))
{ } // Busy wait for a stop or a ponder reset
{ Cluster::signals_poll(); } // Busy wait for a stop or a ponder reset
// Stop the threads if not already stopped (also raise the stop if
// "ponderhit" just reset Threads.ponder).
Threads.stop = true;
// Finish any outstanding barriers.
Cluster::sync_stop();
// Signal and synchronize all other ranks
Cluster::signals_sync();
// Wait until all threads have finished
for (Thread* th : Threads)
@ -251,7 +251,7 @@ void MainThread::search() {
// When playing in 'nodes as time' mode, subtract the searched nodes from
// the available ones before exiting.
if (Limits.npmsec)
Time.availableNodes += Limits.inc[us] - Threads.nodes_searched();
Time.availableNodes += Limits.inc[us] - Cluster::nodes_searched();
// Check if there are threads with a better score than main thread
Thread* bestThread = this;
@ -370,7 +370,7 @@ void Thread::search() {
// Iterative deepening loop until requested to stop or the target depth is reached
while ( (rootDepth += ONE_PLY) < DEPTH_MAX
&& !Threads.stop
&& !(Limits.depth && mainThread && rootDepth / ONE_PLY > Limits.depth))
&& !(Limits.depth && mainThread && Cluster::is_root() && rootDepth / ONE_PLY > Limits.depth))
{
// Distribute search depths across the helper threads
if (idx + Cluster::rank() > 0)
@ -384,6 +384,7 @@ void Thread::search() {
if (mainThread)
mainThread->bestMoveChanges *= 0.517, failedLow = false;
// Save the last iteration's scores before first PV line is searched and
// all the move scores except the (new) PV are set to -VALUE_INFINITE.
for (RootMove& rm : rootMoves)
@ -1609,16 +1610,16 @@ void MainThread::check_time() {
dbg_print();
}
// poll on MPI signals
Cluster::signals_poll();
// We should not stop pondering until told so by the GUI
if (Threads.ponder)
return;
// Check if root has reached a stop barrier
Cluster::sync_stop();
if ( (Limits.use_time_management() && elapsed > Time.maximum() - 10)
|| (Limits.movetime && elapsed >= Limits.movetime)
|| (Limits.nodes && Threads.nodes_searched() >= (uint64_t)Limits.nodes))
|| (Limits.nodes && Cluster::nodes_searched() >= (uint64_t)Limits.nodes))
Threads.stop = true;
}
@ -1633,7 +1634,7 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) {
const RootMoves& rootMoves = pos.this_thread()->rootMoves;
size_t pvIdx = pos.this_thread()->pvIdx;
size_t multiPV = std::min((size_t)Options["MultiPV"], rootMoves.size());
uint64_t nodesSearched = Threads.nodes_searched();
uint64_t nodesSearched = Cluster::nodes_searched();
uint64_t tbHits = Threads.tb_hits() + (TB::RootInTB ? rootMoves.size() : 0);
for (size_t i = 0; i < multiPV; ++i)
@ -1661,9 +1662,8 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) {
if (!tb && i == pvIdx)
ss << (v >= beta ? " lowerbound" : v <= alpha ? " upperbound" : "");
// TODO fix approximate node calculation.
ss << " nodes " << nodesSearched * Cluster::size()
<< " nps " << nodesSearched * Cluster::size() * 1000 / elapsed;
ss << " nodes " << nodesSearched
<< " nps " << nodesSearched * 1000 / elapsed;
if (elapsed > 1000) // Earlier makes little sense
ss << " hashfull " << TT.hashfull();

View file

@ -163,7 +163,6 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states,
main()->wait_for_search_finished();
stopOnPonderhit = stop = false;
Cluster::sync_start();
ponder = ponderMode;
Search::Limits = limits;
@ -201,5 +200,7 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states,
setupStates->back() = tmp;
Cluster::signals_init();
main()->start_searching();
}

View file

@ -23,7 +23,7 @@
#include "misc.h"
#include "search.h"
#include "thread.h"
#include "cluster.h"
/// The TimeManagement class computes the optimal time to think depending on
/// the maximum available time, the game move number and other parameters.
@ -34,7 +34,7 @@ public:
TimePoint optimum() const { return optimumTime; }
TimePoint maximum() const { return maximumTime; }
TimePoint elapsed() const { return Search::Limits.npmsec ?
TimePoint(Threads.nodes_searched()) : now() - startTime; }
TimePoint(Cluster::nodes_searched()) : now() - startTime; }
int64_t availableNodes; // When in 'nodes as time' mode

View file

@ -162,7 +162,7 @@ namespace {
cerr << "\nPosition: " << cnt++ << '/' << num << endl;
go(pos, is, states);
Threads.main()->wait_for_search_finished();
nodes += Threads.nodes_searched();
nodes += Cluster::nodes_searched();
}
else if (token == "setoption") setoption(is);
else if (token == "position") position(pos, is, states);
@ -173,7 +173,6 @@ namespace {
dbg_print(); // Just before exiting
Cluster::sum(nodes);
if (Cluster::is_root())
cerr << "\n==========================="
<< "\nTotal time (ms) : " << elapsed