From 18686e29c78f46eaaa1e66398fbe6575bdf7ed88 Mon Sep 17 00:00:00 2001 From: Joost VandeVondele Date: Sun, 2 Aug 2020 15:31:51 +0200 Subject: [PATCH] Revisit NNUE initialization this revisits the initialization of NNUE, basically only changing the state on the UCI options 'Use NNUE' and 'EvalFile' calling init_NNUE(), which sets the Eval::useNNUE variable, and loads the network if needed (i.e. useNNUE is true and the same network is not yet loaded) init_NNUE is silent (i.e. no info strings), so that it can be called at startup without confusing certain GUIs. An error message on wrong setting when asking for (i.e. the net failed to load), is delayed to the point where everything must be consistent (start of search or eval). The engine will stop if the settings are wrong at that point. Also works if the default value of Use NNUE would become true. --- src/evaluate.cpp | 28 +++++++++++++++++++++++++++- src/evaluate.h | 22 +++++++++++++--------- src/main.cpp | 1 + src/nnue/evaluate_nnue.cpp | 8 +++----- src/search.cpp | 2 ++ src/uci.cpp | 36 ++++++------------------------------ src/uci.h | 4 ---- src/ucioption.cpp | 22 +++------------------- 8 files changed, 55 insertions(+), 68 deletions(-) diff --git a/src/evaluate.cpp b/src/evaluate.cpp index 5d42d554..f562339f 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -20,18 +20,44 @@ #include #include +#include #include // For std::memset #include #include +#include #include "bitboard.h" #include "evaluate.h" #include "material.h" #include "pawns.h" #include "thread.h" +#include "uci.h" namespace Eval { - bool useNNUE; + + bool useNNUE; + std::string eval_file_loaded="None"; + + void init_NNUE() { + + useNNUE = Options["Use NNUE"]; + std::string eval_file = std::string(Options["EvalFile"]); + if (useNNUE && eval_file_loaded != eval_file) + if (Eval::NNUE::load_eval_file(eval_file)) + eval_file_loaded = eval_file; + } + + void verify_NNUE() { + + std::string eval_file = std::string(Options["EvalFile"]); + if (useNNUE && eval_file_loaded != eval_file) + { + std::cerr << "Use of NNUE evaluation, but the file " << eval_file << " was not loaded successfully. " + << "These network evaluation parameters must be available, compatible with this version of the code. " + << "The UCI option EvalFile might need to specify the full path, including the directory/folder name, to the file." << std::endl; + std::exit(EXIT_FAILURE); + } + } } namespace Trace { diff --git a/src/evaluate.h b/src/evaluate.h index 293b4ce4..115bc49a 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -29,18 +29,22 @@ class Position; namespace Eval { -std::string trace(const Position& pos); -Value evaluate(const Position& pos); -extern bool useNNUE; + std::string trace(const Position& pos); + Value evaluate(const Position& pos); -namespace NNUE { + extern bool useNNUE; + extern std::string eval_file_loaded; + void init_NNUE(); + void verify_NNUE(); -Value evaluate(const Position& pos); -Value compute_eval(const Position& pos); -void update_eval(const Position& pos); -void load_eval(const std::string& evalFile); + namespace NNUE { -} // namespace NNUE + Value evaluate(const Position& pos); + Value compute_eval(const Position& pos); + void update_eval(const Position& pos); + bool load_eval_file(const std::string& evalFile); + + } // namespace NNUE } // namespace Eval diff --git a/src/main.cpp b/src/main.cpp index fafefee2..024552e8 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -46,6 +46,7 @@ int main(int argc, char* argv[]) { Endgames::init(); Threads.set(size_t(Options["Threads"])); Search::clear(); // After threads are up + Eval::init_NNUE(); UCI::loop(argc, argv); diff --git a/src/nnue/evaluate_nnue.cpp b/src/nnue/evaluate_nnue.cpp index 9775149a..96730e67 100644 --- a/src/nnue/evaluate_nnue.cpp +++ b/src/nnue/evaluate_nnue.cpp @@ -127,18 +127,16 @@ namespace Eval::NNUE { } // Load the evaluation function file - void load_eval(const std::string& evalFile) { + bool load_eval_file(const std::string& evalFile) { Initialize(); fileName = evalFile; std::ifstream stream(evalFile, std::ios::binary); + const bool result = ReadParameters(stream); - if (!result) - std::cout << "Error! " << fileName << " not found or wrong format" << std::endl; - else - std::cout << "info string NNUE " << fileName << " found & loaded" << std::endl; + return result; } // Evaluation function. Perform differential calculation. diff --git a/src/search.cpp b/src/search.cpp index 17ccab92..f4562846 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -227,6 +227,8 @@ void MainThread::search() { Time.init(Limits, us, rootPos.game_ply()); TT.new_search(); + Eval::verify_NNUE(); + if (rootMoves.empty()) { rootMoves.emplace_back(MOVE_NONE); diff --git a/src/uci.cpp b/src/uci.cpp index 5a8e0cb7..31fe3e6f 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -86,6 +86,9 @@ namespace { StateListPtr states(new std::deque(1)); Position p; p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main()); + + Eval::verify_NNUE(); + sync_cout << "\n" << Eval::trace(p) << sync_endl; } @@ -181,12 +184,7 @@ namespace { } else if (token == "setoption") setoption(is); else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") - { - init_nnue(Options["EvalFile"]); - Search::clear(); - elapsed = now(); // initialization may take some time - } + else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while } elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' @@ -224,17 +222,6 @@ namespace { } // namespace -void UCI::init_nnue(const std::string& evalFile) -{ - if (Options["Use NNUE"] && !UCI::load_eval_finished) - { - // Load evaluation function from a file - Eval::NNUE::load_eval(evalFile); - UCI::load_eval_finished = true; - } -} - - /// UCI::loop() waits for a command from stdin, parses it and calls the appropriate /// function. Also intercepts EOF from stdin to ensure gracefully exiting if the /// GUI dies unexpectedly. When called with some command line arguments, e.g. to @@ -249,9 +236,6 @@ void UCI::loop(int argc, char* argv[]) { pos.set(StartFEN, false, &states->back(), Threads.main()); - if (argc > 1) - init_nnue(Options["EvalFile"]); - for (int i = 1; i < argc; ++i) cmd += std::string(argv[i]) + " "; @@ -283,16 +267,8 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "setoption") setoption(is); else if (token == "go") go(pos, is, states); else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") - { - init_nnue(Options["EvalFile"]); - Search::clear(); - } - else if (token == "isready") - { - init_nnue(Options["EvalFile"]); - sync_cout << "readyok" << sync_endl; - } + else if (token == "ucinewgame") Search::clear(); + else if (token == "isready") sync_cout << "readyok" << sync_endl; // Additional custom non-UCI commands, mainly for debugging. // Do not use these commands during a search! diff --git a/src/uci.h b/src/uci.h index e5ebe144..ad954d9f 100644 --- a/src/uci.h +++ b/src/uci.h @@ -76,10 +76,6 @@ std::string pv(const Position& pos, Depth depth, Value alpha, Value beta); std::string wdl(Value v, int ply); Move to_move(const Position& pos, std::string& str); -void init_nnue(const std::string& evalFile); - -extern bool load_eval_finished; - } // namespace UCI extern UCI::OptionsMap Options; diff --git a/src/ucioption.cpp b/src/ucioption.cpp index 9029a3a1..b4b76d7b 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -42,23 +42,8 @@ void on_hash_size(const Option& o) { TT.resize(size_t(o)); } void on_logger(const Option& o) { start_logger(o); } void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_tb_path(const Option& o) { Tablebases::init(o); } - -void on_use_nnue(const Option& o) { - - if (o) - std::cout << "info string NNUE eval used" << std::endl; - else - std::cout << "info string classic eval used" << std::endl; - - Eval::useNNUE = o; - init_nnue(Options["EvalFile"]); -} - -void on_eval_file(const Option& o) { - - load_eval_finished = false; - init_nnue(o); -} +void on_use_NNUE(const Option& ) { Eval::init_NNUE(); } +void on_eval_file(const Option& ) { Eval::init_NNUE(); } /// Our case insensitive less() function as required by UCI protocol bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { @@ -95,7 +80,7 @@ void init(OptionsMap& o) { o["SyzygyProbeDepth"] << Option(1, 1, 100); o["Syzygy50MoveRule"] << Option(true); o["SyzygyProbeLimit"] << Option(7, 0, 7); - o["Use NNUE"] << Option(false, on_use_nnue); + o["Use NNUE"] << Option(false, on_use_NNUE); o["EvalFile"] << Option("nn-c157e0a5755b.nnue", on_eval_file); } @@ -205,5 +190,4 @@ Option& Option::operator=(const string& v) { return *this; } -bool load_eval_finished = false; } // namespace UCI