diff --git a/src/evaluate.cpp b/src/evaluate.cpp index 5d42d554..f562339f 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -20,18 +20,44 @@ #include #include +#include #include // For std::memset #include #include +#include #include "bitboard.h" #include "evaluate.h" #include "material.h" #include "pawns.h" #include "thread.h" +#include "uci.h" namespace Eval { - bool useNNUE; + + bool useNNUE; + std::string eval_file_loaded="None"; + + void init_NNUE() { + + useNNUE = Options["Use NNUE"]; + std::string eval_file = std::string(Options["EvalFile"]); + if (useNNUE && eval_file_loaded != eval_file) + if (Eval::NNUE::load_eval_file(eval_file)) + eval_file_loaded = eval_file; + } + + void verify_NNUE() { + + std::string eval_file = std::string(Options["EvalFile"]); + if (useNNUE && eval_file_loaded != eval_file) + { + std::cerr << "Use of NNUE evaluation, but the file " << eval_file << " was not loaded successfully. " + << "These network evaluation parameters must be available, compatible with this version of the code. " + << "The UCI option EvalFile might need to specify the full path, including the directory/folder name, to the file." << std::endl; + std::exit(EXIT_FAILURE); + } + } } namespace Trace { diff --git a/src/evaluate.h b/src/evaluate.h index 293b4ce4..115bc49a 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -29,18 +29,22 @@ class Position; namespace Eval { -std::string trace(const Position& pos); -Value evaluate(const Position& pos); -extern bool useNNUE; + std::string trace(const Position& pos); + Value evaluate(const Position& pos); -namespace NNUE { + extern bool useNNUE; + extern std::string eval_file_loaded; + void init_NNUE(); + void verify_NNUE(); -Value evaluate(const Position& pos); -Value compute_eval(const Position& pos); -void update_eval(const Position& pos); -void load_eval(const std::string& evalFile); + namespace NNUE { -} // namespace NNUE + Value evaluate(const Position& pos); + Value compute_eval(const Position& pos); + void update_eval(const Position& pos); + bool load_eval_file(const std::string& evalFile); + + } // namespace NNUE } // namespace Eval diff --git a/src/main.cpp b/src/main.cpp index fafefee2..024552e8 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -46,6 +46,7 @@ int main(int argc, char* argv[]) { Endgames::init(); Threads.set(size_t(Options["Threads"])); Search::clear(); // After threads are up + Eval::init_NNUE(); UCI::loop(argc, argv); diff --git a/src/nnue/evaluate_nnue.cpp b/src/nnue/evaluate_nnue.cpp index 9775149a..96730e67 100644 --- a/src/nnue/evaluate_nnue.cpp +++ b/src/nnue/evaluate_nnue.cpp @@ -127,18 +127,16 @@ namespace Eval::NNUE { } // Load the evaluation function file - void load_eval(const std::string& evalFile) { + bool load_eval_file(const std::string& evalFile) { Initialize(); fileName = evalFile; std::ifstream stream(evalFile, std::ios::binary); + const bool result = ReadParameters(stream); - if (!result) - std::cout << "Error! " << fileName << " not found or wrong format" << std::endl; - else - std::cout << "info string NNUE " << fileName << " found & loaded" << std::endl; + return result; } // Evaluation function. Perform differential calculation. diff --git a/src/search.cpp b/src/search.cpp index 17ccab92..f4562846 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -227,6 +227,8 @@ void MainThread::search() { Time.init(Limits, us, rootPos.game_ply()); TT.new_search(); + Eval::verify_NNUE(); + if (rootMoves.empty()) { rootMoves.emplace_back(MOVE_NONE); diff --git a/src/uci.cpp b/src/uci.cpp index 5a8e0cb7..31fe3e6f 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -86,6 +86,9 @@ namespace { StateListPtr states(new std::deque(1)); Position p; p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main()); + + Eval::verify_NNUE(); + sync_cout << "\n" << Eval::trace(p) << sync_endl; } @@ -181,12 +184,7 @@ namespace { } else if (token == "setoption") setoption(is); else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") - { - init_nnue(Options["EvalFile"]); - Search::clear(); - elapsed = now(); // initialization may take some time - } + else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while } elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' @@ -224,17 +222,6 @@ namespace { } // namespace -void UCI::init_nnue(const std::string& evalFile) -{ - if (Options["Use NNUE"] && !UCI::load_eval_finished) - { - // Load evaluation function from a file - Eval::NNUE::load_eval(evalFile); - UCI::load_eval_finished = true; - } -} - - /// UCI::loop() waits for a command from stdin, parses it and calls the appropriate /// function. Also intercepts EOF from stdin to ensure gracefully exiting if the /// GUI dies unexpectedly. When called with some command line arguments, e.g. to @@ -249,9 +236,6 @@ void UCI::loop(int argc, char* argv[]) { pos.set(StartFEN, false, &states->back(), Threads.main()); - if (argc > 1) - init_nnue(Options["EvalFile"]); - for (int i = 1; i < argc; ++i) cmd += std::string(argv[i]) + " "; @@ -283,16 +267,8 @@ void UCI::loop(int argc, char* argv[]) { else if (token == "setoption") setoption(is); else if (token == "go") go(pos, is, states); else if (token == "position") position(pos, is, states); - else if (token == "ucinewgame") - { - init_nnue(Options["EvalFile"]); - Search::clear(); - } - else if (token == "isready") - { - init_nnue(Options["EvalFile"]); - sync_cout << "readyok" << sync_endl; - } + else if (token == "ucinewgame") Search::clear(); + else if (token == "isready") sync_cout << "readyok" << sync_endl; // Additional custom non-UCI commands, mainly for debugging. // Do not use these commands during a search! diff --git a/src/uci.h b/src/uci.h index e5ebe144..ad954d9f 100644 --- a/src/uci.h +++ b/src/uci.h @@ -76,10 +76,6 @@ std::string pv(const Position& pos, Depth depth, Value alpha, Value beta); std::string wdl(Value v, int ply); Move to_move(const Position& pos, std::string& str); -void init_nnue(const std::string& evalFile); - -extern bool load_eval_finished; - } // namespace UCI extern UCI::OptionsMap Options; diff --git a/src/ucioption.cpp b/src/ucioption.cpp index 9029a3a1..b4b76d7b 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -42,23 +42,8 @@ void on_hash_size(const Option& o) { TT.resize(size_t(o)); } void on_logger(const Option& o) { start_logger(o); } void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_tb_path(const Option& o) { Tablebases::init(o); } - -void on_use_nnue(const Option& o) { - - if (o) - std::cout << "info string NNUE eval used" << std::endl; - else - std::cout << "info string classic eval used" << std::endl; - - Eval::useNNUE = o; - init_nnue(Options["EvalFile"]); -} - -void on_eval_file(const Option& o) { - - load_eval_finished = false; - init_nnue(o); -} +void on_use_NNUE(const Option& ) { Eval::init_NNUE(); } +void on_eval_file(const Option& ) { Eval::init_NNUE(); } /// Our case insensitive less() function as required by UCI protocol bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { @@ -95,7 +80,7 @@ void init(OptionsMap& o) { o["SyzygyProbeDepth"] << Option(1, 1, 100); o["Syzygy50MoveRule"] << Option(true); o["SyzygyProbeLimit"] << Option(7, 0, 7); - o["Use NNUE"] << Option(false, on_use_nnue); + o["Use NNUE"] << Option(false, on_use_NNUE); o["EvalFile"] << Option("nn-c157e0a5755b.nnue", on_eval_file); } @@ -205,5 +190,4 @@ Option& Option::operator=(const string& v) { return *this; } -bool load_eval_finished = false; } // namespace UCI