mirror of
https://github.com/sockspls/badfish
synced 2025-05-01 01:03:09 +00:00
Merged the training data generator and the machine learning logic from YaneuraOu.
This commit is contained in:
parent
87445881ec
commit
bcd6985871
37 changed files with 6306 additions and 139 deletions
83
src/eval/evaluate_common.h
Normal file
83
src/eval/evaluate_common.h
Normal file
|
@ -0,0 +1,83 @@
|
|||
#ifndef _EVALUATE_COMMON_H_
|
||||
#define _EVALUATE_COMMON_H_
|
||||
|
||||
// いまどきの手番つき評価関数(EVAL_KPPTとEVAL_KPP_KKPT)の共用header的なもの。
|
||||
|
||||
#if defined (EVAL_KPPT) || defined(EVAL_KPP_KKPT) || defined(EVAL_NNUE)
|
||||
#include <functional>
|
||||
|
||||
// KKファイル名
|
||||
#define KK_BIN "KK_synthesized.bin"
|
||||
|
||||
// KKPファイル名
|
||||
#define KKP_BIN "KKP_synthesized.bin"
|
||||
|
||||
// KPPファイル名
|
||||
#define KPP_BIN "KPP_synthesized.bin"
|
||||
|
||||
namespace Eval
|
||||
{
|
||||
|
||||
#if defined(USE_EVAL_HASH)
|
||||
// prefetchする関数
|
||||
void prefetch_evalhash(const Key key);
|
||||
#endif
|
||||
|
||||
// 評価関数のそれぞれのパラメーターに対して関数fを適用してくれるoperator。
|
||||
// パラメーターの分析などに用いる。
|
||||
// typeは調査対象を表す。
|
||||
// type = -1 : KK,KKP,KPPすべて
|
||||
// type = 0 : KK のみ
|
||||
// type = 1 : KKPのみ
|
||||
// type = 2 : KPPのみ
|
||||
void foreach_eval_param(std::function<void(int32_t, int32_t)>f, int type = -1);
|
||||
|
||||
// --------------------------
|
||||
// 学習用
|
||||
// --------------------------
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
// 学習のときの勾配配列の初期化
|
||||
// 学習率を引数に渡しておく。0.0なら、defaultの値を採用する。
|
||||
// update_weights()のepochが、eta_epochまでetaから徐々にeta2に変化する。
|
||||
// eta2_epoch以降は、eta2から徐々にeta3に変化する。
|
||||
void init_grad(double eta1, uint64_t eta_epoch, double eta2, uint64_t eta2_epoch, double eta3);
|
||||
|
||||
// 現在の局面で出現している特徴すべてに対して、勾配の差分値を勾配配列に加算する。
|
||||
// freeze[0] : kkは学習させないフラグ
|
||||
// freeze[1] : kkpは学習させないフラグ
|
||||
// freeze[2] : kppは学習させないフラグ
|
||||
// freeze[3] : kpppは学習させないフラグ
|
||||
void add_grad(Position& pos, Color rootColor, double delt_grad, const std::array<bool, 4>& freeze);
|
||||
|
||||
// 現在の勾配をもとにSGDかAdaGradか何かする。
|
||||
// epoch : 世代カウンター(0から始まる)
|
||||
// freeze[0] : kkは学習させないフラグ
|
||||
// freeze[1] : kkpは学習させないフラグ
|
||||
// freeze[2] : kppは学習させないフラグ
|
||||
// freeze[3] : kpppは学習させないフラグ
|
||||
void update_weights(uint64_t epoch, const std::array<bool,4>& freeze);
|
||||
|
||||
// 評価関数パラメーターをファイルに保存する。
|
||||
// ファイルの末尾につける拡張子を指定できる。
|
||||
void save_eval(std::string suffix);
|
||||
|
||||
// 現在のetaを取得する。
|
||||
double get_eta();
|
||||
|
||||
// -- 学習に関連したコマンド
|
||||
|
||||
// KKを正規化する関数。元の評価関数と完全に等価にはならないので注意。
|
||||
// kkp,kppの値をなるべくゼロに近づけることで、学習中に出現しなかった特徴因子の値(ゼロになっている)が
|
||||
// 妥当であることを保証しようという考え。
|
||||
void regularize_kk();
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
#endif // _EVALUATE_KPPT_COMMON_H_
|
186
src/eval/evaluate_mir_inv_tools.cpp
Normal file
186
src/eval/evaluate_mir_inv_tools.cpp
Normal file
|
@ -0,0 +1,186 @@
|
|||
#include "evaluate_mir_inv_tools.h"
|
||||
|
||||
namespace Eval
|
||||
{
|
||||
|
||||
// --- tables
|
||||
|
||||
// あるBonaPieceを相手側から見たときの値
|
||||
// BONA_PIECE_INITが-1なので符号型で持つ必要がある。
|
||||
// KPPTを拡張しても当面、BonaPieceが2^15を超えることはないのでint16_tで良しとする。
|
||||
int16_t inv_piece_[Eval::fe_end];
|
||||
|
||||
// 盤面上のあるBonaPieceをミラーした位置にあるものを返す。
|
||||
int16_t mir_piece_[Eval::fe_end];
|
||||
|
||||
|
||||
// --- methods
|
||||
|
||||
// あるBonaPieceを相手側から見たときの値を返す
|
||||
Eval::BonaPiece inv_piece(Eval::BonaPiece p) { return (Eval::BonaPiece)inv_piece_[p]; }
|
||||
|
||||
// 盤面上のあるBonaPieceをミラーした位置にあるものを返す。
|
||||
Eval::BonaPiece mir_piece(Eval::BonaPiece p) { return (Eval::BonaPiece)mir_piece_[p]; }
|
||||
|
||||
std::function<void()> mir_piece_init_function;
|
||||
|
||||
void init_mir_inv_tables()
|
||||
{
|
||||
// mirrorとinverseのテーブルの初期化。
|
||||
|
||||
// 初期化は1回に限る。
|
||||
static bool first = true;
|
||||
if (!first) return;
|
||||
first = false;
|
||||
|
||||
// fとeとの交換
|
||||
int t[] = {
|
||||
f_pawn , e_pawn ,
|
||||
f_knight , e_knight ,
|
||||
f_bishop , e_bishop ,
|
||||
f_rook , e_rook ,
|
||||
f_queen , e_queen ,
|
||||
};
|
||||
|
||||
// 未初期化の値を突っ込んでおく。
|
||||
for (BonaPiece p = BONA_PIECE_ZERO; p < fe_end; ++p)
|
||||
{
|
||||
inv_piece_[p] = BONA_PIECE_NOT_INIT;
|
||||
|
||||
// mirrorは手駒に対しては機能しない。元の値を返すだけ。
|
||||
mir_piece_[p] = (p < f_pawn) ? p : BONA_PIECE_NOT_INIT;
|
||||
}
|
||||
|
||||
for (BonaPiece p = BONA_PIECE_ZERO; p < fe_end; ++p)
|
||||
{
|
||||
for (int i = 0; i < 32 /* t.size() */; i += 2)
|
||||
{
|
||||
if (t[i] <= p && p < t[i + 1])
|
||||
{
|
||||
Square sq = (Square)(p - t[i]);
|
||||
|
||||
// 見つかった!!
|
||||
BonaPiece q = (p < fe_hand_end) ? BonaPiece(sq + t[i + 1]) : (BonaPiece)(Inv(sq) + t[i + 1]);
|
||||
inv_piece_[p] = q;
|
||||
inv_piece_[q] = p;
|
||||
|
||||
/*
|
||||
ちょっとトリッキーだが、pに関して盤上の駒は
|
||||
p >= fe_hand_end
|
||||
のとき。
|
||||
|
||||
このpに対して、nを整数として(上のコードのiは偶数しかとらない)、
|
||||
a) t[2n + 0] <= p < t[2n + 1] のときは先手の駒
|
||||
b) t[2n + 1] <= p < t[2n + 2] のときは後手の駒
|
||||
である。
|
||||
|
||||
ゆえに、a)の範囲にあるpをq = Inv(p-t[2n+0]) + t[2n+1] とすると180度回転させた升にある後手の駒となる。
|
||||
そこでpとqをswapさせてinv_piece[ ]を初期化してある。
|
||||
*/
|
||||
|
||||
// 手駒に関してはmirrorなど存在しない。
|
||||
if (p < fe_hand_end)
|
||||
continue;
|
||||
|
||||
BonaPiece r1 = (BonaPiece)(Mir(sq) + t[i]);
|
||||
mir_piece_[p] = r1;
|
||||
mir_piece_[r1] = p;
|
||||
|
||||
BonaPiece p2 = (BonaPiece)(sq + t[i + 1]);
|
||||
BonaPiece r2 = (BonaPiece)(Mir(sq) + t[i + 1]);
|
||||
mir_piece_[p2] = r2;
|
||||
mir_piece_[r2] = p2;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mir_piece_init_function)
|
||||
mir_piece_init_function();
|
||||
|
||||
for (BonaPiece p = BONA_PIECE_ZERO; p < fe_end; ++p)
|
||||
{
|
||||
// 未初期化のままになっている。上のテーブルの初期化コードがおかしい。
|
||||
assert(mir_piece_[p] != BONA_PIECE_NOT_INIT && mir_piece_[p] < fe_end);
|
||||
assert(inv_piece_[p] != BONA_PIECE_NOT_INIT && inv_piece_[p] < fe_end);
|
||||
|
||||
// mirとinvは、2回適用したら元の座標に戻る。
|
||||
assert(mir_piece_[mir_piece_[p]] == p);
|
||||
assert(inv_piece_[inv_piece_[p]] == p);
|
||||
|
||||
// mir->inv->mir->invは元の場所でなければならない。
|
||||
assert(p == inv_piece(mir_piece(inv_piece(mir_piece(p)))));
|
||||
|
||||
// inv->mir->inv->mirは元の場所でなければならない。
|
||||
assert(p == mir_piece(inv_piece(mir_piece(inv_piece(p)))));
|
||||
}
|
||||
|
||||
#if 0
|
||||
// 評価関数のミラーをしても大丈夫であるかの事前検証
|
||||
// 値を書き込んだときにassertionがあるので、ミラーしてダメである場合、
|
||||
// そのassertに引っかかるはず。
|
||||
|
||||
// AperyのWCSC26の評価関数、kppのp1==0とかp1==20(後手の0枚目の歩)とかの
|
||||
// ところにゴミが入っていて、これを回避しないとassertに引っかかる。
|
||||
|
||||
std::unordered_set<BonaPiece> s;
|
||||
vector<int> a = {
|
||||
f_hand_pawn - 1,e_hand_pawn - 1,
|
||||
f_hand_lance - 1, e_hand_lance - 1,
|
||||
f_hand_knight - 1, e_hand_knight - 1,
|
||||
f_hand_silver - 1, e_hand_silver - 1,
|
||||
f_hand_gold - 1, e_hand_gold - 1,
|
||||
f_hand_bishop - 1, e_hand_bishop - 1,
|
||||
f_hand_rook - 1, e_hand_rook - 1,
|
||||
};
|
||||
for (auto b : a)
|
||||
s.insert((BonaPiece)b);
|
||||
|
||||
// さらに出現しない升の盤上の歩、香、桂も除外(Aperyはここにもゴミが入っている)
|
||||
for (Rank r = RANK_1; r <= RANK_2; ++r)
|
||||
for (File f = FILE_1; f <= FILE_9; ++f)
|
||||
{
|
||||
if (r == RANK_1)
|
||||
{
|
||||
// 1段目の歩
|
||||
BonaPiece b1 = BonaPiece(f_pawn + (f | r));
|
||||
s.insert(b1);
|
||||
s.insert(inv_piece[b1]);
|
||||
|
||||
// 1段目の香
|
||||
BonaPiece b2 = BonaPiece(f_lance + (f | r));
|
||||
s.insert(b2);
|
||||
s.insert(inv_piece[b2]);
|
||||
}
|
||||
|
||||
// 1,2段目の桂
|
||||
BonaPiece b = BonaPiece(f_knight + (f | r));
|
||||
s.insert(b);
|
||||
s.insert(inv_piece[b]);
|
||||
}
|
||||
|
||||
cout << "\nchecking kpp_write()..";
|
||||
for (auto sq : SQ)
|
||||
{
|
||||
cout << sq << ' ';
|
||||
for (BonaPiece p1 = BONA_PIECE_ZERO; p1 < fe_end; ++p1)
|
||||
for (BonaPiece p2 = BONA_PIECE_ZERO; p2 < fe_end; ++p2)
|
||||
if (!s.count(p1) && !s.count(p2))
|
||||
kpp_write(sq, p1, p2, kpp[sq][p1][p2]);
|
||||
}
|
||||
cout << "\nchecking kkp_write()..";
|
||||
|
||||
for (auto sq1 : SQ)
|
||||
{
|
||||
cout << sq1 << ' ';
|
||||
for (auto sq2 : SQ)
|
||||
for (BonaPiece p1 = BONA_PIECE_ZERO; p1 < fe_end; ++p1)
|
||||
if (!s.count(p1))
|
||||
kkp_write(sq1, sq2, p1, kkp[sq1][sq2][p1]);
|
||||
}
|
||||
cout << "..done!" << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
43
src/eval/evaluate_mir_inv_tools.h
Normal file
43
src/eval/evaluate_mir_inv_tools.h
Normal file
|
@ -0,0 +1,43 @@
|
|||
#ifndef _EVALUATE_MIR_INV_TOOLS_
|
||||
#define _EVALUATE_MIR_INV_TOOLS_
|
||||
|
||||
// BonaPieceのmirror(左右反転)やinverse(盤上の180度回転)させた駒を得るためのツール類。
|
||||
|
||||
#include "../types.h"
|
||||
#include "../evaluate.h"
|
||||
#include <functional>
|
||||
|
||||
namespace Eval
|
||||
{
|
||||
// -------------------------------------------------
|
||||
// tables
|
||||
// -------------------------------------------------
|
||||
|
||||
// --- BonaPieceに対してMirrorとInverseを提供する。
|
||||
|
||||
// これらの配列は、init()かinit_mir_inv_tables();を呼び出すと初期化される。
|
||||
// このテーブルのみを評価関数のほうから使いたいときは、評価関数の初期化のときに
|
||||
// init_mir_inv_tables()を呼び出すと良い。
|
||||
// これらの配列は、以下のKK/KKP/KPPクラスから参照される。
|
||||
|
||||
// あるBonaPieceを相手側から見たときの値を返す
|
||||
extern Eval::BonaPiece inv_piece(Eval::BonaPiece p);
|
||||
|
||||
// 盤面上のあるBonaPieceをミラーした位置にあるものを返す。
|
||||
extern Eval::BonaPiece mir_piece(Eval::BonaPiece p);
|
||||
|
||||
|
||||
// mir_piece/inv_pieceの初期化のときに呼び出されるcallback
|
||||
// fe_endをユーザー側で拡張するときに用いる。
|
||||
// この初期化のときに必要なのでinv_piece_とinv_piece_を公開している。
|
||||
// mir_piece_init_functionが呼び出されたタイミングで、fe_old_endまでは
|
||||
// これらのテーブルの初期化が完了していることが保証されている。
|
||||
extern std::function<void()> mir_piece_init_function;
|
||||
extern int16_t mir_piece_[Eval::fe_end];
|
||||
extern int16_t inv_piece_[Eval::fe_end];
|
||||
|
||||
// この関数を明示的に呼び出すか、init()を呼び出すかしたときに、上のテーブルが初期化される。
|
||||
extern void init_mir_inv_tables();
|
||||
}
|
||||
|
||||
#endif
|
|
@ -224,8 +224,8 @@ EvaluateHashTable g_evalTable;
|
|||
|
||||
// prefetchする関数も用意しておく。
|
||||
void prefetch_evalhash(const Key key) {
|
||||
constexpr auto mask = ~((u64)0x1f);
|
||||
prefetch((void*)((u64)g_evalTable[key] & mask));
|
||||
constexpr auto mask = ~((uint64_t)0x1f);
|
||||
prefetch((void*)((uint64_t)g_evalTable[key] & mask));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -269,7 +269,7 @@ Value compute_eval(const Position& pos) {
|
|||
}
|
||||
|
||||
// 評価関数
|
||||
Value NNUE::evaluate(const Position& pos) {
|
||||
Value evaluate(const Position& pos) {
|
||||
const auto& accumulator = pos.state()->accumulator;
|
||||
if (accumulator.computed_score) {
|
||||
return accumulator.score;
|
||||
|
|
|
@ -55,8 +55,6 @@ bool ReadParameters(std::istream& stream);
|
|||
// 評価関数パラメータを書き込む
|
||||
bool WriteParameters(std::ostream& stream);
|
||||
|
||||
Value evaluate(const Position& pos);
|
||||
|
||||
} // namespace NNUE
|
||||
|
||||
} // namespace Eval
|
||||
|
|
|
@ -9,8 +9,9 @@
|
|||
#include "../../learn/learning_tools.h"
|
||||
|
||||
#include "../../position.h"
|
||||
#include "../../usi.h"
|
||||
#include "../../uci.h"
|
||||
#include "../../misc.h"
|
||||
#include "../../thread_win32_osx.h"
|
||||
|
||||
#include "../evaluate_common.h"
|
||||
|
||||
|
@ -37,7 +38,7 @@ std::vector<Example> examples;
|
|||
Mutex examples_mutex;
|
||||
|
||||
// ミニバッチのサンプル数
|
||||
u64 batch_size;
|
||||
uint64_t batch_size;
|
||||
|
||||
// 乱数生成器
|
||||
std::mt19937 rng;
|
||||
|
@ -57,20 +58,20 @@ double GetGlobalLearningRateScale() {
|
|||
void SendMessages(std::vector<Message> messages) {
|
||||
for (auto& message : messages) {
|
||||
trainer->SendMessage(&message);
|
||||
ASSERT_LV3(message.num_receivers > 0);
|
||||
assert(message.num_receivers > 0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// 学習の初期化を行う
|
||||
void InitializeTraining(double eta1, u64 eta1_epoch,
|
||||
double eta2, u64 eta2_epoch, double eta3) {
|
||||
void InitializeTraining(double eta1, uint64_t eta1_epoch,
|
||||
double eta2, uint64_t eta2_epoch, double eta3) {
|
||||
std::cout << "Initializing NN training for "
|
||||
<< GetArchitectureString() << std::endl;
|
||||
|
||||
ASSERT(feature_transformer);
|
||||
ASSERT(network);
|
||||
assert(feature_transformer);
|
||||
assert(network);
|
||||
trainer = Trainer<Network>::Create(network.get(), feature_transformer.get());
|
||||
|
||||
if (Options["SkipLoadingEval"]) {
|
||||
|
@ -82,8 +83,8 @@ void InitializeTraining(double eta1, u64 eta1_epoch,
|
|||
}
|
||||
|
||||
// ミニバッチのサンプル数を設定する
|
||||
void SetBatchSize(u64 size) {
|
||||
ASSERT_LV3(size > 0);
|
||||
void SetBatchSize(uint64_t size) {
|
||||
assert(size > 0);
|
||||
batch_size = size;
|
||||
}
|
||||
|
||||
|
@ -97,7 +98,7 @@ void SetOptions(const std::string& options) {
|
|||
std::vector<Message> messages;
|
||||
for (const auto& option : Split(options, ',')) {
|
||||
const auto fields = Split(option, '=');
|
||||
ASSERT_LV3(fields.size() == 1 || fields.size() == 2);
|
||||
assert(fields.size() == 1 || fields.size() == 2);
|
||||
if (fields.size() == 1) {
|
||||
messages.emplace_back(fields[0]);
|
||||
} else {
|
||||
|
@ -112,7 +113,7 @@ void RestoreParameters(const std::string& dir_name) {
|
|||
const std::string file_name = Path::Combine(dir_name, NNUE::kFileName);
|
||||
std::ifstream stream(file_name, std::ios::binary);
|
||||
bool result = ReadParameters(stream);
|
||||
ASSERT(result);
|
||||
assert(result);
|
||||
|
||||
SendMessages({{"reset"}});
|
||||
}
|
||||
|
@ -136,7 +137,7 @@ void AddExample(Position& pos, Color rootColor,
|
|||
if (pos.side_to_move() != BLACK) {
|
||||
active_indices[0].swap(active_indices[1]);
|
||||
}
|
||||
for (const auto color : COLOR) {
|
||||
for (const auto color : Colors) {
|
||||
std::vector<TrainingFeature> training_features;
|
||||
for (const auto base_index : active_indices[color]) {
|
||||
static_assert(Features::Factorizer<RawFeatures>::GetDimensions() <
|
||||
|
@ -162,8 +163,8 @@ void AddExample(Position& pos, Color rootColor,
|
|||
}
|
||||
|
||||
// 評価関数パラメーターを更新する
|
||||
void UpdateParameters(u64 epoch) {
|
||||
ASSERT_LV3(batch_size > 0);
|
||||
void UpdateParameters(uint64_t epoch) {
|
||||
assert(batch_size > 0);
|
||||
|
||||
EvalLearningTools::Weight::calc_eta(epoch);
|
||||
const auto learning_rate = static_cast<LearnFloatType>(
|
||||
|
@ -215,7 +216,7 @@ void save_eval(std::string dir_name) {
|
|||
const std::string file_name = Path::Combine(eval_dir, NNUE::kFileName);
|
||||
std::ofstream stream(file_name, std::ios::binary);
|
||||
const bool result = NNUE::WriteParameters(stream);
|
||||
ASSERT(result);
|
||||
assert(result);
|
||||
|
||||
std::cout << "save_eval() finished. folder = " << eval_dir << std::endl;
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _EVALUATE_NNUE_LEARNER_H_
|
||||
#define _EVALUATE_NNUE_LEARNER_H_
|
||||
|
||||
#include "../../config.h"
|
||||
|
||||
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
|
||||
|
||||
#include "../../learn/learn.h"
|
||||
|
@ -14,11 +12,11 @@ namespace Eval {
|
|||
namespace NNUE {
|
||||
|
||||
// 学習の初期化を行う
|
||||
void InitializeTraining(double eta1, u64 eta1_epoch,
|
||||
double eta2, u64 eta2_epoch, double eta3);
|
||||
void InitializeTraining(double eta1, uint64_t eta1_epoch,
|
||||
double eta2, uint64_t eta2_epoch, double eta3);
|
||||
|
||||
// ミニバッチのサンプル数を設定する
|
||||
void SetBatchSize(u64 size);
|
||||
void SetBatchSize(uint64_t size);
|
||||
|
||||
// 学習率のスケールを設定する
|
||||
void SetGlobalLearningRateScale(double scale);
|
||||
|
@ -34,7 +32,7 @@ void AddExample(Position& pos, Color rootColor,
|
|||
const Learner::PackedSfenValue& psv, double weight);
|
||||
|
||||
// 評価関数パラメータを更新する
|
||||
void UpdateParameters(u64 epoch);
|
||||
void UpdateParameters(uint64_t epoch);
|
||||
|
||||
// 学習に問題が生じていないかチェックする
|
||||
void CheckHealth();
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_LAYERS_SUM_H_
|
||||
#define _NNUE_LAYERS_SUM_H_
|
||||
|
||||
#include "../../../config.h"
|
||||
|
||||
#if defined(EVAL_NNUE)
|
||||
|
||||
#include "../nnue_common.h"
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_FEATURES_FACTORIZER_H_
|
||||
#define _NNUE_TRAINER_FEATURES_FACTORIZER_H_
|
||||
|
||||
#include "../../../../config.h"
|
||||
|
||||
#if defined(EVAL_NNUE)
|
||||
|
||||
#include "../../nnue_common.h"
|
||||
|
@ -29,7 +27,7 @@ class Factorizer {
|
|||
// 学習用特徴量のインデックスと学習率のスケールを取得する
|
||||
static void AppendTrainingFeatures(
|
||||
IndexType base_index, std::vector<TrainingFeature>* training_features) {
|
||||
ASSERT_LV5(base_index < FeatureType::kDimensions);
|
||||
assert(base_index < FeatureType::kDimensions);
|
||||
training_features->emplace_back(base_index);
|
||||
}
|
||||
};
|
||||
|
@ -45,8 +43,8 @@ template <typename FeatureType>
|
|||
IndexType AppendBaseFeature(
|
||||
FeatureProperties properties, IndexType base_index,
|
||||
std::vector<TrainingFeature>* training_features) {
|
||||
ASSERT_LV5(properties.dimensions == FeatureType::kDimensions);
|
||||
ASSERT_LV5(base_index < FeatureType::kDimensions);
|
||||
assert(properties.dimensions == FeatureType::kDimensions);
|
||||
assert(base_index < FeatureType::kDimensions);
|
||||
training_features->emplace_back(base_index);
|
||||
return properties.dimensions;
|
||||
}
|
||||
|
@ -59,14 +57,14 @@ IndexType InheritFeaturesIfRequired(
|
|||
if (!properties.active) {
|
||||
return 0;
|
||||
}
|
||||
ASSERT_LV5(properties.dimensions == Factorizer<FeatureType>::GetDimensions());
|
||||
ASSERT_LV5(base_index < FeatureType::kDimensions);
|
||||
assert(properties.dimensions == Factorizer<FeatureType>::GetDimensions());
|
||||
assert(base_index < FeatureType::kDimensions);
|
||||
const auto start = training_features->size();
|
||||
Factorizer<FeatureType>::AppendTrainingFeatures(
|
||||
base_index, training_features);
|
||||
for (auto i = start; i < training_features->size(); ++i) {
|
||||
auto& feature = (*training_features)[i];
|
||||
ASSERT_LV5(feature.GetIndex() < Factorizer<FeatureType>::GetDimensions());
|
||||
assert(feature.GetIndex() < Factorizer<FeatureType>::GetDimensions());
|
||||
feature.ShiftIndex(index_offset);
|
||||
}
|
||||
return properties.dimensions;
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_FEATURES_FACTORIZER_FEATURE_SET_H_
|
||||
#define _NNUE_TRAINER_FEATURES_FACTORIZER_FEATURE_SET_H_
|
||||
|
||||
#include "../../../../config.h"
|
||||
|
||||
#if defined(EVAL_NNUE)
|
||||
|
||||
#include "../../features/feature_set.h"
|
||||
|
@ -38,7 +36,7 @@ class Factorizer<FeatureSet<FirstFeatureType, RemainingFeatureTypes...>> {
|
|||
static void AppendTrainingFeatures(
|
||||
IndexType base_index, std::vector<TrainingFeature>* training_features,
|
||||
IndexType base_dimensions = kBaseDimensions) {
|
||||
ASSERT_LV5(base_index < kBaseDimensions);
|
||||
assert(base_index < kBaseDimensions);
|
||||
constexpr auto boundary = FeatureSet<RemainingFeatureTypes...>::kDimensions;
|
||||
if (base_index < boundary) {
|
||||
Tail::AppendTrainingFeatures(
|
||||
|
@ -50,7 +48,7 @@ class Factorizer<FeatureSet<FirstFeatureType, RemainingFeatureTypes...>> {
|
|||
for (auto i = start; i < training_features->size(); ++i) {
|
||||
auto& feature = (*training_features)[i];
|
||||
const auto index = feature.GetIndex();
|
||||
ASSERT_LV5(index < Head::GetDimensions() ||
|
||||
assert(index < Head::GetDimensions() ||
|
||||
(index >= base_dimensions &&
|
||||
index < base_dimensions +
|
||||
Head::GetDimensions() - Head::kBaseDimensions));
|
||||
|
@ -81,13 +79,13 @@ public:
|
|||
static void AppendTrainingFeatures(
|
||||
IndexType base_index, std::vector<TrainingFeature>* training_features,
|
||||
IndexType base_dimensions = kBaseDimensions) {
|
||||
ASSERT_LV5(base_index < kBaseDimensions);
|
||||
assert(base_index < kBaseDimensions);
|
||||
const auto start = training_features->size();
|
||||
Factorizer<FeatureType>::AppendTrainingFeatures(
|
||||
base_index, training_features);
|
||||
for (auto i = start; i < training_features->size(); ++i) {
|
||||
auto& feature = (*training_features)[i];
|
||||
ASSERT_LV5(feature.GetIndex() < Factorizer<FeatureType>::GetDimensions());
|
||||
assert(feature.GetIndex() < Factorizer<FeatureType>::GetDimensions());
|
||||
if (feature.GetIndex() >= kBaseDimensions) {
|
||||
feature.ShiftIndex(base_dimensions - kBaseDimensions);
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_FEATURES_FACTORIZER_HALF_KP_H_
|
||||
#define _NNUE_TRAINER_FEATURES_FACTORIZER_HALF_KP_H_
|
||||
|
||||
#include "../../../../config.h"
|
||||
|
||||
#if defined(EVAL_NNUE)
|
||||
|
||||
#include "../../features/half_kp.h"
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_H_
|
||||
#define _NNUE_TRAINER_H_
|
||||
|
||||
#include "../../../config.h"
|
||||
|
||||
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
|
||||
|
||||
#include "../nnue_common.h"
|
||||
|
@ -36,11 +34,11 @@ class TrainingFeature {
|
|||
|
||||
explicit TrainingFeature(IndexType index) :
|
||||
index_and_count_((index << kCountBits) | 1) {
|
||||
ASSERT_LV3(index < (1 << kIndexBits));
|
||||
assert(index < (1 << kIndexBits));
|
||||
}
|
||||
TrainingFeature& operator+=(const TrainingFeature& other) {
|
||||
ASSERT_LV3(other.GetIndex() == GetIndex());
|
||||
ASSERT_LV3(other.GetCount() + GetCount() < (1 << kCountBits));
|
||||
assert(other.GetIndex() == GetIndex());
|
||||
assert(other.GetCount() + GetCount() < (1 << kCountBits));
|
||||
index_and_count_ += other.GetCount();
|
||||
return *this;
|
||||
}
|
||||
|
@ -48,7 +46,7 @@ class TrainingFeature {
|
|||
return static_cast<IndexType>(index_and_count_ >> kCountBits);
|
||||
}
|
||||
void ShiftIndex(IndexType offset) {
|
||||
ASSERT_LV3(GetIndex() + offset < (1 << kIndexBits));
|
||||
assert(GetIndex() + offset < (1 << kIndexBits));
|
||||
index_and_count_ += offset << kCountBits;
|
||||
}
|
||||
IndexType GetCount() const {
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_AFFINE_TRANSFORM_H_
|
||||
#define _NNUE_TRAINER_AFFINE_TRANSFORM_H_
|
||||
|
||||
#include "../../../config.h"
|
||||
|
||||
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
|
||||
|
||||
#include "../../../learn/learn.h"
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_CLIPPED_RELU_H_
|
||||
#define _NNUE_TRAINER_CLIPPED_RELU_H_
|
||||
|
||||
#include "../../../config.h"
|
||||
|
||||
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
|
||||
|
||||
#include "../../../learn/learn.h"
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_FEATURE_TRANSFORMER_H_
|
||||
#define _NNUE_TRAINER_FEATURE_TRANSFORMER_H_
|
||||
|
||||
#include "../../../config.h"
|
||||
|
||||
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
|
||||
|
||||
#include "../../../learn/learn.h"
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_INPUT_SLICE_H_
|
||||
#define _NNUE_TRAINER_INPUT_SLICE_H_
|
||||
|
||||
#include "../../../config.h"
|
||||
|
||||
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
|
||||
|
||||
#include "../../../learn/learn.h"
|
||||
|
@ -35,7 +33,7 @@ class SharedInputTrainer {
|
|||
current_operation_ = Operation::kSendMessage;
|
||||
feature_transformer_trainer_->SendMessage(message);
|
||||
}
|
||||
ASSERT_LV3(current_operation_ == Operation::kSendMessage);
|
||||
assert(current_operation_ == Operation::kSendMessage);
|
||||
if (++num_calls_ == num_referrers_) {
|
||||
num_calls_ = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
|
@ -49,7 +47,7 @@ class SharedInputTrainer {
|
|||
current_operation_ = Operation::kInitialize;
|
||||
feature_transformer_trainer_->Initialize(rng);
|
||||
}
|
||||
ASSERT_LV3(current_operation_ == Operation::kInitialize);
|
||||
assert(current_operation_ == Operation::kInitialize);
|
||||
if (++num_calls_ == num_referrers_) {
|
||||
num_calls_ = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
|
@ -66,7 +64,7 @@ class SharedInputTrainer {
|
|||
current_operation_ = Operation::kPropagate;
|
||||
output_ = feature_transformer_trainer_->Propagate(batch);
|
||||
}
|
||||
ASSERT_LV3(current_operation_ == Operation::kPropagate);
|
||||
assert(current_operation_ == Operation::kPropagate);
|
||||
if (++num_calls_ == num_referrers_) {
|
||||
num_calls_ = 0;
|
||||
current_operation_ = Operation::kNone;
|
||||
|
@ -90,7 +88,7 @@ class SharedInputTrainer {
|
|||
}
|
||||
}
|
||||
}
|
||||
ASSERT_LV3(current_operation_ == Operation::kBackPropagate);
|
||||
assert(current_operation_ == Operation::kBackPropagate);
|
||||
for (IndexType b = 0; b < batch_size_; ++b) {
|
||||
const IndexType batch_offset = kInputDimensions * b;
|
||||
for (IndexType i = 0; i < kInputDimensions; ++i) {
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#ifndef _NNUE_TRAINER_SUM_H_
|
||||
#define _NNUE_TRAINER_SUM_H_
|
||||
|
||||
#include "../../../config.h"
|
||||
|
||||
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
|
||||
|
||||
#include "../../../learn/learn.h"
|
||||
|
|
|
@ -864,13 +864,11 @@ namespace {
|
|||
/// evaluate() is the evaluator for the outer world. It returns a static
|
||||
/// evaluation of the position from the point of view of the side to move.
|
||||
|
||||
#if !defined(EVAL_NNUE)
|
||||
Value Eval::evaluate(const Position& pos) {
|
||||
#if defined(EVAL_NNUE)
|
||||
return Eval::NNUE::evaluate(pos);
|
||||
#else
|
||||
return Evaluation<NO_TRACE>(pos).value();
|
||||
#endif // defined(EVAL_NNUE)
|
||||
}
|
||||
#endif // defined(EVAL_NNUE)
|
||||
|
||||
|
||||
/// trace() is like evaluate(), but instead of returning a value, it returns
|
||||
|
|
|
@ -35,6 +35,8 @@ std::string trace(const Position& pos);
|
|||
|
||||
Value evaluate(const Position& pos);
|
||||
|
||||
void evaluate_with_no_return(const Position& pos);
|
||||
|
||||
#if defined(EVAL_NNUE)
|
||||
// 評価関数ファイルを読み込む。
|
||||
// これは、"is_ready"コマンドの応答時に1度だけ呼び出される。2度呼び出すことは想定していない。
|
||||
|
@ -85,6 +87,13 @@ enum BonaPiece : int32_t
|
|||
fe_end2 = e_king + SQUARE_NB, // 玉も含めた末尾の番号。
|
||||
};
|
||||
|
||||
#define ENABLE_INCR_OPERATORS_ON(T) \
|
||||
inline T& operator++(T& d) { return d = T(int(d) + 1); } \
|
||||
inline T& operator--(T& d) { return d = T(int(d) - 1); }
|
||||
|
||||
ENABLE_INCR_OPERATORS_ON(BonaPiece)
|
||||
|
||||
#undef ENABLE_INCR_OPERATORS_ON
|
||||
|
||||
// BonaPieceを後手から見たとき(先手の39の歩を後手から見ると後手の71の歩)の番号とを
|
||||
// ペアにしたものをExtBonaPiece型と呼ぶことにする。
|
||||
|
@ -132,7 +141,7 @@ struct EvalList
|
|||
|
||||
// 盤上のsqの升にpiece_noのpcの駒を配置する
|
||||
void put_piece(PieceNumber piece_no, Square sq, Piece pc) {
|
||||
set_piece_on_board(piece_no, BonaPiece(kpp_board_index[pc].fw + sq), BonaPiece(kpp_board_index[pc].fb + inverse(sq)), sq);
|
||||
set_piece_on_board(piece_no, BonaPiece(kpp_board_index[pc].fw + sq), BonaPiece(kpp_board_index[pc].fb + Inv(sq)), sq);
|
||||
}
|
||||
|
||||
// 盤上のある升sqに対応するPieceNumberを返す。
|
||||
|
@ -181,8 +190,8 @@ public:
|
|||
static const int MAX_LENGTH = 40;
|
||||
|
||||
// 盤上の駒に対して、その駒番号(PieceNumber)を保持している配列
|
||||
// 玉がSQ_NBに移動しているとき用に+1まで保持しておくが、
|
||||
// SQ_NBの玉を移動させないので、この値を使うことはないはず。
|
||||
// 玉がSQUARE_NBに移動しているとき用に+1まで保持しておくが、
|
||||
// SQUARE_NBの玉を移動させないので、この値を使うことはないはず。
|
||||
PieceNumber piece_no_list_board[SQUARE_NB_PLUS1];
|
||||
private:
|
||||
|
||||
|
|
444
src/extra/sfen_packer.cpp
Normal file
444
src/extra/sfen_packer.cpp
Normal file
|
@ -0,0 +1,444 @@
|
|||
#if defined (EVAL_LEARN)
|
||||
|
||||
#include "../misc.h"
|
||||
#include "../position.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <cstring> // std::memset()
|
||||
|
||||
using namespace std;
|
||||
|
||||
// -----------------------------------
|
||||
// 局面の圧縮・解凍
|
||||
// -----------------------------------
|
||||
|
||||
// ビットストリームを扱うクラス
|
||||
// 局面の符号化を行なうときに、これがあると便利
|
||||
struct BitStream
|
||||
{
|
||||
// データを格納するメモリを事前にセットする。
|
||||
// そのメモリは0クリアされているものとする。
|
||||
void set_data(uint8_t* data_) { data = data_; reset(); }
|
||||
|
||||
// set_data()で渡されたポインタの取得。
|
||||
uint8_t* get_data() const { return data; }
|
||||
|
||||
// カーソルの取得。
|
||||
int get_cursor() const { return bit_cursor; }
|
||||
|
||||
// カーソルのリセット
|
||||
void reset() { bit_cursor = 0; }
|
||||
|
||||
// ストリームに1bit書き出す。
|
||||
// bは非0なら1を書き出す。0なら0を書き出す。
|
||||
void write_one_bit(int b)
|
||||
{
|
||||
if (b)
|
||||
data[bit_cursor / 8] |= 1 << (bit_cursor & 7);
|
||||
|
||||
++bit_cursor;
|
||||
}
|
||||
|
||||
// ストリームから1ビット取り出す。
|
||||
int read_one_bit()
|
||||
{
|
||||
int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1;
|
||||
++bit_cursor;
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
// nビットのデータを書き出す
|
||||
// データはdの下位から順に書き出されるものとする。
|
||||
void write_n_bit(int d, int n)
|
||||
{
|
||||
for (int i = 0; i < n; ++i)
|
||||
write_one_bit(d & (1 << i));
|
||||
}
|
||||
|
||||
// nビットのデータを読み込む
|
||||
// write_n_bit()の逆変換。
|
||||
int read_n_bit(int n)
|
||||
{
|
||||
int result = 0;
|
||||
for (int i = 0; i < n; ++i)
|
||||
result |= read_one_bit() ? (1 << i) : 0;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
// 次に読み書きすべきbit位置。
|
||||
int bit_cursor;
|
||||
|
||||
// データの実体
|
||||
uint8_t* data;
|
||||
};
|
||||
|
||||
|
||||
// ハフマン符号化
|
||||
// ※ なのはminiの符号化から、変換が楽になるように単純化。
|
||||
//
|
||||
// 盤上の1升(NO_PIECE以外) = 2~6bit ( + 成りフラグ1bit+ 先後1bit )
|
||||
// 手駒の1枚 = 1~5bit ( + 成りフラグ1bit+ 先後1bit )
|
||||
//
|
||||
// 空 xxxxx0 + 0 (none)
|
||||
// 歩 xxxx01 + 2 xxxx0 + 2
|
||||
// 香 xx0011 + 2 xx001 + 2
|
||||
// 桂 xx1011 + 2 xx101 + 2
|
||||
// 銀 xx0111 + 2 xx011 + 2
|
||||
// 金 x01111 + 1 x0111 + 1 // 金は成りフラグはない。
|
||||
// 角 011111 + 2 01111 + 2
|
||||
// 飛 111111 + 2 11111 + 2
|
||||
//
|
||||
// すべての駒が盤上にあるとして、
|
||||
// 空 81 - 40駒 = 41升 = 41bit
|
||||
// 歩 4bit*18駒 = 72bit
|
||||
// 香 6bit* 4駒 = 24bit
|
||||
// 桂 6bit* 4駒 = 24bit
|
||||
// 銀 6bit* 4駒 = 24bit
|
||||
// 金 6bit* 4駒 = 24bit
|
||||
// 角 8bit* 2駒 = 16bit
|
||||
// 飛 8bit* 2駒 = 16bit
|
||||
// -------
|
||||
// 241bit + 1bit(手番) + 7bit×2(王の位置先後) = 256bit
|
||||
//
|
||||
// 盤上の駒が手駒に移動すると盤上の駒が空になるので盤上のその升は1bitで表現でき、
|
||||
// 手駒は、盤上の駒より1bit少なく表現できるので結局、全体のbit数に変化はない。
|
||||
// ゆえに、この表現において、どんな局面でもこのbit数で表現できる。
|
||||
// 手駒に成りフラグは不要だが、これも含めておくと盤上の駒のbit数-1になるので
|
||||
// 全体のbit数が固定化できるのでこれも含めておくことにする。
|
||||
|
||||
// Huffman Encoding
|
||||
//
|
||||
// Empty xxxxxxx0
|
||||
// Pawn xxxxx001 + 1 bit (Side to move)
|
||||
// Knight xxxxx011 + 1 bit (Side to move)
|
||||
// Bishop xxxxx101 + 1 bit (Side to move)
|
||||
// Rook xxxxx111 + 1 bit (Side to move)
|
||||
|
||||
struct HuffmanedPiece
|
||||
{
|
||||
int code; // どうコード化されるか
|
||||
int bits; // 何bit専有するのか
|
||||
};
|
||||
|
||||
HuffmanedPiece huffman_table[] =
|
||||
{
|
||||
{0b000,1}, // NO_PIECE
|
||||
{0b001,3}, // PAWN
|
||||
{0b011,3}, // KNIGHT
|
||||
{0b101,3}, // BISHOP
|
||||
{0b111,3}, // ROOK
|
||||
};
|
||||
|
||||
// sfenを圧縮/解凍するためのクラス
|
||||
// sfenはハフマン符号化をすることで256bit(32bytes)にpackできる。
|
||||
// このことはなのはminiにより証明された。上のハフマン符号化である。
|
||||
//
|
||||
// 内部フォーマット = 手番1bit+王の位置7bit*2 + 盤上の駒(ハフマン符号化) + 手駒(ハフマン符号化)
|
||||
// Side to move (White = 0, Black = 1) (1bit)
|
||||
// White King Position (6 bits)
|
||||
// Black King Position (6 bits)
|
||||
// Huffman Encoding of the board
|
||||
// Castling availability (1 bit x 4)
|
||||
// En passant square (1 or 1 + 6 bits)
|
||||
// Rule 50 (6 bits)
|
||||
// Game play (8 bits)
|
||||
//
|
||||
// TODO(someone): Rename SFEN to FEN.
|
||||
//
|
||||
struct SfenPacker
|
||||
{
|
||||
// sfenをpackしてdata[32]に格納する。
|
||||
void pack(const Position& pos)
|
||||
{
|
||||
// cout << pos;
|
||||
|
||||
memset(data, 0, 32 /* 256bit */);
|
||||
stream.set_data(data);
|
||||
|
||||
// 手番
|
||||
// Side to move.
|
||||
stream.write_one_bit((int)(pos.side_to_move()));
|
||||
|
||||
// 先手玉、後手玉の位置、それぞれ7bit
|
||||
// White king and black king, 6 bits for each.
|
||||
for(auto c : Colors)
|
||||
stream.write_n_bit(pos.king_square(c), 6);
|
||||
|
||||
// Write the pieces on the board other than the kings.
|
||||
for (Rank r = RANK_8; r >= RANK_1; --r)
|
||||
{
|
||||
for (File f = FILE_A; f <= FILE_H; ++f)
|
||||
{
|
||||
Piece pc = pos.piece_on(make_square(f, r));
|
||||
if (type_of(pc) == KING)
|
||||
continue;
|
||||
write_board_piece_to_stream(pc);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(someone): Support chess960.
|
||||
stream.write_one_bit(pos.can_castle(WHITE_OO));
|
||||
stream.write_one_bit(pos.can_castle(WHITE_OOO));
|
||||
stream.write_one_bit(pos.can_castle(BLACK_OO));
|
||||
stream.write_one_bit(pos.can_castle(BLACK_OOO));
|
||||
|
||||
if (pos.ep_square() == SQ_NONE) {
|
||||
stream.write_one_bit(0);
|
||||
}
|
||||
else {
|
||||
stream.write_one_bit(1);
|
||||
stream.write_n_bit(static_cast<int>(pos.ep_square()), 6);
|
||||
}
|
||||
|
||||
stream.write_n_bit(pos.state()->rule50, 6);
|
||||
|
||||
stream.write_n_bit(pos.game_ply(), 8);
|
||||
|
||||
assert(stream.get_cursor() <= 256);
|
||||
}
|
||||
|
||||
// pack()でpackされたsfen(256bit = 32bytes)
|
||||
// もしくはunpack()でdecodeするsfen
|
||||
uint8_t *data; // uint8_t[32];
|
||||
|
||||
//private:
|
||||
// Position::set_from_packed_sfen(uint8_t data[32])でこれらの関数を使いたいので筋は悪いがpublicにしておく。
|
||||
|
||||
BitStream stream;
|
||||
|
||||
// 盤面の駒をstreamに出力する。
|
||||
void write_board_piece_to_stream(Piece pc)
|
||||
{
|
||||
// 駒種
|
||||
PieceType pr = type_of(pc);
|
||||
auto c = huffman_table[pr];
|
||||
stream.write_n_bit(c.code, c.bits);
|
||||
|
||||
if (pc == NO_PIECE)
|
||||
return;
|
||||
|
||||
// 先後フラグ
|
||||
stream.write_one_bit(color_of(pc));
|
||||
}
|
||||
|
||||
// 盤面の駒を1枚streamから読み込む
|
||||
Piece read_board_piece_from_stream()
|
||||
{
|
||||
PieceType pr = NO_PIECE_TYPE;
|
||||
int code = 0, bits = 0;
|
||||
while (true)
|
||||
{
|
||||
code |= stream.read_one_bit() << bits;
|
||||
++bits;
|
||||
|
||||
assert(bits <= 6);
|
||||
|
||||
for (pr = NO_PIECE_TYPE; pr < KING; ++pr)
|
||||
if (huffman_table[pr].code == code
|
||||
&& huffman_table[pr].bits == bits)
|
||||
goto Found;
|
||||
}
|
||||
Found:;
|
||||
if (pr == NO_PIECE_TYPE)
|
||||
return NO_PIECE;
|
||||
|
||||
// 先後フラグ
|
||||
Color c = (Color)stream.read_one_bit();
|
||||
|
||||
return make_piece(c, pr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// -----------------------------------
|
||||
// Positionクラスに追加
|
||||
// -----------------------------------
|
||||
|
||||
// 高速化のために直接unpackする関数を追加。かなりしんどい。
|
||||
// packer::unpack()とPosition::set()とを合体させて書く。
|
||||
// 渡された局面に問題があって、エラーのときは非0を返す。
|
||||
int Position::set_from_packed_sfen(const PackedSfen& sfen , StateInfo * si, Thread* th, bool mirror)
|
||||
{
|
||||
SfenPacker packer;
|
||||
auto& stream = packer.stream;
|
||||
stream.set_data((uint8_t*)&sfen);
|
||||
|
||||
std::memset(this, 0, sizeof(Position));
|
||||
std::memset(si, 0, sizeof(StateInfo));
|
||||
st = si;
|
||||
|
||||
// Active color
|
||||
sideToMove = (Color)stream.read_one_bit();
|
||||
|
||||
// evalListのclear。上でmemsetでゼロクリアしたときにクリアされているが…。
|
||||
evalList.clear();
|
||||
|
||||
// PieceListを更新する上で、どの駒がどこにあるかを設定しなければならないが、
|
||||
// それぞれの駒をどこまで使ったかのカウンター
|
||||
PieceNumber piece_no_count[KING] = {
|
||||
PIECE_NUMBER_ZERO,
|
||||
PIECE_NUMBER_PAWN,
|
||||
PIECE_NUMBER_KNIGHT,
|
||||
PIECE_NUMBER_BISHOP,
|
||||
PIECE_NUMBER_ROOK,
|
||||
};
|
||||
|
||||
pieceList[W_KING][0] = SQUARE_NB;
|
||||
pieceList[B_KING][0] = SQUARE_NB;
|
||||
|
||||
// まず玉の位置
|
||||
if (mirror)
|
||||
{
|
||||
for (auto c : Colors)
|
||||
board[Mir((Square)stream.read_n_bit(7))] = make_piece(c, KING);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto c : Colors)
|
||||
board[stream.read_n_bit(7)] = make_piece(c, KING);
|
||||
}
|
||||
|
||||
// Piece placement
|
||||
for (Rank r = RANK_8; r >= RANK_1; --r)
|
||||
{
|
||||
for (File f = FILE_A; f <= FILE_H; ++f)
|
||||
{
|
||||
auto sq = make_square(f, r);
|
||||
if (mirror) {
|
||||
sq = Mir(sq);
|
||||
}
|
||||
|
||||
// すでに玉がいるようだ
|
||||
Piece pc;
|
||||
if (type_of(board[sq]) != KING)
|
||||
{
|
||||
assert(board[sq] == NO_PIECE);
|
||||
pc = packer.read_board_piece_from_stream();
|
||||
}
|
||||
else
|
||||
{
|
||||
pc = board[sq];
|
||||
board[sq] = NO_PIECE; // いっかい取り除いておかないとput_piece()でASSERTに引っかかる。
|
||||
}
|
||||
|
||||
// 駒がない場合もあるのでその場合はスキップする。
|
||||
if (pc == NO_PIECE)
|
||||
continue;
|
||||
|
||||
put_piece(Piece(pc), sq);
|
||||
|
||||
// evalListの更新
|
||||
PieceNumber piece_no =
|
||||
(pc == B_KING) ? PIECE_NUMBER_BKING : // 先手玉
|
||||
(pc == W_KING) ? PIECE_NUMBER_WKING : // 後手玉
|
||||
piece_no_count[type_of(pc)]++; // それ以外
|
||||
|
||||
evalList.put_piece(piece_no, sq, pc); // sqの升にpcの駒を配置する
|
||||
|
||||
//cout << sq << ' ' << board[sq] << ' ' << stream.get_cursor() << endl;
|
||||
|
||||
if (stream.get_cursor() > 256)
|
||||
return 1;
|
||||
//assert(stream.get_cursor() <= 256);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Castling availability.
|
||||
// TODO(someone): Support chess960.
|
||||
st->castlingRights = 0;
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(WHITE, SQ_H1); piece_on(rsq) != W_ROOK; --rsq) {}
|
||||
set_castling_right(WHITE, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(WHITE, SQ_A1); piece_on(rsq) != W_ROOK; ++rsq) {}
|
||||
set_castling_right(WHITE, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(BLACK, SQ_H1); piece_on(rsq) != W_ROOK; --rsq) {}
|
||||
set_castling_right(BLACK, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(BLACK, SQ_A1); piece_on(rsq) != W_ROOK; ++rsq) {}
|
||||
set_castling_right(BLACK, rsq);
|
||||
}
|
||||
|
||||
// En passant square. Ignore if no pawn capture is possible
|
||||
if (stream.read_one_bit()) {
|
||||
Square ep_square = static_cast<Square>(stream.read_n_bit(6));
|
||||
st->epSquare = ep_square;
|
||||
|
||||
if (!(attackers_to(st->epSquare) & pieces(sideToMove, PAWN))
|
||||
|| !(pieces(~sideToMove, PAWN) & (st->epSquare + pawn_push(~sideToMove))))
|
||||
st->epSquare = SQ_NONE;
|
||||
}
|
||||
|
||||
// Halfmove clock
|
||||
st->rule50 = static_cast<Square>(stream.read_n_bit(6));
|
||||
|
||||
// Fullmove number
|
||||
gamePly = static_cast<Square>(stream.read_n_bit(8));
|
||||
// Convert from fullmove starting from 1 to gamePly starting from 0,
|
||||
// handle also common incorrect FEN with fullmove = 0.
|
||||
gamePly = std::max(2 * (gamePly - 1), 0) + (sideToMove == BLACK);
|
||||
|
||||
assert(stream.get_cursor() <= 256);
|
||||
|
||||
chess960 = false;
|
||||
thisThread = th;
|
||||
set_state(st);
|
||||
|
||||
assert(pos_is_ok());
|
||||
#if defined(EVAL_NNUE)
|
||||
assert(evalList.is_valid(*this));
|
||||
#endif // defined(EVAL_NNUE)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 盤面と手駒、手番を与えて、そのsfenを返す。
|
||||
//std::string Position::sfen_from_rawdata(Piece board[81], Hand hands[2], Color turn, int gamePly_)
|
||||
//{
|
||||
// // 内部的な構造体にコピーして、sfen()を呼べば、変換過程がそこにしか依存していないならば
|
||||
// // これで正常に変換されるのでは…。
|
||||
// Position pos;
|
||||
//
|
||||
// memcpy(pos.board, board, sizeof(Piece) * 81);
|
||||
// memcpy(pos.hand, hands, sizeof(Hand) * 2);
|
||||
// pos.sideToMove = turn;
|
||||
// pos.gamePly = gamePly_;
|
||||
//
|
||||
// return pos.sfen();
|
||||
//
|
||||
// // ↑の実装、美しいが、いかんせん遅い。
|
||||
// // 棋譜を大量に読み込ませて学習させるときにここがボトルネックになるので直接unpackする関数を書く。
|
||||
//}
|
||||
|
||||
// packされたsfenを得る。引数に指定したバッファに返す。
|
||||
void Position::sfen_pack(PackedSfen& sfen)
|
||||
{
|
||||
SfenPacker sp;
|
||||
sp.data = (uint8_t*)&sfen;
|
||||
sp.pack(*this);
|
||||
}
|
||||
|
||||
//// packされたsfenを解凍する。sfen文字列が返る。
|
||||
//std::string Position::sfen_unpack(const PackedSfen& sfen)
|
||||
//{
|
||||
// SfenPacker sp;
|
||||
// sp.data = (uint8_t*)&sfen;
|
||||
// return sp.unpack();
|
||||
//}
|
||||
|
||||
|
||||
#endif // USE_SFEN_PACKER
|
||||
|
1
src/learn/gensfen2019.cpp
Normal file
1
src/learn/gensfen2019.cpp
Normal file
|
@ -0,0 +1 @@
|
|||
// just a place holder
|
133
src/learn/half_float.h
Normal file
133
src/learn/half_float.h
Normal file
|
@ -0,0 +1,133 @@
|
|||
#ifndef __HALF_FLOAT_H__
|
||||
#define __HALF_FLOAT_H__
|
||||
|
||||
// Half Float Library by yaneurao
|
||||
// (16-bit float)
|
||||
|
||||
// 16bit型による浮動小数点演算
|
||||
// コンパイラの生成するfloat型のコードがIEEE 754の形式であると仮定して、それを利用する。
|
||||
|
||||
#include "../types.h"
|
||||
|
||||
namespace HalfFloat
|
||||
{
|
||||
// IEEE 754 float 32 format is :
|
||||
// sign(1bit) + exponent(8bits) + fraction(23bits) = 32bits
|
||||
//
|
||||
// Our float16 format is :
|
||||
// sign(1bit) + exponent(5bits) + fraction(10bits) = 16bits
|
||||
union float32_converter
|
||||
{
|
||||
int32_t n;
|
||||
float f;
|
||||
};
|
||||
|
||||
|
||||
// 16-bit float
|
||||
struct float16
|
||||
{
|
||||
// --- constructors
|
||||
|
||||
float16() {}
|
||||
float16(int16_t n) { from_float((float)n); }
|
||||
float16(int32_t n) { from_float((float)n); }
|
||||
float16(float n) { from_float(n); }
|
||||
float16(double n) { from_float((float)n); }
|
||||
|
||||
// build from a float
|
||||
void from_float(float f) { *this = to_float16(f); }
|
||||
|
||||
// --- implicit converters
|
||||
|
||||
operator int32_t() const { return (int32_t)to_float(*this); }
|
||||
operator float() const { return to_float(*this); }
|
||||
operator double() const { return double(to_float(*this)); }
|
||||
|
||||
// --- operators
|
||||
|
||||
float16 operator += (float16 rhs) { from_float(to_float(*this) + to_float(rhs)); return *this; }
|
||||
float16 operator -= (float16 rhs) { from_float(to_float(*this) - to_float(rhs)); return *this; }
|
||||
float16 operator *= (float16 rhs) { from_float(to_float(*this) * to_float(rhs)); return *this; }
|
||||
float16 operator /= (float16 rhs) { from_float(to_float(*this) / to_float(rhs)); return *this; }
|
||||
float16 operator + (float16 rhs) const { return float16(*this) += rhs; }
|
||||
float16 operator - (float16 rhs) const { return float16(*this) -= rhs; }
|
||||
float16 operator * (float16 rhs) const { return float16(*this) *= rhs; }
|
||||
float16 operator / (float16 rhs) const { return float16(*this) /= rhs; }
|
||||
float16 operator - () const { return float16(-to_float(*this)); }
|
||||
bool operator == (float16 rhs) const { return this->v_ == rhs.v_; }
|
||||
bool operator != (float16 rhs) const { return !(*this == rhs); }
|
||||
|
||||
static void UnitTest() { unit_test(); }
|
||||
|
||||
private:
|
||||
|
||||
// --- entity
|
||||
|
||||
uint16_t v_;
|
||||
|
||||
// --- conversion between float and float16
|
||||
|
||||
static float16 to_float16(float f)
|
||||
{
|
||||
float32_converter c;
|
||||
c.f = f;
|
||||
u32 n = c.n;
|
||||
|
||||
// The sign bit is MSB in common.
|
||||
uint16_t sign_bit = (n >> 16) & 0x8000;
|
||||
|
||||
// The exponent of IEEE 754's float 32 is biased +127 , so we change this bias into +15 and limited to 5-bit.
|
||||
uint16_t exponent = (((n >> 23) - 127 + 15) & 0x1f) << 10;
|
||||
|
||||
// The fraction is limited to 10-bit.
|
||||
uint16_t fraction = (n >> (23-10)) & 0x3ff;
|
||||
|
||||
float16 f_;
|
||||
f_.v_ = sign_bit | exponent | fraction;
|
||||
|
||||
return f_;
|
||||
}
|
||||
|
||||
static float to_float(float16 v)
|
||||
{
|
||||
u32 sign_bit = (v.v_ & 0x8000) << 16;
|
||||
u32 exponent = ((((v.v_ >> 10) & 0x1f) - 15 + 127) & 0xff) << 23;
|
||||
u32 fraction = (v.v_ & 0x3ff) << (23 - 10);
|
||||
|
||||
float32_converter c;
|
||||
c.n = sign_bit | exponent | fraction;
|
||||
return c.f;
|
||||
}
|
||||
|
||||
// unit testになってないが、一応計算が出来ることは確かめた。コードはあとでなおす(かも)。
|
||||
static void unit_test()
|
||||
{
|
||||
float16 a, b, c, d;
|
||||
a = 1;
|
||||
std::cout << (float)a << std::endl;
|
||||
b = -118.625;
|
||||
std::cout << (float)b << std::endl;
|
||||
c = 2.5;
|
||||
std::cout << (float)c << std::endl;
|
||||
d = a + c;
|
||||
std::cout << (float)d << std::endl;
|
||||
|
||||
c *= 1.5;
|
||||
std::cout << (float)c << std::endl;
|
||||
|
||||
b /= 3;
|
||||
std::cout << (float)b << std::endl;
|
||||
|
||||
float f1 = 1.5;
|
||||
a += f1;
|
||||
std::cout << (float)a << std::endl;
|
||||
|
||||
a += f1 * (float)a;
|
||||
std::cout << (float)a << std::endl;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // __HALF_FLOAT_H__
|
237
src/learn/learn.h
Normal file
237
src/learn/learn.h
Normal file
|
@ -0,0 +1,237 @@
|
|||
#ifndef _LEARN_H_
|
||||
#define _LEARN_H_
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
|
||||
#include <vector>
|
||||
|
||||
// =====================
|
||||
// 学習時の設定
|
||||
// =====================
|
||||
|
||||
// 以下のいずれかを選択すれば、そのあとの細々したものは自動的に選択される。
|
||||
// いずれも選択しない場合は、そのあとの細々したものをひとつひとつ設定する必要がある。
|
||||
|
||||
// elmo方式での学習設定。これをデフォルト設定とする。
|
||||
// 標準の雑巾絞りにするためにはlearnコマンドで "lambda 1"を指定してやれば良い。
|
||||
#define LEARN_ELMO_METHOD
|
||||
|
||||
|
||||
// ----------------------
|
||||
// 更新式
|
||||
// ----------------------
|
||||
|
||||
// AdaGrad。これが安定しているのでお勧め。
|
||||
// #define ADA_GRAD_UPDATE
|
||||
|
||||
// 勾配の符号だけ見るSGD。省メモリで済むが精度は…。
|
||||
// #define SGD_UPDATE
|
||||
|
||||
// ----------------------
|
||||
// 学習時の設定
|
||||
// ----------------------
|
||||
|
||||
// mini-batchサイズ。
|
||||
// この数だけの局面をまとめて勾配を計算する。
|
||||
// 小さくするとupdate_weights()の回数が増えるので収束が速くなる。勾配が不正確になる。
|
||||
// 大きくするとupdate_weights()の回数が減るので収束が遅くなる。勾配は正確に出るようになる。
|
||||
// 多くの場合において、この値を変更する必要はないと思う。
|
||||
|
||||
#define LEARN_MINI_BATCH_SIZE (1000 * 1000 * 1)
|
||||
|
||||
// ファイルから1回に読み込む局面数。これだけ読み込んだあとshuffleする。
|
||||
// ある程度大きいほうが良いが、この数×40byte×3倍ぐらいのメモリを消費する。10M局面なら400MB*3程度消費する。
|
||||
// THREAD_BUFFER_SIZE(=10000)の倍数にすること。
|
||||
|
||||
#define LEARN_SFEN_READ_SIZE (1000 * 1000 * 10)
|
||||
|
||||
// 学習時の評価関数の保存間隔。この局面数だけ学習させるごとに保存。
|
||||
// 当然ながら、保存間隔を長くしたほうが学習時間は短くなる。
|
||||
// フォルダ名は 0/ , 1/ , 2/ ...のように保存ごとにインクリメントされていく。
|
||||
// デフォルトでは10億局面に1回。
|
||||
#define LEARN_EVAL_SAVE_INTERVAL (1000000000ULL)
|
||||
|
||||
|
||||
// ----------------------
|
||||
// 目的関数の選択
|
||||
// ----------------------
|
||||
|
||||
// 目的関数が勝率の差の二乗和
|
||||
// 詳しい説明は、learner.cppを見ること。
|
||||
|
||||
//#define LOSS_FUNCTION_IS_WINNING_PERCENTAGE
|
||||
|
||||
// 目的関数が交差エントロピー
|
||||
// 詳しい説明は、learner.cppを見ること。
|
||||
// いわゆる、普通の「雑巾絞り」
|
||||
//#define LOSS_FUNCTION_IS_CROSS_ENTOROPY
|
||||
|
||||
// 目的関数が交差エントロピーだが、勝率の関数を通さない版
|
||||
// #define LOSS_FUNCTION_IS_CROSS_ENTOROPY_FOR_VALUE
|
||||
|
||||
// elmo(WCSC27)の方式
|
||||
// #define LOSS_FUNCTION_IS_ELMO_METHOD
|
||||
|
||||
// ※ 他、色々追加するかも。
|
||||
|
||||
|
||||
// ----------------------
|
||||
// 学習に関するデバッグ設定
|
||||
// ----------------------
|
||||
|
||||
// 学習時のrmseの出力をこの回数に1回に減らす。
|
||||
// rmseの計算は1スレッドで行なうためそこそこ時間をとられるので出力を減らすと効果がある。
|
||||
#define LEARN_RMSE_OUTPUT_INTERVAL 1
|
||||
|
||||
|
||||
// ----------------------
|
||||
// ゼロベクトルからの学習
|
||||
// ----------------------
|
||||
|
||||
// 評価関数パラメーターをゼロベクトルから学習を開始する。
|
||||
// ゼロ初期化して棋譜生成してゼロベクトルから学習させて、
|
||||
// 棋譜生成→学習を繰り返すとプロの棋譜に依らないパラメーターが得られる。(かも)
|
||||
// (すごく時間かかる)
|
||||
|
||||
//#define RESET_TO_ZERO_VECTOR
|
||||
|
||||
|
||||
// ----------------------
|
||||
// 学習のときの浮動小数
|
||||
// ----------------------
|
||||
|
||||
// これをdoubleにしたほうが計算精度は上がるが、重み配列絡みのメモリが倍必要になる。
|
||||
// 現状、ここをfloatにした場合、評価関数ファイルに対して、重み配列はその4.5倍のサイズ。(KPPTで4.5GB程度)
|
||||
// double型にしても収束の仕方にほとんど差異がなかったのでfloatに固定する。
|
||||
|
||||
// floatを使う場合
|
||||
typedef float LearnFloatType;
|
||||
|
||||
// doubleを使う場合
|
||||
//typedef double LearnFloatType;
|
||||
|
||||
// float16を使う場合
|
||||
//#include "half_float.h"
|
||||
//typedef HalfFloat::float16 LearnFloatType;
|
||||
|
||||
// ----------------------
|
||||
// 省メモリ化
|
||||
// ----------------------
|
||||
|
||||
// Weight配列(のうちのKPP)に三角配列を用いて省メモリ化する。
|
||||
// これを用いると、学習用の重み配列は評価関数ファイルの3倍程度で済むようになる。
|
||||
|
||||
#define USE_TRIANGLE_WEIGHT_ARRAY
|
||||
|
||||
// ----------------------
|
||||
// 次元下げ
|
||||
// ----------------------
|
||||
|
||||
// ミラー(左右対称性)、インバース(先後対称性)に関して次元下げを行なう。
|
||||
// デフォルトではすべてオン。
|
||||
|
||||
// KKに対してミラー、インバースを利用した次元下げを行なう。(効果のほどは不明)
|
||||
// USE_KK_INVERSE_WRITEをオンにするときはUSE_KK_MIRROR_WRITEもオンでなければならない。
|
||||
#define USE_KK_MIRROR_WRITE
|
||||
#define USE_KK_INVERSE_WRITE
|
||||
|
||||
// KKPに対してミラー、インバースを利用した次元下げを行なう。(インバースのほうは効果のほどは不明)
|
||||
// USE_KKP_INVERSE_WRITEをオンにするときは、USE_KKP_MIRROR_WRITEもオンになっていなければならない。
|
||||
#define USE_KKP_MIRROR_WRITE
|
||||
#define USE_KKP_INVERSE_WRITE
|
||||
|
||||
// KPPに対してミラーを利用した次元下げを行なう。(これをオフにすると教師局面が倍ぐらい必要になる)
|
||||
// KPPにはインバースはない。(先手側のKしかないので)
|
||||
#define USE_KPP_MIRROR_WRITE
|
||||
|
||||
// KPPPに対してミラーを利用した次元下げを行なう。(これをオフにすると教師局面が倍ぐらい必要になる)
|
||||
// KPPPにもインバースはない。(先手側のKしかないので)
|
||||
#define USE_KPPP_MIRROR_WRITE
|
||||
|
||||
// KKPP成分に対して学習時にKPPによる次元下げを行なう。
|
||||
// 学習、めっちゃ遅くなる。
|
||||
// 未デバッグなので使わないこと。
|
||||
//#define USE_KKPP_LOWER_DIM
|
||||
|
||||
|
||||
// ======================
|
||||
// 教師局面生成時の設定
|
||||
// ======================
|
||||
|
||||
// ----------------------
|
||||
// 引き分けを書き出す
|
||||
// ----------------------
|
||||
|
||||
// 引き分けに至ったとき、それを教師局面として書き出す
|
||||
// これをするほうが良いかどうかは微妙。
|
||||
// #define LEARN_GENSFEN_USE_DRAW_RESULT
|
||||
|
||||
|
||||
// ======================
|
||||
// configure
|
||||
// ======================
|
||||
|
||||
// ----------------------
|
||||
// elmo(WCSC27)の方法での学習
|
||||
// ----------------------
|
||||
|
||||
#if defined( LEARN_ELMO_METHOD )
|
||||
#define LOSS_FUNCTION_IS_ELMO_METHOD
|
||||
#define ADA_GRAD_UPDATE
|
||||
#endif
|
||||
|
||||
|
||||
// ----------------------
|
||||
// Learnerで用いるstructの定義
|
||||
// ----------------------
|
||||
#include "../position.h"
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
// PackedSfenと評価値が一体化した構造体
|
||||
// オプションごとに書き出す内容が異なると教師棋譜を再利用するときに困るので
|
||||
// とりあえず、以下のメンバーはオプションによらずすべて書き出しておく。
|
||||
struct PackedSfenValue
|
||||
{
|
||||
// 局面
|
||||
PackedSfen sfen;
|
||||
|
||||
// Learner::search()から返ってきた評価値
|
||||
int16_t score;
|
||||
|
||||
// PVの初手
|
||||
// 教師との指し手一致率を求めるときなどに用いる
|
||||
uint16_t move;
|
||||
|
||||
// 初期局面からの局面の手数。
|
||||
uint16_t gamePly;
|
||||
|
||||
// この局面の手番側が、ゲームを最終的に勝っているなら1。負けているなら-1。
|
||||
// 引き分けに至った場合は、0。
|
||||
// 引き分けは、教師局面生成コマンドgensfenにおいて、
|
||||
// LEARN_GENSFEN_DRAW_RESULTが有効なときにだけ書き出す。
|
||||
int8_t game_result;
|
||||
|
||||
// 教師局面を書き出したファイルを他の人とやりとりするときに
|
||||
// この構造体サイズが不定だと困るため、paddingしてどの環境でも必ず40bytesになるようにしておく。
|
||||
uint8_t padding;
|
||||
|
||||
// 32 + 2 + 2 + 2 + 1 + 1 = 40bytes
|
||||
};
|
||||
|
||||
// 読み筋とそのときの評価値を返す型
|
||||
// Learner::search() , Learner::qsearch()で用いる。
|
||||
typedef std::pair<Value, std::vector<Move> > ValueAndPV;
|
||||
|
||||
// いまのところ、やねうら王2018 Otafukuしか、このスタブを持っていないが
|
||||
// EVAL_LEARNをdefineするなら、このスタブが必須。
|
||||
extern Learner::ValueAndPV search(Position& pos, int depth , size_t multiPV = 1 , uint64_t NodesLimit = 0);
|
||||
extern Learner::ValueAndPV qsearch(Position& pos);
|
||||
|
||||
double calc_grad(Value shallow, const PackedSfenValue& psv);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif // ifndef _LEARN_H_
|
2922
src/learn/learner.cpp
Normal file
2922
src/learn/learner.cpp
Normal file
File diff suppressed because it is too large
Load diff
256
src/learn/learning_tools.cpp
Normal file
256
src/learn/learning_tools.cpp
Normal file
|
@ -0,0 +1,256 @@
|
|||
#include "learning_tools.h"
|
||||
|
||||
#if defined (EVAL_LEARN)
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
#include "../misc.h"
|
||||
|
||||
using namespace Eval;
|
||||
|
||||
namespace EvalLearningTools
|
||||
{
|
||||
|
||||
// --- static variables
|
||||
|
||||
double Weight::eta;
|
||||
double Weight::eta1;
|
||||
double Weight::eta2;
|
||||
double Weight::eta3;
|
||||
uint64_t Weight::eta1_epoch;
|
||||
uint64_t Weight::eta2_epoch;
|
||||
|
||||
std::vector<bool> min_index_flag;
|
||||
|
||||
// --- 個別のテーブルごとの初期化
|
||||
|
||||
void init_min_index_flag()
|
||||
{
|
||||
// mir_piece、inv_pieceの初期化が終わっていなければならない。
|
||||
assert(mir_piece(Eval::f_pawn) == Eval::e_pawn);
|
||||
|
||||
// 次元下げ用フラグ配列の初期化
|
||||
// KPPPに関しては関与しない。
|
||||
|
||||
KK g_kk;
|
||||
g_kk.set(SQUARE_NB, Eval::fe_end, 0);
|
||||
KKP g_kkp;
|
||||
g_kkp.set(SQUARE_NB, Eval::fe_end, g_kk.max_index());
|
||||
KPP g_kpp;
|
||||
g_kpp.set(SQUARE_NB, Eval::fe_end, g_kkp.max_index());
|
||||
|
||||
uint64_t size = g_kpp.max_index();
|
||||
min_index_flag.resize(size);
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
#if defined(_OPENMP)
|
||||
// Windows環境下でCPUが2つあるときに、論理64コアまでしか使用されないのを防ぐために
|
||||
// ここで明示的にCPUに割り当てる
|
||||
int thread_index = omp_get_thread_num(); // 自分のthread numberを取得
|
||||
WinProcGroup::bindThisThread(thread_index);
|
||||
#endif
|
||||
|
||||
#pragma omp for schedule(dynamic,20000)
|
||||
|
||||
for (int64_t index_ = 0; index_ < (int64_t)size; ++index_)
|
||||
{
|
||||
// OpenMPの制約からループ変数は符号型でないといけないらしいのだが、
|
||||
// さすがに使いにくい。
|
||||
uint64_t index = (uint64_t)index_;
|
||||
|
||||
if (g_kk.is_ok(index))
|
||||
{
|
||||
// indexからの変換と逆変換によって元のindexに戻ることを確認しておく。
|
||||
// 起動時に1回しか実行しない処理なのでassertで書いておく。
|
||||
assert(g_kk.fromIndex(index).toIndex() == index);
|
||||
|
||||
KK a[KK_LOWER_COUNT];
|
||||
g_kk.fromIndex(index).toLowerDimensions(a);
|
||||
|
||||
// 次元下げの1つ目の要素が元のindexと同一であることを確認しておく。
|
||||
assert(a[0].toIndex() == index);
|
||||
|
||||
uint64_t min_index = UINT64_MAX;
|
||||
for (auto& e : a)
|
||||
min_index = std::min(min_index, e.toIndex());
|
||||
min_index_flag[index] = (min_index == index);
|
||||
}
|
||||
else if (g_kkp.is_ok(index))
|
||||
{
|
||||
assert(g_kkp.fromIndex(index).toIndex() == index);
|
||||
|
||||
KKP x = g_kkp.fromIndex(index);
|
||||
KKP a[KKP_LOWER_COUNT];
|
||||
x.toLowerDimensions(a);
|
||||
|
||||
assert(a[0].toIndex() == index);
|
||||
|
||||
uint64_t min_index = UINT64_MAX;
|
||||
for (auto& e : a)
|
||||
min_index = std::min(min_index, e.toIndex());
|
||||
min_index_flag[index] = (min_index == index);
|
||||
}
|
||||
else if (g_kpp.is_ok(index))
|
||||
{
|
||||
assert(g_kpp.fromIndex(index).toIndex() == index);
|
||||
|
||||
KPP x = g_kpp.fromIndex(index);
|
||||
KPP a[KPP_LOWER_COUNT];
|
||||
x.toLowerDimensions(a);
|
||||
|
||||
assert(a[0].toIndex() == index);
|
||||
|
||||
uint64_t min_index = UINT64_MAX;
|
||||
for (auto& e : a)
|
||||
min_index = std::min(min_index, e.toIndex());
|
||||
min_index_flag[index] = (min_index == index);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void learning_tools_unit_test_kpp()
|
||||
{
|
||||
|
||||
// KPPの三角配列化にバグがないかテストする
|
||||
// k-p0-p1のすべての組み合わせがきちんとKPPの扱う対象になっていかと、そのときの次元下げが
|
||||
// 正しいかを判定する。
|
||||
|
||||
KK g_kk;
|
||||
g_kk.set(SQUARE_NB, Eval::fe_end, 0);
|
||||
KKP g_kkp;
|
||||
g_kkp.set(SQUARE_NB, Eval::fe_end, g_kk.max_index());
|
||||
KPP g_kpp;
|
||||
g_kpp.set(SQUARE_NB, Eval::fe_end, g_kkp.max_index());
|
||||
|
||||
std::vector<bool> f;
|
||||
f.resize(g_kpp.max_index() - g_kpp.min_index());
|
||||
|
||||
for(auto k = SQUARE_ZERO ; k < SQUARE_NB ; ++k)
|
||||
for(auto p0 = BonaPiece::BONA_PIECE_ZERO; p0 < fe_end ; ++p0)
|
||||
for (auto p1 = BonaPiece::BONA_PIECE_ZERO; p1 < fe_end; ++p1)
|
||||
{
|
||||
KPP kpp_org = g_kpp.fromKPP(k,p0,p1);
|
||||
KPP kpp0;
|
||||
KPP kpp1 = g_kpp.fromKPP(Mir(k), mir_piece(p0), mir_piece(p1));
|
||||
KPP kpp_array[2];
|
||||
|
||||
auto index = kpp_org.toIndex();
|
||||
assert(g_kpp.is_ok(index));
|
||||
|
||||
kpp0 = g_kpp.fromIndex(index);
|
||||
|
||||
//if (kpp0 != kpp_org)
|
||||
// std::cout << "index = " << index << "," << kpp_org << "," << kpp0 << std::endl;
|
||||
|
||||
kpp0.toLowerDimensions(kpp_array);
|
||||
|
||||
assert(kpp_array[0] == kpp0);
|
||||
assert(kpp0 == kpp_org);
|
||||
assert(kpp_array[1] == kpp1);
|
||||
|
||||
auto index2 = kpp1.toIndex();
|
||||
f[index - g_kpp.min_index()] = f[index2-g_kpp.min_index()] = true;
|
||||
}
|
||||
|
||||
// 抜けてるindexがなかったかの確認。
|
||||
for(size_t index = 0 ; index < f.size(); index++)
|
||||
if (!f[index])
|
||||
{
|
||||
std::cout << index << g_kpp.fromIndex(index + g_kpp.min_index()) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void learning_tools_unit_test_kppp()
|
||||
{
|
||||
// KPPPの計算に抜けがないかをテストする
|
||||
|
||||
KPPP g_kppp;
|
||||
g_kppp.set(15, Eval::fe_end,0);
|
||||
uint64_t min_index = g_kppp.min_index();
|
||||
uint64_t max_index = g_kppp.max_index();
|
||||
|
||||
// 最後の要素の確認。
|
||||
//KPPP x = KPPP::fromIndex(max_index-1);
|
||||
//std::cout << x << std::endl;
|
||||
|
||||
for (uint64_t index = min_index; index < max_index; ++index)
|
||||
{
|
||||
KPPP x = g_kppp.fromIndex(index);
|
||||
//std::cout << x << std::endl;
|
||||
|
||||
#if 0
|
||||
if ((index % 10000000) == 0)
|
||||
std::cout << "index = " << index << std::endl;
|
||||
|
||||
// index = 9360000000
|
||||
// done.
|
||||
|
||||
if (x.toIndex() != index)
|
||||
{
|
||||
std::cout << "assertion failed , index = " << index << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
assert(x.toIndex() == index);
|
||||
|
||||
// ASSERT((&kppp_ksq_pcpcpc(x.king(), x.piece0(), x.piece1(), x.piece2()) - &kppp[0][0]) == (index - min_index));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void learning_tools_unit_test_kkpp()
|
||||
{
|
||||
KKPP g_kkpp;
|
||||
g_kkpp.set(SQUARE_NB, 10000 , 0);
|
||||
uint64_t n = 0;
|
||||
for (int k = 0; k<SQUARE_NB; ++k)
|
||||
for (int i = 0; i<10000; ++i) // 試しに、かなり大きなfe_endを想定して10000で回してみる。
|
||||
for (int j = 0; j < i; ++j)
|
||||
{
|
||||
auto kkpp = g_kkpp.fromKKPP(k, (BonaPiece)i, (BonaPiece)j);
|
||||
auto r = kkpp.toRawIndex();
|
||||
assert(n++ == r);
|
||||
auto kkpp2 = g_kkpp.fromIndex(r + g_kkpp.min_index());
|
||||
assert(kkpp2.king() == k && kkpp2.piece0() == i && kkpp2.piece1() == j);
|
||||
}
|
||||
}
|
||||
|
||||
// このEvalLearningTools全体の初期化
|
||||
void init()
|
||||
{
|
||||
// 初期化は、起動後1回限りで良いのでそのためのフラグ。
|
||||
static bool first = true;
|
||||
|
||||
if (first)
|
||||
{
|
||||
std::cout << "EvalLearningTools init..";
|
||||
|
||||
// mir_piece()とinv_piece()を利用可能にする。
|
||||
// このあとmin_index_flagの初期化を行なうが、そこが
|
||||
// これに依存しているので、こちらを先に行なう必要がある。
|
||||
init_mir_inv_tables();
|
||||
|
||||
//learning_tools_unit_test_kpp();
|
||||
//learning_tools_unit_test_kppp();
|
||||
//learning_tools_unit_test_kkpp();
|
||||
|
||||
// UnitTestを実行するの最後でも良いのだが、init_min_index_flag()にとても時間がかかるので
|
||||
// デバッグ時はこのタイミングで行いたい。
|
||||
|
||||
init_min_index_flag();
|
||||
|
||||
std::cout << "done." << std::endl;
|
||||
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
1032
src/learn/learning_tools.h
Normal file
1032
src/learn/learning_tools.h
Normal file
File diff suppressed because it is too large
Load diff
123
src/learn/multi_think.cpp
Normal file
123
src/learn/multi_think.cpp
Normal file
|
@ -0,0 +1,123 @@
|
|||
#include "../types.h"
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
|
||||
#include "multi_think.h"
|
||||
#include "../tt.h"
|
||||
#include "../uci.h"
|
||||
|
||||
#include <thread>
|
||||
|
||||
void MultiThink::go_think()
|
||||
{
|
||||
// あとでOptionsの設定を復元するためにコピーで保持しておく。
|
||||
auto oldOptions = Options;
|
||||
|
||||
// 定跡を用いる場合、on the flyで行なうとすごく時間がかかる&ファイルアクセスを行なう部分が
|
||||
// thread safeではないので、メモリに丸読みされている状態であることをここで保証する。
|
||||
Options["BookOnTheFly"] = std::string("false");
|
||||
|
||||
// 評価関数の読み込み等
|
||||
// learnコマンドの場合、評価関数読み込み後に評価関数の値を補正している可能性があるので、
|
||||
// メモリの破損チェックは省略する。
|
||||
is_ready(true);
|
||||
|
||||
// 派生クラスのinit()を呼び出す。
|
||||
init();
|
||||
|
||||
// ループ上限はset_loop_max()で設定されているものとする。
|
||||
loop_count = 0;
|
||||
done_count = 0;
|
||||
|
||||
// threadをOptions["Threads"]の数だけ生成して思考開始。
|
||||
std::vector<std::thread> threads;
|
||||
auto thread_num = (size_t)Options["Threads"];
|
||||
|
||||
// worker threadの終了フラグの確保
|
||||
thread_finished.resize(thread_num);
|
||||
|
||||
// worker threadの起動
|
||||
for (size_t i = 0; i < thread_num; ++i)
|
||||
{
|
||||
thread_finished[i] = 0;
|
||||
threads.push_back(std::thread([i, this]
|
||||
{
|
||||
// プロセッサの全スレッドを使い切る。
|
||||
WinProcGroup::bindThisThread(i);
|
||||
|
||||
// オーバーライドされている処理を実行
|
||||
this->thread_worker(i);
|
||||
|
||||
// スレッドが終了したので終了フラグを立てる
|
||||
this->thread_finished[i] = 1;
|
||||
}));
|
||||
}
|
||||
|
||||
// すべてのthreadの終了待ちを
|
||||
// for (auto& th : threads)
|
||||
// th.join();
|
||||
// のように書くとスレッドがまだ仕事をしている状態でここに突入するので、
|
||||
// その間、callback_func()が呼び出せず、セーブできなくなる。
|
||||
// そこで終了フラグを自前でチェックする必要がある。
|
||||
|
||||
// すべてのスレッドが終了したかを判定する関数
|
||||
auto threads_done = [&]()
|
||||
{
|
||||
// ひとつでも終了していなければfalseを返す
|
||||
for (auto& f : thread_finished)
|
||||
if (!f)
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
// コールバック関数が設定されているならコールバックする。
|
||||
auto do_a_callback = [&]()
|
||||
{
|
||||
if (callback_func)
|
||||
callback_func();
|
||||
};
|
||||
|
||||
|
||||
for (uint64_t i = 0 ; ; )
|
||||
{
|
||||
// 全スレッドが終了していたら、ループを抜ける。
|
||||
if (threads_done())
|
||||
break;
|
||||
|
||||
sleep(1000);
|
||||
|
||||
// callback_secondsごとにcallback_func()が呼び出される。
|
||||
if (++i == callback_seconds)
|
||||
{
|
||||
do_a_callback();
|
||||
// ↑から戻ってきてからカウンターをリセットしているので、
|
||||
// do_a_callback()のなかでsave()などにどれだけ時間がかかろうと
|
||||
// 次に呼び出すのは、そこから一定時間の経過を要する。
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// 最後の保存。
|
||||
std::cout << std::endl << "finalize..";
|
||||
|
||||
// do_a_callback();
|
||||
// → 呼び出し元で保存するはずで、ここでは要らない気がする。
|
||||
|
||||
// 終了したフラグは立っているがスレッドの終了コードの実行中であるということはありうるので
|
||||
// join()でその終了を待つ必要がある。
|
||||
for (auto& th : threads)
|
||||
th.join();
|
||||
|
||||
// 全スレッドが終了しただけでfileの書き出しスレッドなどはまだ動いていて
|
||||
// 作業自体は完了していない可能性があるのでスレッドがすべて終了したことだけ出力する。
|
||||
std::cout << "all threads are joined." << std::endl;
|
||||
|
||||
// Optionsを書き換えたので復元。
|
||||
// 値を代入しないとハンドラが起動しないのでこうやって復元する。
|
||||
for (auto& s : oldOptions)
|
||||
Options[s.first] = std::string(s.second);
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif // defined(EVAL_LEARN)
|
151
src/learn/multi_think.h
Normal file
151
src/learn/multi_think.h
Normal file
|
@ -0,0 +1,151 @@
|
|||
#ifndef _MULTI_THINK_
|
||||
#define _MULTI_THINK_
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "../misc.h"
|
||||
#include "../learn/learn.h"
|
||||
#include "../thread_win32_osx.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
// 棋譜からの学習や、自ら思考させて定跡を生成するときなど、
|
||||
// 複数スレッドが個別にSearch::think()を呼び出したいときに用いるヘルパクラス。
|
||||
// このクラスを派生させて用いる。
|
||||
struct MultiThink
|
||||
{
|
||||
MultiThink() : prng(21120903)
|
||||
{
|
||||
loop_count = 0;
|
||||
}
|
||||
|
||||
// マスタースレッドからこの関数を呼び出すと、スレッドがそれぞれ思考して、
|
||||
// 思考終了条件を満たしたところで制御を返す。
|
||||
// 他にやってくれること。
|
||||
// ・各スレッドがLearner::search(),qsearch()を呼び出しても安全なように
|
||||
// 置換表をスレッドごとに分離してくれる。(終了後、元に戻してくれる。)
|
||||
// ・bookはon the flyモードだとthread safeではないので、このモードを一時的に
|
||||
// オフにしてくれる。
|
||||
// [要件]
|
||||
// 1) thread_worker()のオーバーライド
|
||||
// 2) set_loop_max()でループ回数の設定
|
||||
// 3) 定期的にcallbackされる関数を設定する(必要なら)
|
||||
// callback_funcとcallback_interval
|
||||
void go_think();
|
||||
|
||||
// 派生クラス側で初期化したいものがあればこれをoverrideしておけば、
|
||||
// go_think()で初期化が終わったタイミングで呼び出される。
|
||||
// 定跡の読み込みなどはそのタイミングで行うと良い。
|
||||
virtual void init() {}
|
||||
|
||||
// go_think()したときにスレッドを生成して呼び出されるthread worker
|
||||
// これをoverrideして用いる。
|
||||
virtual void thread_worker(size_t thread_id) = 0;
|
||||
|
||||
// go_think()したときにcallback_seconds[秒]ごとにcallbackされる。
|
||||
std::function<void()> callback_func;
|
||||
uint64_t callback_seconds = 600;
|
||||
|
||||
// workerが処理する(Search::think()を呼び出す)回数を設定する。
|
||||
void set_loop_max(uint64_t loop_max_) { loop_max = loop_max_; }
|
||||
|
||||
// set_loop_max()で設定した値を取得する。
|
||||
uint64_t get_loop_max() const { return loop_max; }
|
||||
|
||||
// [ASYNC] ループカウンターの値を取り出して、取り出し後にループカウンターを加算する。
|
||||
// もしループカウンターがloop_maxに達していたらUINT64_MAXを返す。
|
||||
// 局面を生成する場合などは、局面を生成するタイミングでこの関数を呼び出すようにしないと、
|
||||
// 生成した局面数と、カウンターの値が一致しなくなってしまうので注意すること。
|
||||
uint64_t get_next_loop_count() {
|
||||
std::unique_lock<Mutex> lk(loop_mutex);
|
||||
if (loop_count >= loop_max)
|
||||
return UINT64_MAX;
|
||||
return loop_count++;
|
||||
}
|
||||
|
||||
// [ASYNC] 処理した個数を返す用。呼び出されるごとにインクリメントされたカウンターが返る。
|
||||
uint64_t get_done_count() {
|
||||
std::unique_lock<Mutex> lk(loop_mutex);
|
||||
return ++done_count;
|
||||
}
|
||||
|
||||
// worker threadがI/Oにアクセスするときのmutex
|
||||
Mutex io_mutex;
|
||||
|
||||
protected:
|
||||
// 乱数発生器本体
|
||||
AsyncPRNG prng;
|
||||
|
||||
private:
|
||||
// workerが処理する(Search::think()を呼び出す)回数
|
||||
std::atomic<uint64_t> loop_max;
|
||||
// workerが処理した(Search::think()を呼び出した)回数
|
||||
std::atomic<uint64_t> loop_count;
|
||||
// 処理した回数を返す用。
|
||||
std::atomic<uint64_t> done_count;
|
||||
|
||||
// ↑の変数を変更するときのmutex
|
||||
Mutex loop_mutex;
|
||||
|
||||
// スレッドの終了フラグ。
|
||||
// vector<bool>にすると複数スレッドから書き換えようとしたときに正しく反映されないことがある…はず。
|
||||
typedef uint8_t Flag;
|
||||
std::vector<Flag> thread_finished;
|
||||
|
||||
};
|
||||
|
||||
// idle時間にtaskを処理する仕組み。
|
||||
// masterは好きなときにpush_task_async()でtaskを渡す。
|
||||
// slaveは暇なときにon_idle()を実行すると、taskを一つ取り出してqueueがなくなるまで実行を続ける。
|
||||
// MultiThinkのthread workerをmaster-slave方式で書きたいときに用いると便利。
|
||||
struct TaskDispatcher
|
||||
{
|
||||
typedef std::function<void(size_t /* thread_id */)> Task;
|
||||
|
||||
// slaveはidle中にこの関数を呼び出す。
|
||||
void on_idle(size_t thread_id)
|
||||
{
|
||||
Task task;
|
||||
while ((task = get_task_async()) != nullptr)
|
||||
task(thread_id);
|
||||
|
||||
sleep(1);
|
||||
}
|
||||
|
||||
// [ASYNC] taskを一つ積む。
|
||||
void push_task_async(Task task)
|
||||
{
|
||||
std::unique_lock<Mutex> lk(task_mutex);
|
||||
tasks.push_back(task);
|
||||
}
|
||||
|
||||
// task用の配列の要素をsize分だけ事前に確保する。
|
||||
void task_reserve(size_t size)
|
||||
{
|
||||
tasks.reserve(size);
|
||||
}
|
||||
|
||||
protected:
|
||||
// taskの集合
|
||||
std::vector<Task> tasks;
|
||||
|
||||
// [ASYNC] taskを一つ取り出す。on_idle()から呼び出される。
|
||||
Task get_task_async()
|
||||
{
|
||||
std::unique_lock<Mutex> lk(task_mutex);
|
||||
if (tasks.size() == 0)
|
||||
return nullptr;
|
||||
Task task = *tasks.rbegin();
|
||||
tasks.pop_back();
|
||||
return task;
|
||||
}
|
||||
|
||||
// tasksにアクセスするとき用のmutex
|
||||
Mutex task_mutex;
|
||||
};
|
||||
|
||||
#endif // defined(EVAL_LEARN) && defined(YANEURAOU_2018_OTAFUKU_ENGINE)
|
||||
|
||||
#endif
|
146
src/misc.cpp
146
src/misc.cpp
|
@ -42,6 +42,7 @@ typedef bool(*fun3_t)(HANDLE, CONST GROUP_AFFINITY*, PGROUP_AFFINITY);
|
|||
#endif
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
@ -316,6 +317,27 @@ void bindThisThread(size_t idx) {
|
|||
|
||||
} // namespace WinProcGroup
|
||||
|
||||
// 現在時刻を文字列化したもを返す。(評価関数の学習時などに用いる)
|
||||
std::string now_string()
|
||||
{
|
||||
// std::ctime(), localtime()を使うと、MSVCでセキュアでないという警告が出る。
|
||||
// C++標準的にはそんなことないはずなのだが…。
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
// C4996 : 'ctime' : This function or variable may be unsafe.Consider using ctime_s instead.
|
||||
#pragma warning(disable : 4996)
|
||||
#endif
|
||||
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto tp = std::chrono::system_clock::to_time_t(now);
|
||||
auto result = string(std::ctime(&tp));
|
||||
|
||||
// 末尾に改行コードが含まれているならこれを除去する
|
||||
while (*result.rbegin() == '\n' || (*result.rbegin() == '\r'))
|
||||
result.pop_back();
|
||||
return result;
|
||||
}
|
||||
|
||||
void sleep(int ms)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(ms));
|
||||
|
@ -331,3 +353,127 @@ void* aligned_malloc(size_t size, size_t align)
|
|||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
int read_file_to_memory(std::string filename, std::function<void* (uint64_t)> callback_func)
|
||||
{
|
||||
fstream fs(filename, ios::in | ios::binary);
|
||||
if (fs.fail())
|
||||
return 1;
|
||||
|
||||
fs.seekg(0, fstream::end);
|
||||
uint64_t eofPos = (uint64_t)fs.tellg();
|
||||
fs.clear(); // これをしないと次のseekに失敗することがある。
|
||||
fs.seekg(0, fstream::beg);
|
||||
uint64_t begPos = (uint64_t)fs.tellg();
|
||||
uint64_t file_size = eofPos - begPos;
|
||||
//std::cout << "filename = " << filename << " , file_size = " << file_size << endl;
|
||||
|
||||
// ファイルサイズがわかったのでcallback_funcを呼び出してこの分のバッファを確保してもらい、
|
||||
// そのポインターをもらう。
|
||||
void* ptr = callback_func(file_size);
|
||||
|
||||
// バッファが確保できなかった場合や、想定していたファイルサイズと異なった場合は、
|
||||
// nullptrを返すことになっている。このとき、読み込みを中断し、エラーリターンする。
|
||||
if (ptr == nullptr)
|
||||
return 2;
|
||||
|
||||
// 細切れに読み込む
|
||||
|
||||
const uint64_t block_size = 1024 * 1024 * 1024; // 1回のreadで読み込む要素の数(1GB)
|
||||
for (uint64_t pos = 0; pos < file_size; pos += block_size)
|
||||
{
|
||||
// 今回読み込むサイズ
|
||||
uint64_t read_size = (pos + block_size < file_size) ? block_size : (file_size - pos);
|
||||
fs.read((char*)ptr + pos, read_size);
|
||||
|
||||
// ファイルの途中で読み込みエラーに至った。
|
||||
if (fs.fail())
|
||||
return 2;
|
||||
|
||||
//cout << ".";
|
||||
}
|
||||
fs.close();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int write_memory_to_file(std::string filename, void* ptr, uint64_t size)
|
||||
{
|
||||
fstream fs(filename, ios::out | ios::binary);
|
||||
if (fs.fail())
|
||||
return 1;
|
||||
|
||||
const uint64_t block_size = 1024 * 1024 * 1024; // 1回のwriteで書き出す要素の数(1GB)
|
||||
for (uint64_t pos = 0; pos < size; pos += block_size)
|
||||
{
|
||||
// 今回書き出すメモリサイズ
|
||||
uint64_t write_size = (pos + block_size < size) ? block_size : (size - pos);
|
||||
fs.write((char*)ptr + pos, write_size);
|
||||
//cout << ".";
|
||||
}
|
||||
fs.close();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// mkdir wrapper
|
||||
// ----------------------------
|
||||
|
||||
// カレントフォルダ相対で指定する。成功すれば0、失敗すれば非0が返る。
|
||||
// フォルダを作成する。日本語は使っていないものとする。
|
||||
// どうもmsys2環境下のgccだと_wmkdir()だとフォルダの作成に失敗する。原因不明。
|
||||
// 仕方ないので_mkdir()を用いる。
|
||||
|
||||
#if defined(_WIN32)
|
||||
// Windows用
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include <codecvt> // mkdirするのにwstringが欲しいのでこれが必要
|
||||
#include <locale> // wstring_convertにこれが必要。
|
||||
|
||||
namespace Dependency {
|
||||
int mkdir(std::string dir_name)
|
||||
{
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> cv;
|
||||
return _wmkdir(cv.from_bytes(dir_name).c_str());
|
||||
// ::CreateDirectory(cv.from_bytes(dir_name).c_str(),NULL);
|
||||
}
|
||||
}
|
||||
|
||||
#elif defined(__GNUC__)
|
||||
|
||||
#include <direct.h>
|
||||
namespace Dependency {
|
||||
int mkdir(std::string dir_name)
|
||||
{
|
||||
return _mkdir(dir_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#elif defined(_LINUX)
|
||||
|
||||
// linux環境において、この_LINUXというシンボルはmakefileにて定義されるものとする。
|
||||
|
||||
// Linux用のmkdir実装。
|
||||
#include "sys/stat.h"
|
||||
|
||||
namespace Dependency {
|
||||
int mkdir(std::string dir_name)
|
||||
{
|
||||
return ::mkdir(dir_name.c_str(), 0777);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
// Linux環境かどうかを判定するためにはmakefileを分けないといけなくなってくるな..
|
||||
// linuxでフォルダ掘る機能は、とりあえずナシでいいや..。評価関数ファイルの保存にしか使ってないし…。
|
||||
|
||||
namespace Dependency {
|
||||
int mkdir(std::string dir_name)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
100
src/misc.h
100
src/misc.h
|
@ -24,11 +24,13 @@
|
|||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "types.h"
|
||||
#include "thread_win32_osx.h"
|
||||
|
||||
const std::string engine_info(bool to_uci = false);
|
||||
void prefetch(void* addr);
|
||||
|
@ -98,8 +100,20 @@ public:
|
|||
/// Output values only have 1/8th of their bits set on average.
|
||||
template<typename T> T sparse_rand()
|
||||
{ return T(rand64() & rand64() & rand64()); }
|
||||
|
||||
// 0からn-1までの乱数を返す。(一様分布ではないが現実的にはこれで十分)
|
||||
uint64_t rand(uint64_t n) { return rand<uint64_t>() % n; }
|
||||
|
||||
// 内部で使用している乱数seedを返す。
|
||||
uint64_t get_seed() const { return s; }
|
||||
};
|
||||
|
||||
// 乱数のseedを表示する。(デバッグ用)
|
||||
inline std::ostream& operator<<(std::ostream& os, PRNG& prng)
|
||||
{
|
||||
os << "PRNG::seed = " << std::hex << prng.get_seed() << std::dec;
|
||||
return os;
|
||||
}
|
||||
|
||||
/// Under Windows it is not possible for a process to run on more than one
|
||||
/// logical processor group. This usually means to be limited to use max 64
|
||||
|
@ -114,6 +128,9 @@ namespace WinProcGroup {
|
|||
// 指定されたミリ秒だけsleepする。
|
||||
extern void sleep(int ms);
|
||||
|
||||
// 現在時刻を文字列化したもを返す。(評価関数の学習時などにログ出力のために用いる)
|
||||
std::string now_string();
|
||||
|
||||
// 途中での終了処理のためのwrapper
|
||||
static void my_exit()
|
||||
{
|
||||
|
@ -121,6 +138,54 @@ static void my_exit()
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// msys2、Windows Subsystem for Linuxなどのgcc/clangでコンパイルした場合、
|
||||
// C++のstd::ifstreamで::read()は、一発で2GB以上のファイルの読み書きが出来ないのでそのためのwrapperである。
|
||||
//
|
||||
// read_file_to_memory()の引数のcallback_funcは、ファイルがオープン出来た時点でそのファイルサイズを引数として
|
||||
// callbackされるので、バッファを確保して、その先頭ポインタを返す関数を渡すと、そこに読み込んでくれる。
|
||||
// これらの関数は、ファイルが見つからないときなどエラーの際には非0を返す。
|
||||
//
|
||||
// また、callbackされた関数のなかでバッファが確保できなかった場合や、想定していたファイルサイズと異なった場合は、
|
||||
// nullptrを返せば良い。このとき、read_file_to_memory()は、読み込みを中断し、エラーリターンする。
|
||||
|
||||
int read_file_to_memory(std::string filename, std::function<void* (uint64_t)> callback_func);
|
||||
int write_memory_to_file(std::string filename, void* ptr, uint64_t size);
|
||||
|
||||
// --------------------
|
||||
// PRNGのasync版
|
||||
// --------------------
|
||||
|
||||
// PRNGのasync版
|
||||
struct AsyncPRNG
|
||||
{
|
||||
AsyncPRNG(uint64_t seed) : prng(seed) { assert(seed); }
|
||||
// [ASYNC] 乱数を一つ取り出す。
|
||||
template<typename T> T rand() {
|
||||
std::unique_lock<Mutex> lk(mutex);
|
||||
return prng.rand<T>();
|
||||
}
|
||||
|
||||
// [ASYNC] 0からn-1までの乱数を返す。(一様分布ではないが現実的にはこれで十分)
|
||||
uint64_t rand(uint64_t n) {
|
||||
std::unique_lock<Mutex> lk(mutex);
|
||||
return prng.rand(n);
|
||||
}
|
||||
|
||||
// 内部で使用している乱数seedを返す。
|
||||
uint64_t get_seed() const { return prng.get_seed(); }
|
||||
|
||||
protected:
|
||||
Mutex mutex;
|
||||
PRNG prng;
|
||||
};
|
||||
|
||||
// 乱数のseedを表示する。(デバッグ用)
|
||||
inline std::ostream& operator<<(std::ostream& os, AsyncPRNG& prng)
|
||||
{
|
||||
os << "AsyncPRNG::seed = " << std::hex << prng.get_seed() << std::dec;
|
||||
return os;
|
||||
}
|
||||
|
||||
// --------------------
|
||||
// Math
|
||||
// --------------------
|
||||
|
@ -176,4 +241,39 @@ struct Path
|
|||
extern void* aligned_malloc(size_t size, size_t align);
|
||||
static void aligned_free(void* ptr) { _mm_free(ptr); }
|
||||
|
||||
// alignasを指定しているのにnewのときに無視される&STLのコンテナがメモリ確保するときに無視するので、
|
||||
// そのために用いるカスタムアロケーター。
|
||||
template <typename T>
|
||||
class AlignedAllocator {
|
||||
public:
|
||||
using value_type = T;
|
||||
|
||||
AlignedAllocator() {}
|
||||
AlignedAllocator(const AlignedAllocator&) {}
|
||||
AlignedAllocator(AlignedAllocator&&) {}
|
||||
|
||||
template <typename U> AlignedAllocator(const AlignedAllocator<U>&) {}
|
||||
|
||||
T* allocate(std::size_t n) { return (T*)aligned_malloc(n * sizeof(T), alignof(T)); }
|
||||
void deallocate(T* p, std::size_t n) { aligned_free(p); }
|
||||
};
|
||||
|
||||
// --------------------
|
||||
// Dependency Wrapper
|
||||
// --------------------
|
||||
|
||||
namespace Dependency
|
||||
{
|
||||
// Linux環境ではgetline()したときにテキストファイルが'\r\n'だと
|
||||
// '\r'が末尾に残るのでこの'\r'を除去するためにwrapperを書く。
|
||||
// そのため、fstreamに対してgetline()を呼び出すときは、
|
||||
// std::getline()ではなく単にgetline()と書いて、この関数を使うべき。
|
||||
extern bool getline(std::ifstream& fs, std::string& s);
|
||||
|
||||
// フォルダを作成する。
|
||||
// カレントフォルダ相対で指定する。dir_nameに日本語は使っていないものとする。
|
||||
// 成功すれば0、失敗すれば非0が返る。
|
||||
extern int mkdir(std::string dir_name);
|
||||
}
|
||||
|
||||
#endif // #ifndef MISC_H_INCLUDED
|
||||
|
|
|
@ -68,6 +68,9 @@ struct MoveList {
|
|||
return std::find(begin(), end(), move) != end();
|
||||
}
|
||||
|
||||
// i番目の要素を返す
|
||||
const ExtMove at(size_t i) const { assert(0 <= i && i < size()); return begin()[i]; }
|
||||
|
||||
private:
|
||||
ExtMove moveList[MAX_MOVES], *last;
|
||||
};
|
||||
|
|
|
@ -1480,3 +1480,12 @@ PieceNumber Position::piece_no_of(Square sq) const
|
|||
return n;
|
||||
}
|
||||
#endif // defined(EVAL_NNUE)
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
// 現局面で指し手がないかをテストする。指し手生成ルーチンを用いるので速くない。探索中には使わないこと。
|
||||
bool Position::is_mated() const
|
||||
{
|
||||
// 不成で詰めろを回避できるパターンはないのでLEGAL_ALLである必要はない。
|
||||
return MoveList<LEGAL>(*this).size() == 0;
|
||||
}
|
||||
#endif // EVAL_LEARN
|
||||
|
|
|
@ -80,6 +80,9 @@ typedef std::unique_ptr<std::deque<StateInfo>> StateListPtr;
|
|||
/// traversing the search tree.
|
||||
class Thread;
|
||||
|
||||
// packされたsfen
|
||||
struct PackedSfen { uint8_t data[32]; };
|
||||
|
||||
class Position {
|
||||
public:
|
||||
static void init();
|
||||
|
@ -187,6 +190,29 @@ public:
|
|||
const Eval::EvalList* eval_list() const { return &evalList; }
|
||||
#endif // defined(EVAL_NNUE)
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
// 現局面で指し手がないかをテストする。指し手生成ルーチンを用いるので速くない。探索中には使わないこと。
|
||||
bool is_mated() const;
|
||||
|
||||
// -- sfen化ヘルパ
|
||||
|
||||
// packされたsfenを得る。引数に指定したバッファに返す。
|
||||
// gamePlyはpackに含めない。
|
||||
void sfen_pack(PackedSfen& sfen);
|
||||
|
||||
// ↑sfenを経由すると遅いので直接packされたsfenをセットする関数を作った。
|
||||
// pos.set(sfen_unpack(data),si,th); と等価。
|
||||
// 渡された局面に問題があって、エラーのときは非0を返す。
|
||||
// PackedSfenにgamePlyは含まないので復元できない。そこを設定したいのであれば引数で指定すること。
|
||||
int set_from_packed_sfen(const PackedSfen& sfen, StateInfo* si, Thread* th, bool mirror = false);
|
||||
|
||||
// 盤面と手駒、手番を与えて、そのsfenを返す。
|
||||
//static std::string sfen_from_rawdata(Piece board[81], Hand hands[2], Color turn, int gamePly);
|
||||
|
||||
// c側の玉の位置を返す。
|
||||
Square king_square(Color c) const { return pieceList[make_piece(c, KING)][0]; }
|
||||
#endif // EVAL_LEARN
|
||||
|
||||
private:
|
||||
// Initialization helpers (used while setting up a position)
|
||||
void set_castling_right(Color c, Square rfrom);
|
||||
|
|
280
src/search.cpp
280
src/search.cpp
|
@ -1721,3 +1721,283 @@ void Tablebases::rank_root_moves(Position& pos, Search::RootMoves& rootMoves) {
|
|||
m.tbRank = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// --- 学習時に用いる、depth固定探索などの関数を外部に対して公開
|
||||
|
||||
#if defined (EVAL_LEARN)
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
// 学習用に、1つのスレッドからsearch,qsearch()を呼び出せるようなスタブを用意する。
|
||||
// いまにして思えば、AperyのようにSearcherを持ってスレッドごとに置換表などを用意するほうが
|
||||
// 良かったかも知れない。
|
||||
|
||||
// 学習のための初期化。
|
||||
// Learner::search(),Learner::qsearch()から呼び出される。
|
||||
void init_for_search(Position& pos, Stack* ss)
|
||||
{
|
||||
|
||||
// RootNodeはss->ply == 0がその条件。
|
||||
// ゼロクリアするので、ss->ply == 0となるので大丈夫…。
|
||||
|
||||
memset(ss - 4, 0, 7 * sizeof(Stack));
|
||||
|
||||
// Search::Limitsに関して
|
||||
// このメンバー変数はglobalなので他のスレッドに影響を及ぼすので気をつけること。
|
||||
{
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// 探索を"go infinite"コマンド相当にする。(time managementされると困るため)
|
||||
limits.infinite = true;
|
||||
|
||||
// PVを表示されると邪魔なので消しておく。
|
||||
//limits.silent = true;
|
||||
|
||||
// これを用いると各スレッドのnodesを積算したものと比較されてしまう。ゆえに使用しない。
|
||||
limits.nodes = 0;
|
||||
|
||||
// depthも、Learner::search()の引数として渡されたもので処理する。
|
||||
limits.depth = 0;
|
||||
|
||||
// 引き分け付近の手数で引き分けの値が返るのを防ぐために大きな値にしておく。
|
||||
//limits.max_game_ply = 1 << 16;
|
||||
|
||||
// 入玉ルールも入れておかないと引き分けになって決着つきにくい。
|
||||
//limits.enteringKingRule = EnteringKingRule::EKR_27_POINT;
|
||||
}
|
||||
|
||||
// DrawValueの設定
|
||||
{
|
||||
// スレッドごとに用意してないので
|
||||
// 他のスレッドで上書きされかねない。仕方がないが。
|
||||
// どうせそうなるなら、0にすべきだと思う。
|
||||
//drawValueTable[REPETITION_DRAW][BLACK] = VALUE_ZERO;
|
||||
//drawValueTable[REPETITION_DRAW][WHITE] = VALUE_ZERO;
|
||||
}
|
||||
|
||||
// this_threadに関して。
|
||||
{
|
||||
auto th = pos.this_thread();
|
||||
|
||||
th->completedDepth = DEPTH_ZERO;
|
||||
th->selDepth = 0;
|
||||
th->rootDepth = DEPTH_ZERO;
|
||||
|
||||
// 探索ノード数のゼロ初期化
|
||||
th->nodes = 0;
|
||||
|
||||
// history類を全部クリアする。この初期化は少し時間がかかるし、探索の精度はむしろ下がるので善悪はよくわからない。
|
||||
// th->clear();
|
||||
|
||||
for (int i = 4; i > 0; i--)
|
||||
(ss - i)->continuationHistory = &th->continuationHistory[SQUARE_ZERO][NO_PIECE];
|
||||
|
||||
// rootMovesの設定
|
||||
auto& rootMoves = th->rootMoves;
|
||||
|
||||
rootMoves.clear();
|
||||
for (auto m : MoveList<LEGAL>(pos))
|
||||
rootMoves.push_back(Search::RootMove(m));
|
||||
|
||||
assert(!rootMoves.empty());
|
||||
|
||||
//#if defined(USE_GLOBAL_OPTIONS)
|
||||
// 探索スレッドごとの置換表の世代を管理しているはずなので、
|
||||
// 新規の探索であるから、このスレッドに対する置換表の世代を増やす。
|
||||
//TT.new_search(th->thread_id());
|
||||
|
||||
// ↑ここでnew_searchを呼び出すと1手前の探索結果が使えなくて損ということはあるのでは…。
|
||||
// ここでこれはやらずに、呼び出し側で1局ごとにTT.new_search(th->thread_id())をやるべきでは…。
|
||||
|
||||
// → 同一の終局図に至るのを回避したいので、教師生成時には置換表は全スレ共通で使うようにする。
|
||||
//#endif
|
||||
}
|
||||
}
|
||||
|
||||
// 読み筋と評価値のペア。Learner::search(),Learner::qsearch()が返す。
|
||||
typedef std::pair<Value, std::vector<Move> > ValueAndPV;
|
||||
|
||||
// 静止探索。
|
||||
//
|
||||
// 前提条件) pos.set_this_thread(Threads[thread_id])で探索スレッドが設定されていること。
|
||||
// また、Threads.stopが来ると探索を中断してしまうので、そのときのPVは正しくない。
|
||||
// search()から戻ったあと、Threads.stop == trueなら、その探索結果を用いてはならない。
|
||||
// あと、呼び出し前は、Threads.stop == falseの状態で呼び出さないと、探索を中断して返ってしまうので注意。
|
||||
//
|
||||
// 詰まされている場合は、PV配列にMOVE_RESIGNが返る。
|
||||
//
|
||||
// 引数でalpha,betaを指定できるようにしていたが、これがその窓で探索したときの結果を
|
||||
// 置換表に書き込むので、その窓に対して枝刈りが出来るような値が書き込まれて学習のときに
|
||||
// 悪い影響があるので、窓の範囲を指定できるようにするのをやめることにした。
|
||||
ValueAndPV qsearch(Position& pos)
|
||||
{
|
||||
Stack stack[MAX_PLY + 7], * ss = stack + 4;
|
||||
Move pv[MAX_PLY + 1];
|
||||
std::vector<Move> pvs;
|
||||
|
||||
init_for_search(pos, ss);
|
||||
ss->pv = pv; // とりあえずダミーでどこかバッファがないといけない。
|
||||
|
||||
// 詰まされているのか
|
||||
if (pos.is_mated())
|
||||
{
|
||||
pvs.push_back(MOVE_NONE);
|
||||
return ValueAndPV(mated_in(/*ss->ply*/ 0 + 1), pvs);
|
||||
}
|
||||
|
||||
auto bestValue = ::qsearch<PV>(pos, ss, -VALUE_INFINITE, VALUE_INFINITE, DEPTH_ZERO);
|
||||
|
||||
// 得られたPVを返す。
|
||||
for (Move* p = &ss->pv[0]; is_ok(*p); ++p)
|
||||
pvs.push_back(*p);
|
||||
|
||||
return ValueAndPV(bestValue, pvs);
|
||||
}
|
||||
|
||||
// 通常探索。深さdepth(整数で指定)。
|
||||
// 3手読み時のスコアが欲しいなら、
|
||||
// auto v = search(pos,3);
|
||||
// のようにすべし。
|
||||
// v.firstに評価値、v.secondにPVが得られる。
|
||||
// multi pvが有効のときは、pos.this_thread()->rootMoves[N].pvにそのPV(読み筋)の配列が得られる。
|
||||
// multi pvの指定はこの関数の引数multiPVで行なう。(Options["MultiPV"]の値は無視される)
|
||||
//
|
||||
// rootでの宣言勝ち判定はしないので(扱いが面倒なので)、ここでは行わない。
|
||||
// 呼び出し側で処理すること。
|
||||
//
|
||||
// 前提条件) pos.set_this_thread(Threads[thread_id])で探索スレッドが設定されていること。
|
||||
// また、Threads.stopが来ると探索を中断してしまうので、そのときのPVは正しくない。
|
||||
// search()から戻ったあと、Threads.stop == trueなら、その探索結果を用いてはならない。
|
||||
// あと、呼び出し前は、Threads.stop == falseの状態で呼び出さないと、探索を中断して返ってしまうので注意。
|
||||
|
||||
ValueAndPV search(Position& pos, int depth_, size_t multiPV /* = 1 */, uint64_t nodesLimit /* = 0 */)
|
||||
{
|
||||
std::vector<Move> pvs;
|
||||
|
||||
Depth depth = depth_ * ONE_PLY;
|
||||
if (depth < DEPTH_ZERO)
|
||||
return std::pair<Value, std::vector<Move>>(Eval::evaluate(pos), std::vector<Move>());
|
||||
|
||||
if (depth == DEPTH_ZERO)
|
||||
return qsearch(pos);
|
||||
|
||||
Stack stack[MAX_PLY + 7], * ss = stack + 4;
|
||||
Move pv[MAX_PLY + 1];
|
||||
|
||||
init_for_search(pos, ss);
|
||||
|
||||
ss->pv = pv; // とりあえずダミーでどこかバッファがないといけない。
|
||||
|
||||
// this_threadに関連する変数の初期化
|
||||
auto th = pos.this_thread();
|
||||
auto& rootDepth = th->rootDepth;
|
||||
auto& pvIdx = th->pvIdx;
|
||||
auto& rootMoves = th->rootMoves;
|
||||
auto& completedDepth = th->completedDepth;
|
||||
auto& selDepth = th->selDepth;
|
||||
|
||||
// bestmoveとしてしこの局面の上位N個を探索する機能
|
||||
//size_t multiPV = Options["MultiPV"];
|
||||
|
||||
// この局面での指し手の数を上回ってはいけない
|
||||
multiPV = std::min(multiPV, rootMoves.size());
|
||||
|
||||
// ノード制限にMultiPVの値を掛けておかないと、depth固定、MultiPVありにしたときに1つの候補手に同じnodeだけ思考したことにならない。
|
||||
nodesLimit *= multiPV;
|
||||
|
||||
Value alpha = -VALUE_INFINITE;
|
||||
Value beta = VALUE_INFINITE;
|
||||
Value delta = -VALUE_INFINITE;
|
||||
Value bestValue = -VALUE_INFINITE;
|
||||
|
||||
while ((rootDepth += ONE_PLY) <= depth
|
||||
// node制限を超えた場合もこのループを抜ける
|
||||
// 探索ノード数は、この関数の引数で渡されている。
|
||||
&& !(nodesLimit /*node制限あり*/ && th->nodes.load(std::memory_order_relaxed) >= nodesLimit)
|
||||
)
|
||||
{
|
||||
for (RootMove& rm : rootMoves)
|
||||
rm.previousScore = rm.score;
|
||||
|
||||
// MultiPV
|
||||
for (pvIdx = 0; pvIdx < multiPV && !Threads.stop; ++pvIdx)
|
||||
{
|
||||
// それぞれのdepthとPV lineに対するUSI infoで出力するselDepth
|
||||
selDepth = 0;
|
||||
|
||||
// depth 5以上においてはaspiration searchに切り替える。
|
||||
if (rootDepth >= 5 * ONE_PLY)
|
||||
{
|
||||
delta = Value(20);
|
||||
|
||||
Value p = rootMoves[pvIdx].previousScore;
|
||||
|
||||
alpha = std::max(p - delta, -VALUE_INFINITE);
|
||||
beta = std::min(p + delta, VALUE_INFINITE);
|
||||
}
|
||||
|
||||
// aspiration search
|
||||
int failedHighCnt = 0;
|
||||
while (true)
|
||||
{
|
||||
Depth adjustedDepth = std::max(ONE_PLY, rootDepth - failedHighCnt * ONE_PLY);
|
||||
bestValue = ::search<PV>(pos, ss, alpha, beta, adjustedDepth, false);
|
||||
|
||||
stable_sort(rootMoves.begin() + pvIdx, rootMoves.end());
|
||||
//my_stable_sort(pos.this_thread()->thread_id(),&rootMoves[0] + pvIdx, rootMoves.size() - pvIdx);
|
||||
|
||||
// fail low/highに対してaspiration windowを広げる。
|
||||
// ただし、引数で指定されていた値になっていたら、もうfail low/high扱いとしてbreakする。
|
||||
if (bestValue <= alpha)
|
||||
{
|
||||
beta = (alpha + beta) / 2;
|
||||
alpha = std::max(bestValue - delta, -VALUE_INFINITE);
|
||||
|
||||
failedHighCnt = 0;
|
||||
//if (mainThread)
|
||||
// mainThread->stopOnPonderhit = false;
|
||||
|
||||
}
|
||||
else if (bestValue >= beta)
|
||||
{
|
||||
beta = std::min(bestValue + delta, VALUE_INFINITE);
|
||||
++failedHighCnt;
|
||||
}
|
||||
else
|
||||
break;
|
||||
|
||||
delta += delta / 4 + 5;
|
||||
assert(-VALUE_INFINITE <= alpha && beta <= VALUE_INFINITE);
|
||||
|
||||
// 暴走チェック
|
||||
//assert(th->nodes.load(std::memory_order_relaxed) <= 1000000 );
|
||||
}
|
||||
|
||||
stable_sort(rootMoves.begin(), rootMoves.begin() + pvIdx + 1);
|
||||
//my_stable_sort(pos.this_thread()->thread_id() , &rootMoves[0] , pvIdx + 1);
|
||||
|
||||
} // multi PV
|
||||
|
||||
completedDepth = rootDepth;
|
||||
}
|
||||
|
||||
// このPV、途中でNULL_MOVEの可能性があるかも知れないので排除するためにis_ok()を通す。
|
||||
// → PVなのでNULL_MOVEはしないことになっているはずだし、
|
||||
// MOVE_WINも突っ込まれていることはない。(いまのところ)
|
||||
for (Move move : rootMoves[0].pv)
|
||||
{
|
||||
if (!is_ok(move))
|
||||
break;
|
||||
pvs.push_back(move);
|
||||
}
|
||||
|
||||
//sync_cout << rootDepth << sync_endl;
|
||||
|
||||
// multiPV時を考慮して、rootMoves[0]のscoreをbestValueとして返す。
|
||||
bestValue = rootMoves[0].score;
|
||||
|
||||
return ValueAndPV(bestValue, pvs);
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
|
14
src/types.h
14
src/types.h
|
@ -235,8 +235,8 @@ enum Square : int {
|
|||
SQ_A8, SQ_B8, SQ_C8, SQ_D8, SQ_E8, SQ_F8, SQ_G8, SQ_H8,
|
||||
SQ_NONE,
|
||||
|
||||
SQUARE_NB = 64,
|
||||
SQUARE_NB_PLUS1 = SQUARE_NB + 1, // 玉がいない場合、SQ_NBに移動したものとして扱うため、配列をSQ_NB+1で確保しないといけないときがあるのでこの定数を用いる。
|
||||
SQUARE_ZERO = 0, SQUARE_NB = 64,
|
||||
SQUARE_NB_PLUS1 = SQUARE_NB + 1, // 玉がいない場合、SQUARE_NBに移動したものとして扱うため、配列をSQUARE_NB+1で確保しないといけないときがあるのでこの定数を用いる。
|
||||
};
|
||||
|
||||
enum Direction : int {
|
||||
|
@ -362,10 +362,6 @@ constexpr Square operator~(Square s) {
|
|||
return Square(s ^ SQ_A8); // Vertical flip SQ_A1 -> SQ_A8
|
||||
}
|
||||
|
||||
constexpr Square inverse(Square s) {
|
||||
return static_cast<Square>(static_cast<int>(SQUARE_NB) - s - 1);
|
||||
}
|
||||
|
||||
constexpr File operator~(File f) {
|
||||
return File(f ^ FILE_H); // Horizontal flip FILE_A -> FILE_H
|
||||
}
|
||||
|
@ -464,6 +460,12 @@ constexpr bool is_ok(Move m) {
|
|||
return from_sq(m) != to_sq(m); // Catch MOVE_NULL and MOVE_NONE
|
||||
}
|
||||
|
||||
// 盤面を180°回したときの升目を返す
|
||||
constexpr Square Inv(Square sq) { return (Square)((SQUARE_NB - 1) - sq); }
|
||||
|
||||
// 盤面をミラーしたときの升目を返す
|
||||
constexpr Square Mir(Square sq) { return make_square(File(7 - (int)file_of(sq)), rank_of(sq)); }
|
||||
|
||||
#if defined(EVAL_NNUE)
|
||||
// --------------------
|
||||
// ‹î”
|
||||
|
|
12
src/uci.cpp
12
src/uci.cpp
|
@ -37,12 +37,10 @@ using namespace std;
|
|||
|
||||
extern vector<string> setup_bench(const Position&, istream&);
|
||||
|
||||
namespace {
|
||||
|
||||
// FEN string of the initial position, normal chess
|
||||
const char* StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
|
||||
|
||||
|
||||
namespace {
|
||||
// position() is called when engine receives the "position" UCI command.
|
||||
// The function sets up the position described in the given FEN string ("fen")
|
||||
// or the starting position ("startpos") and then makes the moves given in the
|
||||
|
@ -179,10 +177,11 @@ namespace {
|
|||
|
||||
// check sumを計算したとき、それを保存しておいてあとで次回以降、整合性のチェックを行なう。
|
||||
uint64_t eval_sum;
|
||||
} // namespace
|
||||
|
||||
// is_ready_cmd()を外部から呼び出せるようにしておく。(benchコマンドなどから呼び出したいため)
|
||||
// 局面は初期化されないので注意。
|
||||
void is_ready(Position& pos, istringstream& is, StateListPtr& states)
|
||||
void is_ready(bool skipCorruptCheck)
|
||||
{
|
||||
#if defined(EVAL_NNUE)
|
||||
// "isready"を受け取ったあと、"readyok"を返すまで5秒ごとに改行を送るように修正する。(keep alive的な処理)
|
||||
|
@ -226,7 +225,7 @@ namespace {
|
|||
{
|
||||
// メモリが破壊されていないかを調べるためにチェックサムを毎回調べる。
|
||||
// 時間が少しもったいない気もするが.. 0.1秒ぐらいのことなので良しとする。
|
||||
if (eval_sum != Eval::calc_check_sum())
|
||||
if (!skipCorruptCheck && eval_sum != Eval::calc_check_sum())
|
||||
sync_cout << "Error! : EVAL memory is corrupted" << sync_endl;
|
||||
}
|
||||
|
||||
|
@ -246,7 +245,6 @@ namespace {
|
|||
|
||||
sync_cout << "readyok" << sync_endl;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
||||
/// UCI::loop() waits for a command from stdin, parses it and calls the appropriate
|
||||
|
@ -296,7 +294,7 @@ void UCI::loop(int argc, char* argv[]) {
|
|||
else if (token == "go") go(pos, is, states);
|
||||
else if (token == "position") position(pos, is, states);
|
||||
else if (token == "ucinewgame") Search::clear();
|
||||
else if (token == "isready") is_ready(pos, is, states);
|
||||
else if (token == "isready") is_ready();
|
||||
|
||||
// Additional custom non-UCI commands, mainly for debugging
|
||||
else if (token == "flip") pos.flip();
|
||||
|
|
|
@ -81,4 +81,12 @@ extern bool load_eval_finished; // = false;
|
|||
|
||||
extern UCI::OptionsMap Options;
|
||||
|
||||
// USIの"isready"コマンドが呼び出されたときの処理。このときに評価関数の読み込みなどを行なう。
|
||||
// benchmarkコマンドのハンドラなどで"isready"が来ていないときに評価関数を読み込ませたいときに用いる。
|
||||
// skipCorruptCheck == trueのときは評価関数の2度目の読み込みのときのcheck sumによるメモリ破損チェックを省略する。
|
||||
// ※ この関数は、Stockfishにはないがないと不便なので追加しておく。
|
||||
void is_ready(bool skipCorruptCheck = false);
|
||||
|
||||
extern const char* StartFEN;
|
||||
|
||||
#endif // #ifndef UCI_H_INCLUDED
|
||||
|
|
Loading…
Add table
Reference in a new issue