1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-04-29 08:13:08 +00:00

Improve Comments for Pairwise Multiplication Optimization

closes https://github.com/official-stockfish/Stockfish/pull/5524

no functional change
This commit is contained in:
Shawn Xu 2024-07-30 01:33:56 -07:00 committed by Joost VandeVondele
parent d626af5c3a
commit bc80ece6c7

View file

@ -352,26 +352,68 @@ class FeatureTransformer {
reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
vec_t* out = reinterpret_cast<vec_t*>(output + offset);
// Per the NNUE architecture, here we want to multiply pairs of
// clipped elements and divide the product by 128. To do this,
// we can naively perform min/max operation to clip each of the
// four int16 vectors, mullo pairs together, then pack them into
// one int8 vector. However, there exists a faster way.
// The idea here is to use the implicit clipping from packus to
// save us two vec_max_16 instructions. This clipping works due
// to the fact that any int16 integer below zero will be zeroed
// on packus.
// Consider the case where the second element is negative.
// If we do standard clipping, that element will be zero, which
// means our pairwise product is zero. If we perform packus and
// remove the lower-side clip for the second element, then our
// product before packus will be negative, and is zeroed on pack.
// The two operation produce equivalent results, but the second
// one (using packus) saves one max operation per pair.
// But here we run into a problem: mullo does not preserve the
// sign of the multiplication. We can get around this by doing
// mulhi, which keeps the sign. But that requires an additional
// tweak.
// mulhi cuts off the last 16 bits of the resulting product,
// which is the same as performing a rightward shift of 16 bits.
// We can use this to our advantage. Recall that we want to
// divide the final product by 128, which is equivalent to a
// 7-bit right shift. Intuitively, if we shift the clipped
// value left by 9, and perform mulhi, which shifts the product
// right by 16 bits, then we will net a right shift of 7 bits.
// However, this won't work as intended. Since we clip the
// values to have a maximum value of 127, shifting it by 9 bits
// might occupy the signed bit, resulting in some positive
// values being interpreted as negative after the shift.
// There is a way, however, to get around this limitation. When
// loading the network, scale accumulator weights and biases by
// 2. To get the same pairwise multiplication result as before,
// we need to divide the product by 128 * 2 * 2 = 512, which
// amounts to a right shift of 9 bits. So now we only have to
// shift left by 7 bits, perform mulhi (shifts right by 16 bits)
// and net a 9 bit right shift. Since we scaled everything by
// two, the values are clipped at 127 * 2 = 254, which occupies
// 8 bits. Shifting it by 7 bits left will no longer occupy the
// signed bit, so we are safe.
// Note that on NEON processors, we shift left by 6 instead
// because the instruction "vqdmulhq_s16" also doubles the
// return value after the multiplication, adding an extra shift
// to the left by 1, so we compensate by shifting less before
// the multiplication.
constexpr int shift =
#if defined(USE_SSE2)
7;
#else
6;
#endif
for (IndexType j = 0; j < NumOutputChunks; ++j)
{
// What we want to do is multiply inputs in a pairwise manner
// (after clipping), and then shift right by 9. Instead, we
// shift left by 7, and use mulhi, stripping the bottom 16 bits,
// effectively shifting right by 16, resulting in a net shift
// of 9 bits. We use mulhi because it maintains the sign of
// the multiplication (unlike mullo), allowing us to make use
// of packus to clip 2 of the inputs, resulting in a save of 2
// "vec_max_16" calls. A special case is when we use NEON,
// where we shift left by 6 instead, because the instruction
// "vqdmulhq_s16" also doubles the return value after the
// multiplication, adding an extra shift to the left by 1, so
// we compensate by shifting less before the multiplication.
#if defined(USE_SSE2)
constexpr int shift = 7;
#else
constexpr int shift = 6;
#endif
const vec_t sum0a =
vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift);
const vec_t sum0b =