mirror of
https://github.com/sockspls/badfish
synced 2025-04-29 16:23:09 +00:00
Replicate network weights only to used NUMA nodes
On a system with multiple NUMA nodes, this patch avoids unneeded replicated (e.g. 8x for a single threaded run), reducting memory use in that case. Lazy initialization forced before search. Passed STC: https://tests.stockfishchess.org/tests/view/66a28c524ff211be9d4ecdd4 LLR: 2.96 (-2.94,2.94) <-1.75,0.25> Total: 691776 W: 179429 L: 179927 D: 332420 Ptnml(0-2): 2573, 79370, 182547, 78778, 2620 closes https://github.com/official-stockfish/Stockfish/pull/5515 No functional change
This commit is contained in:
parent
2343f71f3f
commit
8e560c4fd3
7 changed files with 152 additions and 16 deletions
|
@ -204,6 +204,7 @@ void Engine::set_numa_config_from_option(const std::string& o) {
|
||||||
|
|
||||||
// Force reallocation of threads in case affinities need to change.
|
// Force reallocation of threads in case affinities need to change.
|
||||||
resize_threads();
|
resize_threads();
|
||||||
|
threads.ensure_network_replicated();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Engine::resize_threads() {
|
void Engine::resize_threads() {
|
||||||
|
@ -212,6 +213,7 @@ void Engine::resize_threads() {
|
||||||
|
|
||||||
// Reallocate the hash with the new threadpool size
|
// Reallocate the hash with the new threadpool size
|
||||||
set_tt_size(options["Hash"]);
|
set_tt_size(options["Hash"]);
|
||||||
|
threads.ensure_network_replicated();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Engine::set_tt_size(size_t mb) {
|
void Engine::set_tt_size(size_t mb) {
|
||||||
|
@ -234,18 +236,21 @@ void Engine::load_networks() {
|
||||||
networks_.small.load(binaryDirectory, options["EvalFileSmall"]);
|
networks_.small.load(binaryDirectory, options["EvalFileSmall"]);
|
||||||
});
|
});
|
||||||
threads.clear();
|
threads.clear();
|
||||||
|
threads.ensure_network_replicated();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Engine::load_big_network(const std::string& file) {
|
void Engine::load_big_network(const std::string& file) {
|
||||||
networks.modify_and_replicate(
|
networks.modify_and_replicate(
|
||||||
[this, &file](NN::Networks& networks_) { networks_.big.load(binaryDirectory, file); });
|
[this, &file](NN::Networks& networks_) { networks_.big.load(binaryDirectory, file); });
|
||||||
threads.clear();
|
threads.clear();
|
||||||
|
threads.ensure_network_replicated();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Engine::load_small_network(const std::string& file) {
|
void Engine::load_small_network(const std::string& file) {
|
||||||
networks.modify_and_replicate(
|
networks.modify_and_replicate(
|
||||||
[this, &file](NN::Networks& networks_) { networks_.small.load(binaryDirectory, file); });
|
[this, &file](NN::Networks& networks_) { networks_.small.load(binaryDirectory, file); });
|
||||||
threads.clear();
|
threads.clear();
|
||||||
|
threads.ensure_network_replicated();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Engine::save_network(const std::pair<std::optional<std::string>, std::string> files[2]) {
|
void Engine::save_network(const std::pair<std::optional<std::string>, std::string> files[2]) {
|
||||||
|
|
|
@ -114,10 +114,10 @@ class Engine {
|
||||||
StateListPtr states;
|
StateListPtr states;
|
||||||
Square capSq;
|
Square capSq;
|
||||||
|
|
||||||
OptionsMap options;
|
OptionsMap options;
|
||||||
ThreadPool threads;
|
ThreadPool threads;
|
||||||
TranspositionTable tt;
|
TranspositionTable tt;
|
||||||
NumaReplicated<Eval::NNUE::Networks> networks;
|
LazyNumaReplicated<Eval::NNUE::Networks> networks;
|
||||||
|
|
||||||
Search::SearchManager::UpdateContext updateContext;
|
Search::SearchManager::UpdateContext updateContext;
|
||||||
};
|
};
|
||||||
|
|
112
src/numa.h
112
src/numa.h
|
@ -27,6 +27,7 @@
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -1136,6 +1137,117 @@ class NumaReplicated: public NumaReplicatedBase {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We force boxing with a unique_ptr. If this becomes an issue due to added
|
||||||
|
// indirection we may need to add an option for a custom boxing type.
|
||||||
|
template<typename T>
|
||||||
|
class LazyNumaReplicated: public NumaReplicatedBase {
|
||||||
|
public:
|
||||||
|
using ReplicatorFuncType = std::function<T(const T&)>;
|
||||||
|
|
||||||
|
LazyNumaReplicated(NumaReplicationContext& ctx) :
|
||||||
|
NumaReplicatedBase(ctx) {
|
||||||
|
prepare_replicate_from(T{});
|
||||||
|
}
|
||||||
|
|
||||||
|
LazyNumaReplicated(NumaReplicationContext& ctx, T&& source) :
|
||||||
|
NumaReplicatedBase(ctx) {
|
||||||
|
prepare_replicate_from(std::move(source));
|
||||||
|
}
|
||||||
|
|
||||||
|
LazyNumaReplicated(const LazyNumaReplicated&) = delete;
|
||||||
|
LazyNumaReplicated(LazyNumaReplicated&& other) noexcept :
|
||||||
|
NumaReplicatedBase(std::move(other)),
|
||||||
|
instances(std::exchange(other.instances, {})) {}
|
||||||
|
|
||||||
|
LazyNumaReplicated& operator=(const LazyNumaReplicated&) = delete;
|
||||||
|
LazyNumaReplicated& operator=(LazyNumaReplicated&& other) noexcept {
|
||||||
|
NumaReplicatedBase::operator=(*this, std::move(other));
|
||||||
|
instances = std::exchange(other.instances, {});
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
LazyNumaReplicated& operator=(T&& source) {
|
||||||
|
prepare_replicate_from(std::move(source));
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~LazyNumaReplicated() override = default;
|
||||||
|
|
||||||
|
const T& operator[](NumaReplicatedAccessToken token) const {
|
||||||
|
assert(token.get_numa_index() < instances.size());
|
||||||
|
ensure_present(token.get_numa_index());
|
||||||
|
return *(instances[token.get_numa_index()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const T& operator*() const { return *(instances[0]); }
|
||||||
|
|
||||||
|
const T* operator->() const { return instances[0].get(); }
|
||||||
|
|
||||||
|
template<typename FuncT>
|
||||||
|
void modify_and_replicate(FuncT&& f) {
|
||||||
|
auto source = std::move(instances[0]);
|
||||||
|
std::forward<FuncT>(f)(*source);
|
||||||
|
prepare_replicate_from(std::move(*source));
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_numa_config_changed() override {
|
||||||
|
// Use the first one as the source. It doesn't matter which one we use,
|
||||||
|
// because they all must be identical, but the first one is guaranteed to exist.
|
||||||
|
auto source = std::move(instances[0]);
|
||||||
|
prepare_replicate_from(std::move(*source));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable std::vector<std::unique_ptr<T>> instances;
|
||||||
|
mutable std::mutex mutex;
|
||||||
|
|
||||||
|
void ensure_present(NumaIndex idx) const {
|
||||||
|
assert(idx < instances.size());
|
||||||
|
|
||||||
|
if (instances[idx] != nullptr)
|
||||||
|
return;
|
||||||
|
|
||||||
|
assert(idx != 0);
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
// Check again for races.
|
||||||
|
if (instances[idx] != nullptr)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const NumaConfig& cfg = get_numa_config();
|
||||||
|
cfg.execute_on_numa_node(
|
||||||
|
idx, [this, idx]() { instances[idx] = std::make_unique<T>(*instances[0]); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void prepare_replicate_from(T&& source) {
|
||||||
|
instances.clear();
|
||||||
|
|
||||||
|
const NumaConfig& cfg = get_numa_config();
|
||||||
|
if (cfg.requires_memory_replication())
|
||||||
|
{
|
||||||
|
assert(cfg.num_numa_nodes() > 0);
|
||||||
|
|
||||||
|
// We just need to make sure the first instance is there.
|
||||||
|
// Note that we cannot move here as we need to reallocate the data
|
||||||
|
// on the correct NUMA node.
|
||||||
|
cfg.execute_on_numa_node(
|
||||||
|
0, [this, &source]() { instances.emplace_back(std::make_unique<T>(source)); });
|
||||||
|
|
||||||
|
// Prepare others for lazy init.
|
||||||
|
instances.resize(cfg.num_numa_nodes());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
assert(cfg.num_numa_nodes() == 1);
|
||||||
|
// We take advantage of the fact that replication is not required
|
||||||
|
// and reuse the source value, avoiding one copy operation.
|
||||||
|
instances.emplace_back(std::make_unique<T>(std::move(source)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class NumaReplicationContext {
|
class NumaReplicationContext {
|
||||||
public:
|
public:
|
||||||
NumaReplicationContext(NumaConfig&& cfg) :
|
NumaReplicationContext(NumaConfig&& cfg) :
|
||||||
|
|
|
@ -127,6 +127,12 @@ Search::Worker::Worker(SharedState& sharedState,
|
||||||
clear();
|
clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Search::Worker::ensure_network_replicated() {
|
||||||
|
// Access once to force lazy initialization.
|
||||||
|
// We do this because we want to avoid initialization during search.
|
||||||
|
(void) (networks[numaAccessToken]);
|
||||||
|
}
|
||||||
|
|
||||||
void Search::Worker::start_searching() {
|
void Search::Worker::start_searching() {
|
||||||
|
|
||||||
// Non-main threads go directly to iterative_deepening()
|
// Non-main threads go directly to iterative_deepening()
|
||||||
|
|
26
src/search.h
26
src/search.h
|
@ -131,19 +131,19 @@ struct LimitsType {
|
||||||
// The UCI stores the uci options, thread pool, and transposition table.
|
// The UCI stores the uci options, thread pool, and transposition table.
|
||||||
// This struct is used to easily forward data to the Search::Worker class.
|
// This struct is used to easily forward data to the Search::Worker class.
|
||||||
struct SharedState {
|
struct SharedState {
|
||||||
SharedState(const OptionsMap& optionsMap,
|
SharedState(const OptionsMap& optionsMap,
|
||||||
ThreadPool& threadPool,
|
ThreadPool& threadPool,
|
||||||
TranspositionTable& transpositionTable,
|
TranspositionTable& transpositionTable,
|
||||||
const NumaReplicated<Eval::NNUE::Networks>& nets) :
|
const LazyNumaReplicated<Eval::NNUE::Networks>& nets) :
|
||||||
options(optionsMap),
|
options(optionsMap),
|
||||||
threads(threadPool),
|
threads(threadPool),
|
||||||
tt(transpositionTable),
|
tt(transpositionTable),
|
||||||
networks(nets) {}
|
networks(nets) {}
|
||||||
|
|
||||||
const OptionsMap& options;
|
const OptionsMap& options;
|
||||||
ThreadPool& threads;
|
ThreadPool& threads;
|
||||||
TranspositionTable& tt;
|
TranspositionTable& tt;
|
||||||
const NumaReplicated<Eval::NNUE::Networks>& networks;
|
const LazyNumaReplicated<Eval::NNUE::Networks>& networks;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Worker;
|
class Worker;
|
||||||
|
@ -274,6 +274,8 @@ class Worker {
|
||||||
|
|
||||||
bool is_mainthread() const { return threadIdx == 0; }
|
bool is_mainthread() const { return threadIdx == 0; }
|
||||||
|
|
||||||
|
void ensure_network_replicated();
|
||||||
|
|
||||||
// Public because they need to be updatable by the stats
|
// Public because they need to be updatable by the stats
|
||||||
ButterflyHistory mainHistory;
|
ButterflyHistory mainHistory;
|
||||||
CapturePieceToHistory captureHistory;
|
CapturePieceToHistory captureHistory;
|
||||||
|
@ -328,10 +330,10 @@ class Worker {
|
||||||
|
|
||||||
Tablebases::Config tbConfig;
|
Tablebases::Config tbConfig;
|
||||||
|
|
||||||
const OptionsMap& options;
|
const OptionsMap& options;
|
||||||
ThreadPool& threads;
|
ThreadPool& threads;
|
||||||
TranspositionTable& tt;
|
TranspositionTable& tt;
|
||||||
const NumaReplicated<Eval::NNUE::Networks>& networks;
|
const LazyNumaReplicated<Eval::NNUE::Networks>& networks;
|
||||||
|
|
||||||
// Used by NNUE
|
// Used by NNUE
|
||||||
Eval::NNUE::AccumulatorCaches refreshTable;
|
Eval::NNUE::AccumulatorCaches refreshTable;
|
||||||
|
|
|
@ -102,6 +102,8 @@ void Thread::run_custom_job(std::function<void()> f) {
|
||||||
cv.notify_one();
|
cv.notify_one();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Thread::ensure_network_replicated() { worker->ensure_network_replicated(); }
|
||||||
|
|
||||||
// Thread gets parked here, blocked on the condition variable
|
// Thread gets parked here, blocked on the condition variable
|
||||||
// when the thread has no work to do.
|
// when the thread has no work to do.
|
||||||
|
|
||||||
|
@ -400,4 +402,9 @@ std::vector<size_t> ThreadPool::get_bound_thread_count_by_numa_node() const {
|
||||||
return counts;
|
return counts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ThreadPool::ensure_network_replicated() {
|
||||||
|
for (auto&& th : threads)
|
||||||
|
th->ensure_network_replicated();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Stockfish
|
} // namespace Stockfish
|
||||||
|
|
|
@ -83,6 +83,8 @@ class Thread {
|
||||||
void clear_worker();
|
void clear_worker();
|
||||||
void run_custom_job(std::function<void()> f);
|
void run_custom_job(std::function<void()> f);
|
||||||
|
|
||||||
|
void ensure_network_replicated();
|
||||||
|
|
||||||
// Thread has been slightly altered to allow running custom jobs, so
|
// Thread has been slightly altered to allow running custom jobs, so
|
||||||
// this name is no longer correct. However, this class (and ThreadPool)
|
// this name is no longer correct. However, this class (and ThreadPool)
|
||||||
// require further work to make them properly generic while maintaining
|
// require further work to make them properly generic while maintaining
|
||||||
|
@ -146,6 +148,8 @@ class ThreadPool {
|
||||||
|
|
||||||
std::vector<size_t> get_bound_thread_count_by_numa_node() const;
|
std::vector<size_t> get_bound_thread_count_by_numa_node() const;
|
||||||
|
|
||||||
|
void ensure_network_replicated();
|
||||||
|
|
||||||
std::atomic_bool stop, abortedSearch, increaseDepth;
|
std::atomic_bool stop, abortedSearch, increaseDepth;
|
||||||
|
|
||||||
auto cbegin() const noexcept { return threads.cbegin(); }
|
auto cbegin() const noexcept { return threads.cbegin(); }
|
||||||
|
|
Loading…
Add table
Reference in a new issue