diff --git a/src/evaluate.cpp b/src/evaluate.cpp index e3f60f9c..e220b92a 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -23,10 +23,10 @@ #include #include #include -#include #include #include #include +#include #include #include "incbin/incbin.h" @@ -62,9 +62,10 @@ namespace Stockfish { namespace Eval { -std::string currentEvalFileName[2] = {"None", "None"}; -const std::string EvFiles[2] = {"EvalFile", "EvalFileSmall"}; -const std::string EvFileNames[2] = {EvalFileDefaultNameBig, EvalFileDefaultNameSmall}; +std::unordered_map EvalFiles = { + {NNUE::Big, {"EvalFile", EvalFileDefaultNameBig, "None"}}, + {NNUE::Small, {"EvalFileSmall", EvalFileDefaultNameSmall, "None"}}}; + // Tries to load a NNUE network at startup time, or when the engine // receives a UCI command "setoption name EvalFile value nn-[a-z0-9]{12}.nnue" @@ -75,13 +76,16 @@ const std::string EvFileNames[2] = {EvalFileDefaultNameBig, EvalFileDefa // variable to have the engine search in a special directory in their distro. void NNUE::init() { - for (NetSize netSize : {Big, Small}) + for (auto& [netSize, evalFile] : EvalFiles) { - // change after fishtest supports EvalFileSmall - std::string eval_file = - std::string(netSize == Small ? EvalFileDefaultNameSmall : Options[EvFiles[netSize]]); - if (eval_file.empty()) - eval_file = EvFileNames[netSize]; + // Replace with + // Options[evalFile.option_name] + // once fishtest supports the uci option EvalFileSmall + std::string user_eval_file = + netSize == Small ? evalFile.default_name : Options[evalFile.option_name]; + + if (user_eval_file.empty()) + user_eval_file = evalFile.default_name; #if defined(DEFAULT_NNUE_DIRECTORY) std::vector dirs = {"", "", CommandLine::binaryDirectory, @@ -92,16 +96,16 @@ void NNUE::init() { for (const std::string& directory : dirs) { - if (currentEvalFileName[netSize] != eval_file) + if (evalFile.selected_name != user_eval_file) { if (directory != "") { - std::ifstream stream(directory + eval_file, std::ios::binary); - if (NNUE::load_eval(eval_file, stream, netSize)) - currentEvalFileName[netSize] = eval_file; + std::ifstream stream(directory + user_eval_file, std::ios::binary); + if (NNUE::load_eval(user_eval_file, stream, netSize)) + evalFile.selected_name = user_eval_file; } - if (directory == "" && eval_file == EvFileNames[netSize]) + if (directory == "" && user_eval_file == evalFile.default_name) { // C++ way to prepare a buffer for a memory stream class MemoryBuffer: public std::basic_streambuf { @@ -120,8 +124,8 @@ void NNUE::init() { (void) gEmbeddedNNUESmallEnd; std::istream stream(&buffer); - if (NNUE::load_eval(eval_file, stream, netSize)) - currentEvalFileName[netSize] = eval_file; + if (NNUE::load_eval(user_eval_file, stream, netSize)) + evalFile.selected_name = user_eval_file; } } } @@ -131,24 +135,27 @@ void NNUE::init() { // Verifies that the last net used was loaded successfully void NNUE::verify() { - for (NetSize netSize : {Big, Small}) + for (const auto& [netSize, evalFile] : EvalFiles) { - // change after fishtest supports EvalFileSmall - std::string eval_file = - std::string(netSize == Small ? EvalFileDefaultNameSmall : Options[EvFiles[netSize]]); - if (eval_file.empty()) - eval_file = EvFileNames[netSize]; + // Replace with + // Options[evalFile.option_name] + // once fishtest supports the uci option EvalFileSmall + std::string user_eval_file = + netSize == Small ? evalFile.default_name : Options[evalFile.option_name]; + if (user_eval_file.empty()) + user_eval_file = evalFile.default_name; - if (currentEvalFileName[netSize] != eval_file) + if (evalFile.selected_name != user_eval_file) { std::string msg1 = "Network evaluation parameters compatible with the engine must be available."; - std::string msg2 = "The network file " + eval_file + " was not loaded successfully."; + std::string msg2 = + "The network file " + user_eval_file + " was not loaded successfully."; std::string msg3 = "The UCI option EvalFile might need to specify the full path, " "including the directory name, to the network file."; std::string msg4 = "The default net can be downloaded from: " "https://tests.stockfishchess.org/api/nn/" - + std::string(EvFileNames[netSize]); + + evalFile.default_name; std::string msg5 = "The engine will be terminated now."; sync_cout << "info string ERROR: " << msg1 << sync_endl; @@ -160,7 +167,7 @@ void NNUE::verify() { exit(EXIT_FAILURE); } - sync_cout << "info string NNUE evaluation using " << eval_file << sync_endl; + sync_cout << "info string NNUE evaluation using " << user_eval_file << sync_endl; } } } diff --git a/src/evaluate.h b/src/evaluate.h index ce608735..f712d8e6 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -20,6 +20,7 @@ #define EVALUATE_H_INCLUDED #include +#include #include "types.h" @@ -34,8 +35,6 @@ std::string trace(Position& pos); int simple_eval(const Position& pos, Color c); Value evaluate(const Position& pos); -extern std::string currentEvalFileName[2]; - // The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue // for the build process (profile-build and fishtest) to work. Do not change the // name of the macro, as it is used in the Makefile. @@ -44,11 +43,21 @@ extern std::string currentEvalFileName[2]; namespace NNUE { +enum NetSize : int; + void init(); void verify(); } // namespace NNUE +struct EvalFile { + std::string option_name; + std::string default_name; + std::string selected_name; +}; + +extern std::unordered_map EvalFiles; + } // namespace Eval } // namespace Stockfish diff --git a/src/nnue/evaluate_nnue.cpp b/src/nnue/evaluate_nnue.cpp index 7a3f6877..86fe5230 100644 --- a/src/nnue/evaluate_nnue.cpp +++ b/src/nnue/evaluate_nnue.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include "../evaluate.h" #include "../misc.h" @@ -449,7 +450,7 @@ bool save_eval(const std::optional& filename, NetSize netSize) { actualFilename = filename.value(); else { - if (currentEvalFileName[netSize] + if (EvalFiles.at(netSize).selected_name != (netSize == Small ? EvalFileDefaultNameSmall : EvalFileDefaultNameBig)) { msg = "Failed to export a net. " diff --git a/src/nnue/nnue_architecture.h b/src/nnue/nnue_architecture.h index 949f2d86..b222ab99 100644 --- a/src/nnue/nnue_architecture.h +++ b/src/nnue/nnue_architecture.h @@ -37,7 +37,7 @@ namespace Stockfish::Eval::NNUE { // Input features used in evaluation function using FeatureSet = Features::HalfKAv2_hm; -enum NetSize { +enum NetSize : int { Big, Small };