diff --git a/src/learn/learn.h b/src/learn/learn.h index 58a017bd..8e3172d3 100644 --- a/src/learn/learn.h +++ b/src/learn/learn.h @@ -165,8 +165,8 @@ typedef float LearnFloatType; // 引き分けに至ったとき、それを教師局面として書き出す // これをするほうが良いかどうかは微妙。 // #define LEARN_GENSFEN_USE_DRAW_RESULT - - +extern bool use_draw_in_training; +extern bool use_hash_in_training; // ====================== // configure // ====================== @@ -234,4 +234,4 @@ namespace Learner #endif -#endif // ifndef _LEARN_H_ \ No newline at end of file +#endif // ifndef _LEARN_H_ diff --git a/src/learn/learner.cpp b/src/learn/learner.cpp index 221a561e..09af98d3 100644 --- a/src/learn/learner.cpp +++ b/src/learn/learner.cpp @@ -84,6 +84,10 @@ #include #endif +bool use_draw_in_training=false; +bool use_draw_in_validation=false; +bool use_hash_in_training=true; + using namespace std; //// これは探索部で定義されているものとする。 @@ -1248,11 +1252,8 @@ struct SfenReader { if (eval_limit < abs(p.score)) continue; -#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT) - if (p.game_result == 0) + if (!use_draw_in_validation && p.game_result == 0) continue; -#endif - sfen_for_mse.push_back(p); } else { break; @@ -1934,10 +1935,10 @@ void LearnerThink::thread_worker(size_t thread_id) if (eval_limit < abs(ps.score)) goto RetryRead; -#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT) - if (ps.game_result == 0) + + if (!use_draw_in_training && ps.game_result == 0) goto RetryRead; -#endif + // 序盤局面に関する読み飛ばし if (ps.gamePly < prng.rand(reduction_gameply)) @@ -1961,13 +1962,13 @@ void LearnerThink::thread_worker(size_t thread_id) { auto key = pos.key(); // rmseの計算用に使っている局面なら除外する。 - if (sr.is_for_rmse(key)) + if (sr.is_for_rmse(key) && use_hash_in_training) goto RetryRead; // 直近で用いた局面も除外する。 auto hash_index = size_t(key & (sr.READ_SFEN_HASH_SIZE - 1)); auto key2 = sr.hash[hash_index]; - if (key == key2) + if (key == key2 && use_hash_in_training) goto RetryRead; sr.hash[hash_index] = key; // 今回のkeyに入れ替えておく。 } @@ -2416,30 +2417,36 @@ void shuffle_files_on_memory(const vector& filenames,const string output std::cout << "..shuffle_on_memory done." << std::endl; } -void convert_bin(const vector& filenames , const string& output_file_name) +void convert_bin(const vector& filenames, const string& output_file_name, const int ply_minimum, const int ply_maximum, const int interpolate_eval) { std::fstream fs; + uint64_t data_size=0; + uint64_t filtered_size = 0; auto th = Threads.main(); auto &tpos = th->rootPos; // plain形式の雑巾をやねうら王用のpackedsfenvalueに変換する fs.open(output_file_name, ios::app | ios::binary); - + StateListPtr states; for (auto filename : filenames) { std::cout << "convert " << filename << " ... "; std::string line; ifstream ifs; ifs.open(filename); PackedSfenValue p; + data_size = 0; + filtered_size = 0; p.gamePly = 1; // apery形式では含まれない。一応初期化するべし + bool ignore_flag = false; + while (std::getline(ifs, line)) { std::stringstream ss(line); std::string token; std::string value; ss >> token; - if (token == "sfen") { - StateInfo si; - tpos.set(line.substr(5), false, &si, Threads.main()); - tpos.sfen_pack(p.sfen); + if (token == "fen") { + states = StateListPtr(new std::deque(1)); // Drop old and create a new one + tpos.set(line.substr(4), false, &states->back(), Threads.main()); + tpos.sfen_pack(p.sfen); } else if (token == "move") { ss >> value; @@ -2451,23 +2458,37 @@ void convert_bin(const vector& filenames , const string& output_file_nam else if (token == "ply") { int temp; ss >> temp; + if(temp < ply_minimum || temp > ply_maximum){ + ignore_flag = true; + } p.gamePly = uint16_t(temp); // 此処のキャストいらない? + if (interpolate_eval != 0){ + p.score = min(3000, interpolate_eval * temp); + } } else if (token == "result") { int temp; ss >> temp; p.game_result = int8_t(temp); // 此処のキャストいらない? + if (interpolate_eval){ + p.score = p.score * p.game_result; + } } else if (token == "e") { + if(!ignore_flag){ fs.write((char*)&p, sizeof(PackedSfenValue)); + data_size+=1; // debug - /* - std::cout<> eta3; else if (option == "eta1_epoch") is >> eta1_epoch; else if (option == "eta2_epoch") is >> eta2_epoch; - + else if (option == "use_draw_in_training") is >> use_draw_in_training; + else if (option == "use_draw_in_validation") is >> use_draw_in_validation; + else if (option == "use_hash_in_training") is >> use_hash_in_training; // 割引率 else if (option == "discount_rate") is >> discount_rate; @@ -2672,7 +2698,7 @@ void learn(Position&, istringstream& is) else if (option == "eval_limit") is >> eval_limit; else if (option == "save_only_once") save_only_once = true; else if (option == "no_shuffle") no_shuffle = true; - + #if defined(EVAL_NNUE) else if (option == "nn_batch_size") is >> nn_batch_size; else if (option == "newbob_decay") is >> newbob_decay; @@ -2687,6 +2713,7 @@ void learn(Position&, istringstream& is) // 雑巾のconvert関連 else if (option == "convert_plain") use_convert_plain = true; else if (option == "convert_bin") use_convert_bin = true; + else if (option == "interpolate_eval") is >> interpolate_eval; // さもなくば、それはファイル名である。 else filenames.push_back(option); @@ -2796,7 +2823,7 @@ void learn(Position&, istringstream& is) { is_ready(true); cout << "convert_bin.." << endl; - convert_bin(filenames,output_file_name); + convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval); return; }