1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-04-29 16:23:09 +00:00

Replicate network weights only to used NUMA nodes

On a system with multiple NUMA nodes, this patch avoids unneeded replicated
(e.g. 8x for a single threaded run), reducting memory use in that case.

Lazy initialization forced before search.

Passed STC:
https://tests.stockfishchess.org/tests/view/66a28c524ff211be9d4ecdd4
LLR: 2.96 (-2.94,2.94) <-1.75,0.25>
Total: 691776 W: 179429 L: 179927 D: 332420
Ptnml(0-2): 2573, 79370, 182547, 78778, 2620

closes https://github.com/official-stockfish/Stockfish/pull/5515

No functional change
This commit is contained in:
Tomasz Sobczyk 2024-07-25 14:37:08 +02:00 committed by Joost VandeVondele
parent 2343f71f3f
commit 8e560c4fd3
7 changed files with 152 additions and 16 deletions

View file

@ -204,6 +204,7 @@ void Engine::set_numa_config_from_option(const std::string& o) {
// Force reallocation of threads in case affinities need to change.
resize_threads();
threads.ensure_network_replicated();
}
void Engine::resize_threads() {
@ -212,6 +213,7 @@ void Engine::resize_threads() {
// Reallocate the hash with the new threadpool size
set_tt_size(options["Hash"]);
threads.ensure_network_replicated();
}
void Engine::set_tt_size(size_t mb) {
@ -234,18 +236,21 @@ void Engine::load_networks() {
networks_.small.load(binaryDirectory, options["EvalFileSmall"]);
});
threads.clear();
threads.ensure_network_replicated();
}
void Engine::load_big_network(const std::string& file) {
networks.modify_and_replicate(
[this, &file](NN::Networks& networks_) { networks_.big.load(binaryDirectory, file); });
threads.clear();
threads.ensure_network_replicated();
}
void Engine::load_small_network(const std::string& file) {
networks.modify_and_replicate(
[this, &file](NN::Networks& networks_) { networks_.small.load(binaryDirectory, file); });
threads.clear();
threads.ensure_network_replicated();
}
void Engine::save_network(const std::pair<std::optional<std::string>, std::string> files[2]) {

View file

@ -117,7 +117,7 @@ class Engine {
OptionsMap options;
ThreadPool threads;
TranspositionTable tt;
NumaReplicated<Eval::NNUE::Networks> networks;
LazyNumaReplicated<Eval::NNUE::Networks> networks;
Search::SearchManager::UpdateContext updateContext;
};

View file

@ -27,6 +27,7 @@
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <sstream>
#include <string>
@ -1136,6 +1137,117 @@ class NumaReplicated: public NumaReplicatedBase {
}
};
// We force boxing with a unique_ptr. If this becomes an issue due to added
// indirection we may need to add an option for a custom boxing type.
template<typename T>
class LazyNumaReplicated: public NumaReplicatedBase {
public:
using ReplicatorFuncType = std::function<T(const T&)>;
LazyNumaReplicated(NumaReplicationContext& ctx) :
NumaReplicatedBase(ctx) {
prepare_replicate_from(T{});
}
LazyNumaReplicated(NumaReplicationContext& ctx, T&& source) :
NumaReplicatedBase(ctx) {
prepare_replicate_from(std::move(source));
}
LazyNumaReplicated(const LazyNumaReplicated&) = delete;
LazyNumaReplicated(LazyNumaReplicated&& other) noexcept :
NumaReplicatedBase(std::move(other)),
instances(std::exchange(other.instances, {})) {}
LazyNumaReplicated& operator=(const LazyNumaReplicated&) = delete;
LazyNumaReplicated& operator=(LazyNumaReplicated&& other) noexcept {
NumaReplicatedBase::operator=(*this, std::move(other));
instances = std::exchange(other.instances, {});
return *this;
}
LazyNumaReplicated& operator=(T&& source) {
prepare_replicate_from(std::move(source));
return *this;
}
~LazyNumaReplicated() override = default;
const T& operator[](NumaReplicatedAccessToken token) const {
assert(token.get_numa_index() < instances.size());
ensure_present(token.get_numa_index());
return *(instances[token.get_numa_index()]);
}
const T& operator*() const { return *(instances[0]); }
const T* operator->() const { return instances[0].get(); }
template<typename FuncT>
void modify_and_replicate(FuncT&& f) {
auto source = std::move(instances[0]);
std::forward<FuncT>(f)(*source);
prepare_replicate_from(std::move(*source));
}
void on_numa_config_changed() override {
// Use the first one as the source. It doesn't matter which one we use,
// because they all must be identical, but the first one is guaranteed to exist.
auto source = std::move(instances[0]);
prepare_replicate_from(std::move(*source));
}
private:
mutable std::vector<std::unique_ptr<T>> instances;
mutable std::mutex mutex;
void ensure_present(NumaIndex idx) const {
assert(idx < instances.size());
if (instances[idx] != nullptr)
return;
assert(idx != 0);
std::unique_lock<std::mutex> lock(mutex);
// Check again for races.
if (instances[idx] != nullptr)
return;
const NumaConfig& cfg = get_numa_config();
cfg.execute_on_numa_node(
idx, [this, idx]() { instances[idx] = std::make_unique<T>(*instances[0]); });
}
void prepare_replicate_from(T&& source) {
instances.clear();
const NumaConfig& cfg = get_numa_config();
if (cfg.requires_memory_replication())
{
assert(cfg.num_numa_nodes() > 0);
// We just need to make sure the first instance is there.
// Note that we cannot move here as we need to reallocate the data
// on the correct NUMA node.
cfg.execute_on_numa_node(
0, [this, &source]() { instances.emplace_back(std::make_unique<T>(source)); });
// Prepare others for lazy init.
instances.resize(cfg.num_numa_nodes());
}
else
{
assert(cfg.num_numa_nodes() == 1);
// We take advantage of the fact that replication is not required
// and reuse the source value, avoiding one copy operation.
instances.emplace_back(std::make_unique<T>(std::move(source)));
}
}
};
class NumaReplicationContext {
public:
NumaReplicationContext(NumaConfig&& cfg) :

View file

@ -127,6 +127,12 @@ Search::Worker::Worker(SharedState& sharedState,
clear();
}
void Search::Worker::ensure_network_replicated() {
// Access once to force lazy initialization.
// We do this because we want to avoid initialization during search.
(void) (networks[numaAccessToken]);
}
void Search::Worker::start_searching() {
// Non-main threads go directly to iterative_deepening()

View file

@ -134,7 +134,7 @@ struct SharedState {
SharedState(const OptionsMap& optionsMap,
ThreadPool& threadPool,
TranspositionTable& transpositionTable,
const NumaReplicated<Eval::NNUE::Networks>& nets) :
const LazyNumaReplicated<Eval::NNUE::Networks>& nets) :
options(optionsMap),
threads(threadPool),
tt(transpositionTable),
@ -143,7 +143,7 @@ struct SharedState {
const OptionsMap& options;
ThreadPool& threads;
TranspositionTable& tt;
const NumaReplicated<Eval::NNUE::Networks>& networks;
const LazyNumaReplicated<Eval::NNUE::Networks>& networks;
};
class Worker;
@ -274,6 +274,8 @@ class Worker {
bool is_mainthread() const { return threadIdx == 0; }
void ensure_network_replicated();
// Public because they need to be updatable by the stats
ButterflyHistory mainHistory;
CapturePieceToHistory captureHistory;
@ -331,7 +333,7 @@ class Worker {
const OptionsMap& options;
ThreadPool& threads;
TranspositionTable& tt;
const NumaReplicated<Eval::NNUE::Networks>& networks;
const LazyNumaReplicated<Eval::NNUE::Networks>& networks;
// Used by NNUE
Eval::NNUE::AccumulatorCaches refreshTable;

View file

@ -102,6 +102,8 @@ void Thread::run_custom_job(std::function<void()> f) {
cv.notify_one();
}
void Thread::ensure_network_replicated() { worker->ensure_network_replicated(); }
// Thread gets parked here, blocked on the condition variable
// when the thread has no work to do.
@ -400,4 +402,9 @@ std::vector<size_t> ThreadPool::get_bound_thread_count_by_numa_node() const {
return counts;
}
void ThreadPool::ensure_network_replicated() {
for (auto&& th : threads)
th->ensure_network_replicated();
}
} // namespace Stockfish

View file

@ -83,6 +83,8 @@ class Thread {
void clear_worker();
void run_custom_job(std::function<void()> f);
void ensure_network_replicated();
// Thread has been slightly altered to allow running custom jobs, so
// this name is no longer correct. However, this class (and ThreadPool)
// require further work to make them properly generic while maintaining
@ -146,6 +148,8 @@ class ThreadPool {
std::vector<size_t> get_bound_thread_count_by_numa_node() const;
void ensure_network_replicated();
std::atomic_bool stop, abortedSearch, increaseDepth;
auto cbegin() const noexcept { return threads.cbegin(); }