mirror of
https://github.com/sockspls/badfish
synced 2025-05-01 01:03:09 +00:00
add convert_bin and option for draw positions
This commit is contained in:
parent
2af46deede
commit
0761d9504e
2 changed files with 53 additions and 26 deletions
|
@ -165,8 +165,8 @@ typedef float LearnFloatType;
|
||||||
// 引き分けに至ったとき、それを教師局面として書き出す
|
// 引き分けに至ったとき、それを教師局面として書き出す
|
||||||
// これをするほうが良いかどうかは微妙。
|
// これをするほうが良いかどうかは微妙。
|
||||||
// #define LEARN_GENSFEN_USE_DRAW_RESULT
|
// #define LEARN_GENSFEN_USE_DRAW_RESULT
|
||||||
|
extern bool use_draw_in_training;
|
||||||
|
extern bool use_hash_in_training;
|
||||||
// ======================
|
// ======================
|
||||||
// configure
|
// configure
|
||||||
// ======================
|
// ======================
|
||||||
|
|
|
@ -84,6 +84,10 @@
|
||||||
#include <shared_mutex>
|
#include <shared_mutex>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
bool use_draw_in_training=false;
|
||||||
|
bool use_draw_in_validation=false;
|
||||||
|
bool use_hash_in_training=true;
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
//// これは探索部で定義されているものとする。
|
//// これは探索部で定義されているものとする。
|
||||||
|
@ -1248,11 +1252,8 @@ struct SfenReader
|
||||||
{
|
{
|
||||||
if (eval_limit < abs(p.score))
|
if (eval_limit < abs(p.score))
|
||||||
continue;
|
continue;
|
||||||
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT)
|
if (!use_draw_in_validation && p.game_result == 0)
|
||||||
if (p.game_result == 0)
|
|
||||||
continue;
|
continue;
|
||||||
#endif
|
|
||||||
|
|
||||||
sfen_for_mse.push_back(p);
|
sfen_for_mse.push_back(p);
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
|
@ -1934,10 +1935,10 @@ void LearnerThink::thread_worker(size_t thread_id)
|
||||||
if (eval_limit < abs(ps.score))
|
if (eval_limit < abs(ps.score))
|
||||||
goto RetryRead;
|
goto RetryRead;
|
||||||
|
|
||||||
#if !defined (LEARN_GENSFEN_USE_DRAW_RESULT)
|
|
||||||
if (ps.game_result == 0)
|
if (!use_draw_in_training && ps.game_result == 0)
|
||||||
goto RetryRead;
|
goto RetryRead;
|
||||||
#endif
|
|
||||||
|
|
||||||
// 序盤局面に関する読み飛ばし
|
// 序盤局面に関する読み飛ばし
|
||||||
if (ps.gamePly < prng.rand(reduction_gameply))
|
if (ps.gamePly < prng.rand(reduction_gameply))
|
||||||
|
@ -1961,13 +1962,13 @@ void LearnerThink::thread_worker(size_t thread_id)
|
||||||
{
|
{
|
||||||
auto key = pos.key();
|
auto key = pos.key();
|
||||||
// rmseの計算用に使っている局面なら除外する。
|
// rmseの計算用に使っている局面なら除外する。
|
||||||
if (sr.is_for_rmse(key))
|
if (sr.is_for_rmse(key) && use_hash_in_training)
|
||||||
goto RetryRead;
|
goto RetryRead;
|
||||||
|
|
||||||
// 直近で用いた局面も除外する。
|
// 直近で用いた局面も除外する。
|
||||||
auto hash_index = size_t(key & (sr.READ_SFEN_HASH_SIZE - 1));
|
auto hash_index = size_t(key & (sr.READ_SFEN_HASH_SIZE - 1));
|
||||||
auto key2 = sr.hash[hash_index];
|
auto key2 = sr.hash[hash_index];
|
||||||
if (key == key2)
|
if (key == key2 && use_hash_in_training)
|
||||||
goto RetryRead;
|
goto RetryRead;
|
||||||
sr.hash[hash_index] = key; // 今回のkeyに入れ替えておく。
|
sr.hash[hash_index] = key; // 今回のkeyに入れ替えておく。
|
||||||
}
|
}
|
||||||
|
@ -2416,30 +2417,36 @@ void shuffle_files_on_memory(const vector<string>& filenames,const string output
|
||||||
std::cout << "..shuffle_on_memory done." << std::endl;
|
std::cout << "..shuffle_on_memory done." << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void convert_bin(const vector<string>& filenames , const string& output_file_name)
|
void convert_bin(const vector<string>& filenames, const string& output_file_name, const int ply_minimum, const int ply_maximum, const int interpolate_eval)
|
||||||
{
|
{
|
||||||
std::fstream fs;
|
std::fstream fs;
|
||||||
|
uint64_t data_size=0;
|
||||||
|
uint64_t filtered_size = 0;
|
||||||
auto th = Threads.main();
|
auto th = Threads.main();
|
||||||
auto &tpos = th->rootPos;
|
auto &tpos = th->rootPos;
|
||||||
// plain形式の雑巾をやねうら王用のpackedsfenvalueに変換する
|
// plain形式の雑巾をやねうら王用のpackedsfenvalueに変換する
|
||||||
fs.open(output_file_name, ios::app | ios::binary);
|
fs.open(output_file_name, ios::app | ios::binary);
|
||||||
|
StateListPtr states;
|
||||||
for (auto filename : filenames) {
|
for (auto filename : filenames) {
|
||||||
std::cout << "convert " << filename << " ... ";
|
std::cout << "convert " << filename << " ... ";
|
||||||
std::string line;
|
std::string line;
|
||||||
ifstream ifs;
|
ifstream ifs;
|
||||||
ifs.open(filename);
|
ifs.open(filename);
|
||||||
PackedSfenValue p;
|
PackedSfenValue p;
|
||||||
|
data_size = 0;
|
||||||
|
filtered_size = 0;
|
||||||
p.gamePly = 1; // apery形式では含まれない。一応初期化するべし
|
p.gamePly = 1; // apery形式では含まれない。一応初期化するべし
|
||||||
|
bool ignore_flag = false;
|
||||||
|
|
||||||
while (std::getline(ifs, line)) {
|
while (std::getline(ifs, line)) {
|
||||||
std::stringstream ss(line);
|
std::stringstream ss(line);
|
||||||
std::string token;
|
std::string token;
|
||||||
std::string value;
|
std::string value;
|
||||||
ss >> token;
|
ss >> token;
|
||||||
if (token == "sfen") {
|
if (token == "fen") {
|
||||||
StateInfo si;
|
states = StateListPtr(new std::deque<StateInfo>(1)); // Drop old and create a new one
|
||||||
tpos.set(line.substr(5), false, &si, Threads.main());
|
tpos.set(line.substr(4), false, &states->back(), Threads.main());
|
||||||
tpos.sfen_pack(p.sfen);
|
tpos.sfen_pack(p.sfen);
|
||||||
}
|
}
|
||||||
else if (token == "move") {
|
else if (token == "move") {
|
||||||
ss >> value;
|
ss >> value;
|
||||||
|
@ -2451,23 +2458,37 @@ void convert_bin(const vector<string>& filenames , const string& output_file_nam
|
||||||
else if (token == "ply") {
|
else if (token == "ply") {
|
||||||
int temp;
|
int temp;
|
||||||
ss >> temp;
|
ss >> temp;
|
||||||
|
if(temp < ply_minimum || temp > ply_maximum){
|
||||||
|
ignore_flag = true;
|
||||||
|
}
|
||||||
p.gamePly = uint16_t(temp); // 此処のキャストいらない?
|
p.gamePly = uint16_t(temp); // 此処のキャストいらない?
|
||||||
|
if (interpolate_eval != 0){
|
||||||
|
p.score = min(3000, interpolate_eval * temp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (token == "result") {
|
else if (token == "result") {
|
||||||
int temp;
|
int temp;
|
||||||
ss >> temp;
|
ss >> temp;
|
||||||
p.game_result = int8_t(temp); // 此処のキャストいらない?
|
p.game_result = int8_t(temp); // 此処のキャストいらない?
|
||||||
|
if (interpolate_eval){
|
||||||
|
p.score = p.score * p.game_result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (token == "e") {
|
else if (token == "e") {
|
||||||
|
if(!ignore_flag){
|
||||||
fs.write((char*)&p, sizeof(PackedSfenValue));
|
fs.write((char*)&p, sizeof(PackedSfenValue));
|
||||||
|
data_size+=1;
|
||||||
// debug
|
// debug
|
||||||
/*
|
// std::cout<<tpos<<std::endl;
|
||||||
std::cout<<tpos<<std::endl;
|
// std::cout<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
|
||||||
std::cout<<to_usi_string(Move(p.move))<<","<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
|
}else{
|
||||||
*/
|
ignore_flag = false;
|
||||||
|
filtered_size += 1;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << "done" << std::endl;
|
std::cout << "done" << data_size <<" parsed " << filtered_size<<" is filtered"<< std::endl;
|
||||||
ifs.close();
|
ifs.close();
|
||||||
}
|
}
|
||||||
std::cout << "all done" << std::endl;
|
std::cout << "all done" << std::endl;
|
||||||
|
@ -2557,6 +2578,9 @@ void learn(Position&, istringstream& is)
|
||||||
bool use_convert_plain = false;
|
bool use_convert_plain = false;
|
||||||
// plain形式の教師をやねうら王のbinに変換する
|
// plain形式の教師をやねうら王のbinに変換する
|
||||||
bool use_convert_bin = false;
|
bool use_convert_bin = false;
|
||||||
|
int ply_minimum = 0;
|
||||||
|
int ply_maximum = 114514;
|
||||||
|
bool interpolate_eval = 0;
|
||||||
// それらのときに書き出すファイル名(デフォルトでは"shuffled_sfen.bin")
|
// それらのときに書き出すファイル名(デフォルトでは"shuffled_sfen.bin")
|
||||||
string output_file_name = "shuffled_sfen.bin";
|
string output_file_name = "shuffled_sfen.bin";
|
||||||
|
|
||||||
|
@ -2636,7 +2660,9 @@ void learn(Position&, istringstream& is)
|
||||||
else if (option == "eta3") is >> eta3;
|
else if (option == "eta3") is >> eta3;
|
||||||
else if (option == "eta1_epoch") is >> eta1_epoch;
|
else if (option == "eta1_epoch") is >> eta1_epoch;
|
||||||
else if (option == "eta2_epoch") is >> eta2_epoch;
|
else if (option == "eta2_epoch") is >> eta2_epoch;
|
||||||
|
else if (option == "use_draw_in_training") is >> use_draw_in_training;
|
||||||
|
else if (option == "use_draw_in_validation") is >> use_draw_in_validation;
|
||||||
|
else if (option == "use_hash_in_training") is >> use_hash_in_training;
|
||||||
// 割引率
|
// 割引率
|
||||||
else if (option == "discount_rate") is >> discount_rate;
|
else if (option == "discount_rate") is >> discount_rate;
|
||||||
|
|
||||||
|
@ -2687,6 +2713,7 @@ void learn(Position&, istringstream& is)
|
||||||
// 雑巾のconvert関連
|
// 雑巾のconvert関連
|
||||||
else if (option == "convert_plain") use_convert_plain = true;
|
else if (option == "convert_plain") use_convert_plain = true;
|
||||||
else if (option == "convert_bin") use_convert_bin = true;
|
else if (option == "convert_bin") use_convert_bin = true;
|
||||||
|
else if (option == "interpolate_eval") is >> interpolate_eval;
|
||||||
// さもなくば、それはファイル名である。
|
// さもなくば、それはファイル名である。
|
||||||
else
|
else
|
||||||
filenames.push_back(option);
|
filenames.push_back(option);
|
||||||
|
@ -2796,7 +2823,7 @@ void learn(Position&, istringstream& is)
|
||||||
{
|
{
|
||||||
is_ready(true);
|
is_ready(true);
|
||||||
cout << "convert_bin.." << endl;
|
cout << "convert_bin.." << endl;
|
||||||
convert_bin(filenames,output_file_name);
|
convert_bin(filenames,output_file_name, ply_minimum, ply_maximum, interpolate_eval);
|
||||||
return;
|
return;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue