mirror of
https://github.com/sockspls/badfish
synced 2025-04-30 08:43:09 +00:00
add convert_bin and option for draw positions
This commit is contained in:
parent
2af46deede
commit
0761d9504e
2 changed files with 53 additions and 26 deletions
|
@ -165,8 +165,8 @@ typedef float LearnFloatType;
|
|||
// 引き分けに至ったとき、それを教師局面として書き出す
|
||||
// これをするほうが良いかどうかは微妙。
|
||||
// #define LEARN_GENSFEN_USE_DRAW_RESULT
|
||||
|
||||
|
||||
extern bool use_draw_in_training;
|
||||
extern bool use_hash_in_training;
|
||||
// ======================
|
||||
// configure
|
||||
// ======================
|
||||
|
@ -234,4 +234,4 @@ namespace Learner
|
|||
|
||||
#endif
|
||||
|
||||
#endif // ifndef _LEARN_H_
|
||||
#endif // ifndef _LEARN_H_
|
||||
|
|
|
@ -84,6 +84,10 @@
|
|||
#include <shared_mutex>
|
||||
#endif
|
||||
|
||||
bool use_draw_in_training=false;
|
||||
bool use_draw_in_validation=false;
|
||||
bool use_hash_in_training=true;
|
||||
|
||||
using namespace std;
|
||||
|
||||
//// これは探索部で定義されているものとする。
|
||||
|
@ -1248,11 +1252,8 @@ struct SfenReader
|
|||
{
|
||||
if (eval_limit < abs(p.score))
|
||||
continue;
|
||||
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT)
|
||||
if (p.game_result == 0)
|
||||
if (!use_draw_in_validation && p.game_result == 0)
|
||||
continue;
|
||||
#endif
|
||||
|
||||
sfen_for_mse.push_back(p);
|
||||
} else {
|
||||
break;
|
||||
|
@ -1934,10 +1935,10 @@ void LearnerThink::thread_worker(size_t thread_id)
|
|||
if (eval_limit < abs(ps.score))
|
||||
goto RetryRead;
|
||||
|
||||
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT)
|
||||
if (ps.game_result == 0)
|
||||
|
||||
if (!use_draw_in_training && ps.game_result == 0)
|
||||
goto RetryRead;
|
||||
#endif
|
||||
|
||||
|
||||
// 序盤局面に関する読み飛ばし
|
||||
if (ps.gamePly < prng.rand(reduction_gameply))
|
||||
|
@ -1961,13 +1962,13 @@ void LearnerThink::thread_worker(size_t thread_id)
|
|||
{
|
||||
auto key = pos.key();
|
||||
// rmseの計算用に使っている局面なら除外する。
|
||||
if (sr.is_for_rmse(key))
|
||||
if (sr.is_for_rmse(key) && use_hash_in_training)
|
||||
goto RetryRead;
|
||||
|
||||
// 直近で用いた局面も除外する。
|
||||
auto hash_index = size_t(key & (sr.READ_SFEN_HASH_SIZE - 1));
|
||||
auto key2 = sr.hash[hash_index];
|
||||
if (key == key2)
|
||||
if (key == key2 && use_hash_in_training)
|
||||
goto RetryRead;
|
||||
sr.hash[hash_index] = key; // 今回のkeyに入れ替えておく。
|
||||
}
|
||||
|
@ -2416,30 +2417,36 @@ void shuffle_files_on_memory(const vector<string>& filenames,const string output
|
|||
std::cout << "..shuffle_on_memory done." << std::endl;
|
||||
}
|
||||
|
||||
void convert_bin(const vector<string>& filenames , const string& output_file_name)
|
||||
void convert_bin(const vector<string>& filenames, const string& output_file_name, const int ply_minimum, const int ply_maximum, const int interpolate_eval)
|
||||
{
|
||||
std::fstream fs;
|
||||
uint64_t data_size=0;
|
||||
uint64_t filtered_size = 0;
|
||||
auto th = Threads.main();
|
||||
auto &tpos = th->rootPos;
|
||||
// plain形式の雑巾をやねうら王用のpackedsfenvalueに変換する
|
||||
fs.open(output_file_name, ios::app | ios::binary);
|
||||
|
||||
StateListPtr states;
|
||||
for (auto filename : filenames) {
|
||||
std::cout << "convert " << filename << " ... ";
|
||||
std::string line;
|
||||
ifstream ifs;
|
||||
ifs.open(filename);
|
||||
PackedSfenValue p;
|
||||
data_size = 0;
|
||||
filtered_size = 0;
|
||||
p.gamePly = 1; // apery形式では含まれない。一応初期化するべし
|
||||
bool ignore_flag = false;
|
||||
|
||||
while (std::getline(ifs, line)) {
|
||||
std::stringstream ss(line);
|
||||
std::string token;
|
||||
std::string value;
|
||||
ss >> token;
|
||||
if (token == "sfen") {
|
||||
StateInfo si;
|
||||
tpos.set(line.substr(5), false, &si, Threads.main());
|
||||
tpos.sfen_pack(p.sfen);
|
||||
if (token == "fen") {
|
||||
states = StateListPtr(new std::deque<StateInfo>(1)); // Drop old and create a new one
|
||||
tpos.set(line.substr(4), false, &states->back(), Threads.main());
|
||||
tpos.sfen_pack(p.sfen);
|
||||
}
|
||||
else if (token == "move") {
|
||||
ss >> value;
|
||||
|
@ -2451,23 +2458,37 @@ void convert_bin(const vector<string>& filenames , const string& output_file_nam
|
|||
else if (token == "ply") {
|
||||
int temp;
|
||||
ss >> temp;
|
||||
if(temp < ply_minimum || temp > ply_maximum){
|
||||
ignore_flag = true;
|
||||
}
|
||||
p.gamePly = uint16_t(temp); // 此処のキャストいらない?
|
||||
if (interpolate_eval != 0){
|
||||
p.score = min(3000, interpolate_eval * temp);
|
||||
}
|
||||
}
|
||||
else if (token == "result") {
|
||||
int temp;
|
||||
ss >> temp;
|
||||
p.game_result = int8_t(temp); // 此処のキャストいらない?
|
||||
if (interpolate_eval){
|
||||
p.score = p.score * p.game_result;
|
||||
}
|
||||
}
|
||||
else if (token == "e") {
|
||||
if(!ignore_flag){
|
||||
fs.write((char*)&p, sizeof(PackedSfenValue));
|
||||
data_size+=1;
|
||||
// debug
|
||||
/*
|
||||
std::cout<<tpos<<std::endl;
|
||||
std::cout<<to_usi_string(Move(p.move))<<","<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
|
||||
*/
|
||||
// std::cout<<tpos<<std::endl;
|
||||
// std::cout<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
|
||||
}else{
|
||||
ignore_flag = false;
|
||||
filtered_size += 1;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
std::cout << "done" << std::endl;
|
||||
std::cout << "done" << data_size <<" parsed " << filtered_size<<" is filtered"<< std::endl;
|
||||
ifs.close();
|
||||
}
|
||||
std::cout << "all done" << std::endl;
|
||||
|
@ -2557,6 +2578,9 @@ void learn(Position&, istringstream& is)
|
|||
bool use_convert_plain = false;
|
||||
// plain形式の教師をやねうら王のbinに変換する
|
||||
bool use_convert_bin = false;
|
||||
int ply_minimum = 0;
|
||||
int ply_maximum = 114514;
|
||||
bool interpolate_eval = 0;
|
||||
// それらのときに書き出すファイル名(デフォルトでは"shuffled_sfen.bin")
|
||||
string output_file_name = "shuffled_sfen.bin";
|
||||
|
||||
|
@ -2636,7 +2660,9 @@ void learn(Position&, istringstream& is)
|
|||
else if (option == "eta3") is >> eta3;
|
||||
else if (option == "eta1_epoch") is >> eta1_epoch;
|
||||
else if (option == "eta2_epoch") is >> eta2_epoch;
|
||||
|
||||
else if (option == "use_draw_in_training") is >> use_draw_in_training;
|
||||
else if (option == "use_draw_in_validation") is >> use_draw_in_validation;
|
||||
else if (option == "use_hash_in_training") is >> use_hash_in_training;
|
||||
// 割引率
|
||||
else if (option == "discount_rate") is >> discount_rate;
|
||||
|
||||
|
@ -2672,7 +2698,7 @@ void learn(Position&, istringstream& is)
|
|||
else if (option == "eval_limit") is >> eval_limit;
|
||||
else if (option == "save_only_once") save_only_once = true;
|
||||
else if (option == "no_shuffle") no_shuffle = true;
|
||||
|
||||
|
||||
#if defined(EVAL_NNUE)
|
||||
else if (option == "nn_batch_size") is >> nn_batch_size;
|
||||
else if (option == "newbob_decay") is >> newbob_decay;
|
||||
|
@ -2687,6 +2713,7 @@ void learn(Position&, istringstream& is)
|
|||
// 雑巾のconvert関連
|
||||
else if (option == "convert_plain") use_convert_plain = true;
|
||||
else if (option == "convert_bin") use_convert_bin = true;
|
||||
else if (option == "interpolate_eval") is >> interpolate_eval;
|
||||
// さもなくば、それはファイル名である。
|
||||
else
|
||||
filenames.push_back(option);
|
||||
|
@ -2796,7 +2823,7 @@ void learn(Position&, istringstream& is)
|
|||
{
|
||||
is_ready(true);
|
||||
cout << "convert_bin.." << endl;
|
||||
convert_bin(filenames,output_file_name);
|
||||
convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval);
|
||||
return;
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue