1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-05-01 01:03:09 +00:00

Rework loading the net.

This commit is contained in:
joergoster 2020-07-06 17:10:31 +02:00 committed by nodchip
parent 288fdc5597
commit a5af8510a5
6 changed files with 52 additions and 73 deletions

View file

@ -233,32 +233,28 @@ void prefetch_evalhash(const Key key) {
// Save and restore Options with bench command etc., so EvalDir is changed at this time, // Save and restore Options with bench command etc., so EvalDir is changed at this time,
// This function may be called twice to flag that the evaluation function needs to be reloaded. // This function may be called twice to flag that the evaluation function needs to be reloaded.
void load_eval() { void load_eval() {
if (Options["SkipLoadingEval"])
{
std::cout << "info string SkipLoadingEval set to true, Net not loaded!" << std::endl;
return;
}
NNUE::Initialize(); NNUE::Initialize();
if (!Options["SkipLoadingEval"])
{
const std::string dir_name = Options["EvalDir"]; const std::string dir_name = Options["EvalDir"];
const std::string file_name = Path::Combine(dir_name, NNUE::kFileName); const std::string file_name = Path::Combine(dir_name, NNUE::kFileName);
//{
// std::ofstream stream(file_name, std::ios::binary);
// NNUE::WriteParameters(stream);
//}
std::ifstream stream(file_name, std::ios::binary); std::ifstream stream(file_name, std::ios::binary);
const bool result = NNUE::ReadParameters(stream); const bool result = NNUE::ReadParameters(stream);
// ASSERT(result);
if (!result) if (!result)
{
// It's a problem if it doesn't finish when there is a read error. // It's a problem if it doesn't finish when there is a read error.
std::cout << "Error! " << NNUE::kFileName << " not found or wrong format" << std::endl; std::cout << "Error! " << NNUE::kFileName << " not found or wrong format" << std::endl;
//my_exit();
}
else else
std::cout << "info string NNUE " << NNUE::kFileName << " found & loaded" << std::endl; std::cout << "info string NNUE " << NNUE::kFileName << " found & loaded" << std::endl;
} }
else
std::cout << "info string NNUE " << NNUE::kFileName << " not loaded" << std::endl;
}
// Initialization // Initialization
void init() { void init() {

View file

@ -3092,7 +3092,7 @@ void learn(Position&, istringstream& is)
//} //}
if (use_convert_bin) if (use_convert_bin)
{ {
is_ready(true); init_nnue(true);
cout << "convert_bin.." << endl; cout << "convert_bin.." << endl;
convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval); convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval);
return; return;
@ -3100,7 +3100,7 @@ void learn(Position&, istringstream& is)
} }
if (use_convert_bin_from_pgn_extract) if (use_convert_bin_from_pgn_extract)
{ {
is_ready(true); init_nnue(true);
cout << "convert_bin_from_pgn-extract.." << endl; cout << "convert_bin_from_pgn-extract.." << endl;
convert_bin_from_pgn_extract(filenames, output_file_name, pgn_eval_side_to_move); convert_bin_from_pgn_extract(filenames, output_file_name, pgn_eval_side_to_move);
return; return;
@ -3166,7 +3166,7 @@ void learn(Position&, istringstream& is)
cout << "init.." << endl; cout << "init.." << endl;
// Read evaluation function parameters // Read evaluation function parameters
is_ready(true); init_nnue(true);
#if !defined(EVAL_NNUE) #if !defined(EVAL_NNUE)
cout << "init_grad.." << endl; cout << "init_grad.." << endl;

View file

@ -20,7 +20,7 @@ void MultiThink::go_think()
// Read evaluation function, etc. // Read evaluation function, etc.
// In the case of the learn command, the value of the evaluation function may be corrected after reading the evaluation function, so // In the case of the learn command, the value of the evaluation function may be corrected after reading the evaluation function, so
// Skip memory corruption check. // Skip memory corruption check.
is_ready(true); init_nnue(true);
// Call the derived class's init(). // Call the derived class's init().
init(); init();

View file

@ -73,7 +73,7 @@ namespace Learner
void test_cmd(Position& pos, istringstream& is) void test_cmd(Position& pos, istringstream& is)
{ {
// Initialize as it may be searched. // Initialize as it may be searched.
is_ready(); init_nnue();
std::string param; std::string param;
is >> param; is >> param;
@ -209,7 +209,14 @@ namespace {
} }
else if (token == "setoption") setoption(is); else if (token == "setoption") setoption(is);
else if (token == "position") position(pos, is, states); else if (token == "position") position(pos, is, states);
else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while else if (token == "ucinewgame")
{
#if defined(EVAL_NNUE)
init_nnue();
#endif
Search::clear();
elapsed = now(); // Search::clear() may take some while
}
} }
elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero'
@ -250,7 +257,7 @@ namespace {
// Make is_ready_cmd() callable from outside. (Because I want to call it from the bench command etc.) // Make is_ready_cmd() callable from outside. (Because I want to call it from the bench command etc.)
// Note that the phase is not initialized. // Note that the phase is not initialized.
void is_ready(bool skipCorruptCheck) void init_nnue(bool skipCorruptCheck)
{ {
#if defined(EVAL_NNUE) #if defined(EVAL_NNUE)
// After receiving "isready", modify so that a line feed is sent every 5 seconds until "readyok" is returned. (keep alive processing) // After receiving "isready", modify so that a line feed is sent every 5 seconds until "readyok" is returned. (keep alive processing)
@ -260,20 +267,6 @@ void is_ready(bool skipCorruptCheck)
// -Shogi GUI already does so, so MyShogi will follow along. // -Shogi GUI already does so, so MyShogi will follow along.
//-Also, the engine side of Yaneura King modifies it so that after "isready" is received, a line feed is sent every 5 seconds until "readyok" is returned. //-Also, the engine side of Yaneura King modifies it so that after "isready" is received, a line feed is sent every 5 seconds until "readyok" is returned.
auto ended = false;
auto th = std::thread([&ended] {
int count = 0;
while (!ended)
{
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (++count >= 50 /* 5 seconds */)
{
count = 0;
sync_cout << sync_endl; // Send a line break.
}
}
});
// Perform processing that may take time, such as reading the evaluation function, at this timing. // Perform processing that may take time, such as reading the evaluation function, at this timing.
// If you do a time-consuming process at startup, Shogi place will make a timeout judgment and retire the recognition as a thinking engine. // If you do a time-consuming process at startup, Shogi place will make a timeout judgment and retire the recognition as a thinking engine.
if (!UCI::load_eval_finished) if (!UCI::load_eval_finished)
@ -288,7 +281,6 @@ void is_ready(bool skipCorruptCheck)
Eval::print_softname(eval_sum); Eval::print_softname(eval_sum);
UCI::load_eval_finished = true; UCI::load_eval_finished = true;
} }
else else
{ {
@ -297,22 +289,7 @@ void is_ready(bool skipCorruptCheck)
if (!skipCorruptCheck && eval_sum != Eval::calc_check_sum()) if (!skipCorruptCheck && eval_sum != Eval::calc_check_sum())
sync_cout << "Error! : EVAL memory is corrupted" << sync_endl; sync_cout << "Error! : EVAL memory is corrupted" << sync_endl;
} }
// For isready, it is promised that the next command will not come until it returns readyok.
// Initialize various variables at this timing.
TT.resize(Options["Hash"]);
Search::clear();
Time.availableNodes = 0;
Threads.stop = false;
// Terminate the thread created to send keep alive and wait.
ended = true;
th.join();
#endif // defined(EVAL_NNUE) #endif // defined(EVAL_NNUE)
sync_cout << "readyok" << sync_endl;
} }
@ -399,8 +376,14 @@ void UCI::loop(int argc, char* argv[]) {
else if (token == "setoption") setoption(is); else if (token == "setoption") setoption(is);
else if (token == "go") go(pos, is, states); else if (token == "go") go(pos, is, states);
else if (token == "position") position(pos, is, states); else if (token == "position") position(pos, is, states);
else if (token == "ucinewgame") Search::clear(); else if (token == "ucinewgame")
else if (token == "isready") is_ready(); {
#if defined(EVAL_NNUE)
init_nnue();
#endif
Search::clear();
}
else if (token == "isready") sync_cout << "readyok" << sync_endl;
// Additional custom non-UCI commands, mainly for debugging. // Additional custom non-UCI commands, mainly for debugging.
// Do not use these commands during a search! // Do not use these commands during a search!

View file

@ -87,7 +87,7 @@ extern UCI::OptionsMap Options;
// If skipCorruptCheck == true, skip memory corruption check by check sum when reading the evaluation function a second time. // If skipCorruptCheck == true, skip memory corruption check by check sum when reading the evaluation function a second time.
// * This function is inconvenient if it is not available in Stockfish, so add it. // * This function is inconvenient if it is not available in Stockfish, so add it.
void is_ready(bool skipCorruptCheck = false); void init_nnue(bool skipCorruptCheck = false);
extern const char* StartFEN; extern const char* StartFEN;

View file

@ -42,7 +42,7 @@ void on_hash_size(const Option& o) { TT.resize(size_t(o)); }
void on_logger(const Option& o) { start_logger(o); } void on_logger(const Option& o) { start_logger(o); }
void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_threads(const Option& o) { Threads.set(size_t(o)); }
void on_tb_path(const Option& o) { Tablebases::init(o); } void on_tb_path(const Option& o) { Tablebases::init(o); }
void on_eval_dir(const Option& o) { load_eval_finished = false; } void on_eval_dir(const Option& o) { load_eval_finished = false; init_nnue(); }
/// Our case insensitive less() function as required by UCI protocol /// Our case insensitive less() function as required by UCI protocol