From 6a9b8a0c7b913b9d4c4474bae7804184d20e8c4a Mon Sep 17 00:00:00 2001 From: cj5716 <125858804+cj5716@users.noreply.github.com> Date: Sun, 28 Apr 2024 16:33:59 +0800 Subject: [PATCH] Optimise NNUE Accumulator updates Passed STC: https://tests.stockfishchess.org/tests/view/662e3c6a5e9274400985a741 LLR: 2.94 (-2.94,2.94) <0.00,2.00> Total: 86176 W: 22284 L: 21905 D: 41987 Ptnml(0-2): 254, 9572, 23051, 9963, 248 closes https://github.com/official-stockfish/Stockfish/pull/5202 No functional change --- src/nnue/nnue_feature_transformer.h | 76 ++++++++++++++--------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 6b3f78a9..402a47a8 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -404,19 +404,25 @@ class FeatureTransformer { return {st, next}; } - // NOTE: The parameter states_to_update is an array of position states, ending with nullptr. + // NOTE: The parameter states_to_update is an array of position states. // All states must be sequential, that is states_to_update[i] must either be reachable - // by repeatedly applying ->previous from states_to_update[i+1] or - // states_to_update[i] == nullptr. + // by repeatedly applying ->previous from states_to_update[i+1]. // computed_st must be reachable by repeatedly applying ->previous on - // states_to_update[0], if not nullptr. + // states_to_update[0]. template void update_accumulator_incremental(const Position& pos, StateInfo* computed_st, StateInfo* states_to_update[N], bool psqtOnly) const { static_assert(N > 0); - assert(states_to_update[N - 1] == nullptr); + assert([&]() { + for (size_t i = 0; i < N; ++i) + { + if (states_to_update[i] == nullptr) + return false; + } + return true; + }()); #ifdef VECTOR // Gcc-10.2 unnecessarily spills AVX2 registers if this array @@ -425,11 +431,7 @@ class FeatureTransformer { psqt_vec_t psqt[NumPsqtRegs]; #endif - if (states_to_update[0] == nullptr) - return; - // Update incrementally going back through states_to_update. - // Gather all features to be updated. const Square ksq = pos.square(Perspective); @@ -437,28 +439,18 @@ class FeatureTransformer { // That might depend on the feature set and generally relies on the // feature set's update cost calculation to be correct and never allow // updates with more added/removed features than MaxActiveDimensions. - FeatureSet::IndexList removed[N - 1], added[N - 1]; + FeatureSet::IndexList removed[N], added[N]; + for (int i = N - 1; i >= 0; --i) { - int i = - N - - 2; // Last potential state to update. Skip last element because it must be nullptr. - while (states_to_update[i] == nullptr) - --i; + (states_to_update[i]->*accPtr).computed[Perspective] = !psqtOnly; + (states_to_update[i]->*accPtr).computedPSQT[Perspective] = true; - StateInfo* st2 = states_to_update[i]; + const StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1]; - for (; i >= 0; --i) - { - (states_to_update[i]->*accPtr).computed[Perspective] = !psqtOnly; - (states_to_update[i]->*accPtr).computedPSQT[Perspective] = true; - - const StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1]; - - for (; st2 != end_state; st2 = st2->previous) - FeatureSet::append_changed_indices(ksq, st2->dirtyPiece, - removed[i], added[i]); - } + for (StateInfo* st2 = states_to_update[i]; st2 != end_state; st2 = st2->previous) + FeatureSet::append_changed_indices(ksq, st2->dirtyPiece, removed[i], + added[i]); } StateInfo* st = computed_st; @@ -466,8 +458,7 @@ class FeatureTransformer { // Now update the accumulators listed in states_to_update[], where the last element is a sentinel. #ifdef VECTOR - if (states_to_update[1] == nullptr && (removed[0].size() == 1 || removed[0].size() == 2) - && added[0].size() == 1) + if (N == 1 && (removed[0].size() == 1 || removed[0].size() == 2) && added[0].size() == 1) { assert(states_to_update[0]); @@ -541,7 +532,7 @@ class FeatureTransformer { for (IndexType k = 0; k < NumRegs; ++k) acc[k] = vec_load(&accTileIn[k]); - for (IndexType i = 0; states_to_update[i]; ++i) + for (IndexType i = 0; i < N; ++i) { // Difference calculation for the deactivated features for (const auto index : removed[i]) @@ -578,7 +569,7 @@ class FeatureTransformer { for (std::size_t k = 0; k < NumPsqtRegs; ++k) psqt[k] = vec_load_psqt(&accTilePsqtIn[k]); - for (IndexType i = 0; states_to_update[i]; ++i) + for (IndexType i = 0; i < N; ++i) { // Difference calculation for the deactivated features for (const auto index : removed[i]) @@ -608,7 +599,7 @@ class FeatureTransformer { } } #else - for (IndexType i = 0; states_to_update[i]; ++i) + for (IndexType i = 0; i < N; ++i) { if (!psqtOnly) std::memcpy((states_to_update[i]->*accPtr).accumulation[Perspective], @@ -847,8 +838,8 @@ class FeatureTransformer { || (psqtOnly && (oldest_st->*accPtr).computedPSQT[Perspective])) { // Only update current position accumulator to minimize work. - StateInfo* states_to_update[2] = {pos.state(), nullptr}; - update_accumulator_incremental(pos, oldest_st, states_to_update, + StateInfo* states_to_update[1] = {pos.state()}; + update_accumulator_incremental(pos, oldest_st, states_to_update, psqtOnly); } else @@ -873,11 +864,20 @@ class FeatureTransformer { // 1. for the current position // 2. the next accumulator after the computed one // The heuristic may change in the future. - StateInfo* states_to_update[3] = {next, next == pos.state() ? nullptr : pos.state(), - nullptr}; + if (next == pos.state()) + { + StateInfo* states_to_update[1] = {next}; - update_accumulator_incremental(pos, oldest_st, states_to_update, - psqtOnly); + update_accumulator_incremental(pos, oldest_st, states_to_update, + psqtOnly); + } + else + { + StateInfo* states_to_update[2] = {next, pos.state()}; + + update_accumulator_incremental(pos, oldest_st, states_to_update, + psqtOnly); + } } else update_accumulator_refresh_cache(pos, cache, psqtOnly);