1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-04-30 08:43:09 +00:00

add convert_bin and option for draw positions

This commit is contained in:
rqs 2020-06-27 13:06:05 +09:00 committed by nodchip
parent 2af46deede
commit 0761d9504e
2 changed files with 53 additions and 26 deletions

View file

@ -165,8 +165,8 @@ typedef float LearnFloatType;
// 引き分けに至ったとき、それを教師局面として書き出す
// これをするほうが良いかどうかは微妙。
// #define LEARN_GENSFEN_USE_DRAW_RESULT
extern bool use_draw_in_training;
extern bool use_hash_in_training;
// ======================
// configure
// ======================
@ -234,4 +234,4 @@ namespace Learner
#endif
#endif // ifndef _LEARN_H_
#endif // ifndef _LEARN_H_

View file

@ -84,6 +84,10 @@
#include <shared_mutex>
#endif
bool use_draw_in_training=false;
bool use_draw_in_validation=false;
bool use_hash_in_training=true;
using namespace std;
//// これは探索部で定義されているものとする。
@ -1248,11 +1252,8 @@ struct SfenReader
{
if (eval_limit < abs(p.score))
continue;
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT)
if (p.game_result == 0)
if (!use_draw_in_validation && p.game_result == 0)
continue;
#endif
sfen_for_mse.push_back(p);
} else {
break;
@ -1934,10 +1935,10 @@ void LearnerThink::thread_worker(size_t thread_id)
if (eval_limit < abs(ps.score))
goto RetryRead;
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT)
if (ps.game_result == 0)
if (!use_draw_in_training && ps.game_result == 0)
goto RetryRead;
#endif
// 序盤局面に関する読み飛ばし
if (ps.gamePly < prng.rand(reduction_gameply))
@ -1961,13 +1962,13 @@ void LearnerThink::thread_worker(size_t thread_id)
{
auto key = pos.key();
// rmseの計算用に使っている局面なら除外する。
if (sr.is_for_rmse(key))
if (sr.is_for_rmse(key) && use_hash_in_training)
goto RetryRead;
// 直近で用いた局面も除外する。
auto hash_index = size_t(key & (sr.READ_SFEN_HASH_SIZE - 1));
auto key2 = sr.hash[hash_index];
if (key == key2)
if (key == key2 && use_hash_in_training)
goto RetryRead;
sr.hash[hash_index] = key; // 今回のkeyに入れ替えておく。
}
@ -2416,30 +2417,36 @@ void shuffle_files_on_memory(const vector<string>& filenames,const string output
std::cout << "..shuffle_on_memory done." << std::endl;
}
void convert_bin(const vector<string>& filenames , const string& output_file_name)
void convert_bin(const vector<string>& filenames, const string& output_file_name, const int ply_minimum, const int ply_maximum, const int interpolate_eval)
{
std::fstream fs;
uint64_t data_size=0;
uint64_t filtered_size = 0;
auto th = Threads.main();
auto &tpos = th->rootPos;
// plain形式の雑巾をやねうら王用のpackedsfenvalueに変換する
fs.open(output_file_name, ios::app | ios::binary);
StateListPtr states;
for (auto filename : filenames) {
std::cout << "convert " << filename << " ... ";
std::string line;
ifstream ifs;
ifs.open(filename);
PackedSfenValue p;
data_size = 0;
filtered_size = 0;
p.gamePly = 1; // apery形式では含まれない。一応初期化するべし
bool ignore_flag = false;
while (std::getline(ifs, line)) {
std::stringstream ss(line);
std::string token;
std::string value;
ss >> token;
if (token == "sfen") {
StateInfo si;
tpos.set(line.substr(5), false, &si, Threads.main());
tpos.sfen_pack(p.sfen);
if (token == "fen") {
states = StateListPtr(new std::deque<StateInfo>(1)); // Drop old and create a new one
tpos.set(line.substr(4), false, &states->back(), Threads.main());
tpos.sfen_pack(p.sfen);
}
else if (token == "move") {
ss >> value;
@ -2451,23 +2458,37 @@ void convert_bin(const vector<string>& filenames , const string& output_file_nam
else if (token == "ply") {
int temp;
ss >> temp;
if(temp < ply_minimum || temp > ply_maximum){
ignore_flag = true;
}
p.gamePly = uint16_t(temp); // 此処のキャストいらない?
if (interpolate_eval != 0){
p.score = min(3000, interpolate_eval * temp);
}
}
else if (token == "result") {
int temp;
ss >> temp;
p.game_result = int8_t(temp); // 此処のキャストいらない?
if (interpolate_eval){
p.score = p.score * p.game_result;
}
}
else if (token == "e") {
if(!ignore_flag){
fs.write((char*)&p, sizeof(PackedSfenValue));
data_size+=1;
// debug
/*
std::cout<<tpos<<std::endl;
std::cout<<to_usi_string(Move(p.move))<<","<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
*/
// std::cout<<tpos<<std::endl;
// std::cout<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
}else{
ignore_flag = false;
filtered_size += 1;
}
}
}
std::cout << "done" << std::endl;
std::cout << "done" << data_size <<" parsed " << filtered_size<<" is filtered"<< std::endl;
ifs.close();
}
std::cout << "all done" << std::endl;
@ -2557,6 +2578,9 @@ void learn(Position&, istringstream& is)
bool use_convert_plain = false;
// plain形式の教師をやねうら王のbinに変換する
bool use_convert_bin = false;
int ply_minimum = 0;
int ply_maximum = 114514;
bool interpolate_eval = 0;
// それらのときに書き出すファイル名(デフォルトでは"shuffled_sfen.bin")
string output_file_name = "shuffled_sfen.bin";
@ -2636,7 +2660,9 @@ void learn(Position&, istringstream& is)
else if (option == "eta3") is >> eta3;
else if (option == "eta1_epoch") is >> eta1_epoch;
else if (option == "eta2_epoch") is >> eta2_epoch;
else if (option == "use_draw_in_training") is >> use_draw_in_training;
else if (option == "use_draw_in_validation") is >> use_draw_in_validation;
else if (option == "use_hash_in_training") is >> use_hash_in_training;
// 割引率
else if (option == "discount_rate") is >> discount_rate;
@ -2672,7 +2698,7 @@ void learn(Position&, istringstream& is)
else if (option == "eval_limit") is >> eval_limit;
else if (option == "save_only_once") save_only_once = true;
else if (option == "no_shuffle") no_shuffle = true;
#if defined(EVAL_NNUE)
else if (option == "nn_batch_size") is >> nn_batch_size;
else if (option == "newbob_decay") is >> newbob_decay;
@ -2687,6 +2713,7 @@ void learn(Position&, istringstream& is)
// 雑巾のconvert関連
else if (option == "convert_plain") use_convert_plain = true;
else if (option == "convert_bin") use_convert_bin = true;
else if (option == "interpolate_eval") is >> interpolate_eval;
// さもなくば、それはファイル名である。
else
filenames.push_back(option);
@ -2796,7 +2823,7 @@ void learn(Position&, istringstream& is)
{
is_ready(true);
cout << "convert_bin.." << endl;
convert_bin(filenames,output_file_name);
convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval);
return;
}