diff --git a/src/cluster.cpp b/src/cluster.cpp index fd5c9daf..a09f2da6 100644 --- a/src/cluster.cpp +++ b/src/cluster.cpp @@ -37,10 +37,13 @@ namespace Cluster { static int world_rank = MPI_PROC_NULL; static int world_size = 0; +static bool stop_signal = false; +static MPI_Request reqStop = MPI_REQUEST_NULL; static MPI_Comm InputComm = MPI_COMM_NULL; static MPI_Comm TTComm = MPI_COMM_NULL; static MPI_Comm MoveComm = MPI_COMM_NULL; +static MPI_Comm StopComm = MPI_COMM_NULL; static MPI_Datatype TTEntryDatatype = MPI_DATATYPE_NULL; static std::vector TTBuff; @@ -104,6 +107,7 @@ void init() { MPI_Comm_dup(MPI_COMM_WORLD, &InputComm); MPI_Comm_dup(MPI_COMM_WORLD, &TTComm); MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm); + MPI_Comm_dup(MPI_COMM_WORLD, &StopComm); } void finalize() { @@ -131,6 +135,32 @@ bool getline(std::istream& input, std::string& str) { return state; } +void sync_start() { + stop_signal = false; + + // Start listening to stop signal + if (!is_root()) + MPI_Ibarrier(StopComm, &reqStop); +} + +void sync_stop() { + if (is_root()) { + if (!stop_signal && Threads.stop) { + // Signal the cluster about stopping + stop_signal = true; + MPI_Ibarrier(StopComm, &reqStop); + MPI_Wait(&reqStop, MPI_STATUS_IGNORE); + } + } + else { + int flagStop; + // Check if we've received any stop signal + MPI_Test(&reqStop, &flagStop, MPI_STATUS_IGNORE); + if (flagStop) + Threads.stop = true; + } +} + int size() { return world_size; } diff --git a/src/cluster.h b/src/cluster.h index bbd06875..ea0b1bbf 100644 --- a/src/cluster.h +++ b/src/cluster.h @@ -69,6 +69,8 @@ inline bool is_root() { return rank() == 0; } void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev); void reduce_moves(MoveInfo& mi); +void sync_start(); +void sync_stop(); #else @@ -86,6 +88,8 @@ inline void save(Thread* thread, TTEntry* tte, tte->save(k, v, b, d, m, ev); } inline void reduce_moves(MoveInfo&) { } +inline void sync_start() { } +inline void sync_stop() { } #endif /* USE_MPI */ diff --git a/src/main.cpp b/src/main.cpp index 1624a91f..c566bd2a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -27,7 +27,6 @@ #include "tt.h" #include "uci.h" #include "syzygy/tbprobe.h" -#include "cluster.h" namespace PSQT { void init(); diff --git a/src/search.cpp b/src/search.cpp index 221b57e1..9e6e5f54 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -234,12 +234,15 @@ void MainThread::search() { Threads.stopOnPonderhit = true; while (!Threads.stop && (Threads.ponder || Limits.infinite)) - {} // Busy wait for a stop or a ponder reset + { } // Busy wait for a stop or a ponder reset // Stop the threads if not already stopped (also raise the stop if // "ponderhit" just reset Threads.ponder). Threads.stop = true; + // Finish any outstanding barriers. + Cluster::sync_stop(); + // Wait until all threads have finished for (Thread* th : Threads) if (th != this) @@ -292,8 +295,8 @@ void MainThread::search() { previousScore = static_cast(mi.score); - // Send again PV info if we have a new best thread if (Cluster::is_root()) { + // Send again PV info if we have a new best thread if (bestThread != this) sync_cout << UCI::pv(bestThread->rootPos, bestThread->completedDepth, -VALUE_INFINITE, VALUE_INFINITE) << sync_endl; @@ -1608,6 +1611,9 @@ void MainThread::check_time() { if (Threads.ponder) return; + // Check if root has reached a stop barrier + Cluster::sync_stop(); + if ( (Limits.use_time_management() && elapsed > Time.maximum() - 10) || (Limits.movetime && elapsed >= Limits.movetime) || (Limits.nodes && Threads.nodes_searched() >= (uint64_t)Limits.nodes)) @@ -1653,8 +1659,8 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) { if (!tb && i == pvIdx) ss << (v >= beta ? " lowerbound" : v <= alpha ? " upperbound" : ""); - ss << " nodes " << nodesSearched - << " nps " << nodesSearched * 1000 / elapsed; + ss << " nodes " << nodesSearched * Cluster::size() + << " nps " << nodesSearched * Cluster::size() * 1000 / elapsed; if (elapsed > 1000) // Earlier makes little sense ss << " hashfull " << TT.hashfull(); diff --git a/src/search.h b/src/search.h index 92e124fc..87241374 100644 --- a/src/search.h +++ b/src/search.h @@ -26,6 +26,7 @@ #include "misc.h" #include "movepick.h" #include "types.h" +#include "cluster.h" class Position; @@ -89,7 +90,7 @@ struct LimitsType { } bool use_time_management() const { - return !(mate | movetime | depth | nodes | perft | infinite); + return Cluster::is_root() && !(mate | movetime | depth | nodes | perft | infinite); } std::vector searchmoves; diff --git a/src/thread.cpp b/src/thread.cpp index f88e359b..d9d8fd24 100644 --- a/src/thread.cpp +++ b/src/thread.cpp @@ -163,6 +163,8 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states, main()->wait_for_search_finished(); stopOnPonderhit = stop = false; + Cluster::sync_start(); + ponder = ponderMode; Search::Limits = limits; Search::RootMoves rootMoves;