diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index caf315b2..0e0515f9 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -680,9 +680,8 @@ namespace Eval::NNUE::Layers { for (IndexType j = 0; j < kNumChunks; ++j) { __m128i row_j = _mm_load_si128(&row[j]); __m128i input_j = _mm_load_si128(&input_vector[j]); - __m128i row_signs = _mm_cmpgt_epi8(kZeros, row_j); - __m128i extended_row_lo = _mm_unpacklo_epi8(row_j, row_signs); - __m128i extended_row_hi = _mm_unpackhi_epi8(row_j, row_signs); + __m128i extended_row_lo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8); + __m128i extended_row_hi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8); __m128i extended_input_lo = _mm_unpacklo_epi8(input_j, kZeros); __m128i extended_input_hi = _mm_unpackhi_epi8(input_j, kZeros); __m128i product_lo = _mm_madd_epi16(extended_row_lo, extended_input_lo); @@ -704,9 +703,8 @@ namespace Eval::NNUE::Layers { for (IndexType j = 0; j < kNumChunks; ++j) { __m64 row_j = row[j]; __m64 input_j = input_vector[j]; - __m64 row_signs = _mm_cmpgt_pi8(kZeros, row_j); - __m64 extended_row_lo = _mm_unpacklo_pi8(row_j, row_signs); - __m64 extended_row_hi = _mm_unpackhi_pi8(row_j, row_signs); + __m64 extended_row_lo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8); + __m64 extended_row_hi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8); __m64 extended_input_lo = _mm_unpacklo_pi8(input_j, kZeros); __m64 extended_input_hi = _mm_unpackhi_pi8(input_j, kZeros); __m64 product_lo = _mm_madd_pi16(extended_row_lo, extended_input_lo);