1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-05-03 01:59:36 +00:00
BadFish/src/eval/nnue/trainer/trainer.h
2019-06-09 10:40:12 +09:00

127 lines
3.7 KiB
C++

// NNUE評価関数の学習用クラステンプレートの共通ヘッダ
#ifndef _NNUE_TRAINER_H_
#define _NNUE_TRAINER_H_
#include "../../../config.h"
#if defined(EVAL_LEARN) && defined(EVAL_NNUE)
#include "../nnue_common.h"
#include "../features/index_list.h"
#include <sstream>
#if defined(USE_BLAS)
static_assert(std::is_same<LearnFloatType, float>::value, "");
#include <cblas.h>
#endif
namespace Eval {
namespace NNUE {
// 評価値と勝率の関係式で用いるPonanza定数
constexpr double kPonanzaConstant = 600.0;
// 学習用特徴量のインデックス1つを表すクラス
class TrainingFeature {
using StorageType = std::uint32_t;
static_assert(std::is_unsigned<StorageType>::value, "");
public:
static constexpr std::uint32_t kIndexBits = 24;
static_assert(kIndexBits < std::numeric_limits<StorageType>::digits, "");
static constexpr std::uint32_t kCountBits =
std::numeric_limits<StorageType>::digits - kIndexBits;
explicit TrainingFeature(IndexType index) :
index_and_count_((index << kCountBits) | 1) {
ASSERT_LV3(index < (1 << kIndexBits));
}
TrainingFeature& operator+=(const TrainingFeature& other) {
ASSERT_LV3(other.GetIndex() == GetIndex());
ASSERT_LV3(other.GetCount() + GetCount() < (1 << kCountBits));
index_and_count_ += other.GetCount();
return *this;
}
IndexType GetIndex() const {
return static_cast<IndexType>(index_and_count_ >> kCountBits);
}
void ShiftIndex(IndexType offset) {
ASSERT_LV3(GetIndex() + offset < (1 << kIndexBits));
index_and_count_ += offset << kCountBits;
}
IndexType GetCount() const {
return static_cast<IndexType>(index_and_count_ & ((1 << kCountBits) - 1));
}
bool operator<(const TrainingFeature& other) const {
return index_and_count_ < other.index_and_count_;
}
private:
StorageType index_and_count_;
};
// 学習データ1サンプルを表す構造体
struct Example {
std::vector<TrainingFeature> training_features[2];
Learner::PackedSfenValue psv;
int sign;
double weight;
};
// ハイパーパラメータの設定などに使用するメッセージ
struct Message {
Message(const std::string& name, const std::string& value = "") :
name(name), value(value), num_peekers(0), num_receivers(0) {}
const std::string name;
const std::string value;
std::uint32_t num_peekers;
std::uint32_t num_receivers;
};
// メッセージを受理するかどうかを判定する
bool ReceiveMessage(const std::string& name, Message* message) {
const auto subscript = "[" + std::to_string(message->num_peekers) + "]";
if (message->name.substr(0, name.size() + 1) == name + "[") {
++message->num_peekers;
}
if (message->name == name || message->name == name + subscript) {
++message->num_receivers;
return true;
}
return false;
}
// 文字列を分割する
std::vector<std::string> Split(const std::string& input, char delimiter) {
std::istringstream stream(input);
std::string field;
std::vector<std::string> fields;
while (std::getline(stream, field, delimiter)) {
fields.push_back(field);
}
return fields;
}
// 浮動小数点数を整数に丸める
template <typename IntType>
IntType Round(double value) {
return static_cast<IntType>(std::floor(value + 0.5));
}
// アライメント付きmake_shared
template <typename T, typename... ArgumentTypes>
std::shared_ptr<T> MakeAlignedSharedPtr(ArgumentTypes&&... arguments) {
const auto ptr = new(aligned_malloc(sizeof(T), alignof(T)))
T(std::forward<ArgumentTypes>(arguments)...);
return std::shared_ptr<T>(ptr, AlignedDeleter<T>());
}
} // namespace NNUE
} // namespace Eval
#endif // defined(EVAL_LEARN) && defined(EVAL_NNUE)
#endif