1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-05-01 01:03:09 +00:00

Revisit NNUE initialization

this revisits the initialization of NNUE, basically only changing
the state on the UCI options 'Use NNUE' and 'EvalFile' calling init_NNUE(),
which sets the Eval::useNNUE variable, and loads the network if needed
(i.e. useNNUE is true and the same network is not yet loaded)

init_NNUE is silent (i.e. no info strings), so that it can be called at startup
without confusing certain GUIs.

An error message on wrong setting when asking for (i.e. the net failed to load),
is delayed to the point where everything must be consistent (start of search or eval).
The engine will stop if the settings are wrong at that point.

Also works if the default value of Use NNUE would become true.
This commit is contained in:
Joost VandeVondele 2020-08-02 15:31:51 +02:00
parent e45d4f1b65
commit 18686e29c7
8 changed files with 55 additions and 68 deletions

View file

@ -20,18 +20,44 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cstdlib>
#include <cstring> // For std::memset #include <cstring> // For std::memset
#include <iomanip> #include <iomanip>
#include <sstream> #include <sstream>
#include <iostream>
#include "bitboard.h" #include "bitboard.h"
#include "evaluate.h" #include "evaluate.h"
#include "material.h" #include "material.h"
#include "pawns.h" #include "pawns.h"
#include "thread.h" #include "thread.h"
#include "uci.h"
namespace Eval { namespace Eval {
bool useNNUE;
bool useNNUE;
std::string eval_file_loaded="None";
void init_NNUE() {
useNNUE = Options["Use NNUE"];
std::string eval_file = std::string(Options["EvalFile"]);
if (useNNUE && eval_file_loaded != eval_file)
if (Eval::NNUE::load_eval_file(eval_file))
eval_file_loaded = eval_file;
}
void verify_NNUE() {
std::string eval_file = std::string(Options["EvalFile"]);
if (useNNUE && eval_file_loaded != eval_file)
{
std::cerr << "Use of NNUE evaluation, but the file " << eval_file << " was not loaded successfully. "
<< "These network evaluation parameters must be available, compatible with this version of the code. "
<< "The UCI option EvalFile might need to specify the full path, including the directory/folder name, to the file." << std::endl;
std::exit(EXIT_FAILURE);
}
}
} }
namespace Trace { namespace Trace {

View file

@ -29,18 +29,22 @@ class Position;
namespace Eval { namespace Eval {
std::string trace(const Position& pos); std::string trace(const Position& pos);
Value evaluate(const Position& pos); Value evaluate(const Position& pos);
extern bool useNNUE;
namespace NNUE { extern bool useNNUE;
extern std::string eval_file_loaded;
void init_NNUE();
void verify_NNUE();
Value evaluate(const Position& pos); namespace NNUE {
Value compute_eval(const Position& pos);
void update_eval(const Position& pos);
void load_eval(const std::string& evalFile);
} // namespace NNUE Value evaluate(const Position& pos);
Value compute_eval(const Position& pos);
void update_eval(const Position& pos);
bool load_eval_file(const std::string& evalFile);
} // namespace NNUE
} // namespace Eval } // namespace Eval

View file

@ -46,6 +46,7 @@ int main(int argc, char* argv[]) {
Endgames::init(); Endgames::init();
Threads.set(size_t(Options["Threads"])); Threads.set(size_t(Options["Threads"]));
Search::clear(); // After threads are up Search::clear(); // After threads are up
Eval::init_NNUE();
UCI::loop(argc, argv); UCI::loop(argc, argv);

View file

@ -127,18 +127,16 @@ namespace Eval::NNUE {
} }
// Load the evaluation function file // Load the evaluation function file
void load_eval(const std::string& evalFile) { bool load_eval_file(const std::string& evalFile) {
Initialize(); Initialize();
fileName = evalFile; fileName = evalFile;
std::ifstream stream(evalFile, std::ios::binary); std::ifstream stream(evalFile, std::ios::binary);
const bool result = ReadParameters(stream); const bool result = ReadParameters(stream);
if (!result) return result;
std::cout << "Error! " << fileName << " not found or wrong format" << std::endl;
else
std::cout << "info string NNUE " << fileName << " found & loaded" << std::endl;
} }
// Evaluation function. Perform differential calculation. // Evaluation function. Perform differential calculation.

View file

@ -227,6 +227,8 @@ void MainThread::search() {
Time.init(Limits, us, rootPos.game_ply()); Time.init(Limits, us, rootPos.game_ply());
TT.new_search(); TT.new_search();
Eval::verify_NNUE();
if (rootMoves.empty()) if (rootMoves.empty())
{ {
rootMoves.emplace_back(MOVE_NONE); rootMoves.emplace_back(MOVE_NONE);

View file

@ -86,6 +86,9 @@ namespace {
StateListPtr states(new std::deque<StateInfo>(1)); StateListPtr states(new std::deque<StateInfo>(1));
Position p; Position p;
p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main()); p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main());
Eval::verify_NNUE();
sync_cout << "\n" << Eval::trace(p) << sync_endl; sync_cout << "\n" << Eval::trace(p) << sync_endl;
} }
@ -181,12 +184,7 @@ namespace {
} }
else if (token == "setoption") setoption(is); else if (token == "setoption") setoption(is);
else if (token == "position") position(pos, is, states); else if (token == "position") position(pos, is, states);
else if (token == "ucinewgame") else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while
{
init_nnue(Options["EvalFile"]);
Search::clear();
elapsed = now(); // initialization may take some time
}
} }
elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero'
@ -224,17 +222,6 @@ namespace {
} // namespace } // namespace
void UCI::init_nnue(const std::string& evalFile)
{
if (Options["Use NNUE"] && !UCI::load_eval_finished)
{
// Load evaluation function from a file
Eval::NNUE::load_eval(evalFile);
UCI::load_eval_finished = true;
}
}
/// UCI::loop() waits for a command from stdin, parses it and calls the appropriate /// UCI::loop() waits for a command from stdin, parses it and calls the appropriate
/// function. Also intercepts EOF from stdin to ensure gracefully exiting if the /// function. Also intercepts EOF from stdin to ensure gracefully exiting if the
/// GUI dies unexpectedly. When called with some command line arguments, e.g. to /// GUI dies unexpectedly. When called with some command line arguments, e.g. to
@ -249,9 +236,6 @@ void UCI::loop(int argc, char* argv[]) {
pos.set(StartFEN, false, &states->back(), Threads.main()); pos.set(StartFEN, false, &states->back(), Threads.main());
if (argc > 1)
init_nnue(Options["EvalFile"]);
for (int i = 1; i < argc; ++i) for (int i = 1; i < argc; ++i)
cmd += std::string(argv[i]) + " "; cmd += std::string(argv[i]) + " ";
@ -283,16 +267,8 @@ void UCI::loop(int argc, char* argv[]) {
else if (token == "setoption") setoption(is); else if (token == "setoption") setoption(is);
else if (token == "go") go(pos, is, states); else if (token == "go") go(pos, is, states);
else if (token == "position") position(pos, is, states); else if (token == "position") position(pos, is, states);
else if (token == "ucinewgame") else if (token == "ucinewgame") Search::clear();
{ else if (token == "isready") sync_cout << "readyok" << sync_endl;
init_nnue(Options["EvalFile"]);
Search::clear();
}
else if (token == "isready")
{
init_nnue(Options["EvalFile"]);
sync_cout << "readyok" << sync_endl;
}
// Additional custom non-UCI commands, mainly for debugging. // Additional custom non-UCI commands, mainly for debugging.
// Do not use these commands during a search! // Do not use these commands during a search!

View file

@ -76,10 +76,6 @@ std::string pv(const Position& pos, Depth depth, Value alpha, Value beta);
std::string wdl(Value v, int ply); std::string wdl(Value v, int ply);
Move to_move(const Position& pos, std::string& str); Move to_move(const Position& pos, std::string& str);
void init_nnue(const std::string& evalFile);
extern bool load_eval_finished;
} // namespace UCI } // namespace UCI
extern UCI::OptionsMap Options; extern UCI::OptionsMap Options;

View file

@ -42,23 +42,8 @@ void on_hash_size(const Option& o) { TT.resize(size_t(o)); }
void on_logger(const Option& o) { start_logger(o); } void on_logger(const Option& o) { start_logger(o); }
void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_threads(const Option& o) { Threads.set(size_t(o)); }
void on_tb_path(const Option& o) { Tablebases::init(o); } void on_tb_path(const Option& o) { Tablebases::init(o); }
void on_use_NNUE(const Option& ) { Eval::init_NNUE(); }
void on_use_nnue(const Option& o) { void on_eval_file(const Option& ) { Eval::init_NNUE(); }
if (o)
std::cout << "info string NNUE eval used" << std::endl;
else
std::cout << "info string classic eval used" << std::endl;
Eval::useNNUE = o;
init_nnue(Options["EvalFile"]);
}
void on_eval_file(const Option& o) {
load_eval_finished = false;
init_nnue(o);
}
/// Our case insensitive less() function as required by UCI protocol /// Our case insensitive less() function as required by UCI protocol
bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const {
@ -95,7 +80,7 @@ void init(OptionsMap& o) {
o["SyzygyProbeDepth"] << Option(1, 1, 100); o["SyzygyProbeDepth"] << Option(1, 1, 100);
o["Syzygy50MoveRule"] << Option(true); o["Syzygy50MoveRule"] << Option(true);
o["SyzygyProbeLimit"] << Option(7, 0, 7); o["SyzygyProbeLimit"] << Option(7, 0, 7);
o["Use NNUE"] << Option(false, on_use_nnue); o["Use NNUE"] << Option(false, on_use_NNUE);
o["EvalFile"] << Option("nn-c157e0a5755b.nnue", on_eval_file); o["EvalFile"] << Option("nn-c157e0a5755b.nnue", on_eval_file);
} }
@ -205,5 +190,4 @@ Option& Option::operator=(const string& v) {
return *this; return *this;
} }
bool load_eval_finished = false;
} // namespace UCI } // namespace UCI