mirror of
https://github.com/sockspls/badfish
synced 2025-04-29 16:23:09 +00:00
Simplify ClippedReLU
Removes some max calls Some speedup stats, courtesy of @AndyGrant (albeit measured in an alternate implementation) Dev 749240 nps Base 748495 nps Gain 0.100% 289936 games STC: LLR: 2.94 (-2.94,2.94) <-1.75,0.25> Total: 203040 W: 52213 L: 52179 D: 98648 Ptnml(0-2): 480, 20722, 59139, 20642, 537 https://tests.stockfishchess.org/tests/view/664805fe6dcff0d1d6b05f2c closes #5261 No functional change
This commit is contained in:
parent
2d32581623
commit
27eb49a221
4 changed files with 30 additions and 39 deletions
|
@ -65,41 +65,37 @@ class ClippedReLU {
|
||||||
if constexpr (InputDimensions % SimdWidth == 0)
|
if constexpr (InputDimensions % SimdWidth == 0)
|
||||||
{
|
{
|
||||||
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
|
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
|
||||||
const __m256i Zero = _mm256_setzero_si256();
|
|
||||||
const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
|
const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
|
||||||
const auto in = reinterpret_cast<const __m256i*>(input);
|
const auto in = reinterpret_cast<const __m256i*>(input);
|
||||||
const auto out = reinterpret_cast<__m256i*>(output);
|
const auto out = reinterpret_cast<__m256i*>(output);
|
||||||
for (IndexType i = 0; i < NumChunks; ++i)
|
for (IndexType i = 0; i < NumChunks; ++i)
|
||||||
{
|
{
|
||||||
const __m256i words0 =
|
const __m256i words0 =
|
||||||
_mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 0]),
|
_mm256_srli_epi16(_mm256_packus_epi32(_mm256_load_si256(&in[i * 4 + 0]),
|
||||||
_mm256_load_si256(&in[i * 4 + 1])),
|
_mm256_load_si256(&in[i * 4 + 1])),
|
||||||
WeightScaleBits);
|
WeightScaleBits);
|
||||||
const __m256i words1 =
|
const __m256i words1 =
|
||||||
_mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 2]),
|
_mm256_srli_epi16(_mm256_packus_epi32(_mm256_load_si256(&in[i * 4 + 2]),
|
||||||
_mm256_load_si256(&in[i * 4 + 3])),
|
_mm256_load_si256(&in[i * 4 + 3])),
|
||||||
WeightScaleBits);
|
WeightScaleBits);
|
||||||
_mm256_store_si256(
|
_mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(
|
||||||
&out[i], _mm256_permutevar8x32_epi32(
|
_mm256_packs_epi16(words0, words1), Offsets));
|
||||||
_mm256_max_epi8(_mm256_packs_epi16(words0, words1), Zero), Offsets));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
|
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
|
||||||
const __m128i Zero = _mm_setzero_si128();
|
|
||||||
const auto in = reinterpret_cast<const __m128i*>(input);
|
const auto in = reinterpret_cast<const __m128i*>(input);
|
||||||
const auto out = reinterpret_cast<__m128i*>(output);
|
const auto out = reinterpret_cast<__m128i*>(output);
|
||||||
for (IndexType i = 0; i < NumChunks; ++i)
|
for (IndexType i = 0; i < NumChunks; ++i)
|
||||||
{
|
{
|
||||||
const __m128i words0 = _mm_srai_epi16(
|
const __m128i words0 = _mm_srli_epi16(
|
||||||
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
|
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
|
||||||
WeightScaleBits);
|
WeightScaleBits);
|
||||||
const __m128i words1 = _mm_srai_epi16(
|
const __m128i words1 = _mm_srli_epi16(
|
||||||
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
|
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
|
||||||
WeightScaleBits);
|
WeightScaleBits);
|
||||||
const __m128i packedbytes = _mm_packs_epi16(words0, words1);
|
_mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
|
||||||
_mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
constexpr IndexType Start = InputDimensions % SimdWidth == 0
|
constexpr IndexType Start = InputDimensions % SimdWidth == 0
|
||||||
|
@ -109,9 +105,7 @@ class ClippedReLU {
|
||||||
#elif defined(USE_SSE2)
|
#elif defined(USE_SSE2)
|
||||||
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
|
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
|
||||||
|
|
||||||
#ifdef USE_SSE41
|
#ifndef USE_SSE41
|
||||||
const __m128i Zero = _mm_setzero_si128();
|
|
||||||
#else
|
|
||||||
const __m128i k0x80s = _mm_set1_epi8(-128);
|
const __m128i k0x80s = _mm_set1_epi8(-128);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -119,6 +113,15 @@ class ClippedReLU {
|
||||||
const auto out = reinterpret_cast<__m128i*>(output);
|
const auto out = reinterpret_cast<__m128i*>(output);
|
||||||
for (IndexType i = 0; i < NumChunks; ++i)
|
for (IndexType i = 0; i < NumChunks; ++i)
|
||||||
{
|
{
|
||||||
|
#if defined(USE_SSE41)
|
||||||
|
const __m128i words0 = _mm_srli_epi16(
|
||||||
|
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
|
||||||
|
WeightScaleBits);
|
||||||
|
const __m128i words1 = _mm_srli_epi16(
|
||||||
|
_mm_packus_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
|
||||||
|
WeightScaleBits);
|
||||||
|
_mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
|
||||||
|
#else
|
||||||
const __m128i words0 = _mm_srai_epi16(
|
const __m128i words0 = _mm_srai_epi16(
|
||||||
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
|
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
|
||||||
WeightScaleBits);
|
WeightScaleBits);
|
||||||
|
@ -126,15 +129,8 @@ class ClippedReLU {
|
||||||
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
|
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
|
||||||
WeightScaleBits);
|
WeightScaleBits);
|
||||||
const __m128i packedbytes = _mm_packs_epi16(words0, words1);
|
const __m128i packedbytes = _mm_packs_epi16(words0, words1);
|
||||||
_mm_store_si128(&out[i],
|
_mm_store_si128(&out[i], _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s));
|
||||||
|
|
||||||
#ifdef USE_SSE41
|
|
||||||
_mm_max_epi8(packedbytes, Zero)
|
|
||||||
#else
|
|
||||||
_mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
constexpr IndexType Start = NumChunks * SimdWidth;
|
constexpr IndexType Start = NumChunks * SimdWidth;
|
||||||
|
|
||||||
|
|
|
@ -178,14 +178,11 @@ trace(Position& pos, const Eval::NNUE::Networks& networks, Eval::NNUE::Accumulat
|
||||||
ss << "| " << bucket << " ";
|
ss << "| " << bucket << " ";
|
||||||
ss << " | ";
|
ss << " | ";
|
||||||
format_cp_aligned_dot(t.psqt[bucket], ss, pos);
|
format_cp_aligned_dot(t.psqt[bucket], ss, pos);
|
||||||
ss << " "
|
ss << " " << " | ";
|
||||||
<< " | ";
|
|
||||||
format_cp_aligned_dot(t.positional[bucket], ss, pos);
|
format_cp_aligned_dot(t.positional[bucket], ss, pos);
|
||||||
ss << " "
|
ss << " " << " | ";
|
||||||
<< " | ";
|
|
||||||
format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss, pos);
|
format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss, pos);
|
||||||
ss << " "
|
ss << " " << " |";
|
||||||
<< " |";
|
|
||||||
if (bucket == t.correctBucket)
|
if (bucket == t.correctBucket)
|
||||||
ss << " <-- this bucket is used";
|
ss << " <-- this bucket is used";
|
||||||
ss << '\n';
|
ss << '\n';
|
||||||
|
|
|
@ -59,8 +59,7 @@ void make_option(OptionsMap* options, const string& n, int v, const SetRange& r)
|
||||||
|
|
||||||
// Print formatted parameters, ready to be copy-pasted in Fishtest
|
// Print formatted parameters, ready to be copy-pasted in Fishtest
|
||||||
std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << ","
|
std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << ","
|
||||||
<< (r(v).second - r(v).first) / 20.0 << ","
|
<< (r(v).second - r(v).first) / 20.0 << "," << "0.0020" << std::endl;
|
||||||
<< "0.0020" << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +117,6 @@ void Tune::Entry<Tune::PostUpdate>::read_option() {
|
||||||
|
|
||||||
namespace Stockfish {
|
namespace Stockfish {
|
||||||
|
|
||||||
void Tune::read_results() { /* ...insert your values here... */
|
void Tune::read_results() { /* ...insert your values here... */ }
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace Stockfish
|
} // namespace Stockfish
|
||||||
|
|
|
@ -286,9 +286,9 @@ void UCIEngine::bench(std::istream& args) {
|
||||||
|
|
||||||
dbg_print();
|
dbg_print();
|
||||||
|
|
||||||
std::cerr << "\n==========================="
|
std::cerr << "\n===========================" << "\nTotal time (ms) : " << elapsed
|
||||||
<< "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes
|
<< "\nNodes searched : " << nodes << "\nNodes/second : " << 1000 * nodes / elapsed
|
||||||
<< "\nNodes/second : " << 1000 * nodes / elapsed << std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
// reset callback, to not capture a dangling reference to nodesSearched
|
// reset callback, to not capture a dangling reference to nodesSearched
|
||||||
engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
|
engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
|
||||||
|
|
Loading…
Add table
Reference in a new issue