1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-04-30 16:53:09 +00:00

Revisit NNUE initialization

this revisits the initialization of NNUE, basically only changing
the state on the UCI options 'Use NNUE' and 'EvalFile' calling init_NNUE(),
which sets the Eval::useNNUE variable, and loads the network if needed
(i.e. useNNUE is true and the same network is not yet loaded)

init_NNUE is silent (i.e. no info strings), so that it can be called at startup
without confusing certain GUIs.

An error message on wrong setting when asking for (i.e. the net failed to load),
is delayed to the point where everything must be consistent (start of search or eval).
The engine will stop if the settings are wrong at that point.

Also works if the default value of Use NNUE would become true.
This commit is contained in:
Joost VandeVondele 2020-08-02 15:31:51 +02:00
parent e45d4f1b65
commit 18686e29c7
8 changed files with 55 additions and 68 deletions

View file

@ -20,18 +20,44 @@
#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <cstring> // For std::memset
#include <iomanip>
#include <sstream>
#include <iostream>
#include "bitboard.h"
#include "evaluate.h"
#include "material.h"
#include "pawns.h"
#include "thread.h"
#include "uci.h"
namespace Eval {
bool useNNUE;
bool useNNUE;
std::string eval_file_loaded="None";
void init_NNUE() {
useNNUE = Options["Use NNUE"];
std::string eval_file = std::string(Options["EvalFile"]);
if (useNNUE && eval_file_loaded != eval_file)
if (Eval::NNUE::load_eval_file(eval_file))
eval_file_loaded = eval_file;
}
void verify_NNUE() {
std::string eval_file = std::string(Options["EvalFile"]);
if (useNNUE && eval_file_loaded != eval_file)
{
std::cerr << "Use of NNUE evaluation, but the file " << eval_file << " was not loaded successfully. "
<< "These network evaluation parameters must be available, compatible with this version of the code. "
<< "The UCI option EvalFile might need to specify the full path, including the directory/folder name, to the file." << std::endl;
std::exit(EXIT_FAILURE);
}
}
}
namespace Trace {

View file

@ -29,18 +29,22 @@ class Position;
namespace Eval {
std::string trace(const Position& pos);
Value evaluate(const Position& pos);
extern bool useNNUE;
std::string trace(const Position& pos);
Value evaluate(const Position& pos);
namespace NNUE {
extern bool useNNUE;
extern std::string eval_file_loaded;
void init_NNUE();
void verify_NNUE();
Value evaluate(const Position& pos);
Value compute_eval(const Position& pos);
void update_eval(const Position& pos);
void load_eval(const std::string& evalFile);
namespace NNUE {
} // namespace NNUE
Value evaluate(const Position& pos);
Value compute_eval(const Position& pos);
void update_eval(const Position& pos);
bool load_eval_file(const std::string& evalFile);
} // namespace NNUE
} // namespace Eval

View file

@ -46,6 +46,7 @@ int main(int argc, char* argv[]) {
Endgames::init();
Threads.set(size_t(Options["Threads"]));
Search::clear(); // After threads are up
Eval::init_NNUE();
UCI::loop(argc, argv);

View file

@ -127,18 +127,16 @@ namespace Eval::NNUE {
}
// Load the evaluation function file
void load_eval(const std::string& evalFile) {
bool load_eval_file(const std::string& evalFile) {
Initialize();
fileName = evalFile;
std::ifstream stream(evalFile, std::ios::binary);
const bool result = ReadParameters(stream);
if (!result)
std::cout << "Error! " << fileName << " not found or wrong format" << std::endl;
else
std::cout << "info string NNUE " << fileName << " found & loaded" << std::endl;
return result;
}
// Evaluation function. Perform differential calculation.

View file

@ -227,6 +227,8 @@ void MainThread::search() {
Time.init(Limits, us, rootPos.game_ply());
TT.new_search();
Eval::verify_NNUE();
if (rootMoves.empty())
{
rootMoves.emplace_back(MOVE_NONE);

View file

@ -86,6 +86,9 @@ namespace {
StateListPtr states(new std::deque<StateInfo>(1));
Position p;
p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main());
Eval::verify_NNUE();
sync_cout << "\n" << Eval::trace(p) << sync_endl;
}
@ -181,12 +184,7 @@ namespace {
}
else if (token == "setoption") setoption(is);
else if (token == "position") position(pos, is, states);
else if (token == "ucinewgame")
{
init_nnue(Options["EvalFile"]);
Search::clear();
elapsed = now(); // initialization may take some time
}
else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while
}
elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero'
@ -224,17 +222,6 @@ namespace {
} // namespace
void UCI::init_nnue(const std::string& evalFile)
{
if (Options["Use NNUE"] && !UCI::load_eval_finished)
{
// Load evaluation function from a file
Eval::NNUE::load_eval(evalFile);
UCI::load_eval_finished = true;
}
}
/// UCI::loop() waits for a command from stdin, parses it and calls the appropriate
/// function. Also intercepts EOF from stdin to ensure gracefully exiting if the
/// GUI dies unexpectedly. When called with some command line arguments, e.g. to
@ -249,9 +236,6 @@ void UCI::loop(int argc, char* argv[]) {
pos.set(StartFEN, false, &states->back(), Threads.main());
if (argc > 1)
init_nnue(Options["EvalFile"]);
for (int i = 1; i < argc; ++i)
cmd += std::string(argv[i]) + " ";
@ -283,16 +267,8 @@ void UCI::loop(int argc, char* argv[]) {
else if (token == "setoption") setoption(is);
else if (token == "go") go(pos, is, states);
else if (token == "position") position(pos, is, states);
else if (token == "ucinewgame")
{
init_nnue(Options["EvalFile"]);
Search::clear();
}
else if (token == "isready")
{
init_nnue(Options["EvalFile"]);
sync_cout << "readyok" << sync_endl;
}
else if (token == "ucinewgame") Search::clear();
else if (token == "isready") sync_cout << "readyok" << sync_endl;
// Additional custom non-UCI commands, mainly for debugging.
// Do not use these commands during a search!

View file

@ -76,10 +76,6 @@ std::string pv(const Position& pos, Depth depth, Value alpha, Value beta);
std::string wdl(Value v, int ply);
Move to_move(const Position& pos, std::string& str);
void init_nnue(const std::string& evalFile);
extern bool load_eval_finished;
} // namespace UCI
extern UCI::OptionsMap Options;

View file

@ -42,23 +42,8 @@ void on_hash_size(const Option& o) { TT.resize(size_t(o)); }
void on_logger(const Option& o) { start_logger(o); }
void on_threads(const Option& o) { Threads.set(size_t(o)); }
void on_tb_path(const Option& o) { Tablebases::init(o); }
void on_use_nnue(const Option& o) {
if (o)
std::cout << "info string NNUE eval used" << std::endl;
else
std::cout << "info string classic eval used" << std::endl;
Eval::useNNUE = o;
init_nnue(Options["EvalFile"]);
}
void on_eval_file(const Option& o) {
load_eval_finished = false;
init_nnue(o);
}
void on_use_NNUE(const Option& ) { Eval::init_NNUE(); }
void on_eval_file(const Option& ) { Eval::init_NNUE(); }
/// Our case insensitive less() function as required by UCI protocol
bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const {
@ -95,7 +80,7 @@ void init(OptionsMap& o) {
o["SyzygyProbeDepth"] << Option(1, 1, 100);
o["Syzygy50MoveRule"] << Option(true);
o["SyzygyProbeLimit"] << Option(7, 0, 7);
o["Use NNUE"] << Option(false, on_use_nnue);
o["Use NNUE"] << Option(false, on_use_NNUE);
o["EvalFile"] << Option("nn-c157e0a5755b.nnue", on_eval_file);
}
@ -205,5 +190,4 @@ Option& Option::operator=(const string& v) {
return *this;
}
bool load_eval_finished = false;
} // namespace UCI