mirror of
https://github.com/sockspls/badfish
synced 2025-07-11 19:49:14 +00:00
[cluster] keep track of node counts cluster-wide.
This generalizes exchange of signals between the ranks using a non-blocking all-reduce. It is now used for the stop signal and the node count, but should be easily generalizable (TB hits, and ponder still missing). It avoids having long-lived outstanding non-blocking collectives (removes an early posted Ibarrier). A bit too short a test, but not worse than before: Score of new-r4-1t vs old-r4-1t: 459 - 401 - 1505 [0.512] 2365 Elo difference: 8.52 +/- 8.43
This commit is contained in:
parent
2f882309d5
commit
87f0fa55a0
7 changed files with 123 additions and 65 deletions
|
@ -359,7 +359,7 @@ endif
|
|||
### 3.10 MPI
|
||||
ifneq (,$(findstring mpi, $(CXX)))
|
||||
mpi = yes
|
||||
CXXFLAGS += -DUSE_MPI -Wno-cast-qual
|
||||
CXXFLAGS += -DUSE_MPI -Wno-cast-qual -fexceptions
|
||||
DEPENDFLAGS += -DUSE_MPI
|
||||
endif
|
||||
|
||||
|
|
136
src/cluster.cpp
136
src/cluster.cpp
|
@ -37,13 +37,21 @@ namespace Cluster {
|
|||
|
||||
static int world_rank = MPI_PROC_NULL;
|
||||
static int world_size = 0;
|
||||
static bool stop_signal = false;
|
||||
static MPI_Request reqStop = MPI_REQUEST_NULL;
|
||||
|
||||
static MPI_Request reqSignals = MPI_REQUEST_NULL;
|
||||
static uint64_t signalsCallCounter = 0;
|
||||
|
||||
enum Signals : int { SIG_NODES = 0, SIG_STOP = 1, SIG_NB = 2};
|
||||
static uint64_t signalsSend[SIG_NB] = {};
|
||||
static uint64_t signalsRecv[SIG_NB] = {};
|
||||
|
||||
static uint64_t nodesSearchedOthers = 0;
|
||||
static uint64_t stopSignalsPosted = 0;
|
||||
|
||||
static MPI_Comm InputComm = MPI_COMM_NULL;
|
||||
static MPI_Comm TTComm = MPI_COMM_NULL;
|
||||
static MPI_Comm MoveComm = MPI_COMM_NULL;
|
||||
static MPI_Comm StopComm = MPI_COMM_NULL;
|
||||
static MPI_Comm signalsComm = MPI_COMM_NULL;
|
||||
|
||||
static std::vector<KeyedTTEntry> TTBuff;
|
||||
|
||||
|
@ -75,14 +83,34 @@ void init() {
|
|||
MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &StopComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &signalsComm);
|
||||
}
|
||||
|
||||
void finalize() {
|
||||
|
||||
|
||||
// free data tyes and communicators
|
||||
MPI_Type_free(&MIDatatype);
|
||||
|
||||
MPI_Comm_free(&InputComm);
|
||||
MPI_Comm_free(&TTComm);
|
||||
MPI_Comm_free(&MoveComm);
|
||||
MPI_Comm_free(&signalsComm);
|
||||
|
||||
MPI_Finalize();
|
||||
}
|
||||
|
||||
int size() {
|
||||
|
||||
return world_size;
|
||||
}
|
||||
|
||||
int rank() {
|
||||
|
||||
return world_rank;
|
||||
}
|
||||
|
||||
|
||||
bool getline(std::istream& input, std::string& str) {
|
||||
|
||||
int size;
|
||||
|
@ -124,45 +152,60 @@ bool getline(std::istream& input, std::string& str) {
|
|||
return state;
|
||||
}
|
||||
|
||||
void sync_start() {
|
||||
void signals_send() {
|
||||
|
||||
stop_signal = false;
|
||||
|
||||
// Start listening to stop signal
|
||||
if (!is_root())
|
||||
MPI_Ibarrier(StopComm, &reqStop);
|
||||
signalsSend[SIG_NODES] = Threads.nodes_searched();
|
||||
signalsSend[SIG_STOP] = Threads.stop;
|
||||
MPI_Iallreduce(signalsSend, signalsRecv, SIG_NB, MPI_UINT64_T,
|
||||
MPI_SUM, signalsComm, &reqSignals);
|
||||
++signalsCallCounter;
|
||||
}
|
||||
|
||||
void sync_stop() {
|
||||
void signals_process() {
|
||||
|
||||
if (is_root())
|
||||
nodesSearchedOthers = signalsRecv[SIG_NODES] - signalsSend[SIG_NODES];
|
||||
stopSignalsPosted = signalsRecv[SIG_STOP];
|
||||
if (signalsRecv[SIG_STOP] > 0)
|
||||
Threads.stop = true;
|
||||
}
|
||||
|
||||
void signals_sync() {
|
||||
|
||||
while(stopSignalsPosted < uint64_t(size()))
|
||||
signals_poll();
|
||||
|
||||
// finalize outstanding messages of the signal loops. We might have issued one call less than needed on some ranks.
|
||||
uint64_t globalCounter;
|
||||
MPI_Allreduce(&signalsCallCounter, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); // MoveComm needed
|
||||
if (signalsCallCounter < globalCounter)
|
||||
signals_send();
|
||||
|
||||
assert(signalsCallCounter == globalCounter);
|
||||
|
||||
MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
|
||||
|
||||
signals_process();
|
||||
|
||||
}
|
||||
|
||||
void signals_init() {
|
||||
|
||||
stopSignalsPosted = nodesSearchedOthers = 0;
|
||||
|
||||
signalsSend[SIG_NODES] = signalsRecv[SIG_NODES] = 0;
|
||||
signalsSend[SIG_STOP] = signalsRecv[SIG_STOP] = 0;
|
||||
|
||||
}
|
||||
|
||||
void signals_poll() {
|
||||
|
||||
int flag;
|
||||
MPI_Test(&reqSignals, &flag, MPI_STATUS_IGNORE);
|
||||
if (flag)
|
||||
{
|
||||
if (!stop_signal && Threads.stop)
|
||||
{
|
||||
// Signal the cluster about stopping
|
||||
stop_signal = true;
|
||||
MPI_Ibarrier(StopComm, &reqStop);
|
||||
MPI_Wait(&reqStop, MPI_STATUS_IGNORE);
|
||||
}
|
||||
signals_process();
|
||||
signals_send();
|
||||
}
|
||||
else
|
||||
{
|
||||
int flagStop;
|
||||
// Check if we've received any stop signal
|
||||
MPI_Test(&reqStop, &flagStop, MPI_STATUS_IGNORE);
|
||||
if (flagStop)
|
||||
Threads.stop = true;
|
||||
}
|
||||
}
|
||||
|
||||
int size() {
|
||||
|
||||
return world_size;
|
||||
}
|
||||
|
||||
int rank() {
|
||||
|
||||
return world_rank;
|
||||
}
|
||||
|
||||
void save(Thread* thread, TTEntry* tte,
|
||||
|
@ -270,10 +313,23 @@ void pick_moves(MoveInfo& mi) {
|
|||
MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm);
|
||||
}
|
||||
|
||||
void sum(uint64_t& val) {
|
||||
uint64_t nodes_searched() {
|
||||
|
||||
const uint64_t send = val;
|
||||
MPI_Reduce(&send, &val, 1, MPI_UINT64_T, MPI_SUM, 0, MoveComm);
|
||||
return nodesSearchedOthers + Threads.nodes_searched();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#include "cluster.h"
|
||||
#include "thread.h"
|
||||
|
||||
namespace Cluster {
|
||||
|
||||
uint64_t nodes_searched() {
|
||||
|
||||
return Threads.nodes_searched();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -74,9 +74,10 @@ int rank();
|
|||
inline bool is_root() { return rank() == 0; }
|
||||
void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev);
|
||||
void pick_moves(MoveInfo& mi);
|
||||
void sum(uint64_t& val);
|
||||
void sync_start();
|
||||
void sync_stop();
|
||||
uint64_t nodes_searched();
|
||||
void signals_init();
|
||||
void signals_poll();
|
||||
void signals_sync();
|
||||
|
||||
#else
|
||||
|
||||
|
@ -88,9 +89,10 @@ constexpr int rank() { return 0; }
|
|||
constexpr bool is_root() { return true; }
|
||||
inline void save(Thread*, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, b, d, m, ev); }
|
||||
inline void pick_moves(MoveInfo&) { }
|
||||
inline void sum(uint64_t& ) { }
|
||||
inline void sync_start() { }
|
||||
inline void sync_stop() { }
|
||||
uint64_t nodes_searched();
|
||||
inline void signals_init() { }
|
||||
inline void signals_poll() { }
|
||||
inline void signals_sync() { }
|
||||
|
||||
#endif /* USE_MPI */
|
||||
|
||||
|
|
|
@ -234,14 +234,14 @@ void MainThread::search() {
|
|||
Threads.stopOnPonderhit = true;
|
||||
|
||||
while (!Threads.stop && (Threads.ponder || Limits.infinite))
|
||||
{ } // Busy wait for a stop or a ponder reset
|
||||
{ Cluster::signals_poll(); } // Busy wait for a stop or a ponder reset
|
||||
|
||||
// Stop the threads if not already stopped (also raise the stop if
|
||||
// "ponderhit" just reset Threads.ponder).
|
||||
Threads.stop = true;
|
||||
|
||||
// Finish any outstanding barriers.
|
||||
Cluster::sync_stop();
|
||||
// Signal and synchronize all other ranks
|
||||
Cluster::signals_sync();
|
||||
|
||||
// Wait until all threads have finished
|
||||
for (Thread* th : Threads)
|
||||
|
@ -251,7 +251,7 @@ void MainThread::search() {
|
|||
// When playing in 'nodes as time' mode, subtract the searched nodes from
|
||||
// the available ones before exiting.
|
||||
if (Limits.npmsec)
|
||||
Time.availableNodes += Limits.inc[us] - Threads.nodes_searched();
|
||||
Time.availableNodes += Limits.inc[us] - Cluster::nodes_searched();
|
||||
|
||||
// Check if there are threads with a better score than main thread
|
||||
Thread* bestThread = this;
|
||||
|
@ -370,7 +370,7 @@ void Thread::search() {
|
|||
// Iterative deepening loop until requested to stop or the target depth is reached
|
||||
while ( (rootDepth += ONE_PLY) < DEPTH_MAX
|
||||
&& !Threads.stop
|
||||
&& !(Limits.depth && mainThread && rootDepth / ONE_PLY > Limits.depth))
|
||||
&& !(Limits.depth && mainThread && Cluster::is_root() && rootDepth / ONE_PLY > Limits.depth))
|
||||
{
|
||||
// Distribute search depths across the helper threads
|
||||
if (idx + Cluster::rank() > 0)
|
||||
|
@ -384,6 +384,7 @@ void Thread::search() {
|
|||
if (mainThread)
|
||||
mainThread->bestMoveChanges *= 0.517, failedLow = false;
|
||||
|
||||
|
||||
// Save the last iteration's scores before first PV line is searched and
|
||||
// all the move scores except the (new) PV are set to -VALUE_INFINITE.
|
||||
for (RootMove& rm : rootMoves)
|
||||
|
@ -1609,16 +1610,16 @@ void MainThread::check_time() {
|
|||
dbg_print();
|
||||
}
|
||||
|
||||
// poll on MPI signals
|
||||
Cluster::signals_poll();
|
||||
|
||||
// We should not stop pondering until told so by the GUI
|
||||
if (Threads.ponder)
|
||||
return;
|
||||
|
||||
// Check if root has reached a stop barrier
|
||||
Cluster::sync_stop();
|
||||
|
||||
if ( (Limits.use_time_management() && elapsed > Time.maximum() - 10)
|
||||
|| (Limits.movetime && elapsed >= Limits.movetime)
|
||||
|| (Limits.nodes && Threads.nodes_searched() >= (uint64_t)Limits.nodes))
|
||||
|| (Limits.nodes && Cluster::nodes_searched() >= (uint64_t)Limits.nodes))
|
||||
Threads.stop = true;
|
||||
}
|
||||
|
||||
|
@ -1633,7 +1634,7 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) {
|
|||
const RootMoves& rootMoves = pos.this_thread()->rootMoves;
|
||||
size_t pvIdx = pos.this_thread()->pvIdx;
|
||||
size_t multiPV = std::min((size_t)Options["MultiPV"], rootMoves.size());
|
||||
uint64_t nodesSearched = Threads.nodes_searched();
|
||||
uint64_t nodesSearched = Cluster::nodes_searched();
|
||||
uint64_t tbHits = Threads.tb_hits() + (TB::RootInTB ? rootMoves.size() : 0);
|
||||
|
||||
for (size_t i = 0; i < multiPV; ++i)
|
||||
|
@ -1661,9 +1662,8 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) {
|
|||
if (!tb && i == pvIdx)
|
||||
ss << (v >= beta ? " lowerbound" : v <= alpha ? " upperbound" : "");
|
||||
|
||||
// TODO fix approximate node calculation.
|
||||
ss << " nodes " << nodesSearched * Cluster::size()
|
||||
<< " nps " << nodesSearched * Cluster::size() * 1000 / elapsed;
|
||||
ss << " nodes " << nodesSearched
|
||||
<< " nps " << nodesSearched * 1000 / elapsed;
|
||||
|
||||
if (elapsed > 1000) // Earlier makes little sense
|
||||
ss << " hashfull " << TT.hashfull();
|
||||
|
|
|
@ -163,7 +163,6 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states,
|
|||
main()->wait_for_search_finished();
|
||||
|
||||
stopOnPonderhit = stop = false;
|
||||
Cluster::sync_start();
|
||||
|
||||
ponder = ponderMode;
|
||||
Search::Limits = limits;
|
||||
|
@ -201,5 +200,7 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states,
|
|||
|
||||
setupStates->back() = tmp;
|
||||
|
||||
Cluster::signals_init();
|
||||
|
||||
main()->start_searching();
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include "misc.h"
|
||||
#include "search.h"
|
||||
#include "thread.h"
|
||||
#include "cluster.h"
|
||||
|
||||
/// The TimeManagement class computes the optimal time to think depending on
|
||||
/// the maximum available time, the game move number and other parameters.
|
||||
|
@ -34,7 +34,7 @@ public:
|
|||
TimePoint optimum() const { return optimumTime; }
|
||||
TimePoint maximum() const { return maximumTime; }
|
||||
TimePoint elapsed() const { return Search::Limits.npmsec ?
|
||||
TimePoint(Threads.nodes_searched()) : now() - startTime; }
|
||||
TimePoint(Cluster::nodes_searched()) : now() - startTime; }
|
||||
|
||||
int64_t availableNodes; // When in 'nodes as time' mode
|
||||
|
||||
|
|
|
@ -162,7 +162,7 @@ namespace {
|
|||
cerr << "\nPosition: " << cnt++ << '/' << num << endl;
|
||||
go(pos, is, states);
|
||||
Threads.main()->wait_for_search_finished();
|
||||
nodes += Threads.nodes_searched();
|
||||
nodes += Cluster::nodes_searched();
|
||||
}
|
||||
else if (token == "setoption") setoption(is);
|
||||
else if (token == "position") position(pos, is, states);
|
||||
|
@ -173,7 +173,6 @@ namespace {
|
|||
|
||||
dbg_print(); // Just before exiting
|
||||
|
||||
Cluster::sum(nodes);
|
||||
if (Cluster::is_root())
|
||||
cerr << "\n==========================="
|
||||
<< "\nTotal time (ms) : " << elapsed
|
||||
|
|
Loading…
Add table
Reference in a new issue