mirror of
https://github.com/sockspls/badfish
synced 2025-04-29 16:23:09 +00:00
Move DotProd code into optimized affine layer
This patch moves the DotProd code into the propagation function which has sequential access optimization. To prove the speedup, the comparison is done without the sparse layer. With the sparse layer the effect is marginal (GCC 0.3%, LLVM/Clang 0.1%). For both tests, binary is compiled with GCC 14.1. Each test had 50 runs. Sparse layer included: ``` speedup = +0.0030 P(speedup > 0) = 1.0000 ``` Sparse layer excluded: ``` speedup = +0.0561 P(speedup > 0) = 1.0000 ``` closes https://github.com/official-stockfish/Stockfish/pull/5520 No functional change
This commit is contained in:
parent
8e560c4fd3
commit
b976f0a101
1 changed files with 28 additions and 30 deletions
|
@ -39,25 +39,26 @@
|
|||
|
||||
namespace Stockfish::Eval::NNUE::Layers {
|
||||
|
||||
#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)
|
||||
#define ENABLE_SEQ_OPT
|
||||
#endif
|
||||
|
||||
// Fallback implementation for older/other architectures.
|
||||
// Requires the input to be padded to at least 16 values.
|
||||
#if !defined(USE_SSSE3)
|
||||
#ifndef ENABLE_SEQ_OPT
|
||||
|
||||
template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
|
||||
static void affine_transform_non_ssse3(std::int32_t* output,
|
||||
const std::int8_t* weights,
|
||||
const std::int32_t* biases,
|
||||
const std::uint8_t* input) {
|
||||
#if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON)
|
||||
#if defined(USE_SSE2) || defined(USE_NEON)
|
||||
#if defined(USE_SSE2)
|
||||
// At least a multiple of 16, with SSE2.
|
||||
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
|
||||
const __m128i Zeros = _mm_setzero_si128();
|
||||
const auto inputVector = reinterpret_cast<const __m128i*>(input);
|
||||
|
||||
#elif defined(USE_NEON_DOTPROD)
|
||||
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
|
||||
const auto inputVector = reinterpret_cast<const int8x16_t*>(input);
|
||||
|
||||
#elif defined(USE_NEON)
|
||||
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
|
||||
const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
|
||||
|
@ -91,16 +92,8 @@ static void affine_transform_non_ssse3(std::int32_t* output,
|
|||
sum = _mm_add_epi32(sum, sum_second_32);
|
||||
output[i] = _mm_cvtsi128_si32(sum);
|
||||
|
||||
#elif defined(USE_NEON_DOTPROD)
|
||||
int32x4_t sum = {biases[i]};
|
||||
const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]);
|
||||
for (IndexType j = 0; j < NumChunks; ++j)
|
||||
{
|
||||
sum = vdotq_s32(sum, inputVector[j], row[j]);
|
||||
}
|
||||
output[i] = vaddvq_s32(sum);
|
||||
|
||||
#elif defined(USE_NEON)
|
||||
|
||||
int32x4_t sum = {biases[i]};
|
||||
const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
|
||||
for (IndexType j = 0; j < NumChunks; ++j)
|
||||
|
@ -127,7 +120,8 @@ static void affine_transform_non_ssse3(std::int32_t* output,
|
|||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // !ENABLE_SEQ_OPT
|
||||
|
||||
template<IndexType InDims, IndexType OutDims>
|
||||
class AffineTransform {
|
||||
|
@ -162,7 +156,7 @@ class AffineTransform {
|
|||
}
|
||||
|
||||
static constexpr IndexType get_weight_index(IndexType i) {
|
||||
#if defined(USE_SSSE3)
|
||||
#ifdef ENABLE_SEQ_OPT
|
||||
return get_weight_index_scrambled(i);
|
||||
#else
|
||||
return i;
|
||||
|
@ -190,29 +184,28 @@ class AffineTransform {
|
|||
// Forward propagation
|
||||
void propagate(const InputType* input, OutputType* output) const {
|
||||
|
||||
#if defined(USE_SSSE3)
|
||||
#ifdef ENABLE_SEQ_OPT
|
||||
|
||||
if constexpr (OutputDimensions > 1)
|
||||
{
|
||||
|
||||
#if defined(USE_AVX512)
|
||||
using vec_t = __m512i;
|
||||
#define vec_setzero _mm512_setzero_si512
|
||||
#define vec_set_32 _mm512_set1_epi32
|
||||
#define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
|
||||
#define vec_hadd Simd::m512_hadd
|
||||
#elif defined(USE_AVX2)
|
||||
using vec_t = __m256i;
|
||||
#define vec_setzero _mm256_setzero_si256
|
||||
#define vec_set_32 _mm256_set1_epi32
|
||||
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
|
||||
#define vec_hadd Simd::m256_hadd
|
||||
#elif defined(USE_SSSE3)
|
||||
using vec_t = __m128i;
|
||||
#define vec_setzero _mm_setzero_si128
|
||||
#define vec_set_32 _mm_set1_epi32
|
||||
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
|
||||
#define vec_hadd Simd::m128_hadd
|
||||
#elif defined(USE_NEON_DOTPROD)
|
||||
using vec_t = int32x4_t;
|
||||
#define vec_set_32 vdupq_n_s32
|
||||
#define vec_add_dpbusd_32(acc, a, b) \
|
||||
Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \
|
||||
vreinterpretq_s8_s32(b))
|
||||
#endif
|
||||
|
||||
static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
|
||||
|
@ -242,28 +235,33 @@ class AffineTransform {
|
|||
for (IndexType k = 0; k < NumRegs; ++k)
|
||||
outptr[k] = acc[k];
|
||||
|
||||
#undef vec_setzero
|
||||
#undef vec_set_32
|
||||
#undef vec_add_dpbusd_32
|
||||
#undef vec_hadd
|
||||
}
|
||||
else if constexpr (OutputDimensions == 1)
|
||||
{
|
||||
|
||||
// We cannot use AVX512 for the last layer because there are only 32 inputs
|
||||
// and the buffer is not padded to 64 elements.
|
||||
#if defined(USE_AVX2)
|
||||
using vec_t = __m256i;
|
||||
#define vec_setzero _mm256_setzero_si256
|
||||
#define vec_setzero() _mm256_setzero_si256()
|
||||
#define vec_set_32 _mm256_set1_epi32
|
||||
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
|
||||
#define vec_hadd Simd::m256_hadd
|
||||
#elif defined(USE_SSSE3)
|
||||
using vec_t = __m128i;
|
||||
#define vec_setzero _mm_setzero_si128
|
||||
#define vec_setzero() _mm_setzero_si128()
|
||||
#define vec_set_32 _mm_set1_epi32
|
||||
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
|
||||
#define vec_hadd Simd::m128_hadd
|
||||
#elif defined(USE_NEON_DOTPROD)
|
||||
using vec_t = int32x4_t;
|
||||
#define vec_setzero() vdupq_n_s32(0)
|
||||
#define vec_set_32 vdupq_n_s32
|
||||
#define vec_add_dpbusd_32(acc, a, b) \
|
||||
Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \
|
||||
vreinterpretq_s8_s32(b))
|
||||
#define vec_hadd Simd::neon_m128_hadd
|
||||
#endif
|
||||
|
||||
const auto inputVector = reinterpret_cast<const vec_t*>(input);
|
||||
|
|
Loading…
Add table
Reference in a new issue