mirror of
https://github.com/sockspls/badfish
synced 2025-04-29 16:23:09 +00:00
Optimise NNUE Accumulator updates
Passed STC: https://tests.stockfishchess.org/tests/view/662e3c6a5e9274400985a741 LLR: 2.94 (-2.94,2.94) <0.00,2.00> Total: 86176 W: 22284 L: 21905 D: 41987 Ptnml(0-2): 254, 9572, 23051, 9963, 248 closes https://github.com/official-stockfish/Stockfish/pull/5202 No functional change
This commit is contained in:
parent
eb20de36c0
commit
6a9b8a0c7b
1 changed files with 38 additions and 38 deletions
|
@ -404,19 +404,25 @@ class FeatureTransformer {
|
|||
return {st, next};
|
||||
}
|
||||
|
||||
// NOTE: The parameter states_to_update is an array of position states, ending with nullptr.
|
||||
// NOTE: The parameter states_to_update is an array of position states.
|
||||
// All states must be sequential, that is states_to_update[i] must either be reachable
|
||||
// by repeatedly applying ->previous from states_to_update[i+1] or
|
||||
// states_to_update[i] == nullptr.
|
||||
// by repeatedly applying ->previous from states_to_update[i+1].
|
||||
// computed_st must be reachable by repeatedly applying ->previous on
|
||||
// states_to_update[0], if not nullptr.
|
||||
// states_to_update[0].
|
||||
template<Color Perspective, size_t N>
|
||||
void update_accumulator_incremental(const Position& pos,
|
||||
StateInfo* computed_st,
|
||||
StateInfo* states_to_update[N],
|
||||
bool psqtOnly) const {
|
||||
static_assert(N > 0);
|
||||
assert(states_to_update[N - 1] == nullptr);
|
||||
assert([&]() {
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if (states_to_update[i] == nullptr)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}());
|
||||
|
||||
#ifdef VECTOR
|
||||
// Gcc-10.2 unnecessarily spills AVX2 registers if this array
|
||||
|
@ -425,11 +431,7 @@ class FeatureTransformer {
|
|||
psqt_vec_t psqt[NumPsqtRegs];
|
||||
#endif
|
||||
|
||||
if (states_to_update[0] == nullptr)
|
||||
return;
|
||||
|
||||
// Update incrementally going back through states_to_update.
|
||||
|
||||
// Gather all features to be updated.
|
||||
const Square ksq = pos.square<KING>(Perspective);
|
||||
|
||||
|
@ -437,28 +439,18 @@ class FeatureTransformer {
|
|||
// That might depend on the feature set and generally relies on the
|
||||
// feature set's update cost calculation to be correct and never allow
|
||||
// updates with more added/removed features than MaxActiveDimensions.
|
||||
FeatureSet::IndexList removed[N - 1], added[N - 1];
|
||||
FeatureSet::IndexList removed[N], added[N];
|
||||
|
||||
for (int i = N - 1; i >= 0; --i)
|
||||
{
|
||||
int i =
|
||||
N
|
||||
- 2; // Last potential state to update. Skip last element because it must be nullptr.
|
||||
while (states_to_update[i] == nullptr)
|
||||
--i;
|
||||
(states_to_update[i]->*accPtr).computed[Perspective] = !psqtOnly;
|
||||
(states_to_update[i]->*accPtr).computedPSQT[Perspective] = true;
|
||||
|
||||
StateInfo* st2 = states_to_update[i];
|
||||
const StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1];
|
||||
|
||||
for (; i >= 0; --i)
|
||||
{
|
||||
(states_to_update[i]->*accPtr).computed[Perspective] = !psqtOnly;
|
||||
(states_to_update[i]->*accPtr).computedPSQT[Perspective] = true;
|
||||
|
||||
const StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1];
|
||||
|
||||
for (; st2 != end_state; st2 = st2->previous)
|
||||
FeatureSet::append_changed_indices<Perspective>(ksq, st2->dirtyPiece,
|
||||
removed[i], added[i]);
|
||||
}
|
||||
for (StateInfo* st2 = states_to_update[i]; st2 != end_state; st2 = st2->previous)
|
||||
FeatureSet::append_changed_indices<Perspective>(ksq, st2->dirtyPiece, removed[i],
|
||||
added[i]);
|
||||
}
|
||||
|
||||
StateInfo* st = computed_st;
|
||||
|
@ -466,8 +458,7 @@ class FeatureTransformer {
|
|||
// Now update the accumulators listed in states_to_update[], where the last element is a sentinel.
|
||||
#ifdef VECTOR
|
||||
|
||||
if (states_to_update[1] == nullptr && (removed[0].size() == 1 || removed[0].size() == 2)
|
||||
&& added[0].size() == 1)
|
||||
if (N == 1 && (removed[0].size() == 1 || removed[0].size() == 2) && added[0].size() == 1)
|
||||
{
|
||||
assert(states_to_update[0]);
|
||||
|
||||
|
@ -541,7 +532,7 @@ class FeatureTransformer {
|
|||
for (IndexType k = 0; k < NumRegs; ++k)
|
||||
acc[k] = vec_load(&accTileIn[k]);
|
||||
|
||||
for (IndexType i = 0; states_to_update[i]; ++i)
|
||||
for (IndexType i = 0; i < N; ++i)
|
||||
{
|
||||
// Difference calculation for the deactivated features
|
||||
for (const auto index : removed[i])
|
||||
|
@ -578,7 +569,7 @@ class FeatureTransformer {
|
|||
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
|
||||
psqt[k] = vec_load_psqt(&accTilePsqtIn[k]);
|
||||
|
||||
for (IndexType i = 0; states_to_update[i]; ++i)
|
||||
for (IndexType i = 0; i < N; ++i)
|
||||
{
|
||||
// Difference calculation for the deactivated features
|
||||
for (const auto index : removed[i])
|
||||
|
@ -608,7 +599,7 @@ class FeatureTransformer {
|
|||
}
|
||||
}
|
||||
#else
|
||||
for (IndexType i = 0; states_to_update[i]; ++i)
|
||||
for (IndexType i = 0; i < N; ++i)
|
||||
{
|
||||
if (!psqtOnly)
|
||||
std::memcpy((states_to_update[i]->*accPtr).accumulation[Perspective],
|
||||
|
@ -847,8 +838,8 @@ class FeatureTransformer {
|
|||
|| (psqtOnly && (oldest_st->*accPtr).computedPSQT[Perspective]))
|
||||
{
|
||||
// Only update current position accumulator to minimize work.
|
||||
StateInfo* states_to_update[2] = {pos.state(), nullptr};
|
||||
update_accumulator_incremental<Perspective, 2>(pos, oldest_st, states_to_update,
|
||||
StateInfo* states_to_update[1] = {pos.state()};
|
||||
update_accumulator_incremental<Perspective, 1>(pos, oldest_st, states_to_update,
|
||||
psqtOnly);
|
||||
}
|
||||
else
|
||||
|
@ -873,11 +864,20 @@ class FeatureTransformer {
|
|||
// 1. for the current position
|
||||
// 2. the next accumulator after the computed one
|
||||
// The heuristic may change in the future.
|
||||
StateInfo* states_to_update[3] = {next, next == pos.state() ? nullptr : pos.state(),
|
||||
nullptr};
|
||||
if (next == pos.state())
|
||||
{
|
||||
StateInfo* states_to_update[1] = {next};
|
||||
|
||||
update_accumulator_incremental<Perspective, 3>(pos, oldest_st, states_to_update,
|
||||
psqtOnly);
|
||||
update_accumulator_incremental<Perspective, 1>(pos, oldest_st, states_to_update,
|
||||
psqtOnly);
|
||||
}
|
||||
else
|
||||
{
|
||||
StateInfo* states_to_update[2] = {next, pos.state()};
|
||||
|
||||
update_accumulator_incremental<Perspective, 2>(pos, oldest_st, states_to_update,
|
||||
psqtOnly);
|
||||
}
|
||||
}
|
||||
else
|
||||
update_accumulator_refresh_cache<Perspective>(pos, cache, psqtOnly);
|
||||
|
|
Loading…
Add table
Reference in a new issue