From 27eb49a2211c90650ef64d5102e6e36ca5e69af0 Mon Sep 17 00:00:00 2001 From: cj5716 <125858804+cj5716@users.noreply.github.com> Date: Fri, 17 May 2024 18:05:12 +0800 Subject: [PATCH] Simplify ClippedReLU Removes some max calls Some speedup stats, courtesy of @AndyGrant (albeit measured in an alternate implementation) Dev 749240 nps Base 748495 nps Gain 0.100% 289936 games STC: LLR: 2.94 (-2.94,2.94) <-1.75,0.25> Total: 203040 W: 52213 L: 52179 D: 98648 Ptnml(0-2): 480, 20722, 59139, 20642, 537 https://tests.stockfishchess.org/tests/view/664805fe6dcff0d1d6b05f2c closes #5261 No functional change --- src/nnue/layers/clipped_relu.h | 48 ++++++++++++++++------------------ src/nnue/nnue_misc.cpp | 9 +++---- src/tune.cpp | 6 ++--- src/uci.cpp | 6 ++--- 4 files changed, 30 insertions(+), 39 deletions(-) diff --git a/src/nnue/layers/clipped_relu.h b/src/nnue/layers/clipped_relu.h index 813234c5..2ee378ad 100644 --- a/src/nnue/layers/clipped_relu.h +++ b/src/nnue/layers/clipped_relu.h @@ -65,41 +65,37 @@ class ClippedReLU { if constexpr (InputDimensions % SimdWidth == 0) { constexpr IndexType NumChunks = InputDimensions / SimdWidth; - const __m256i Zero = _mm256_setzero_si256(); const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); const auto in = reinterpret_cast(input); const auto out = reinterpret_cast<__m256i*>(output); for (IndexType i = 0; i < NumChunks; ++i) { const __m256i words0 = - _mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 0]), - _mm256_load_si256(&in[i * 4 + 1])), + _mm256_srli_epi16(_mm256_packus_epi32(_mm256_load_si256(&in[i * 4 + 0]), + _mm256_load_si256(&in[i * 4 + 1])), WeightScaleBits); const __m256i words1 = - _mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 2]), - _mm256_load_si256(&in[i * 4 + 3])), + _mm256_srli_epi16(_mm256_packus_epi32(_mm256_load_si256(&in[i * 4 + 2]), + _mm256_load_si256(&in[i * 4 + 3])), WeightScaleBits); - _mm256_store_si256( - &out[i], _mm256_permutevar8x32_epi32( - _mm256_max_epi8(_mm256_packs_epi16(words0, words1), Zero), Offsets)); + _mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32( + _mm256_packs_epi16(words0, words1), Offsets)); } } else { constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2); - const __m128i Zero = _mm_setzero_si128(); const auto in = reinterpret_cast(input); const auto out = reinterpret_cast<__m128i*>(output); for (IndexType i = 0; i < NumChunks; ++i) { - const __m128i words0 = _mm_srai_epi16( - _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])), + const __m128i words0 = _mm_srli_epi16( + _mm_packus_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])), WeightScaleBits); - const __m128i words1 = _mm_srai_epi16( - _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])), + const __m128i words1 = _mm_srli_epi16( + _mm_packus_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])), WeightScaleBits); - const __m128i packedbytes = _mm_packs_epi16(words0, words1); - _mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero)); + _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)); } } constexpr IndexType Start = InputDimensions % SimdWidth == 0 @@ -109,9 +105,7 @@ class ClippedReLU { #elif defined(USE_SSE2) constexpr IndexType NumChunks = InputDimensions / SimdWidth; - #ifdef USE_SSE41 - const __m128i Zero = _mm_setzero_si128(); - #else + #ifndef USE_SSE41 const __m128i k0x80s = _mm_set1_epi8(-128); #endif @@ -119,6 +113,15 @@ class ClippedReLU { const auto out = reinterpret_cast<__m128i*>(output); for (IndexType i = 0; i < NumChunks; ++i) { + #if defined(USE_SSE41) + const __m128i words0 = _mm_srli_epi16( + _mm_packus_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])), + WeightScaleBits); + const __m128i words1 = _mm_srli_epi16( + _mm_packus_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])), + WeightScaleBits); + _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)); + #else const __m128i words0 = _mm_srai_epi16( _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])), WeightScaleBits); @@ -126,15 +129,8 @@ class ClippedReLU { _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])), WeightScaleBits); const __m128i packedbytes = _mm_packs_epi16(words0, words1); - _mm_store_si128(&out[i], - - #ifdef USE_SSE41 - _mm_max_epi8(packedbytes, Zero) - #else - _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) + _mm_store_si128(&out[i], _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)); #endif - - ); } constexpr IndexType Start = NumChunks * SimdWidth; diff --git a/src/nnue/nnue_misc.cpp b/src/nnue/nnue_misc.cpp index a13c717c..b54bbaba 100644 --- a/src/nnue/nnue_misc.cpp +++ b/src/nnue/nnue_misc.cpp @@ -178,14 +178,11 @@ trace(Position& pos, const Eval::NNUE::Networks& networks, Eval::NNUE::Accumulat ss << "| " << bucket << " "; ss << " | "; format_cp_aligned_dot(t.psqt[bucket], ss, pos); - ss << " " - << " | "; + ss << " " << " | "; format_cp_aligned_dot(t.positional[bucket], ss, pos); - ss << " " - << " | "; + ss << " " << " | "; format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss, pos); - ss << " " - << " |"; + ss << " " << " |"; if (bucket == t.correctBucket) ss << " <-- this bucket is used"; ss << '\n'; diff --git a/src/tune.cpp b/src/tune.cpp index 3e5ebe5e..84f59524 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -59,8 +59,7 @@ void make_option(OptionsMap* options, const string& n, int v, const SetRange& r) // Print formatted parameters, ready to be copy-pasted in Fishtest std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << "," - << (r(v).second - r(v).first) / 20.0 << "," - << "0.0020" << std::endl; + << (r(v).second - r(v).first) / 20.0 << "," << "0.0020" << std::endl; } } @@ -118,7 +117,6 @@ void Tune::Entry::read_option() { namespace Stockfish { -void Tune::read_results() { /* ...insert your values here... */ -} +void Tune::read_results() { /* ...insert your values here... */ } } // namespace Stockfish diff --git a/src/uci.cpp b/src/uci.cpp index cb686a02..cb9d7b08 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -286,9 +286,9 @@ void UCIEngine::bench(std::istream& args) { dbg_print(); - std::cerr << "\n===========================" - << "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes - << "\nNodes/second : " << 1000 * nodes / elapsed << std::endl; + std::cerr << "\n===========================" << "\nTotal time (ms) : " << elapsed + << "\nNodes searched : " << nodes << "\nNodes/second : " << 1000 * nodes / elapsed + << std::endl; // reset callback, to not capture a dangling reference to nodesSearched engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });