// NNUE評価関数に関するUSI拡張コマンド #if defined(ENABLE_TEST_CMD) && defined(EVAL_NNUE) #include "../../thread.h" #include "../../uci.h" #include "evaluate_nnue.h" #include "nnue_test_command.h" #include #include #define ASSERT(X) { if (!(X)) { std::cout << "\nError : ASSERT(" << #X << "), " << __FILE__ << "(" << __LINE__ << "): " << __func__ << std::endl; \ std::this_thread::sleep_for(std::chrono::microseconds(3000)); *(int*)1 =0;} } namespace Eval { namespace NNUE { namespace { // 主に差分計算に関するRawFeaturesのテスト void TestFeatures(Position& pos) { const std::uint64_t num_games = 1000; StateInfo si; pos.set(StartFEN, false, &si, Threads.main()); const int MAX_PLY = 256; // 256手までテスト StateInfo state[MAX_PLY]; // StateInfoを最大手数分だけ int ply; // 初期局面からの手数 PRNG prng(20171128); std::uint64_t num_moves = 0; std::vector num_updates(kRefreshTriggers.size() + 1); std::vector num_resets(kRefreshTriggers.size()); constexpr IndexType kUnknown = -1; std::vector trigger_map(RawFeatures::kDimensions, kUnknown); auto make_index_sets = [&](const Position& pos) { std::vector>> index_sets( kRefreshTriggers.size(), std::vector>(2)); for (IndexType i = 0; i < kRefreshTriggers.size(); ++i) { Features::IndexList active_indices[2]; RawFeatures::AppendActiveIndices(pos, kRefreshTriggers[i], active_indices); for (const auto perspective : Colors) { for (const auto index : active_indices[perspective]) { ASSERT(index < RawFeatures::kDimensions); ASSERT(index_sets[i][perspective].count(index) == 0); ASSERT(trigger_map[index] == kUnknown || trigger_map[index] == i); index_sets[i][perspective].insert(index); trigger_map[index] = i; } } } return index_sets; }; auto update_index_sets = [&](const Position& pos, auto* index_sets) { for (IndexType i = 0; i < kRefreshTriggers.size(); ++i) { Features::IndexList removed_indices[2], added_indices[2]; bool reset[2]; RawFeatures::AppendChangedIndices(pos, kRefreshTriggers[i], removed_indices, added_indices, reset); for (const auto perspective : Colors) { if (reset[perspective]) { (*index_sets)[i][perspective].clear(); ++num_resets[i]; } else { for (const auto index : removed_indices[perspective]) { ASSERT(index < RawFeatures::kDimensions); ASSERT((*index_sets)[i][perspective].count(index) == 1); ASSERT(trigger_map[index] == kUnknown || trigger_map[index] == i); (*index_sets)[i][perspective].erase(index); ++num_updates.back(); ++num_updates[i]; trigger_map[index] = i; } } for (const auto index : added_indices[perspective]) { ASSERT(index < RawFeatures::kDimensions); ASSERT((*index_sets)[i][perspective].count(index) == 0); ASSERT(trigger_map[index] == kUnknown || trigger_map[index] == i); (*index_sets)[i][perspective].insert(index); ++num_updates.back(); ++num_updates[i]; trigger_map[index] = i; } } } }; std::cout << "feature set: " << RawFeatures::GetName() << "[" << RawFeatures::kDimensions << "]" << std::endl; std::cout << "start testing with random games"; for (std::uint64_t i = 0; i < num_games; ++i) { auto index_sets = make_index_sets(pos); for (ply = 0; ply < MAX_PLY; ++ply) { MoveList mg(pos); // 全合法手の生成 // 合法な指し手がなかった == 詰み if (mg.size() == 0) break; // 生成された指し手のなかからランダムに選び、その指し手で局面を進める。 Move m = mg.begin()[prng.rand(mg.size())]; pos.do_move(m, state[ply]); ++num_moves; update_index_sets(pos, &index_sets); ASSERT(index_sets == make_index_sets(pos)); } pos.set(StartFEN, false, &si, Threads.main()); // 100回に1回ごとに'.'を出力(進んでいることがわかるように) if ((i % 100) == 0) std::cout << "." << std::flush; } std::cout << "passed." << std::endl; std::cout << num_games << " games, " << num_moves << " moves, " << num_updates.back() << " updates, " << (1.0 * num_updates.back() / num_moves) << " updates per move" << std::endl; std::size_t num_observed_indices = 0; for (IndexType i = 0; i < kRefreshTriggers.size(); ++i) { const auto count = std::count(trigger_map.begin(), trigger_map.end(), i); num_observed_indices += count; std::cout << "TriggerEvent(" << static_cast(kRefreshTriggers[i]) << "): " << count << " features (" << (100.0 * count / RawFeatures::kDimensions) << "%), " << num_updates[i] << " updates (" << (1.0 * num_updates[i] / num_moves) << " per move), " << num_resets[i] << " resets (" << (100.0 * num_resets[i] / num_moves) << "%)" << std::endl; } std::cout << "observed " << num_observed_indices << " (" << (100.0 * num_observed_indices / RawFeatures::kDimensions) << "% of " << RawFeatures::kDimensions << ") features" << std::endl; } // 評価関数の構造を表す文字列を出力する void PrintInfo(std::istream& stream) { std::cout << "network architecture: " << GetArchitectureString() << std::endl; while (true) { std::string file_name; stream >> file_name; if (file_name.empty()) break; std::uint32_t hash_value; std::string architecture; const bool success = [&]() { std::ifstream file_stream(file_name, std::ios::binary); if (!file_stream) return false; if (!ReadHeader(file_stream, &hash_value, &architecture)) return false; return true; }(); std::cout << file_name << ": "; if (success) { if (hash_value == kHashValue) { std::cout << "matches with this binary"; if (architecture != GetArchitectureString()) { std::cout << ", but architecture string differs: " << architecture; } std::cout << std::endl; } else { std::cout << architecture << std::endl; } } else { std::cout << "failed to read header" << std::endl; } } } } // namespace // NNUE評価関数に関するUSI拡張コマンド void TestCommand(Position& pos, std::istream& stream) { std::string sub_command; stream >> sub_command; if (sub_command == "test_features") { TestFeatures(pos); } else if (sub_command == "info") { PrintInfo(stream); } else { std::cout << "usage:" << std::endl; std::cout << " test nnue test_features" << std::endl; std::cout << " test nnue info [path/to/" << kFileName << "...]" << std::endl; } } } // namespace NNUE } // namespace Eval #endif // defined(ENABLE_TEST_CMD) && defined(EVAL_NNUE)