1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-05-01 01:03:09 +00:00

add convert_bin and option for draw positions

This commit is contained in:
rqs 2020-06-27 13:06:05 +09:00 committed by nodchip
parent 2af46deede
commit 0761d9504e
2 changed files with 53 additions and 26 deletions

View file

@ -165,8 +165,8 @@ typedef float LearnFloatType;
// 引き分けに至ったとき、それを教師局面として書き出す // 引き分けに至ったとき、それを教師局面として書き出す
// これをするほうが良いかどうかは微妙。 // これをするほうが良いかどうかは微妙。
// #define LEARN_GENSFEN_USE_DRAW_RESULT // #define LEARN_GENSFEN_USE_DRAW_RESULT
extern bool use_draw_in_training;
extern bool use_hash_in_training;
// ====================== // ======================
// configure // configure
// ====================== // ======================

View file

@ -84,6 +84,10 @@
#include <shared_mutex> #include <shared_mutex>
#endif #endif
bool use_draw_in_training=false;
bool use_draw_in_validation=false;
bool use_hash_in_training=true;
using namespace std; using namespace std;
//// これは探索部で定義されているものとする。 //// これは探索部で定義されているものとする。
@ -1248,11 +1252,8 @@ struct SfenReader
{ {
if (eval_limit < abs(p.score)) if (eval_limit < abs(p.score))
continue; continue;
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT) if (!use_draw_in_validation && p.game_result == 0)
if (p.game_result == 0)
continue; continue;
#endif
sfen_for_mse.push_back(p); sfen_for_mse.push_back(p);
} else { } else {
break; break;
@ -1934,10 +1935,10 @@ void LearnerThink::thread_worker(size_t thread_id)
if (eval_limit < abs(ps.score)) if (eval_limit < abs(ps.score))
goto RetryRead; goto RetryRead;
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT)
if (ps.game_result == 0) if (!use_draw_in_training && ps.game_result == 0)
goto RetryRead; goto RetryRead;
#endif
// 序盤局面に関する読み飛ばし // 序盤局面に関する読み飛ばし
if (ps.gamePly < prng.rand(reduction_gameply)) if (ps.gamePly < prng.rand(reduction_gameply))
@ -1961,13 +1962,13 @@ void LearnerThink::thread_worker(size_t thread_id)
{ {
auto key = pos.key(); auto key = pos.key();
// rmseの計算用に使っている局面なら除外する。 // rmseの計算用に使っている局面なら除外する。
if (sr.is_for_rmse(key)) if (sr.is_for_rmse(key) && use_hash_in_training)
goto RetryRead; goto RetryRead;
// 直近で用いた局面も除外する。 // 直近で用いた局面も除外する。
auto hash_index = size_t(key & (sr.READ_SFEN_HASH_SIZE - 1)); auto hash_index = size_t(key & (sr.READ_SFEN_HASH_SIZE - 1));
auto key2 = sr.hash[hash_index]; auto key2 = sr.hash[hash_index];
if (key == key2) if (key == key2 && use_hash_in_training)
goto RetryRead; goto RetryRead;
sr.hash[hash_index] = key; // 今回のkeyに入れ替えておく。 sr.hash[hash_index] = key; // 今回のkeyに入れ替えておく。
} }
@ -2416,30 +2417,36 @@ void shuffle_files_on_memory(const vector<string>& filenames,const string output
std::cout << "..shuffle_on_memory done." << std::endl; std::cout << "..shuffle_on_memory done." << std::endl;
} }
void convert_bin(const vector<string>& filenames , const string& output_file_name) void convert_bin(const vector<string>& filenames, const string& output_file_name, const int ply_minimum, const int ply_maximum, const int interpolate_eval)
{ {
std::fstream fs; std::fstream fs;
uint64_t data_size=0;
uint64_t filtered_size = 0;
auto th = Threads.main(); auto th = Threads.main();
auto &tpos = th->rootPos; auto &tpos = th->rootPos;
// plain形式の雑巾をやねうら王用のpackedsfenvalueに変換する // plain形式の雑巾をやねうら王用のpackedsfenvalueに変換する
fs.open(output_file_name, ios::app | ios::binary); fs.open(output_file_name, ios::app | ios::binary);
StateListPtr states;
for (auto filename : filenames) { for (auto filename : filenames) {
std::cout << "convert " << filename << " ... "; std::cout << "convert " << filename << " ... ";
std::string line; std::string line;
ifstream ifs; ifstream ifs;
ifs.open(filename); ifs.open(filename);
PackedSfenValue p; PackedSfenValue p;
data_size = 0;
filtered_size = 0;
p.gamePly = 1; // apery形式では含まれない。一応初期化するべし p.gamePly = 1; // apery形式では含まれない。一応初期化するべし
bool ignore_flag = false;
while (std::getline(ifs, line)) { while (std::getline(ifs, line)) {
std::stringstream ss(line); std::stringstream ss(line);
std::string token; std::string token;
std::string value; std::string value;
ss >> token; ss >> token;
if (token == "sfen") { if (token == "fen") {
StateInfo si; states = StateListPtr(new std::deque<StateInfo>(1)); // Drop old and create a new one
tpos.set(line.substr(5), false, &si, Threads.main()); tpos.set(line.substr(4), false, &states->back(), Threads.main());
tpos.sfen_pack(p.sfen); tpos.sfen_pack(p.sfen);
} }
else if (token == "move") { else if (token == "move") {
ss >> value; ss >> value;
@ -2451,23 +2458,37 @@ void convert_bin(const vector<string>& filenames , const string& output_file_nam
else if (token == "ply") { else if (token == "ply") {
int temp; int temp;
ss >> temp; ss >> temp;
if(temp < ply_minimum || temp > ply_maximum){
ignore_flag = true;
}
p.gamePly = uint16_t(temp); // 此処のキャストいらない? p.gamePly = uint16_t(temp); // 此処のキャストいらない?
if (interpolate_eval != 0){
p.score = min(3000, interpolate_eval * temp);
}
} }
else if (token == "result") { else if (token == "result") {
int temp; int temp;
ss >> temp; ss >> temp;
p.game_result = int8_t(temp); // 此処のキャストいらない? p.game_result = int8_t(temp); // 此処のキャストいらない?
if (interpolate_eval){
p.score = p.score * p.game_result;
}
} }
else if (token == "e") { else if (token == "e") {
if(!ignore_flag){
fs.write((char*)&p, sizeof(PackedSfenValue)); fs.write((char*)&p, sizeof(PackedSfenValue));
data_size+=1;
// debug // debug
/* // std::cout<<tpos<<std::endl;
std::cout<<tpos<<std::endl; // std::cout<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
std::cout<<to_usi_string(Move(p.move))<<","<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl; }else{
*/ ignore_flag = false;
filtered_size += 1;
}
} }
} }
std::cout << "done" << std::endl; std::cout << "done" << data_size <<" parsed " << filtered_size<<" is filtered"<< std::endl;
ifs.close(); ifs.close();
} }
std::cout << "all done" << std::endl; std::cout << "all done" << std::endl;
@ -2557,6 +2578,9 @@ void learn(Position&, istringstream& is)
bool use_convert_plain = false; bool use_convert_plain = false;
// plain形式の教師をやねうら王のbinに変換する // plain形式の教師をやねうら王のbinに変換する
bool use_convert_bin = false; bool use_convert_bin = false;
int ply_minimum = 0;
int ply_maximum = 114514;
bool interpolate_eval = 0;
// それらのときに書き出すファイル名(デフォルトでは"shuffled_sfen.bin") // それらのときに書き出すファイル名(デフォルトでは"shuffled_sfen.bin")
string output_file_name = "shuffled_sfen.bin"; string output_file_name = "shuffled_sfen.bin";
@ -2636,7 +2660,9 @@ void learn(Position&, istringstream& is)
else if (option == "eta3") is >> eta3; else if (option == "eta3") is >> eta3;
else if (option == "eta1_epoch") is >> eta1_epoch; else if (option == "eta1_epoch") is >> eta1_epoch;
else if (option == "eta2_epoch") is >> eta2_epoch; else if (option == "eta2_epoch") is >> eta2_epoch;
else if (option == "use_draw_in_training") is >> use_draw_in_training;
else if (option == "use_draw_in_validation") is >> use_draw_in_validation;
else if (option == "use_hash_in_training") is >> use_hash_in_training;
// 割引率 // 割引率
else if (option == "discount_rate") is >> discount_rate; else if (option == "discount_rate") is >> discount_rate;
@ -2687,6 +2713,7 @@ void learn(Position&, istringstream& is)
// 雑巾のconvert関連 // 雑巾のconvert関連
else if (option == "convert_plain") use_convert_plain = true; else if (option == "convert_plain") use_convert_plain = true;
else if (option == "convert_bin") use_convert_bin = true; else if (option == "convert_bin") use_convert_bin = true;
else if (option == "interpolate_eval") is >> interpolate_eval;
// さもなくば、それはファイル名である。 // さもなくば、それはファイル名である。
else else
filenames.push_back(option); filenames.push_back(option);
@ -2796,7 +2823,7 @@ void learn(Position&, istringstream& is)
{ {
is_ready(true); is_ready(true);
cout << "convert_bin.." << endl; cout << "convert_bin.." << endl;
convert_bin(filenames,output_file_name); convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval);
return; return;
} }