mirror of
https://github.com/sockspls/badfish
synced 2025-07-11 19:49:14 +00:00
[Cluster] Improve message passing part.
This rewrites in part the message passing part, using in place gather, and collecting, rather than merging, the data of all threads. neutral with a single thread per rank: Score of new-2mpi-1t vs old-2mpi-1t: 789 - 787 - 2615 [0.500] 4191 Elo difference: 0.17 +/- 6.44 likely progress with multiple threads per rank: Score of new-2mpi-36t vs old-2mpi-36t: 76 - 53 - 471 [0.519] 600 Elo difference: 13.32 +/- 12.85
This commit is contained in:
parent
7a32d26d5f
commit
ac43bef5c5
4 changed files with 76 additions and 45 deletions
107
src/cluster.cpp
107
src/cluster.cpp
|
@ -54,10 +54,15 @@ static MPI_Comm TTComm = MPI_COMM_NULL;
|
||||||
static MPI_Comm MoveComm = MPI_COMM_NULL;
|
static MPI_Comm MoveComm = MPI_COMM_NULL;
|
||||||
static MPI_Comm signalsComm = MPI_COMM_NULL;
|
static MPI_Comm signalsComm = MPI_COMM_NULL;
|
||||||
|
|
||||||
static std::vector<KeyedTTEntry> TTBuff;
|
static std::vector<KeyedTTEntry> TTRecvBuff;
|
||||||
|
static MPI_Request reqGather = MPI_REQUEST_NULL;
|
||||||
|
static uint64_t gathersPosted = 0;
|
||||||
|
|
||||||
|
static std::atomic<uint64_t> TTCacheCounter = {};
|
||||||
|
|
||||||
static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
|
static MPI_Datatype MIDatatype = MPI_DATATYPE_NULL;
|
||||||
|
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
|
|
||||||
int thread_support;
|
int thread_support;
|
||||||
|
@ -72,8 +77,6 @@ void init() {
|
||||||
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
|
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
|
||||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||||
|
|
||||||
TTBuff.resize(TTSendBufferSize * world_size);
|
|
||||||
|
|
||||||
const std::array<MPI_Aint, 4> MIdisps = {offsetof(MoveInfo, move),
|
const std::array<MPI_Aint, 4> MIdisps = {offsetof(MoveInfo, move),
|
||||||
offsetof(MoveInfo, depth),
|
offsetof(MoveInfo, depth),
|
||||||
offsetof(MoveInfo, score),
|
offsetof(MoveInfo, score),
|
||||||
|
@ -111,6 +114,13 @@ int rank() {
|
||||||
return world_rank;
|
return world_rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ttRecvBuff_resize(size_t nThreads) {
|
||||||
|
|
||||||
|
TTRecvBuff.resize(TTCacheSize * world_size * nThreads);
|
||||||
|
std::fill(TTRecvBuff.begin(), TTRecvBuff.end(), KeyedTTEntry());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
bool getline(std::istream& input, std::string& str) {
|
bool getline(std::istream& input, std::string& str) {
|
||||||
|
|
||||||
|
@ -189,6 +199,18 @@ void signals_sync() {
|
||||||
|
|
||||||
signals_process();
|
signals_process();
|
||||||
|
|
||||||
|
// finalize outstanding messages in the gather loop
|
||||||
|
MPI_Allreduce(&gathersPosted, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm);
|
||||||
|
if (gathersPosted < globalCounter)
|
||||||
|
{
|
||||||
|
size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
|
||||||
|
MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
|
||||||
|
TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
|
||||||
|
TTComm, &reqGather);
|
||||||
|
++gathersPosted;
|
||||||
|
}
|
||||||
|
assert(gathersPosted == globalCounter);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void signals_init() {
|
void signals_init() {
|
||||||
|
@ -221,59 +243,64 @@ void save(Thread* thread, TTEntry* tte,
|
||||||
{
|
{
|
||||||
// Try to add to thread's send buffer
|
// Try to add to thread's send buffer
|
||||||
{
|
{
|
||||||
std::lock_guard<Mutex> lk(thread->ttBuffer.mutex);
|
std::lock_guard<Mutex> lk(thread->ttCache.mutex);
|
||||||
thread->ttBuffer.buffer.replace(KeyedTTEntry(k,*tte));
|
thread->ttCache.buffer.replace(KeyedTTEntry(k,*tte));
|
||||||
++thread->ttBuffer.counter;
|
++TTCacheCounter;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Communicate on main search thread
|
size_t recvBuffPerRankSize = Threads.size() * TTCacheSize;
|
||||||
if (thread == Threads.main() && thread->ttBuffer.counter * Threads.size() > TTSendBufferSize)
|
|
||||||
{
|
|
||||||
static MPI_Request req = MPI_REQUEST_NULL;
|
|
||||||
static TTSendBuffer<TTSendBufferSize> send_buff = {};
|
|
||||||
int flag;
|
|
||||||
|
|
||||||
|
// Communicate on main search thread
|
||||||
|
if (thread == Threads.main() && TTCacheCounter > size() * recvBuffPerRankSize)
|
||||||
|
{
|
||||||
// Test communication status
|
// Test communication status
|
||||||
MPI_Test(&req, &flag, MPI_STATUS_IGNORE);
|
int flag;
|
||||||
|
MPI_Test(&reqGather, &flag, MPI_STATUS_IGNORE);
|
||||||
|
|
||||||
// Current communication is complete
|
// Current communication is complete
|
||||||
if (flag)
|
if (flag)
|
||||||
{
|
{
|
||||||
// Save all received entries (except ours)
|
// Save all received entries to TT, and store our TTCaches, ready for the next round of communication
|
||||||
for (size_t irank = 0; irank < size_t(size()) ; ++irank)
|
for (size_t irank = 0; irank < size_t(size()) ; ++irank)
|
||||||
{
|
{
|
||||||
if (irank == size_t(rank()))
|
if (irank == size_t(rank()))
|
||||||
continue;
|
|
||||||
|
|
||||||
for (size_t i = irank * TTSendBufferSize ; i < (irank + 1) * TTSendBufferSize; ++i)
|
|
||||||
{
|
{
|
||||||
auto&& e = TTBuff[i];
|
// Copy from the thread caches to the right spot in the buffer
|
||||||
bool found;
|
size_t i = irank * recvBuffPerRankSize;
|
||||||
TTEntry* replace_tte;
|
for (auto&& th : Threads)
|
||||||
replace_tte = TT.probe(e.first, found);
|
{
|
||||||
replace_tte->save(e.first, e.second.value(), e.second.bound(), e.second.depth(),
|
std::lock_guard<Mutex> lk(th->ttCache.mutex);
|
||||||
e.second.move(), e.second.eval());
|
|
||||||
|
for (auto&& e : th->ttCache.buffer)
|
||||||
|
TTRecvBuff[i++] = e;
|
||||||
|
|
||||||
|
// Reset thread's send buffer
|
||||||
|
th->ttCache.buffer = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
TTCacheCounter = 0;
|
||||||
}
|
}
|
||||||
}
|
else
|
||||||
|
for (size_t i = irank * recvBuffPerRankSize; i < (irank + 1) * recvBuffPerRankSize; ++i)
|
||||||
// Reset send buffer
|
{
|
||||||
send_buff = {};
|
auto&& e = TTRecvBuff[i];
|
||||||
|
bool found;
|
||||||
// Build up new send buffer: best 16 found across all threads
|
TTEntry* replace_tte;
|
||||||
for (auto&& th : Threads)
|
replace_tte = TT.probe(e.first, found);
|
||||||
{
|
replace_tte->save(e.first, e.second.value(), e.second.bound(), e.second.depth(),
|
||||||
std::lock_guard<Mutex> lk(th->ttBuffer.mutex);
|
e.second.move(), e.second.eval());
|
||||||
for (auto&& e : th->ttBuffer.buffer)
|
}
|
||||||
send_buff.replace(e);
|
|
||||||
// Reset thread's send buffer
|
|
||||||
th->ttBuffer.buffer = {};
|
|
||||||
th->ttBuffer.counter = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start next communication
|
// Start next communication
|
||||||
MPI_Iallgather(send_buff.data(), send_buff.size() * sizeof(KeyedTTEntry), MPI_BYTE,
|
MPI_Iallgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL,
|
||||||
TTBuff.data(), TTSendBufferSize * sizeof(KeyedTTEntry), MPI_BYTE,
|
TTRecvBuff.data(), recvBuffPerRankSize * sizeof(KeyedTTEntry), MPI_BYTE,
|
||||||
TTComm, &req);
|
TTComm, &reqGather);
|
||||||
|
++gathersPosted;
|
||||||
|
|
||||||
|
// Force check of time on the next occasion.
|
||||||
|
static_cast<MainThread*>(thread)->callsCnt = 0;
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,8 +42,8 @@ struct MoveInfo {
|
||||||
#ifdef USE_MPI
|
#ifdef USE_MPI
|
||||||
using KeyedTTEntry = std::pair<Key, TTEntry>;
|
using KeyedTTEntry = std::pair<Key, TTEntry>;
|
||||||
|
|
||||||
constexpr std::size_t TTSendBufferSize = 32;
|
constexpr std::size_t TTCacheSize = 32;
|
||||||
template <std::size_t N> class TTSendBuffer : public std::array<KeyedTTEntry, N> {
|
template <std::size_t N> class TTCache : public std::array<KeyedTTEntry, N> {
|
||||||
|
|
||||||
struct Compare {
|
struct Compare {
|
||||||
inline bool operator()(const KeyedTTEntry& lhs, const KeyedTTEntry& rhs) {
|
inline bool operator()(const KeyedTTEntry& lhs, const KeyedTTEntry& rhs) {
|
||||||
|
@ -74,6 +74,7 @@ int rank();
|
||||||
inline bool is_root() { return rank() == 0; }
|
inline bool is_root() { return rank() == 0; }
|
||||||
void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev);
|
void save(Thread* thread, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev);
|
||||||
void pick_moves(MoveInfo& mi);
|
void pick_moves(MoveInfo& mi);
|
||||||
|
void ttRecvBuff_resize(size_t nThreads);
|
||||||
uint64_t nodes_searched();
|
uint64_t nodes_searched();
|
||||||
uint64_t tb_hits();
|
uint64_t tb_hits();
|
||||||
void signals_init();
|
void signals_init();
|
||||||
|
@ -90,6 +91,7 @@ constexpr int rank() { return 0; }
|
||||||
constexpr bool is_root() { return true; }
|
constexpr bool is_root() { return true; }
|
||||||
inline void save(Thread*, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, b, d, m, ev); }
|
inline void save(Thread*, TTEntry* tte, Key k, Value v, Bound b, Depth d, Move m, Value ev) { tte->save(k, v, b, d, m, ev); }
|
||||||
inline void pick_moves(MoveInfo&) { }
|
inline void pick_moves(MoveInfo&) { }
|
||||||
|
inline void ttRecvBuff_resize(size_t) { }
|
||||||
uint64_t nodes_searched();
|
uint64_t nodes_searched();
|
||||||
uint64_t tb_hits();
|
uint64_t tb_hits();
|
||||||
inline void signals_init() { }
|
inline void signals_init() { }
|
||||||
|
|
|
@ -139,6 +139,9 @@ void ThreadPool::set(size_t requested) {
|
||||||
|
|
||||||
// Reallocate the hash with the new threadpool size
|
// Reallocate the hash with the new threadpool size
|
||||||
TT.resize(Options["Hash"]);
|
TT.resize(Options["Hash"]);
|
||||||
|
|
||||||
|
// Adjust cluster buffers
|
||||||
|
Cluster::ttRecvBuff_resize(requested);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,9 +78,8 @@ public:
|
||||||
#ifdef USE_MPI
|
#ifdef USE_MPI
|
||||||
struct {
|
struct {
|
||||||
Mutex mutex;
|
Mutex mutex;
|
||||||
Cluster::TTSendBuffer<Cluster::TTSendBufferSize> buffer = {};
|
Cluster::TTCache<Cluster::TTCacheSize> buffer = {};
|
||||||
size_t counter = 0;
|
} ttCache;
|
||||||
} ttBuffer;
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue