mirror of
https://github.com/sockspls/badfish
synced 2025-04-30 16:53:09 +00:00
Revisit NNUE initialization
this revisits the initialization of NNUE, basically only changing the state on the UCI options 'Use NNUE' and 'EvalFile' calling init_NNUE(), which sets the Eval::useNNUE variable, and loads the network if needed (i.e. useNNUE is true and the same network is not yet loaded) init_NNUE is silent (i.e. no info strings), so that it can be called at startup without confusing certain GUIs. An error message on wrong setting when asking for (i.e. the net failed to load), is delayed to the point where everything must be consistent (start of search or eval). The engine will stop if the settings are wrong at that point. Also works if the default value of Use NNUE would become true.
This commit is contained in:
parent
e45d4f1b65
commit
18686e29c7
8 changed files with 55 additions and 68 deletions
|
@ -20,18 +20,44 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring> // For std::memset
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "bitboard.h"
|
||||
#include "evaluate.h"
|
||||
#include "material.h"
|
||||
#include "pawns.h"
|
||||
#include "thread.h"
|
||||
#include "uci.h"
|
||||
|
||||
namespace Eval {
|
||||
bool useNNUE;
|
||||
|
||||
bool useNNUE;
|
||||
std::string eval_file_loaded="None";
|
||||
|
||||
void init_NNUE() {
|
||||
|
||||
useNNUE = Options["Use NNUE"];
|
||||
std::string eval_file = std::string(Options["EvalFile"]);
|
||||
if (useNNUE && eval_file_loaded != eval_file)
|
||||
if (Eval::NNUE::load_eval_file(eval_file))
|
||||
eval_file_loaded = eval_file;
|
||||
}
|
||||
|
||||
void verify_NNUE() {
|
||||
|
||||
std::string eval_file = std::string(Options["EvalFile"]);
|
||||
if (useNNUE && eval_file_loaded != eval_file)
|
||||
{
|
||||
std::cerr << "Use of NNUE evaluation, but the file " << eval_file << " was not loaded successfully. "
|
||||
<< "These network evaluation parameters must be available, compatible with this version of the code. "
|
||||
<< "The UCI option EvalFile might need to specify the full path, including the directory/folder name, to the file." << std::endl;
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace Trace {
|
||||
|
|
|
@ -29,18 +29,22 @@ class Position;
|
|||
|
||||
namespace Eval {
|
||||
|
||||
std::string trace(const Position& pos);
|
||||
Value evaluate(const Position& pos);
|
||||
extern bool useNNUE;
|
||||
std::string trace(const Position& pos);
|
||||
Value evaluate(const Position& pos);
|
||||
|
||||
namespace NNUE {
|
||||
extern bool useNNUE;
|
||||
extern std::string eval_file_loaded;
|
||||
void init_NNUE();
|
||||
void verify_NNUE();
|
||||
|
||||
Value evaluate(const Position& pos);
|
||||
Value compute_eval(const Position& pos);
|
||||
void update_eval(const Position& pos);
|
||||
void load_eval(const std::string& evalFile);
|
||||
namespace NNUE {
|
||||
|
||||
} // namespace NNUE
|
||||
Value evaluate(const Position& pos);
|
||||
Value compute_eval(const Position& pos);
|
||||
void update_eval(const Position& pos);
|
||||
bool load_eval_file(const std::string& evalFile);
|
||||
|
||||
} // namespace NNUE
|
||||
|
||||
} // namespace Eval
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ int main(int argc, char* argv[]) {
|
|||
Endgames::init();
|
||||
Threads.set(size_t(Options["Threads"]));
|
||||
Search::clear(); // After threads are up
|
||||
Eval::init_NNUE();
|
||||
|
||||
UCI::loop(argc, argv);
|
||||
|
||||
|
|
|
@ -127,18 +127,16 @@ namespace Eval::NNUE {
|
|||
}
|
||||
|
||||
// Load the evaluation function file
|
||||
void load_eval(const std::string& evalFile) {
|
||||
bool load_eval_file(const std::string& evalFile) {
|
||||
|
||||
Initialize();
|
||||
fileName = evalFile;
|
||||
|
||||
std::ifstream stream(evalFile, std::ios::binary);
|
||||
|
||||
const bool result = ReadParameters(stream);
|
||||
|
||||
if (!result)
|
||||
std::cout << "Error! " << fileName << " not found or wrong format" << std::endl;
|
||||
else
|
||||
std::cout << "info string NNUE " << fileName << " found & loaded" << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Evaluation function. Perform differential calculation.
|
||||
|
|
|
@ -227,6 +227,8 @@ void MainThread::search() {
|
|||
Time.init(Limits, us, rootPos.game_ply());
|
||||
TT.new_search();
|
||||
|
||||
Eval::verify_NNUE();
|
||||
|
||||
if (rootMoves.empty())
|
||||
{
|
||||
rootMoves.emplace_back(MOVE_NONE);
|
||||
|
|
36
src/uci.cpp
36
src/uci.cpp
|
@ -86,6 +86,9 @@ namespace {
|
|||
StateListPtr states(new std::deque<StateInfo>(1));
|
||||
Position p;
|
||||
p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main());
|
||||
|
||||
Eval::verify_NNUE();
|
||||
|
||||
sync_cout << "\n" << Eval::trace(p) << sync_endl;
|
||||
}
|
||||
|
||||
|
@ -181,12 +184,7 @@ namespace {
|
|||
}
|
||||
else if (token == "setoption") setoption(is);
|
||||
else if (token == "position") position(pos, is, states);
|
||||
else if (token == "ucinewgame")
|
||||
{
|
||||
init_nnue(Options["EvalFile"]);
|
||||
Search::clear();
|
||||
elapsed = now(); // initialization may take some time
|
||||
}
|
||||
else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while
|
||||
}
|
||||
|
||||
elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero'
|
||||
|
@ -224,17 +222,6 @@ namespace {
|
|||
} // namespace
|
||||
|
||||
|
||||
void UCI::init_nnue(const std::string& evalFile)
|
||||
{
|
||||
if (Options["Use NNUE"] && !UCI::load_eval_finished)
|
||||
{
|
||||
// Load evaluation function from a file
|
||||
Eval::NNUE::load_eval(evalFile);
|
||||
UCI::load_eval_finished = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// UCI::loop() waits for a command from stdin, parses it and calls the appropriate
|
||||
/// function. Also intercepts EOF from stdin to ensure gracefully exiting if the
|
||||
/// GUI dies unexpectedly. When called with some command line arguments, e.g. to
|
||||
|
@ -249,9 +236,6 @@ void UCI::loop(int argc, char* argv[]) {
|
|||
|
||||
pos.set(StartFEN, false, &states->back(), Threads.main());
|
||||
|
||||
if (argc > 1)
|
||||
init_nnue(Options["EvalFile"]);
|
||||
|
||||
for (int i = 1; i < argc; ++i)
|
||||
cmd += std::string(argv[i]) + " ";
|
||||
|
||||
|
@ -283,16 +267,8 @@ void UCI::loop(int argc, char* argv[]) {
|
|||
else if (token == "setoption") setoption(is);
|
||||
else if (token == "go") go(pos, is, states);
|
||||
else if (token == "position") position(pos, is, states);
|
||||
else if (token == "ucinewgame")
|
||||
{
|
||||
init_nnue(Options["EvalFile"]);
|
||||
Search::clear();
|
||||
}
|
||||
else if (token == "isready")
|
||||
{
|
||||
init_nnue(Options["EvalFile"]);
|
||||
sync_cout << "readyok" << sync_endl;
|
||||
}
|
||||
else if (token == "ucinewgame") Search::clear();
|
||||
else if (token == "isready") sync_cout << "readyok" << sync_endl;
|
||||
|
||||
// Additional custom non-UCI commands, mainly for debugging.
|
||||
// Do not use these commands during a search!
|
||||
|
|
|
@ -76,10 +76,6 @@ std::string pv(const Position& pos, Depth depth, Value alpha, Value beta);
|
|||
std::string wdl(Value v, int ply);
|
||||
Move to_move(const Position& pos, std::string& str);
|
||||
|
||||
void init_nnue(const std::string& evalFile);
|
||||
|
||||
extern bool load_eval_finished;
|
||||
|
||||
} // namespace UCI
|
||||
|
||||
extern UCI::OptionsMap Options;
|
||||
|
|
|
@ -42,23 +42,8 @@ void on_hash_size(const Option& o) { TT.resize(size_t(o)); }
|
|||
void on_logger(const Option& o) { start_logger(o); }
|
||||
void on_threads(const Option& o) { Threads.set(size_t(o)); }
|
||||
void on_tb_path(const Option& o) { Tablebases::init(o); }
|
||||
|
||||
void on_use_nnue(const Option& o) {
|
||||
|
||||
if (o)
|
||||
std::cout << "info string NNUE eval used" << std::endl;
|
||||
else
|
||||
std::cout << "info string classic eval used" << std::endl;
|
||||
|
||||
Eval::useNNUE = o;
|
||||
init_nnue(Options["EvalFile"]);
|
||||
}
|
||||
|
||||
void on_eval_file(const Option& o) {
|
||||
|
||||
load_eval_finished = false;
|
||||
init_nnue(o);
|
||||
}
|
||||
void on_use_NNUE(const Option& ) { Eval::init_NNUE(); }
|
||||
void on_eval_file(const Option& ) { Eval::init_NNUE(); }
|
||||
|
||||
/// Our case insensitive less() function as required by UCI protocol
|
||||
bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const {
|
||||
|
@ -95,7 +80,7 @@ void init(OptionsMap& o) {
|
|||
o["SyzygyProbeDepth"] << Option(1, 1, 100);
|
||||
o["Syzygy50MoveRule"] << Option(true);
|
||||
o["SyzygyProbeLimit"] << Option(7, 0, 7);
|
||||
o["Use NNUE"] << Option(false, on_use_nnue);
|
||||
o["Use NNUE"] << Option(false, on_use_NNUE);
|
||||
o["EvalFile"] << Option("nn-c157e0a5755b.nnue", on_eval_file);
|
||||
}
|
||||
|
||||
|
@ -205,5 +190,4 @@ Option& Option::operator=(const string& v) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
bool load_eval_finished = false;
|
||||
} // namespace UCI
|
||||
|
|
Loading…
Add table
Reference in a new issue