mirror of
https://github.com/sockspls/badfish
synced 2025-04-29 08:13:08 +00:00
Improve Comments for Pairwise Multiplication Optimization
closes https://github.com/official-stockfish/Stockfish/pull/5524 no functional change
This commit is contained in:
parent
d626af5c3a
commit
bc80ece6c7
1 changed files with 60 additions and 18 deletions
|
@ -352,26 +352,68 @@ class FeatureTransformer {
|
|||
reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
|
||||
vec_t* out = reinterpret_cast<vec_t*>(output + offset);
|
||||
|
||||
// Per the NNUE architecture, here we want to multiply pairs of
|
||||
// clipped elements and divide the product by 128. To do this,
|
||||
// we can naively perform min/max operation to clip each of the
|
||||
// four int16 vectors, mullo pairs together, then pack them into
|
||||
// one int8 vector. However, there exists a faster way.
|
||||
|
||||
// The idea here is to use the implicit clipping from packus to
|
||||
// save us two vec_max_16 instructions. This clipping works due
|
||||
// to the fact that any int16 integer below zero will be zeroed
|
||||
// on packus.
|
||||
|
||||
// Consider the case where the second element is negative.
|
||||
// If we do standard clipping, that element will be zero, which
|
||||
// means our pairwise product is zero. If we perform packus and
|
||||
// remove the lower-side clip for the second element, then our
|
||||
// product before packus will be negative, and is zeroed on pack.
|
||||
// The two operation produce equivalent results, but the second
|
||||
// one (using packus) saves one max operation per pair.
|
||||
|
||||
// But here we run into a problem: mullo does not preserve the
|
||||
// sign of the multiplication. We can get around this by doing
|
||||
// mulhi, which keeps the sign. But that requires an additional
|
||||
// tweak.
|
||||
|
||||
// mulhi cuts off the last 16 bits of the resulting product,
|
||||
// which is the same as performing a rightward shift of 16 bits.
|
||||
// We can use this to our advantage. Recall that we want to
|
||||
// divide the final product by 128, which is equivalent to a
|
||||
// 7-bit right shift. Intuitively, if we shift the clipped
|
||||
// value left by 9, and perform mulhi, which shifts the product
|
||||
// right by 16 bits, then we will net a right shift of 7 bits.
|
||||
// However, this won't work as intended. Since we clip the
|
||||
// values to have a maximum value of 127, shifting it by 9 bits
|
||||
// might occupy the signed bit, resulting in some positive
|
||||
// values being interpreted as negative after the shift.
|
||||
|
||||
// There is a way, however, to get around this limitation. When
|
||||
// loading the network, scale accumulator weights and biases by
|
||||
// 2. To get the same pairwise multiplication result as before,
|
||||
// we need to divide the product by 128 * 2 * 2 = 512, which
|
||||
// amounts to a right shift of 9 bits. So now we only have to
|
||||
// shift left by 7 bits, perform mulhi (shifts right by 16 bits)
|
||||
// and net a 9 bit right shift. Since we scaled everything by
|
||||
// two, the values are clipped at 127 * 2 = 254, which occupies
|
||||
// 8 bits. Shifting it by 7 bits left will no longer occupy the
|
||||
// signed bit, so we are safe.
|
||||
|
||||
// Note that on NEON processors, we shift left by 6 instead
|
||||
// because the instruction "vqdmulhq_s16" also doubles the
|
||||
// return value after the multiplication, adding an extra shift
|
||||
// to the left by 1, so we compensate by shifting less before
|
||||
// the multiplication.
|
||||
|
||||
constexpr int shift =
|
||||
#if defined(USE_SSE2)
|
||||
7;
|
||||
#else
|
||||
6;
|
||||
#endif
|
||||
|
||||
for (IndexType j = 0; j < NumOutputChunks; ++j)
|
||||
{
|
||||
// What we want to do is multiply inputs in a pairwise manner
|
||||
// (after clipping), and then shift right by 9. Instead, we
|
||||
// shift left by 7, and use mulhi, stripping the bottom 16 bits,
|
||||
// effectively shifting right by 16, resulting in a net shift
|
||||
// of 9 bits. We use mulhi because it maintains the sign of
|
||||
// the multiplication (unlike mullo), allowing us to make use
|
||||
// of packus to clip 2 of the inputs, resulting in a save of 2
|
||||
// "vec_max_16" calls. A special case is when we use NEON,
|
||||
// where we shift left by 6 instead, because the instruction
|
||||
// "vqdmulhq_s16" also doubles the return value after the
|
||||
// multiplication, adding an extra shift to the left by 1, so
|
||||
// we compensate by shifting less before the multiplication.
|
||||
|
||||
#if defined(USE_SSE2)
|
||||
constexpr int shift = 7;
|
||||
#else
|
||||
constexpr int shift = 6;
|
||||
#endif
|
||||
const vec_t sum0a =
|
||||
vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift);
|
||||
const vec_t sum0b =
|
||||
|
|
Loading…
Add table
Reference in a new issue