mirror of
https://github.com/sockspls/badfish
synced 2025-07-11 19:49:14 +00:00
201 lines
7.1 KiB
C++
201 lines
7.1 KiB
C++
// NNUE評価関数に関するUSI拡張コマンド
|
|
|
|
#if defined(ENABLE_TEST_CMD) && defined(EVAL_NNUE)
|
|
|
|
#include "../../thread.h"
|
|
#include "../../uci.h"
|
|
#include "evaluate_nnue.h"
|
|
#include "nnue_test_command.h"
|
|
|
|
#include <set>
|
|
#include <fstream>
|
|
|
|
#define ASSERT(X) { if (!(X)) { std::cout << "\nError : ASSERT(" << #X << "), " << __FILE__ << "(" << __LINE__ << "): " << __func__ << std::endl; \
|
|
std::this_thread::sleep_for(std::chrono::microseconds(3000)); *(int*)1 =0;} }
|
|
|
|
namespace Eval {
|
|
|
|
namespace NNUE {
|
|
|
|
namespace {
|
|
|
|
// 主に差分計算に関するRawFeaturesのテスト
|
|
void TestFeatures(Position& pos) {
|
|
const std::uint64_t num_games = 1000;
|
|
StateInfo si;
|
|
pos.set(StartFEN, false, &si, Threads.main());
|
|
const int MAX_PLY = 256; // 256手までテスト
|
|
|
|
StateInfo state[MAX_PLY]; // StateInfoを最大手数分だけ
|
|
int ply; // 初期局面からの手数
|
|
|
|
PRNG prng(20171128);
|
|
|
|
std::uint64_t num_moves = 0;
|
|
std::vector<std::uint64_t> num_updates(kRefreshTriggers.size() + 1);
|
|
std::vector<std::uint64_t> num_resets(kRefreshTriggers.size());
|
|
constexpr IndexType kUnknown = -1;
|
|
std::vector<IndexType> trigger_map(RawFeatures::kDimensions, kUnknown);
|
|
auto make_index_sets = [&](const Position& pos) {
|
|
std::vector<std::vector<std::set<IndexType>>> index_sets(
|
|
kRefreshTriggers.size(), std::vector<std::set<IndexType>>(2));
|
|
for (IndexType i = 0; i < kRefreshTriggers.size(); ++i) {
|
|
Features::IndexList active_indices[2];
|
|
RawFeatures::AppendActiveIndices(pos, kRefreshTriggers[i],
|
|
active_indices);
|
|
for (const auto perspective : Colors) {
|
|
for (const auto index : active_indices[perspective]) {
|
|
ASSERT(index < RawFeatures::kDimensions);
|
|
ASSERT(index_sets[i][perspective].count(index) == 0);
|
|
ASSERT(trigger_map[index] == kUnknown || trigger_map[index] == i);
|
|
index_sets[i][perspective].insert(index);
|
|
trigger_map[index] = i;
|
|
}
|
|
}
|
|
}
|
|
return index_sets;
|
|
};
|
|
auto update_index_sets = [&](const Position& pos, auto* index_sets) {
|
|
for (IndexType i = 0; i < kRefreshTriggers.size(); ++i) {
|
|
Features::IndexList removed_indices[2], added_indices[2];
|
|
bool reset[2];
|
|
RawFeatures::AppendChangedIndices(pos, kRefreshTriggers[i],
|
|
removed_indices, added_indices, reset);
|
|
for (const auto perspective : Colors) {
|
|
if (reset[perspective]) {
|
|
(*index_sets)[i][perspective].clear();
|
|
++num_resets[i];
|
|
} else {
|
|
for (const auto index : removed_indices[perspective]) {
|
|
ASSERT(index < RawFeatures::kDimensions);
|
|
ASSERT((*index_sets)[i][perspective].count(index) == 1);
|
|
ASSERT(trigger_map[index] == kUnknown || trigger_map[index] == i);
|
|
(*index_sets)[i][perspective].erase(index);
|
|
++num_updates.back();
|
|
++num_updates[i];
|
|
trigger_map[index] = i;
|
|
}
|
|
}
|
|
for (const auto index : added_indices[perspective]) {
|
|
ASSERT(index < RawFeatures::kDimensions);
|
|
ASSERT((*index_sets)[i][perspective].count(index) == 0);
|
|
ASSERT(trigger_map[index] == kUnknown || trigger_map[index] == i);
|
|
(*index_sets)[i][perspective].insert(index);
|
|
++num_updates.back();
|
|
++num_updates[i];
|
|
trigger_map[index] = i;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
std::cout << "feature set: " << RawFeatures::GetName()
|
|
<< "[" << RawFeatures::kDimensions << "]" << std::endl;
|
|
std::cout << "start testing with random games";
|
|
|
|
for (std::uint64_t i = 0; i < num_games; ++i) {
|
|
auto index_sets = make_index_sets(pos);
|
|
for (ply = 0; ply < MAX_PLY; ++ply) {
|
|
MoveList<LEGAL> mg(pos); // 全合法手の生成
|
|
|
|
// 合法な指し手がなかった == 詰み
|
|
if (mg.size() == 0)
|
|
break;
|
|
|
|
// 生成された指し手のなかからランダムに選び、その指し手で局面を進める。
|
|
Move m = mg.begin()[prng.rand(mg.size())];
|
|
pos.do_move(m, state[ply]);
|
|
|
|
++num_moves;
|
|
update_index_sets(pos, &index_sets);
|
|
ASSERT(index_sets == make_index_sets(pos));
|
|
}
|
|
|
|
pos.set(StartFEN, false, &si, Threads.main());
|
|
|
|
// 100回に1回ごとに'.'を出力(進んでいることがわかるように)
|
|
if ((i % 100) == 0)
|
|
std::cout << "." << std::flush;
|
|
}
|
|
std::cout << "passed." << std::endl;
|
|
std::cout << num_games << " games, " << num_moves << " moves, "
|
|
<< num_updates.back() << " updates, "
|
|
<< (1.0 * num_updates.back() / num_moves)
|
|
<< " updates per move" << std::endl;
|
|
std::size_t num_observed_indices = 0;
|
|
for (IndexType i = 0; i < kRefreshTriggers.size(); ++i) {
|
|
const auto count = std::count(trigger_map.begin(), trigger_map.end(), i);
|
|
num_observed_indices += count;
|
|
std::cout << "TriggerEvent(" << static_cast<int>(kRefreshTriggers[i])
|
|
<< "): " << count << " features ("
|
|
<< (100.0 * count / RawFeatures::kDimensions) << "%), "
|
|
<< num_updates[i] << " updates ("
|
|
<< (1.0 * num_updates[i] / num_moves) << " per move), "
|
|
<< num_resets[i] << " resets ("
|
|
<< (100.0 * num_resets[i] / num_moves) << "%)"
|
|
<< std::endl;
|
|
}
|
|
std::cout << "observed " << num_observed_indices << " ("
|
|
<< (100.0 * num_observed_indices / RawFeatures::kDimensions)
|
|
<< "% of " << RawFeatures::kDimensions
|
|
<< ") features" << std::endl;
|
|
}
|
|
|
|
// 評価関数の構造を表す文字列を出力する
|
|
void PrintInfo(std::istream& stream) {
|
|
std::cout << "network architecture: " << GetArchitectureString() << std::endl;
|
|
|
|
while (true) {
|
|
std::string file_name;
|
|
stream >> file_name;
|
|
if (file_name.empty()) break;
|
|
|
|
std::uint32_t hash_value;
|
|
std::string architecture;
|
|
const bool success = [&]() {
|
|
std::ifstream file_stream(file_name, std::ios::binary);
|
|
if (!file_stream) return false;
|
|
if (!ReadHeader(file_stream, &hash_value, &architecture)) return false;
|
|
return true;
|
|
}();
|
|
|
|
std::cout << file_name << ": ";
|
|
if (success) {
|
|
if (hash_value == kHashValue) {
|
|
std::cout << "matches with this binary";
|
|
if (architecture != GetArchitectureString()) {
|
|
std::cout << ", but architecture string differs: " << architecture;
|
|
}
|
|
std::cout << std::endl;
|
|
} else {
|
|
std::cout << architecture << std::endl;
|
|
}
|
|
} else {
|
|
std::cout << "failed to read header" << std::endl;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// NNUE評価関数に関するUSI拡張コマンド
|
|
void TestCommand(Position& pos, std::istream& stream) {
|
|
std::string sub_command;
|
|
stream >> sub_command;
|
|
|
|
if (sub_command == "test_features") {
|
|
TestFeatures(pos);
|
|
} else if (sub_command == "info") {
|
|
PrintInfo(stream);
|
|
} else {
|
|
std::cout << "usage:" << std::endl;
|
|
std::cout << " test nnue test_features" << std::endl;
|
|
std::cout << " test nnue info [path/to/" << kFileName << "...]" << std::endl;
|
|
}
|
|
}
|
|
|
|
} // namespace NNUE
|
|
|
|
} // namespace Eval
|
|
|
|
#endif // defined(ENABLE_TEST_CMD) && defined(EVAL_NNUE)
|