diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index ad9167c0..59a6149f 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -39,25 +39,26 @@ namespace Stockfish::Eval::NNUE::Layers { +#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD) + #define ENABLE_SEQ_OPT +#endif + // Fallback implementation for older/other architectures. // Requires the input to be padded to at least 16 values. -#if !defined(USE_SSSE3) +#ifndef ENABLE_SEQ_OPT + template static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, const std::int32_t* biases, const std::uint8_t* input) { - #if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON) + #if defined(USE_SSE2) || defined(USE_NEON) #if defined(USE_SSE2) // At least a multiple of 16, with SSE2. constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; const __m128i Zeros = _mm_setzero_si128(); const auto inputVector = reinterpret_cast(input); - #elif defined(USE_NEON_DOTPROD) - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; - const auto inputVector = reinterpret_cast(input); - #elif defined(USE_NEON) constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; const auto inputVector = reinterpret_cast(input); @@ -91,16 +92,8 @@ static void affine_transform_non_ssse3(std::int32_t* output, sum = _mm_add_epi32(sum, sum_second_32); output[i] = _mm_cvtsi128_si32(sum); - #elif defined(USE_NEON_DOTPROD) - int32x4_t sum = {biases[i]}; - const auto row = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumChunks; ++j) - { - sum = vdotq_s32(sum, inputVector[j], row[j]); - } - output[i] = vaddvq_s32(sum); - #elif defined(USE_NEON) + int32x4_t sum = {biases[i]}; const auto row = reinterpret_cast(&weights[offset]); for (IndexType j = 0; j < NumChunks; ++j) @@ -127,7 +120,8 @@ static void affine_transform_non_ssse3(std::int32_t* output, } #endif } -#endif + +#endif // !ENABLE_SEQ_OPT template class AffineTransform { @@ -162,7 +156,7 @@ class AffineTransform { } static constexpr IndexType get_weight_index(IndexType i) { -#if defined(USE_SSSE3) +#ifdef ENABLE_SEQ_OPT return get_weight_index_scrambled(i); #else return i; @@ -190,29 +184,28 @@ class AffineTransform { // Forward propagation void propagate(const InputType* input, OutputType* output) const { -#if defined(USE_SSSE3) +#ifdef ENABLE_SEQ_OPT if constexpr (OutputDimensions > 1) { - #if defined(USE_AVX512) using vec_t = __m512i; - #define vec_setzero _mm512_setzero_si512 #define vec_set_32 _mm512_set1_epi32 #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32 - #define vec_hadd Simd::m512_hadd #elif defined(USE_AVX2) using vec_t = __m256i; - #define vec_setzero _mm256_setzero_si256 #define vec_set_32 _mm256_set1_epi32 #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 - #define vec_hadd Simd::m256_hadd #elif defined(USE_SSSE3) using vec_t = __m128i; - #define vec_setzero _mm_setzero_si128 #define vec_set_32 _mm_set1_epi32 #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 - #define vec_hadd Simd::m128_hadd + #elif defined(USE_NEON_DOTPROD) + using vec_t = int32x4_t; + #define vec_set_32 vdupq_n_s32 + #define vec_add_dpbusd_32(acc, a, b) \ + Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \ + vreinterpretq_s8_s32(b)) #endif static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType); @@ -242,28 +235,33 @@ class AffineTransform { for (IndexType k = 0; k < NumRegs; ++k) outptr[k] = acc[k]; - #undef vec_setzero #undef vec_set_32 #undef vec_add_dpbusd_32 - #undef vec_hadd } else if constexpr (OutputDimensions == 1) { - // We cannot use AVX512 for the last layer because there are only 32 inputs // and the buffer is not padded to 64 elements. #if defined(USE_AVX2) using vec_t = __m256i; - #define vec_setzero _mm256_setzero_si256 + #define vec_setzero() _mm256_setzero_si256() #define vec_set_32 _mm256_set1_epi32 #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 #define vec_hadd Simd::m256_hadd #elif defined(USE_SSSE3) using vec_t = __m128i; - #define vec_setzero _mm_setzero_si128 + #define vec_setzero() _mm_setzero_si128() #define vec_set_32 _mm_set1_epi32 #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 #define vec_hadd Simd::m128_hadd + #elif defined(USE_NEON_DOTPROD) + using vec_t = int32x4_t; + #define vec_setzero() vdupq_n_s32(0) + #define vec_set_32 vdupq_n_s32 + #define vec_add_dpbusd_32(acc, a, b) \ + Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \ + vreinterpretq_s8_s32(b)) + #define vec_hadd Simd::neon_m128_hadd #endif const auto inputVector = reinterpret_cast(input);