From bc80ece6c78cafb3a89d3abcec6c71a517c29f2d Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Tue, 30 Jul 2024 01:33:56 -0700 Subject: [PATCH] Improve Comments for Pairwise Multiplication Optimization closes https://github.com/official-stockfish/Stockfish/pull/5524 no functional change --- src/nnue/nnue_feature_transformer.h | 78 ++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index ad0fb1b4..2f74dcae 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -352,26 +352,68 @@ class FeatureTransformer { reinterpret_cast(&(accumulation[perspectives[p]][HalfDimensions / 2])); vec_t* out = reinterpret_cast(output + offset); + // Per the NNUE architecture, here we want to multiply pairs of + // clipped elements and divide the product by 128. To do this, + // we can naively perform min/max operation to clip each of the + // four int16 vectors, mullo pairs together, then pack them into + // one int8 vector. However, there exists a faster way. + + // The idea here is to use the implicit clipping from packus to + // save us two vec_max_16 instructions. This clipping works due + // to the fact that any int16 integer below zero will be zeroed + // on packus. + + // Consider the case where the second element is negative. + // If we do standard clipping, that element will be zero, which + // means our pairwise product is zero. If we perform packus and + // remove the lower-side clip for the second element, then our + // product before packus will be negative, and is zeroed on pack. + // The two operation produce equivalent results, but the second + // one (using packus) saves one max operation per pair. + + // But here we run into a problem: mullo does not preserve the + // sign of the multiplication. We can get around this by doing + // mulhi, which keeps the sign. But that requires an additional + // tweak. + + // mulhi cuts off the last 16 bits of the resulting product, + // which is the same as performing a rightward shift of 16 bits. + // We can use this to our advantage. Recall that we want to + // divide the final product by 128, which is equivalent to a + // 7-bit right shift. Intuitively, if we shift the clipped + // value left by 9, and perform mulhi, which shifts the product + // right by 16 bits, then we will net a right shift of 7 bits. + // However, this won't work as intended. Since we clip the + // values to have a maximum value of 127, shifting it by 9 bits + // might occupy the signed bit, resulting in some positive + // values being interpreted as negative after the shift. + + // There is a way, however, to get around this limitation. When + // loading the network, scale accumulator weights and biases by + // 2. To get the same pairwise multiplication result as before, + // we need to divide the product by 128 * 2 * 2 = 512, which + // amounts to a right shift of 9 bits. So now we only have to + // shift left by 7 bits, perform mulhi (shifts right by 16 bits) + // and net a 9 bit right shift. Since we scaled everything by + // two, the values are clipped at 127 * 2 = 254, which occupies + // 8 bits. Shifting it by 7 bits left will no longer occupy the + // signed bit, so we are safe. + + // Note that on NEON processors, we shift left by 6 instead + // because the instruction "vqdmulhq_s16" also doubles the + // return value after the multiplication, adding an extra shift + // to the left by 1, so we compensate by shifting less before + // the multiplication. + + constexpr int shift = + #if defined(USE_SSE2) + 7; + #else + 6; + #endif + for (IndexType j = 0; j < NumOutputChunks; ++j) { - // What we want to do is multiply inputs in a pairwise manner - // (after clipping), and then shift right by 9. Instead, we - // shift left by 7, and use mulhi, stripping the bottom 16 bits, - // effectively shifting right by 16, resulting in a net shift - // of 9 bits. We use mulhi because it maintains the sign of - // the multiplication (unlike mullo), allowing us to make use - // of packus to clip 2 of the inputs, resulting in a save of 2 - // "vec_max_16" calls. A special case is when we use NEON, - // where we shift left by 6 instead, because the instruction - // "vqdmulhq_s16" also doubles the return value after the - // multiplication, adding an extra shift to the left by 1, so - // we compensate by shifting less before the multiplication. - - #if defined(USE_SSE2) - constexpr int shift = 7; - #else - constexpr int shift = 6; - #endif const vec_t sum0a = vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift); const vec_t sum0b =