diff --git a/src/nnue/layers/simd.h b/src/nnue/layers/simd.h index 231f7891..381e7a68 100644 --- a/src/nnue/layers/simd.h +++ b/src/nnue/layers/simd.h @@ -165,18 +165,19 @@ namespace Stockfish::Simd { __m512i tmp0 = _mm512_maddubs_epi16(a0, b0); __m512i tmp1 = _mm512_maddubs_epi16(a1, b1); asm( - "vpaddsw %[tmp0], %[tmp1], %[tmp0]\n\t" "vpmaddwd %[tmp0], %[ones], %[tmp0]\n\t" + "vpmaddwd %[tmp1], %[ones], %[tmp1]\n\t" + "vpaddd %[tmp0], %[tmp1], %[tmp0]\n\t" "vpaddd %[acc], %[tmp0], %[acc]\n\t" - : [acc]"+v"(acc), [tmp0]"+&v"(tmp0) - : [tmp1]"v"(tmp1), [ones]"v"(_mm512_set1_epi16(1)) + : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1) + : [ones]"v"(_mm512_set1_epi16(1)) ); # else __m512i product0 = _mm512_maddubs_epi16(a0, b0); __m512i product1 = _mm512_maddubs_epi16(a1, b1); - product0 = _mm512_adds_epi16(product0, product1); product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1)); - acc = _mm512_add_epi32(acc, product0); + product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1)); + acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1)); # endif # endif } @@ -261,18 +262,19 @@ namespace Stockfish::Simd { __m256i tmp0 = _mm256_maddubs_epi16(a0, b0); __m256i tmp1 = _mm256_maddubs_epi16(a1, b1); asm( - "vpaddsw %[tmp0], %[tmp1], %[tmp0]\n\t" "vpmaddwd %[tmp0], %[ones], %[tmp0]\n\t" + "vpmaddwd %[tmp1], %[ones], %[tmp1]\n\t" + "vpaddd %[tmp0], %[tmp1], %[tmp0]\n\t" "vpaddd %[acc], %[tmp0], %[acc]\n\t" - : [acc]"+v"(acc), [tmp0]"+&v"(tmp0) - : [tmp1]"v"(tmp1), [ones]"v"(_mm256_set1_epi16(1)) + : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1) + : [ones]"v"(_mm256_set1_epi16(1)) ); # else __m256i product0 = _mm256_maddubs_epi16(a0, b0); __m256i product1 = _mm256_maddubs_epi16(a1, b1); - product0 = _mm256_adds_epi16(product0, product1); product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1)); - acc = _mm256_add_epi32(acc, product0); + product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1)); + acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1)); # endif # endif } @@ -326,18 +328,19 @@ namespace Stockfish::Simd { __m128i tmp0 = _mm_maddubs_epi16(a0, b0); __m128i tmp1 = _mm_maddubs_epi16(a1, b1); asm( - "paddsw %[tmp1], %[tmp0]\n\t" "pmaddwd %[ones], %[tmp0]\n\t" + "pmaddwd %[ones], %[tmp1]\n\t" + "paddd %[tmp1], %[tmp0]\n\t" "paddd %[tmp0], %[acc]\n\t" - : [acc]"+v"(acc), [tmp0]"+&v"(tmp0) - : [tmp1]"v"(tmp1), [ones]"v"(_mm_set1_epi16(1)) + : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1) + : [ones]"v"(_mm_set1_epi16(1)) ); # else __m128i product0 = _mm_maddubs_epi16(a0, b0); __m128i product1 = _mm_maddubs_epi16(a1, b1); - product0 = _mm_adds_epi16(product0, product1); product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); - acc = _mm_add_epi32(acc, product0); + product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1)); + acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1)); # endif }