1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-06-28 00:19:50 +00:00

add clang-format

This introduces clang-format to enforce a consistent code style for Stockfish.

Having a documented and consistent style across the code will make contributing easier
for new developers, and will make larger changes to the codebase easier to make.

To facilitate formatting, this PR includes a Makefile target (`make format`) to format the code,
this requires clang-format (version 17 currently) to be installed locally.

Installing clang-format is straightforward on most OS and distros
(e.g. with https://apt.llvm.org/, brew install clang-format, etc), as this is part of quite commonly
used suite of tools and compilers (llvm / clang).

Additionally, a CI action is present that will verify if the code requires formatting,
and comment on the PR as needed. Initially, correct formatting is not required, it will be
done by maintainers as part of the merge or in later commits, but obviously this is encouraged.

fixes https://github.com/official-stockfish/Stockfish/issues/3608
closes https://github.com/official-stockfish/Stockfish/pull/4790

Co-Authored-By: Joost VandeVondele <Joost.VandeVondele@gmail.com>
This commit is contained in:
Disservin 2023-10-21 11:40:56 +02:00 committed by Joost VandeVondele
parent 8366ec48ae
commit 2d0237db3f
49 changed files with 6403 additions and 6197 deletions

44
.clang-format Normal file
View file

@ -0,0 +1,44 @@
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: Consecutive
AlignConsecutiveDeclarations: Consecutive
AlignEscapedNewlines: DontAlign
AlignOperands: AlignAfterOperator
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortCaseLabelsOnASingleLine: false
AllowShortEnumsOnASingleLine: false
AllowShortIfStatementsOnASingleLine: false
AlwaysBreakTemplateDeclarations: Yes
BasedOnStyle: WebKit
BitFieldColonSpacing: After
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeBraces: Custom
BraceWrapping:
AfterFunction: false
AfterClass: false
AfterControlStatement: true
BeforeElse: true
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakStringLiterals: false
ColumnLimit: 100
ContinuationIndentWidth: 2
Cpp11BracedListStyle: true
IndentGotoLabels: false
IndentPPDirectives: BeforeHash
IndentWidth: 4
MaxEmptyLinesToKeep: 2
NamespaceIndentation: None
PackConstructorInitializers: Never
ReflowComments: false
SortIncludes: false
SortUsingDeclarations: false
SpaceAfterCStyleCast: true
SpaceAfterTemplateKeyword: false
SpaceBeforeCaseColon: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeInheritanceColon: false
SpaceInEmptyBlock: false
SpacesBeforeTrailingComments: 2

View file

@ -0,0 +1,51 @@
# This workflow will run clang-format and comment on the PR.
# Because of security reasons, it is crucial that this workflow
# executes no shell script nor runs make.
# Read this before editing: https://securitylab.github.com/research/github-actions-preventing-pwn-requests/
name: Stockfish
on:
pull_request_target:
branches:
- 'master'
paths:
- '**.cpp'
- '**.h'
jobs:
Stockfish:
name: clang-format check
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Run clang-format style check
uses: jidicula/clang-format-action@f62da5e3d3a2d88ff364771d9d938773a618ab5e
id: clang-format
continue-on-error: true
with:
clang-format-version: '17'
exclude-regex: 'incbin'
- name: Comment on PR
if: steps.clang-format.outcome == 'failure'
uses: thollander/actions-comment-pull-request@1d3973dc4b8e1399c0620d3f2b1aa5e795465308
with:
message: |
clang-format 17 needs to be run on this PR.
If you do not have clang-format installed, the maintainer will run it when merging.
For the exact version please see https://packages.ubuntu.com/mantic/clang-format-17.
_(execution **${{ github.run_id }}** / attempt **${{ github.run_attempt }}**)_
comment_tag: execution
- name: Comment on PR
if: steps.clang-format.outcome != 'failure'
uses: thollander/actions-comment-pull-request@1d3973dc4b8e1399c0620d3f2b1aa5e795465308
with:
message: |
_(execution **${{ github.run_id }}** / attempt **${{ github.run_attempt }}**)_
create_if_not_exists: false
comment_tag: execution
mode: delete

View file

@ -57,8 +57,9 @@ discussion._
## Code Style ## Code Style
We do not have a strict code style. But it is best to stick to the existing Changes to Stockfish C++ code should respect our coding style defined by
style of the file you are editing. [.clang-format](.clang-format). You can format your changes by running
`make format`. This requires clang-format version 17 to be installed on your system.
## Community and Communication ## Community and Communication

View file

@ -57,6 +57,14 @@ SRCS = benchmark.cpp bitboard.cpp evaluate.cpp main.cpp \
search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \ search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \
nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp
HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \
nnue/evaluate_nnue.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \
nnue/layers/affine_transform_sparse_input.h nnue/layers/clipped_relu.h nnue/layers/simd.h \
nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \
nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \
search.h syzygy/tbprobe.h thread.h thread_win32_osx.h timeman.h \
tt.h tune.h types.h uci.h
OBJS = $(notdir $(SRCS:.cpp=.o)) OBJS = $(notdir $(SRCS:.cpp=.o))
VPATH = syzygy:nnue:nnue/features VPATH = syzygy:nnue:nnue/features
@ -145,6 +153,12 @@ dotprod = no
arm_version = 0 arm_version = 0
STRIP = strip STRIP = strip
ifneq ($(shell command -v clang-format-17),)
CLANG-FORMAT = clang-format-17
else
CLANG-FORMAT = clang-format
endif
### 2.2 Architecture specific ### 2.2 Architecture specific
ifeq ($(findstring x86,$(ARCH)),x86) ifeq ($(findstring x86,$(ARCH)),x86)
@ -936,6 +950,9 @@ net: netvariables
fi; \ fi; \
fi; \ fi; \
format:
$(CLANG-FORMAT) -i $(SRCS) $(HEADERS) -style=file:../.clang-format
# default target # default target
default: default:
help help

View file

@ -27,6 +27,7 @@
namespace { namespace {
// clang-format off
const std::vector<std::string> Defaults = { const std::vector<std::string> Defaults = {
"setoption name UCI_Chess960 value false", "setoption name UCI_Chess960 value false",
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
@ -90,6 +91,7 @@ const std::vector<std::string> Defaults = {
"nqbnrkrb/pppppppp/8/8/8/8/PPPPPPPP/NQBNRKRB w KQkq - 0 1", "nqbnrkrb/pppppppp/8/8/8/8/PPPPPPPP/NQBNRKRB w KQkq - 0 1",
"setoption name UCI_Chess960 value false" "setoption name UCI_Chess960 value false"
}; };
// clang-format on
} // namespace } // namespace

View file

@ -110,7 +110,8 @@ void Bitboards::init() {
if (PseudoAttacks[pt][s1] & s2) if (PseudoAttacks[pt][s1] & s2)
{ {
LineBB[s1][s2] = (attacks_bb(pt, s1, 0) & attacks_bb(pt, s2, 0)) | s1 | s2; LineBB[s1][s2] = (attacks_bb(pt, s1, 0) & attacks_bb(pt, s2, 0)) | s1 | s2;
BetweenBB[s1][s2] = (attacks_bb(pt, s1, square_bb(s2)) & attacks_bb(pt, s2, square_bb(s1))); BetweenBB[s1][s2] =
(attacks_bb(pt, s1, square_bb(s2)) & attacks_bb(pt, s2, square_bb(s1)));
} }
BetweenBB[s1][s2] |= s2; BetweenBB[s1][s2] |= s2;
} }
@ -171,7 +172,8 @@ namespace {
// Use Carry-Rippler trick to enumerate all subsets of masks[s] and // Use Carry-Rippler trick to enumerate all subsets of masks[s] and
// store the corresponding sliding attack bitboard in reference[]. // store the corresponding sliding attack bitboard in reference[].
b = size = 0; b = size = 0;
do { do
{
occupancy[size] = b; occupancy[size] = b;
reference[size] = sliding_attack(pt, s, b); reference[size] = sliding_attack(pt, s, b);

View file

@ -110,40 +110,35 @@ inline Bitboard operator^(Square s, Bitboard b) { return b ^ s; }
inline Bitboard operator|(Square s1, Square s2) { return square_bb(s1) | s2; } inline Bitboard operator|(Square s1, Square s2) { return square_bb(s1) | s2; }
constexpr bool more_than_one(Bitboard b) { constexpr bool more_than_one(Bitboard b) { return b & (b - 1); }
return b & (b - 1);
}
// rank_bb() and file_bb() return a bitboard representing all the squares on // rank_bb() and file_bb() return a bitboard representing all the squares on
// the given file or rank. // the given file or rank.
constexpr Bitboard rank_bb(Rank r) { constexpr Bitboard rank_bb(Rank r) { return Rank1BB << (8 * r); }
return Rank1BB << (8 * r);
}
constexpr Bitboard rank_bb(Square s) { constexpr Bitboard rank_bb(Square s) { return rank_bb(rank_of(s)); }
return rank_bb(rank_of(s));
}
constexpr Bitboard file_bb(File f) { constexpr Bitboard file_bb(File f) { return FileABB << f; }
return FileABB << f;
}
constexpr Bitboard file_bb(Square s) { constexpr Bitboard file_bb(Square s) { return file_bb(file_of(s)); }
return file_bb(file_of(s));
}
// shift() moves a bitboard one or two steps as specified by the direction D // shift() moves a bitboard one or two steps as specified by the direction D
template<Direction D> template<Direction D>
constexpr Bitboard shift(Bitboard b) { constexpr Bitboard shift(Bitboard b) {
return D == NORTH ? b << 8 : D == SOUTH ? b >> 8 return D == NORTH ? b << 8
: D == NORTH+NORTH? b <<16 : D == SOUTH+SOUTH? b >>16 : D == SOUTH ? b >> 8
: D == EAST ? (b & ~FileHBB) << 1 : D == WEST ? (b & ~FileABB) >> 1 : D == NORTH + NORTH ? b << 16
: D == NORTH_EAST ? (b & ~FileHBB) << 9 : D == NORTH_WEST ? (b & ~FileABB) << 7 : D == SOUTH + SOUTH ? b >> 16
: D == SOUTH_EAST ? (b & ~FileHBB) >> 7 : D == SOUTH_WEST ? (b & ~FileABB) >> 9 : D == EAST ? (b & ~FileHBB) << 1
: D == WEST ? (b & ~FileABB) >> 1
: D == NORTH_EAST ? (b & ~FileHBB) << 9
: D == NORTH_WEST ? (b & ~FileABB) << 7
: D == SOUTH_EAST ? (b & ~FileHBB) >> 7
: D == SOUTH_WEST ? (b & ~FileABB) >> 9
: 0; : 0;
} }
@ -194,18 +189,26 @@ inline Bitboard between_bb(Square s1, Square s2) {
// aligned() returns true if the squares s1, s2 and s3 are aligned either on a // aligned() returns true if the squares s1, s2 and s3 are aligned either on a
// straight or on a diagonal line. // straight or on a diagonal line.
inline bool aligned(Square s1, Square s2, Square s3) { inline bool aligned(Square s1, Square s2, Square s3) { return line_bb(s1, s2) & s3; }
return line_bb(s1, s2) & s3;
}
// distance() functions return the distance between x and y, defined as the // distance() functions return the distance between x and y, defined as the
// number of steps for a king in x to reach y. // number of steps for a king in x to reach y.
template<typename T1 = Square> inline int distance(Square x, Square y); template<typename T1 = Square>
template<> inline int distance<File>(Square x, Square y) { return std::abs(file_of(x) - file_of(y)); } inline int distance(Square x, Square y);
template<> inline int distance<Rank>(Square x, Square y) { return std::abs(rank_of(x) - rank_of(y)); } template<>
template<> inline int distance<Square>(Square x, Square y) { return SquareDistance[x][y]; } inline int distance<File>(Square x, Square y) {
return std::abs(file_of(x) - file_of(y));
}
template<>
inline int distance<Rank>(Square x, Square y) {
return std::abs(rank_of(x) - rank_of(y));
}
template<>
inline int distance<Square>(Square x, Square y) {
return SquareDistance[x][y];
}
inline int edge_distance(File f) { return std::min(f, File(FILE_H - f)); } inline int edge_distance(File f) { return std::min(f, File(FILE_H - f)); }
@ -232,10 +235,14 @@ inline Bitboard attacks_bb(Square s, Bitboard occupied) {
switch (Pt) switch (Pt)
{ {
case BISHOP: return BishopMagics[s].attacks[BishopMagics[s].index(occupied)]; case BISHOP :
case ROOK : return RookMagics[s].attacks[ RookMagics[s].index(occupied)]; return BishopMagics[s].attacks[BishopMagics[s].index(occupied)];
case QUEEN : return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied); case ROOK :
default : return PseudoAttacks[Pt][s]; return RookMagics[s].attacks[RookMagics[s].index(occupied)];
case QUEEN :
return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied);
default :
return PseudoAttacks[Pt][s];
} }
} }
@ -245,10 +252,14 @@ inline Bitboard attacks_bb(PieceType pt, Square s, Bitboard occupied) {
switch (pt) switch (pt)
{ {
case BISHOP: return attacks_bb<BISHOP>(s, occupied); case BISHOP :
case ROOK : return attacks_bb< ROOK>(s, occupied); return attacks_bb<BISHOP>(s, occupied);
case QUEEN : return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied); case ROOK :
default : return PseudoAttacks[pt][s]; return attacks_bb<ROOK>(s, occupied);
case QUEEN :
return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied);
default :
return PseudoAttacks[pt][s];
} }
} }
@ -259,7 +270,10 @@ inline int popcount(Bitboard b) {
#ifndef USE_POPCNT #ifndef USE_POPCNT
union { Bitboard bb; uint16_t u[4]; } v = { b }; union {
Bitboard bb;
uint16_t u[4];
} v = {b};
return PopCnt16[v.u[0]] + PopCnt16[v.u[1]] + PopCnt16[v.u[2]] + PopCnt16[v.u[3]]; return PopCnt16[v.u[0]] + PopCnt16[v.u[1]] + PopCnt16[v.u[2]] + PopCnt16[v.u[3]];
#elif defined(_MSC_VER) #elif defined(_MSC_VER)
@ -312,10 +326,13 @@ inline Square lsb(Bitboard b) {
assert(b); assert(b);
unsigned long idx; unsigned long idx;
if (b & 0xffffffff) { if (b & 0xffffffff)
{
_BitScanForward(&idx, int32_t(b)); _BitScanForward(&idx, int32_t(b));
return Square(idx); return Square(idx);
} else { }
else
{
_BitScanForward(&idx, int32_t(b >> 32)); _BitScanForward(&idx, int32_t(b >> 32));
return Square(idx + 32); return Square(idx + 32);
} }
@ -325,10 +342,13 @@ inline Square msb(Bitboard b) {
assert(b); assert(b);
unsigned long idx; unsigned long idx;
if (b >> 32) { if (b >> 32)
{
_BitScanReverse(&idx, int32_t(b >> 32)); _BitScanReverse(&idx, int32_t(b >> 32));
return Square(idx + 32); return Square(idx + 32);
} else { }
else
{
_BitScanReverse(&idx, int32_t(b)); _BitScanReverse(&idx, int32_t(b));
return Square(idx); return Square(idx);
} }

View file

@ -72,7 +72,8 @@ namespace Eval {
eval_file = EvalFileDefaultName; eval_file = EvalFileDefaultName;
#if defined(DEFAULT_NNUE_DIRECTORY) #if defined(DEFAULT_NNUE_DIRECTORY)
std::vector<std::string> dirs = { "<internal>" , "" , CommandLine::binaryDirectory , stringify(DEFAULT_NNUE_DIRECTORY) }; std::vector<std::string> dirs = {"<internal>", "", CommandLine::binaryDirectory,
stringify(DEFAULT_NNUE_DIRECTORY)};
#else #else
std::vector<std::string> dirs = {"<internal>", "", CommandLine::binaryDirectory}; std::vector<std::string> dirs = {"<internal>", "", CommandLine::binaryDirectory};
#endif #endif
@ -91,10 +92,15 @@ namespace Eval {
{ {
// C++ way to prepare a buffer for a memory stream // C++ way to prepare a buffer for a memory stream
class MemoryBuffer: public std::basic_streambuf<char> { class MemoryBuffer: public std::basic_streambuf<char> {
public: MemoryBuffer(char* p, size_t n) { setg(p, p, p + n); setp(p, p + n); } public:
MemoryBuffer(char* p, size_t n) {
setg(p, p, p + n);
setp(p, p + n);
}
}; };
MemoryBuffer buffer(const_cast<char*>(reinterpret_cast<const char*>(gEmbeddedNNUEData)), MemoryBuffer buffer(
const_cast<char*>(reinterpret_cast<const char*>(gEmbeddedNNUEData)),
size_t(gEmbeddedNNUESize)); size_t(gEmbeddedNNUESize));
(void) gEmbeddedNNUEEnd; // Silence warning on unused variable (void) gEmbeddedNNUEEnd; // Silence warning on unused variable
@ -115,10 +121,14 @@ namespace Eval {
if (currentEvalFileName != eval_file) if (currentEvalFileName != eval_file)
{ {
std::string msg1 = "Network evaluation parameters compatible with the engine must be available."; std::string msg1 =
"Network evaluation parameters compatible with the engine must be available.";
std::string msg2 = "The network file " + eval_file + " was not loaded successfully."; std::string msg2 = "The network file " + eval_file + " was not loaded successfully.";
std::string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; std::string msg3 =
std::string msg4 = "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + std::string(EvalFileDefaultName); "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file.";
std::string msg4 =
"The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/"
+ std::string(EvalFileDefaultName);
std::string msg5 = "The engine will be terminated now."; std::string msg5 = "The engine will be terminated now.";
sync_cout << "info string ERROR: " << msg1 << sync_endl; sync_cout << "info string ERROR: " << msg1 << sync_endl;
@ -157,8 +167,7 @@ Value Eval::evaluate(const Position& pos) {
int shuffling = pos.rule50_count(); int shuffling = pos.rule50_count();
int simpleEval = simple_eval(pos, stm) + (int(pos.key() & 7) - 3); int simpleEval = simple_eval(pos, stm) + (int(pos.key() & 7) - 3);
bool lazy = abs(simpleEval) >= RookValue + KnightValue bool lazy = abs(simpleEval) >= RookValue + KnightValue + 16 * shuffling * shuffling
+ 16 * shuffling * shuffling
+ abs(pos.this_thread()->bestValue) + abs(pos.this_thread()->bestValue)
+ abs(pos.this_thread()->rootSimpleEval); + abs(pos.this_thread()->rootSimpleEval);
@ -176,8 +185,7 @@ Value Eval::evaluate(const Position& pos) {
nnue -= nnue * (nnueComplexity + abs(simpleEval - nnue)) / 32768; nnue -= nnue * (nnueComplexity + abs(simpleEval - nnue)) / 32768;
int npm = pos.non_pawn_material() / 64; int npm = pos.non_pawn_material() / 64;
v = ( nnue * (915 + npm + 9 * pos.count<PAWN>()) v = (nnue * (915 + npm + 9 * pos.count<PAWN>()) + optimism * (154 + npm)) / 1024;
+ optimism * (154 + npm )) / 1024;
} }
// Damp down the evaluation linearly when shuffling // Damp down the evaluation linearly when shuffling

View file

@ -35,7 +35,8 @@
// first to define the corresponding function pointers. // first to define the corresponding function pointers.
extern "C" { extern "C" {
using fun1_t = bool (*)(LOGICAL_PROCESSOR_RELATIONSHIP, using fun1_t = bool (*)(LOGICAL_PROCESSOR_RELATIONSHIP,
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, PDWORD); PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX,
PDWORD);
using fun2_t = bool (*)(USHORT, PGROUP_AFFINITY); using fun2_t = bool (*)(USHORT, PGROUP_AFFINITY);
using fun3_t = bool (*)(HANDLE, CONST GROUP_AFFINITY*, PGROUP_AFFINITY); using fun3_t = bool (*)(HANDLE, CONST GROUP_AFFINITY*, PGROUP_AFFINITY);
using fun4_t = bool (*)(USHORT, PGROUP_AFFINITY, USHORT, PUSHORT); using fun4_t = bool (*)(USHORT, PGROUP_AFFINITY, USHORT, PUSHORT);
@ -62,7 +63,9 @@ using fun8_t = bool(*)(HANDLE, BOOL, PTOKEN_PRIVILEGES, DWORD, PTOKEN_PRIVILEGES
#include <sys/mman.h> #include <sys/mman.h>
#endif #endif
#if defined(__APPLE__) || defined(__ANDROID__) || defined(__OpenBSD__) || (defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC) && !defined(_WIN32)) || defined(__e2k__) #if defined(__APPLE__) || defined(__ANDROID__) || defined(__OpenBSD__) \
|| (defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC) && !defined(_WIN32)) \
|| defined(__e2k__)
#define POSIXALIGNEDALLOC #define POSIXALIGNEDALLOC
#include <stdlib.h> #include <stdlib.h>
#endif #endif
@ -82,7 +85,9 @@ constexpr std::string_view version = "dev";
struct Tie: public std::streambuf { // MSVC requires split streambuf for cin and cout struct Tie: public std::streambuf { // MSVC requires split streambuf for cin and cout
Tie(std::streambuf* b, std::streambuf* l) : buf(b), logBuf(l) {} Tie(std::streambuf* b, std::streambuf* l) :
buf(b),
logBuf(l) {}
int sync() override { return logBuf->pubsync(), buf->pubsync(); } int sync() override { return logBuf->pubsync(), buf->pubsync(); }
int overflow(int c) override { return log(buf->sputc(char(c)), "<< "); } int overflow(int c) override { return log(buf->sputc(char(c)), "<< "); }
@ -104,7 +109,9 @@ struct Tie: public std::streambuf { // MSVC requires split streambuf for cin and
class Logger { class Logger {
Logger() : in(std::cin.rdbuf(), file.rdbuf()), out(std::cout.rdbuf(), file.rdbuf()) {} Logger() :
in(std::cin.rdbuf(), file.rdbuf()),
out(std::cout.rdbuf(), file.rdbuf()) {}
~Logger() { start(""); } ~Logger() { start(""); }
std::ofstream file; std::ofstream file;
@ -166,7 +173,8 @@ std::string engine_info(bool to_uci) {
std::stringstream date(__DATE__); // From compiler, format is "Sep 21 2008" std::stringstream date(__DATE__); // From compiler, format is "Sep 21 2008"
date >> month >> day >> year; date >> month >> day >> year;
ss << year << std::setw(2) << std::setfill('0') << (1 + months.find(month) / 4) << std::setw(2) << std::setfill('0') << day; ss << year << std::setw(2) << std::setfill('0') << (1 + months.find(month) / 4)
<< std::setw(2) << std::setfill('0') << day;
#endif #endif
ss << "-"; ss << "-";
@ -178,8 +186,7 @@ std::string engine_info(bool to_uci) {
#endif #endif
} }
ss << (to_uci ? "\nid author ": " by ") ss << (to_uci ? "\nid author " : " by ") << "the Stockfish developers (see AUTHORS file)";
<< "the Stockfish developers (see AUTHORS file)";
return ss.str(); return ss.str();
} }
@ -189,7 +196,8 @@ std::string engine_info(bool to_uci) {
std::string compiler_info() { std::string compiler_info() {
#define make_version_string(major, minor, patch) stringify(major) "." stringify(minor) "." stringify(patch) #define make_version_string(major, minor, patch) \
stringify(major) "." stringify(minor) "." stringify(patch)
// Predefined macros hell: // Predefined macros hell:
// //
@ -222,9 +230,7 @@ std::string compiler_info() {
compiler += "MCST LCC "; compiler += "MCST LCC ";
compiler += "(version "; compiler += "(version ";
compiler += std::to_string(__LCC__ / 100); compiler += std::to_string(__LCC__ / 100);
dot_ver2(__LCC__ % 100) dot_ver2(__LCC__ % 100) dot_ver2(__LCC_MINOR__) compiler += ")";
dot_ver2(__LCC_MINOR__)
compiler += ")";
#elif __GNUC__ #elif __GNUC__
compiler += "g++ (GNUC) "; compiler += "g++ (GNUC) ";
compiler += make_version_string(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__); compiler += make_version_string(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__);
@ -362,26 +368,20 @@ void dbg_print() {
for (int i = 0; i < MaxDebugSlots; ++i) for (int i = 0; i < MaxDebugSlots; ++i)
if ((n = hit[i][0])) if ((n = hit[i][0]))
std::cerr << "Hit #" << i std::cerr << "Hit #" << i << ": Total " << n << " Hits " << hit[i][1]
<< ": Total " << n << " Hits " << hit[i][1] << " Hit Rate (%) " << 100.0 * E(hit[i][1]) << std::endl;
<< " Hit Rate (%) " << 100.0 * E(hit[i][1])
<< std::endl;
for (int i = 0; i < MaxDebugSlots; ++i) for (int i = 0; i < MaxDebugSlots; ++i)
if ((n = mean[i][0])) if ((n = mean[i][0]))
{ {
std::cerr << "Mean #" << i std::cerr << "Mean #" << i << ": Total " << n << " Mean " << E(mean[i][1]) << std::endl;
<< ": Total " << n << " Mean " << E(mean[i][1])
<< std::endl;
} }
for (int i = 0; i < MaxDebugSlots; ++i) for (int i = 0; i < MaxDebugSlots; ++i)
if ((n = stdev[i][0])) if ((n = stdev[i][0]))
{ {
double r = sqrt(E(stdev[i][2]) - sqr(E(stdev[i][1]))); double r = sqrt(E(stdev[i][2]) - sqr(E(stdev[i][1])));
std::cerr << "Stdev #" << i std::cerr << "Stdev #" << i << ": Total " << n << " Stdev " << r << std::endl;
<< ": Total " << n << " Stdev " << r
<< std::endl;
} }
for (int i = 0; i < MaxDebugSlots; ++i) for (int i = 0; i < MaxDebugSlots; ++i)
@ -390,9 +390,7 @@ void dbg_print() {
double r = (E(correl[i][5]) - E(correl[i][1]) * E(correl[i][3])) double r = (E(correl[i][5]) - E(correl[i][1]) * E(correl[i][3]))
/ (sqrt(E(correl[i][2]) - sqr(E(correl[i][1]))) / (sqrt(E(correl[i][2]) - sqr(E(correl[i][1])))
* sqrt(E(correl[i][4]) - sqr(E(correl[i][3])))); * sqrt(E(correl[i][4]) - sqr(E(correl[i][3]))));
std::cerr << "Correl. #" << i std::cerr << "Correl. #" << i << ": Total " << n << " Coefficient " << r << std::endl;
<< ": Total " << n << " Coefficient " << r
<< std::endl;
} }
} }
@ -524,13 +522,13 @@ static void* aligned_large_pages_alloc_windows([[maybe_unused]] size_t allocSize
// Try to enable SeLockMemoryPrivilege. Note that even if AdjustTokenPrivileges() succeeds, // Try to enable SeLockMemoryPrivilege. Note that even if AdjustTokenPrivileges() succeeds,
// we still need to query GetLastError() to ensure that the privileges were actually obtained. // we still need to query GetLastError() to ensure that the privileges were actually obtained.
if (fun8( // AdjustTokenPrivileges() if (fun8( // AdjustTokenPrivileges()
hProcessToken, FALSE, &tp, sizeof(TOKEN_PRIVILEGES), &prevTp, &prevTpLen) && hProcessToken, FALSE, &tp, sizeof(TOKEN_PRIVILEGES), &prevTp, &prevTpLen)
GetLastError() == ERROR_SUCCESS) && GetLastError() == ERROR_SUCCESS)
{ {
// Round up size to full pages and allocate // Round up size to full pages and allocate
allocSize = (allocSize + largePageSize - 1) & ~size_t(largePageSize - 1); allocSize = (allocSize + largePageSize - 1) & ~size_t(largePageSize - 1);
mem = VirtualAlloc( mem = VirtualAlloc(nullptr, allocSize, MEM_RESERVE | MEM_COMMIT | MEM_LARGE_PAGES,
nullptr, allocSize, MEM_RESERVE | MEM_COMMIT | MEM_LARGE_PAGES, PAGE_READWRITE); PAGE_READWRITE);
// Privilege no longer needed, restore previous state // Privilege no longer needed, restore previous state
fun8( // AdjustTokenPrivileges () fun8( // AdjustTokenPrivileges ()
@ -588,8 +586,7 @@ void aligned_large_pages_free(void* mem) {
if (mem && !VirtualFree(mem, 0, MEM_RELEASE)) if (mem && !VirtualFree(mem, 0, MEM_RELEASE))
{ {
DWORD err = GetLastError(); DWORD err = GetLastError();
std::cerr << "Failed to free large page memory. Error code: 0x" std::cerr << "Failed to free large page memory. Error code: 0x" << std::hex << err
<< std::hex << err
<< std::dec << std::endl; << std::dec << std::endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
@ -597,9 +594,7 @@ void aligned_large_pages_free(void* mem) {
#else #else
void aligned_large_pages_free(void *mem) { void aligned_large_pages_free(void* mem) { std_aligned_free(mem); }
std_aligned_free(mem);
}
#endif #endif
@ -718,7 +713,8 @@ void bindThisThread(size_t idx) {
elements = fun5(); // GetMaximumProcessorGroupCount elements = fun5(); // GetMaximumProcessorGroupCount
GROUP_AFFINITY* affinity = (GROUP_AFFINITY*) malloc(elements * sizeof(GROUP_AFFINITY)); GROUP_AFFINITY* affinity = (GROUP_AFFINITY*) malloc(elements * sizeof(GROUP_AFFINITY));
if (fun4(node, affinity, elements, &returnedElements)) // GetNumaNodeProcessorMask2 if (fun4(node, affinity, elements, &returnedElements)) // GetNumaNodeProcessorMask2
fun3(GetCurrentThread(), &affinity[idx % returnedElements], nullptr); // SetThreadGroupAffinity fun3(GetCurrentThread(), &affinity[idx % returnedElements],
nullptr); // SetThreadGroupAffinity
free(affinity); free(affinity);
} }
} }

View file

@ -37,7 +37,8 @@ void prefetch(void* addr);
void start_logger(const std::string& fname); void start_logger(const std::string& fname);
void* std_aligned_alloc(size_t alignment, size_t size); void* std_aligned_alloc(size_t alignment, size_t size);
void std_aligned_free(void* ptr); void std_aligned_free(void* ptr);
void* aligned_large_pages_alloc(size_t size); // memory aligned by page size, min alignment: 4096 bytes void* aligned_large_pages_alloc(
size_t size); // memory aligned by page size, min alignment: 4096 bytes
void aligned_large_pages_free(void* mem); // nop if mem == nullptr void aligned_large_pages_free(void* mem); // nop if mem == nullptr
void dbg_hit_on(bool cond, int slot = 0); void dbg_hit_on(bool cond, int slot = 0);
@ -49,12 +50,16 @@ void dbg_print();
using TimePoint = std::chrono::milliseconds::rep; // A value in milliseconds using TimePoint = std::chrono::milliseconds::rep; // A value in milliseconds
static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits"); static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits");
inline TimePoint now() { inline TimePoint now() {
return std::chrono::duration_cast<std::chrono::milliseconds> return std::chrono::duration_cast<std::chrono::milliseconds>(
(std::chrono::steady_clock::now().time_since_epoch()).count(); std::chrono::steady_clock::now().time_since_epoch())
.count();
} }
enum SyncCout { IO_LOCK, IO_UNLOCK }; enum SyncCout {
IO_LOCK,
IO_UNLOCK
};
std::ostream& operator<<(std::ostream&, SyncCout); std::ostream& operator<<(std::ostream&, SyncCout);
#define sync_cout std::cout << IO_LOCK #define sync_cout std::cout << IO_LOCK
@ -65,17 +70,20 @@ std::ostream& operator<<(std::ostream&, SyncCout);
// ptr must point to an array of size at least `sizeof(T) * N + alignment` bytes, // ptr must point to an array of size at least `sizeof(T) * N + alignment` bytes,
// where N is the number of elements in the array. // where N is the number of elements in the array.
template<uintptr_t Alignment, typename T> template<uintptr_t Alignment, typename T>
T* align_ptr_up(T* ptr) T* align_ptr_up(T* ptr) {
{
static_assert(alignof(T) < Alignment); static_assert(alignof(T) < Alignment);
const uintptr_t ptrint = reinterpret_cast<uintptr_t>(reinterpret_cast<char*>(ptr)); const uintptr_t ptrint = reinterpret_cast<uintptr_t>(reinterpret_cast<char*>(ptr));
return reinterpret_cast<T*>(reinterpret_cast<char*>((ptrint + (Alignment - 1)) / Alignment * Alignment)); return reinterpret_cast<T*>(
reinterpret_cast<char*>((ptrint + (Alignment - 1)) / Alignment * Alignment));
} }
// IsLittleEndian : true if and only if the binary is compiled on a little-endian machine // IsLittleEndian : true if and only if the binary is compiled on a little-endian machine
static inline const union { uint32_t i; char c[4]; } Le = { 0x01020304 }; static inline const union {
uint32_t i;
char c[4];
} Le = {0x01020304};
static inline const bool IsLittleEndian = (Le.c[0] == 4); static inline const bool IsLittleEndian = (Le.c[0] == 4);
@ -121,14 +129,22 @@ class PRNG {
} }
public: public:
PRNG(uint64_t seed) : s(seed) { assert(seed); } PRNG(uint64_t seed) :
s(seed) {
assert(seed);
}
template<typename T> T rand() { return T(rand64()); } template<typename T>
T rand() {
return T(rand64());
}
// Special generator used to fast init magic numbers. // Special generator used to fast init magic numbers.
// Output values only have 1/8th of their bits set on average. // Output values only have 1/8th of their bits set on average.
template<typename T> T sparse_rand() template<typename T>
{ return T(rand64() & rand64() & rand64()); } T sparse_rand() {
return T(rand64() & rand64() & rand64());
}
}; };
inline uint64_t mul_hi64(uint64_t a, uint64_t b) { inline uint64_t mul_hi64(uint64_t a, uint64_t b) {

View file

@ -64,8 +64,7 @@ namespace {
constexpr Direction UpLeft = (Us == WHITE ? NORTH_WEST : SOUTH_EAST); constexpr Direction UpLeft = (Us == WHITE ? NORTH_WEST : SOUTH_EAST);
const Bitboard emptySquares = ~pos.pieces(); const Bitboard emptySquares = ~pos.pieces();
const Bitboard enemies = Type == EVASIONS ? pos.checkers() const Bitboard enemies = Type == EVASIONS ? pos.checkers() : pos.pieces(Them);
: pos.pieces(Them);
Bitboard pawnsOn7 = pos.pieces(Us, PAWN) & TRank7BB; Bitboard pawnsOn7 = pos.pieces(Us, PAWN) & TRank7BB;
Bitboard pawnsNotOn7 = pos.pieces(Us, PAWN) & ~TRank7BB; Bitboard pawnsNotOn7 = pos.pieces(Us, PAWN) & ~TRank7BB;
@ -273,8 +272,8 @@ ExtMove* generate<LEGAL>(const Position& pos, ExtMove* moveList) {
Square ksq = pos.square<KING>(us); Square ksq = pos.square<KING>(us);
ExtMove* cur = moveList; ExtMove* cur = moveList;
moveList = pos.checkers() ? generate<EVASIONS >(pos, moveList) moveList =
: generate<NON_EVASIONS>(pos, moveList); pos.checkers() ? generate<EVASIONS>(pos, moveList) : generate<NON_EVASIONS>(pos, moveList);
while (cur != moveList) while (cur != moveList)
if (((pinned & from_sq(*cur)) || from_sq(*cur) == ksq || type_of(*cur) == EN_PASSANT) if (((pinned & from_sq(*cur)) || from_sq(*cur) == ksq || type_of(*cur) == EN_PASSANT)
&& !pos.legal(*cur)) && !pos.legal(*cur))

View file

@ -49,9 +49,7 @@ struct ExtMove {
operator float() const = delete; operator float() const = delete;
}; };
inline bool operator<(const ExtMove& f, const ExtMove& s) { inline bool operator<(const ExtMove& f, const ExtMove& s) { return f.value < s.value; }
return f.value < s.value;
}
template<GenType> template<GenType>
ExtMove* generate(const Position& pos, ExtMove* moveList); ExtMove* generate(const Position& pos, ExtMove* moveList);
@ -62,13 +60,12 @@ ExtMove* generate(const Position& pos, ExtMove* moveList);
template<GenType T> template<GenType T>
struct MoveList { struct MoveList {
explicit MoveList(const Position& pos) : last(generate<T>(pos, moveList)) {} explicit MoveList(const Position& pos) :
last(generate<T>(pos, moveList)) {}
const ExtMove* begin() const { return moveList; } const ExtMove* begin() const { return moveList; }
const ExtMove* end() const { return last; } const ExtMove* end() const { return last; }
size_t size() const { return last - moveList; } size_t size() const { return last - moveList; }
bool contains(Move move) const { bool contains(Move move) const { return std::find(begin(), end(), move) != end(); }
return std::find(begin(), end(), move) != end();
}
private: private:
ExtMove moveList[MAX_MOVES], *last; ExtMove moveList[MAX_MOVES], *last;

View file

@ -31,10 +31,31 @@ namespace Stockfish {
namespace { namespace {
enum Stages { enum Stages {
MAIN_TT, CAPTURE_INIT, GOOD_CAPTURE, REFUTATION, QUIET_INIT, QUIET, BAD_CAPTURE, // generate main search moves
EVASION_TT, EVASION_INIT, EVASION, MAIN_TT,
PROBCUT_TT, PROBCUT_INIT, PROBCUT, CAPTURE_INIT,
QSEARCH_TT, QCAPTURE_INIT, QCAPTURE, QCHECK_INIT, QCHECK GOOD_CAPTURE,
REFUTATION,
QUIET_INIT,
QUIET,
BAD_CAPTURE,
// generate evasion moves
EVASION_TT,
EVASION_INIT,
EVASION,
// generate probcut moves
PROBCUT_TT,
PROBCUT_INIT,
PROBCUT,
// generate qsearch moves
QSEARCH_TT,
QCAPTURE_INIT,
QCAPTURE,
QCHECK_INIT,
QCHECK
}; };
// partial_insertion_sort() sorts moves in descending order up to and including // partial_insertion_sort() sorts moves in descending order up to and including
@ -62,44 +83,57 @@ namespace {
// move ordering is at the current node. // move ordering is at the current node.
// MovePicker constructor for the main search // MovePicker constructor for the main search
MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, MovePicker::MovePicker(const Position& p,
Move ttm,
Depth d,
const ButterflyHistory* mh,
const CapturePieceToHistory* cph, const CapturePieceToHistory* cph,
const PieceToHistory** ch, const PieceToHistory** ch,
Move cm, Move cm,
const Move* killers) const Move* killers) :
: pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), pos(p),
ttMove(ttm), refutations{{killers[0], 0}, {killers[1], 0}, {cm, 0}}, depth(d) mainHistory(mh),
{ captureHistory(cph),
continuationHistory(ch),
ttMove(ttm),
refutations{{killers[0], 0}, {killers[1], 0}, {cm, 0}},
depth(d) {
assert(d > 0); assert(d > 0);
stage = (pos.checkers() ? EVASION_TT : MAIN_TT) + stage = (pos.checkers() ? EVASION_TT : MAIN_TT) + !(ttm && pos.pseudo_legal(ttm));
!(ttm && pos.pseudo_legal(ttm));
} }
// MovePicker constructor for quiescence search // MovePicker constructor for quiescence search
MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, MovePicker::MovePicker(const Position& p,
Move ttm,
Depth d,
const ButterflyHistory* mh,
const CapturePieceToHistory* cph, const CapturePieceToHistory* cph,
const PieceToHistory** ch, const PieceToHistory** ch,
Square rs) Square rs) :
: pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), ttMove(ttm), recaptureSquare(rs), depth(d) pos(p),
{ mainHistory(mh),
captureHistory(cph),
continuationHistory(ch),
ttMove(ttm),
recaptureSquare(rs),
depth(d) {
assert(d <= 0); assert(d <= 0);
stage = (pos.checkers() ? EVASION_TT : QSEARCH_TT) + stage = (pos.checkers() ? EVASION_TT : QSEARCH_TT) + !(ttm && pos.pseudo_legal(ttm));
!( ttm
&& pos.pseudo_legal(ttm));
} }
// MovePicker constructor for ProbCut: we generate captures with SEE greater // MovePicker constructor for ProbCut: we generate captures with SEE greater
// than or equal to the given threshold. // than or equal to the given threshold.
MovePicker::MovePicker(const Position& p, Move ttm, Value th, const CapturePieceToHistory* cph) MovePicker::MovePicker(const Position& p, Move ttm, Value th, const CapturePieceToHistory* cph) :
: pos(p), captureHistory(cph), ttMove(ttm), threshold(th) pos(p),
{ captureHistory(cph),
ttMove(ttm),
threshold(th) {
assert(!pos.checkers()); assert(!pos.checkers());
stage = PROBCUT_TT + !(ttm && pos.capture_stage(ttm) stage = PROBCUT_TT
&& pos.pseudo_legal(ttm) + !(ttm && pos.capture_stage(ttm) && pos.pseudo_legal(ttm) && pos.see_ge(ttm, threshold));
&& pos.see_ge(ttm, threshold));
} }
// MovePicker::score() assigns a numerical value to each move in a list, used // MovePicker::score() assigns a numerical value to each move in a list, used
@ -110,13 +144,15 @@ void MovePicker::score() {
static_assert(Type == CAPTURES || Type == QUIETS || Type == EVASIONS, "Wrong type"); static_assert(Type == CAPTURES || Type == QUIETS || Type == EVASIONS, "Wrong type");
[[maybe_unused]] Bitboard threatenedByPawn, threatenedByMinor, threatenedByRook, threatenedPieces; [[maybe_unused]] Bitboard threatenedByPawn, threatenedByMinor, threatenedByRook,
threatenedPieces;
if constexpr (Type == QUIETS) if constexpr (Type == QUIETS)
{ {
Color us = pos.side_to_move(); Color us = pos.side_to_move();
threatenedByPawn = pos.attacks_by<PAWN>(~us); threatenedByPawn = pos.attacks_by<PAWN>(~us);
threatenedByMinor = pos.attacks_by<KNIGHT>(~us) | pos.attacks_by<BISHOP>(~us) | threatenedByPawn; threatenedByMinor =
pos.attacks_by<KNIGHT>(~us) | pos.attacks_by<BISHOP>(~us) | threatenedByPawn;
threatenedByRook = pos.attacks_by<ROOK>(~us) | threatenedByMinor; threatenedByRook = pos.attacks_by<ROOK>(~us) | threatenedByMinor;
// Pieces threatened by pieces of lesser material value // Pieces threatened by pieces of lesser material value
@ -127,8 +163,10 @@ void MovePicker::score() {
for (auto& m : *this) for (auto& m : *this)
if constexpr (Type == CAPTURES) if constexpr (Type == CAPTURES)
m.value = (7 * int(PieceValue[pos.piece_on(to_sq(m))]) m.value =
+ (*captureHistory)[pos.moved_piece(m)][to_sq(m)][type_of(pos.piece_on(to_sq(m)))]) / 16; (7 * int(PieceValue[pos.piece_on(to_sq(m))])
+ (*captureHistory)[pos.moved_piece(m)][to_sq(m)][type_of(pos.piece_on(to_sq(m)))])
/ 16;
else if constexpr (Type == QUIETS) else if constexpr (Type == QUIETS)
{ {
@ -149,16 +187,15 @@ void MovePicker::score() {
m.value += bool(pos.check_squares(pt) & to) * 16384; m.value += bool(pos.check_squares(pt) & to) * 16384;
// bonus for escaping from capture // bonus for escaping from capture
m.value += threatenedPieces & from ? m.value += threatenedPieces & from ? (pt == QUEEN && !(to & threatenedByRook) ? 50000
(pt == QUEEN && !(to & threatenedByRook) ? 50000
: pt == ROOK && !(to & threatenedByMinor) ? 25000 : pt == ROOK && !(to & threatenedByMinor) ? 25000
: !(to & threatenedByPawn) ? 15000 : !(to & threatenedByPawn) ? 15000
: 0) : 0)
: 0; : 0;
// malus for putting piece en prise // malus for putting piece en prise
m.value -= !(threatenedPieces & from) ? m.value -= !(threatenedPieces & from)
(pt == QUEEN ? bool(to & threatenedByRook) * 50000 ? (pt == QUEEN ? bool(to & threatenedByRook) * 50000
+ bool(to & threatenedByMinor) * 10000 + bool(to & threatenedByMinor) * 10000
+ bool(to & threatenedByPawn) * 20000 + bool(to & threatenedByPawn) * 20000
: pt == ROOK ? bool(to & threatenedByMinor) * 25000 : pt == ROOK ? bool(to & threatenedByMinor) * 25000
@ -171,8 +208,7 @@ void MovePicker::score() {
else // Type == EVASIONS else // Type == EVASIONS
{ {
if (pos.capture_stage(m)) if (pos.capture_stage(m))
m.value = PieceValue[pos.piece_on(to_sq(m))] m.value = PieceValue[pos.piece_on(to_sq(m))] - Value(type_of(pos.moved_piece(m)))
- Value(type_of(pos.moved_piece(m)))
+ (1 << 28); + (1 << 28);
else else
m.value = (*mainHistory)[pos.side_to_move()][from_to(m)] m.value = (*mainHistory)[pos.side_to_move()][from_to(m)]
@ -204,7 +240,8 @@ Move MovePicker::select(Pred filter) {
Move MovePicker::next_move(bool skipQuiets) { Move MovePicker::next_move(bool skipQuiets) {
top: top:
switch (stage) { switch (stage)
{
case MAIN_TT : case MAIN_TT :
case EVASION_TT : case EVASION_TT :
@ -226,9 +263,12 @@ top:
case GOOD_CAPTURE : case GOOD_CAPTURE :
if (select<Next>([&]() { if (select<Next>([&]() {
return pos.see_ge(*cur, Value(-cur->value)) ? return pos.see_ge(*cur, Value(-cur->value))
?
// Move losing capture to endBadCaptures to be tried later // Move losing capture to endBadCaptures to be tried later
true : (*endBadCaptures++ = *cur, false); })) true
: (*endBadCaptures++ = *cur, false);
}))
return *(cur - 1); return *(cur - 1);
// Prepare the pointers to loop over the refutations array // Prepare the pointers to loop over the refutations array
@ -244,9 +284,9 @@ top:
[[fallthrough]]; [[fallthrough]];
case REFUTATION : case REFUTATION :
if (select<Next>([&](){ return *cur != MOVE_NONE if (select<Next>([&]() {
&& !pos.capture_stage(*cur) return *cur != MOVE_NONE && !pos.capture_stage(*cur) && pos.pseudo_legal(*cur);
&& pos.pseudo_legal(*cur); })) }))
return *(cur - 1); return *(cur - 1);
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
@ -265,10 +305,10 @@ top:
[[fallthrough]]; [[fallthrough]];
case QUIET : case QUIET :
if ( !skipQuiets if (!skipQuiets && select<Next>([&]() {
&& select<Next>([&](){return *cur != refutations[0].move return *cur != refutations[0].move && *cur != refutations[1].move
&& *cur != refutations[1].move && *cur != refutations[2].move;
&& *cur != refutations[2].move;})) }))
return *(cur - 1); return *(cur - 1);
// Prepare the pointers to loop over the bad captures // Prepare the pointers to loop over the bad captures
@ -296,8 +336,8 @@ top:
return select<Next>([&]() { return pos.see_ge(*cur, threshold); }); return select<Next>([&]() { return pos.see_ge(*cur, threshold); });
case QCAPTURE : case QCAPTURE :
if (select<Next>([&](){ return depth > DEPTH_QS_RECAPTURES if (select<Next>(
|| to_sq(*cur) == recaptureSquare; })) [&]() { return depth > DEPTH_QS_RECAPTURES || to_sq(*cur) == recaptureSquare; }))
return *(cur - 1); return *(cur - 1);
// If we did not find any move and we do not try checks, we have finished // If we did not find any move and we do not try checks, we have finished

View file

@ -63,8 +63,7 @@ public:
// values with the << operator, while the last parameters (Size and Sizes) // values with the << operator, while the last parameters (Size and Sizes)
// encode the dimensions of the array. // encode the dimensions of the array.
template<typename T, int D, int Size, int... Sizes> template<typename T, int D, int Size, int... Sizes>
struct Stats : public std::array<Stats<T, D, Sizes...>, Size> struct Stats: public std::array<Stats<T, D, Sizes...>, Size> {
{
using stats = Stats<T, D, Size, Sizes...>; using stats = Stats<T, D, Size, Sizes...>;
void fill(const T& v) { void fill(const T& v) {
@ -82,8 +81,13 @@ template <typename T, int D, int Size>
struct Stats<T, D, Size>: public std::array<StatsEntry<T, D>, Size> {}; struct Stats<T, D, Size>: public std::array<StatsEntry<T, D>, Size> {};
// In stats table, D=0 means that the template parameter is not used // In stats table, D=0 means that the template parameter is not used
enum StatsParams { NOT_USED = 0 }; enum StatsParams {
enum StatsType { NoCaptures, Captures }; NOT_USED = 0
};
enum StatsType {
NoCaptures,
Captures
};
// ButterflyHistory records how often quiet moves have been successful or // ButterflyHistory records how often quiet moves have been successful or
// unsuccessful during the current search, and is used for reduction and move // unsuccessful during the current search, and is used for reduction and move
@ -117,17 +121,26 @@ using ContinuationHistory = Stats<PieceToHistory, NOT_USED, PIECE_NB, SQUARE_NB>
// likely to get a cut-off first. // likely to get a cut-off first.
class MovePicker { class MovePicker {
enum PickType { Next, Best }; enum PickType {
Next,
Best
};
public: public:
MovePicker(const MovePicker&) = delete; MovePicker(const MovePicker&) = delete;
MovePicker& operator=(const MovePicker&) = delete; MovePicker& operator=(const MovePicker&) = delete;
MovePicker(const Position&, Move, Depth, const ButterflyHistory*, MovePicker(const Position&,
Move,
Depth,
const ButterflyHistory*,
const CapturePieceToHistory*, const CapturePieceToHistory*,
const PieceToHistory**, const PieceToHistory**,
Move, Move,
const Move*); const Move*);
MovePicker(const Position&, Move, Depth, const ButterflyHistory*, MovePicker(const Position&,
Move,
Depth,
const ButterflyHistory*,
const CapturePieceToHistory*, const CapturePieceToHistory*,
const PieceToHistory**, const PieceToHistory**,
Square); Square);
@ -135,8 +148,10 @@ public:
Move next_move(bool skipQuiets = false); Move next_move(bool skipQuiets = false);
private: private:
template<PickType T, typename Pred> Move select(Pred); template<PickType T, typename Pred>
template<GenType> void score(); Move select(Pred);
template<GenType>
void score();
ExtMove* begin() { return cur; } ExtMove* begin() { return cur; }
ExtMove* end() { return endMoves; } ExtMove* end() { return endMoves; }

View file

@ -62,7 +62,8 @@ namespace Stockfish::Eval::NNUE {
template<typename T> template<typename T>
void initialize(LargePagePtr<T>& pointer) { void initialize(LargePagePtr<T>& pointer) {
static_assert(alignof(T) <= 4096, "aligned_large_pages_alloc() may fail for such a big alignment requirement of T"); static_assert(alignof(T) <= 4096,
"aligned_large_pages_alloc() may fail for such a big alignment requirement of T");
pointer.reset(reinterpret_cast<T*>(aligned_large_pages_alloc(sizeof(T)))); pointer.reset(reinterpret_cast<T*>(aligned_large_pages_alloc(sizeof(T))));
std::memset(pointer.get(), 0, sizeof(T)); std::memset(pointer.get(), 0, sizeof(T));
} }
@ -73,7 +74,8 @@ namespace Stockfish::Eval::NNUE {
std::uint32_t header; std::uint32_t header;
header = read_little_endian<std::uint32_t>(stream); header = read_little_endian<std::uint32_t>(stream);
if (!stream || header != T::get_hash_value()) return false; if (!stream || header != T::get_hash_value())
return false;
return reference.read_parameters(stream); return reference.read_parameters(stream);
} }
@ -97,22 +99,21 @@ namespace Stockfish::Eval::NNUE {
} }
// Read network header // Read network header
static bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc) static bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc) {
{
std::uint32_t version, size; std::uint32_t version, size;
version = read_little_endian<std::uint32_t>(stream); version = read_little_endian<std::uint32_t>(stream);
*hashValue = read_little_endian<std::uint32_t>(stream); *hashValue = read_little_endian<std::uint32_t>(stream);
size = read_little_endian<std::uint32_t>(stream); size = read_little_endian<std::uint32_t>(stream);
if (!stream || version != Version) return false; if (!stream || version != Version)
return false;
desc->resize(size); desc->resize(size);
stream.read(&(*desc)[0], size); stream.read(&(*desc)[0], size);
return !stream.fail(); return !stream.fail();
} }
// Write network header // Write network header
static bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc) static bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc) {
{
write_little_endian<std::uint32_t>(stream, Version); write_little_endian<std::uint32_t>(stream, Version);
write_little_endian<std::uint32_t>(stream, hashValue); write_little_endian<std::uint32_t>(stream, hashValue);
write_little_endian<std::uint32_t>(stream, (std::uint32_t) desc.size()); write_little_endian<std::uint32_t>(stream, (std::uint32_t) desc.size());
@ -124,21 +125,28 @@ namespace Stockfish::Eval::NNUE {
static bool read_parameters(std::istream& stream) { static bool read_parameters(std::istream& stream) {
std::uint32_t hashValue; std::uint32_t hashValue;
if (!read_header(stream, &hashValue, &netDescription)) return false; if (!read_header(stream, &hashValue, &netDescription))
if (hashValue != HashValue) return false; return false;
if (!Detail::read_parameters(stream, *featureTransformer)) return false; if (hashValue != HashValue)
return false;
if (!Detail::read_parameters(stream, *featureTransformer))
return false;
for (std::size_t i = 0; i < LayerStacks; ++i) for (std::size_t i = 0; i < LayerStacks; ++i)
if (!Detail::read_parameters(stream, *(network[i]))) return false; if (!Detail::read_parameters(stream, *(network[i])))
return false;
return stream && stream.peek() == std::ios::traits_type::eof(); return stream && stream.peek() == std::ios::traits_type::eof();
} }
// Write network parameters // Write network parameters
static bool write_parameters(std::ostream& stream) { static bool write_parameters(std::ostream& stream) {
if (!write_header(stream, HashValue, netDescription)) return false; if (!write_header(stream, HashValue, netDescription))
if (!Detail::write_parameters(stream, *featureTransformer)) return false; return false;
if (!Detail::write_parameters(stream, *featureTransformer))
return false;
for (std::size_t i = 0; i < LayerStacks; ++i) for (std::size_t i = 0; i < LayerStacks; ++i)
if (!Detail::write_parameters(stream, *(network[i]))) return false; if (!Detail::write_parameters(stream, *(network[i])))
return false;
return bool(stream); return bool(stream);
} }
@ -156,13 +164,13 @@ namespace Stockfish::Eval::NNUE {
constexpr int delta = 24; constexpr int delta = 24;
#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN)
TransformedFeatureType transformedFeaturesUnaligned[ TransformedFeatureType
FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; transformedFeaturesUnaligned[FeatureTransformer::BufferSize
+ alignment / sizeof(TransformedFeatureType)];
auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]); auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]);
#else #else
alignas(alignment) alignas(alignment) TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
#endif #endif
ASSERT_ALIGNED(transformedFeatures, alignment); ASSERT_ALIGNED(transformedFeatures, alignment);
@ -176,7 +184,8 @@ namespace Stockfish::Eval::NNUE {
// Give more value to positional evaluation when adjusted flag is set // Give more value to positional evaluation when adjusted flag is set
if (adjusted) if (adjusted)
return static_cast<Value>(((1024 - delta) * psqt + (1024 + delta) * positional) / (1024 * OutputScale)); return static_cast<Value>(((1024 - delta) * psqt + (1024 + delta) * positional)
/ (1024 * OutputScale));
else else
return static_cast<Value>((psqt + positional) / OutputScale); return static_cast<Value>((psqt + positional) / OutputScale);
} }
@ -196,20 +205,21 @@ namespace Stockfish::Eval::NNUE {
constexpr uint64_t alignment = CacheLineSize; constexpr uint64_t alignment = CacheLineSize;
#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN)
TransformedFeatureType transformedFeaturesUnaligned[ TransformedFeatureType
FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; transformedFeaturesUnaligned[FeatureTransformer::BufferSize
+ alignment / sizeof(TransformedFeatureType)];
auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]); auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]);
#else #else
alignas(alignment) alignas(alignment) TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
#endif #endif
ASSERT_ALIGNED(transformedFeatures, alignment); ASSERT_ALIGNED(transformedFeatures, alignment);
NnueEvalTrace t{}; NnueEvalTrace t{};
t.correctBucket = (pos.count<ALL_PIECES>() - 1) / 4; t.correctBucket = (pos.count<ALL_PIECES>() - 1) / 4;
for (IndexType bucket = 0; bucket < LayerStacks; ++bucket) { for (IndexType bucket = 0; bucket < LayerStacks; ++bucket)
{
const auto materialist = featureTransformer->transform(pos, transformedFeatures, bucket); const auto materialist = featureTransformer->transform(pos, transformedFeatures, bucket);
const auto positional = network[bucket]->propagate(transformedFeatures); const auto positional = network[bucket]->propagate(transformedFeatures);
@ -232,23 +242,29 @@ namespace Stockfish::Eval::NNUE {
int cp = std::abs(UCI::to_cp(v)); int cp = std::abs(UCI::to_cp(v));
if (cp >= 10000) if (cp >= 10000)
{ {
buffer[1] = '0' + cp / 10000; cp %= 10000; buffer[1] = '0' + cp / 10000;
buffer[2] = '0' + cp / 1000; cp %= 1000; cp %= 10000;
buffer[2] = '0' + cp / 1000;
cp %= 1000;
buffer[3] = '0' + cp / 100; buffer[3] = '0' + cp / 100;
buffer[4] = ' '; buffer[4] = ' ';
} }
else if (cp >= 1000) else if (cp >= 1000)
{ {
buffer[1] = '0' + cp / 1000; cp %= 1000; buffer[1] = '0' + cp / 1000;
buffer[2] = '0' + cp / 100; cp %= 100; cp %= 1000;
buffer[2] = '0' + cp / 100;
cp %= 100;
buffer[3] = '.'; buffer[3] = '.';
buffer[4] = '0' + cp / 10; buffer[4] = '0' + cp / 10;
} }
else else
{ {
buffer[1] = '0' + cp / 100; cp %= 100; buffer[1] = '0' + cp / 100;
cp %= 100;
buffer[2] = '.'; buffer[2] = '.';
buffer[3] = '0' + cp / 10; cp %= 10; buffer[3] = '0' + cp / 10;
cp %= 10;
buffer[4] = '0' + cp / 1; buffer[4] = '0' + cp / 1;
} }
} }
@ -259,11 +275,10 @@ namespace Stockfish::Eval::NNUE {
const double pawns = std::abs(0.01 * UCI::to_cp(v)); const double pawns = std::abs(0.01 * UCI::to_cp(v));
stream << (v < 0 ? '-' : v > 0 ? '+' : ' ') stream << (v < 0 ? '-'
<< std::setiosflags(std::ios::fixed) : v > 0 ? '+'
<< std::setw(6) : ' ')
<< std::setprecision(2) << std::setiosflags(std::ios::fixed) << std::setw(6) << std::setprecision(2) << pawns;
<< pawns;
} }
@ -280,7 +295,6 @@ namespace Stockfish::Eval::NNUE {
// A lambda to output one box of the board // A lambda to output one box of the board
auto writeSquare = [&board](File file, Rank rank, Piece pc, Value value) { auto writeSquare = [&board](File file, Rank rank, Piece pc, Value value) {
const int x = int(file) * 8; const int x = int(file) * 8;
const int y = (7 - int(rank)) * 3; const int y = (7 - int(rank)) * 3;
for (int i = 1; i < 8; ++i) for (int i = 1; i < 8; ++i)
@ -343,9 +357,15 @@ namespace Stockfish::Eval::NNUE {
for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket) for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket)
{ {
ss << "| " << bucket << " "; ss << "| " << bucket << " ";
ss << " | "; format_cp_aligned_dot(t.psqt[bucket], ss); ss << " " ss << " | ";
<< " | "; format_cp_aligned_dot(t.positional[bucket], ss); ss << " " format_cp_aligned_dot(t.psqt[bucket], ss);
<< " | "; format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss); ss << " " ss << " "
<< " | ";
format_cp_aligned_dot(t.positional[bucket], ss);
ss << " "
<< " | ";
format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss);
ss << " "
<< " |"; << " |";
if (bucket == t.correctBucket) if (bucket == t.correctBucket)
ss << " <-- this bucket is used"; ss << " <-- this bucket is used";
@ -387,7 +407,8 @@ namespace Stockfish::Eval::NNUE {
{ {
if (currentEvalFileName != EvalFileDefaultName) if (currentEvalFileName != EvalFileDefaultName)
{ {
msg = "Failed to export a net. A non-embedded net can only be saved if the filename is specified"; msg =
"Failed to export a net. A non-embedded net can only be saved if the filename is specified";
sync_cout << msg << sync_endl; sync_cout << msg << sync_endl;
return false; return false;
@ -398,8 +419,7 @@ namespace Stockfish::Eval::NNUE {
std::ofstream stream(actualFilename, std::ios_base::binary); std::ofstream stream(actualFilename, std::ios_base::binary);
bool saved = save_eval(stream); bool saved = save_eval(stream);
msg = saved ? "Network saved successfully to " + actualFilename msg = saved ? "Network saved successfully to " + actualFilename : "Failed to export a net";
: "Failed to export a net";
sync_cout << msg << sync_endl; sync_cout << msg << sync_endl;
return saved; return saved;

View file

@ -30,15 +30,13 @@ namespace Stockfish::Eval::NNUE::Features {
// Index of a feature for a given king position and another piece on some square // Index of a feature for a given king position and another piece on some square
template<Color Perspective> template<Color Perspective>
inline IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq) { inline IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq) {
return IndexType((int(s) ^ OrientTBL[Perspective][ksq]) + PieceSquareIndex[Perspective][pc] + KingBuckets[Perspective][ksq]); return IndexType((int(s) ^ OrientTBL[Perspective][ksq]) + PieceSquareIndex[Perspective][pc]
+ KingBuckets[Perspective][ksq]);
} }
// Get a list of indices for active features // Get a list of indices for active features
template<Color Perspective> template<Color Perspective>
void HalfKAv2_hm::append_active_indices( void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active) {
const Position& pos,
IndexList& active
) {
Square ksq = pos.square<KING>(Perspective); Square ksq = pos.square<KING>(Perspective);
Bitboard bb = pos.pieces(); Bitboard bb = pos.pieces();
while (bb) while (bb)
@ -54,13 +52,12 @@ namespace Stockfish::Eval::NNUE::Features {
// append_changed_indices() : get a list of indices for recently changed features // append_changed_indices() : get a list of indices for recently changed features
template<Color Perspective> template<Color Perspective>
void HalfKAv2_hm::append_changed_indices( void HalfKAv2_hm::append_changed_indices(Square ksq,
Square ksq,
const DirtyPiece& dp, const DirtyPiece& dp,
IndexList& removed, IndexList& removed,
IndexList& added IndexList& added) {
) { for (int i = 0; i < dp.dirty_num; ++i)
for (int i = 0; i < dp.dirty_num; ++i) { {
if (dp.from[i] != SQ_NONE) if (dp.from[i] != SQ_NONE)
removed.push_back(make_index<Perspective>(dp.from[i], dp.piece[i], ksq)); removed.push_back(make_index<Perspective>(dp.from[i], dp.piece[i], ksq));
if (dp.to[i] != SQ_NONE) if (dp.to[i] != SQ_NONE)
@ -69,16 +66,18 @@ namespace Stockfish::Eval::NNUE::Features {
} }
// Explicit template instantiations // Explicit template instantiations
template void HalfKAv2_hm::append_changed_indices<WHITE>(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added); template void HalfKAv2_hm::append_changed_indices<WHITE>(Square ksq,
template void HalfKAv2_hm::append_changed_indices<BLACK>(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added); const DirtyPiece& dp,
IndexList& removed,
IndexList& added);
template void HalfKAv2_hm::append_changed_indices<BLACK>(Square ksq,
const DirtyPiece& dp,
IndexList& removed,
IndexList& added);
int HalfKAv2_hm::update_cost(const StateInfo* st) { int HalfKAv2_hm::update_cost(const StateInfo* st) { return st->dirtyPiece.dirty_num; }
return st->dirtyPiece.dirty_num;
}
int HalfKAv2_hm::refresh_cost(const Position& pos) { int HalfKAv2_hm::refresh_cost(const Position& pos) { return pos.count<ALL_PIECES>(); }
return pos.count<ALL_PIECES>();
}
bool HalfKAv2_hm::requires_refresh(const StateInfo* st, Color perspective) { bool HalfKAv2_hm::requires_refresh(const StateInfo* st, Color perspective) {
return st->dirtyPiece.piece[0] == make_piece(perspective, KING); return st->dirtyPiece.piece[0] == make_piece(perspective, KING);

View file

@ -61,8 +61,7 @@ namespace Stockfish::Eval::NNUE::Features {
{PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE, {PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE,
PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE}, PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE},
{PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE, {PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE,
PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE } PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE}};
};
// Index of a feature for a given king position and another piece on some square // Index of a feature for a given king position and another piece on some square
template<Color Perspective> template<Color Perspective>
@ -80,6 +79,7 @@ namespace Stockfish::Eval::NNUE::Features {
static_cast<IndexType>(SQUARE_NB) * static_cast<IndexType>(PS_NB) / 2; static_cast<IndexType>(SQUARE_NB) * static_cast<IndexType>(PS_NB) / 2;
#define B(v) (v * PS_NB) #define B(v) (v * PS_NB)
// clang-format off
static constexpr int KingBuckets[COLOR_NB][SQUARE_NB] = { static constexpr int KingBuckets[COLOR_NB][SQUARE_NB] = {
{ B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28), { B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28),
B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24), B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24),
@ -98,8 +98,9 @@ namespace Stockfish::Eval::NNUE::Features {
B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24), B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24),
B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28) } B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28) }
}; };
// clang-format on
#undef B #undef B
// clang-format off
// Orient a square according to perspective (rotates by 180 for black) // Orient a square according to perspective (rotates by 180 for black)
static constexpr int OrientTBL[COLOR_NB][SQUARE_NB] = { static constexpr int OrientTBL[COLOR_NB][SQUARE_NB] = {
{ SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, { SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1,
@ -119,6 +120,7 @@ namespace Stockfish::Eval::NNUE::Features {
SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8,
SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8 } SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8 }
}; };
// clang-format on
// Maximum number of simultaneously active features. // Maximum number of simultaneously active features.
static constexpr IndexType MaxActiveDimensions = 32; static constexpr IndexType MaxActiveDimensions = 32;
@ -126,18 +128,12 @@ namespace Stockfish::Eval::NNUE::Features {
// Get a list of indices for active features // Get a list of indices for active features
template<Color Perspective> template<Color Perspective>
static void append_active_indices( static void append_active_indices(const Position& pos, IndexList& active);
const Position& pos,
IndexList& active);
// Get a list of indices for recently changed features // Get a list of indices for recently changed features
template<Color Perspective> template<Color Perspective>
static void append_changed_indices( static void
Square ksq, append_changed_indices(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added);
const DirtyPiece& dp,
IndexList& removed,
IndexList& added
);
// Returns the cost of updating one perspective, the most costly one. // Returns the cost of updating one perspective, the most costly one.
// Assumes no refresh needed. // Assumes no refresh needed.

View file

@ -43,8 +43,10 @@ namespace Stockfish::Eval::NNUE::Layers {
// Requires the input to be padded to at least 16 values. // Requires the input to be padded to at least 16 values.
#if !defined(USE_SSSE3) #if !defined(USE_SSSE3)
template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions> template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, const std::int32_t* biases, const std::uint8_t* input) static void affine_transform_non_ssse3(std::int32_t* output,
{ const std::int8_t* weights,
const std::int32_t* biases,
const std::uint8_t* input) {
#if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON) #if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON)
#if defined(USE_SSE2) #if defined(USE_SSE2)
// At least a multiple of 16, with SSE2. // At least a multiple of 16, with SSE2.
@ -61,14 +63,16 @@ namespace Stockfish::Eval::NNUE::Layers {
const auto inputVector = reinterpret_cast<const int8x8_t*>(input); const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
#endif #endif
for (IndexType i = 0; i < OutputDimensions; ++i) { for (IndexType i = 0; i < OutputDimensions; ++i)
{
const IndexType offset = i * PaddedInputDimensions; const IndexType offset = i * PaddedInputDimensions;
#if defined(USE_SSE2) #if defined(USE_SSE2)
__m128i sumLo = _mm_cvtsi32_si128(biases[i]); __m128i sumLo = _mm_cvtsi32_si128(biases[i]);
__m128i sumHi = Zeros; __m128i sumHi = Zeros;
const auto row = reinterpret_cast<const __m128i*>(&weights[offset]); const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j) { for (IndexType j = 0; j < NumChunks; ++j)
{
__m128i row_j = _mm_load_si128(&row[j]); __m128i row_j = _mm_load_si128(&row[j]);
__m128i input_j = _mm_load_si128(&inputVector[j]); __m128i input_j = _mm_load_si128(&inputVector[j]);
__m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8); __m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
@ -90,7 +94,8 @@ namespace Stockfish::Eval::NNUE::Layers {
#elif defined(USE_NEON_DOTPROD) #elif defined(USE_NEON_DOTPROD)
int32x4_t sum = {biases[i]}; int32x4_t sum = {biases[i]};
const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]); const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j) { for (IndexType j = 0; j < NumChunks; ++j)
{
sum = vdotq_s32(sum, inputVector[j], row[j]); sum = vdotq_s32(sum, inputVector[j], row[j]);
} }
output[i] = vaddvq_s32(sum); output[i] = vaddvq_s32(sum);
@ -98,7 +103,8 @@ namespace Stockfish::Eval::NNUE::Layers {
#elif defined(USE_NEON) #elif defined(USE_NEON)
int32x4_t sum = {biases[i]}; int32x4_t sum = {biases[i]};
const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]); const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j) { for (IndexType j = 0; j < NumChunks; ++j)
{
int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]); int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);
product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]); product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);
sum = vpadalq_s16(sum, product); sum = vpadalq_s16(sum, product);
@ -112,7 +118,8 @@ namespace Stockfish::Eval::NNUE::Layers {
// Traverse weights in transpose order to take advantage of input sparsity // Traverse weights in transpose order to take advantage of input sparsity
for (IndexType i = 0; i < InputDimensions; ++i) for (IndexType i = 0; i < InputDimensions; ++i)
if (input[i]) { if (input[i])
{
const std::int8_t* w = &weights[i]; const std::int8_t* w = &weights[i];
const int in = input[i]; const int in = input[i];
for (IndexType j = 0; j < OutputDimensions; ++j) for (IndexType j = 0; j < OutputDimensions; ++j)
@ -149,16 +156,12 @@ namespace Stockfish::Eval::NNUE::Layers {
return hashValue; return hashValue;
} }
static constexpr IndexType get_weight_index_scrambled(IndexType i) static constexpr IndexType get_weight_index_scrambled(IndexType i) {
{ return (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4
return + i / PaddedInputDimensions * 4 + i % 4;
(i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
i / PaddedInputDimensions * 4 +
i % 4;
} }
static constexpr IndexType get_weight_index(IndexType i) static constexpr IndexType get_weight_index(IndexType i) {
{
#if defined(USE_SSSE3) #if defined(USE_SSSE3)
return get_weight_index_scrambled(i); return get_weight_index_scrambled(i);
#else #else
@ -185,8 +188,7 @@ namespace Stockfish::Eval::NNUE::Layers {
return !stream.fail(); return !stream.fail();
} }
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if defined(USE_SSSE3) #if defined(USE_SSSE3)
@ -233,8 +235,10 @@ namespace Stockfish::Eval::NNUE::Layers {
{ {
const vec_t in0 = vec_set_32(input32[i + 0]); const vec_t in0 = vec_set_32(input32[i + 0]);
const vec_t in1 = vec_set_32(input32[i + 1]); const vec_t in1 = vec_set_32(input32[i + 1]);
const auto col0 = reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]); const auto col0 =
const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]); reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]);
const auto col1 =
reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
for (IndexType k = 0; k < NumRegs; ++k) for (IndexType k = 0; k < NumRegs; ++k)
vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]); vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]);
} }
@ -248,7 +252,6 @@ namespace Stockfish::Eval::NNUE::Layers {
#undef vec_add_dpbusd_32 #undef vec_add_dpbusd_32
#undef vec_add_dpbusd_32x2 #undef vec_add_dpbusd_32x2
#undef vec_hadd #undef vec_hadd
} }
else if constexpr (OutputDimensions == 1) else if constexpr (OutputDimensions == 1)
{ {
@ -292,14 +295,11 @@ namespace Stockfish::Eval::NNUE::Layers {
#undef vec_add_dpbusd_32 #undef vec_add_dpbusd_32
#undef vec_add_dpbusd_32x2 #undef vec_add_dpbusd_32x2
#undef vec_hadd #undef vec_hadd
} }
#else #else
// Use old implementation for the other architectures. // Use old implementation for the other architectures.
affine_transform_non_ssse3< affine_transform_non_ssse3<InputDimensions, PaddedInputDimensions, OutputDimensions>(
InputDimensions, output, weights, biases, input);
PaddedInputDimensions,
OutputDimensions>(output, weights, biases, input);
#endif #endif
} }

View file

@ -38,7 +38,8 @@
namespace Stockfish::Eval::NNUE::Layers { namespace Stockfish::Eval::NNUE::Layers {
#if (USE_SSSE3 | (USE_NEON >= 8)) #if (USE_SSSE3 | (USE_NEON >= 8))
alignas(CacheLineSize) static inline const std::array<std::array<std::uint16_t, 8>, 256> lookup_indices = [](){ alignas(CacheLineSize) static inline const
std::array<std::array<std::uint16_t, 8>, 256> lookup_indices = []() {
std::array<std::array<std::uint16_t, 8>, 256> v{}; std::array<std::array<std::uint16_t, 8>, 256> v{};
for (unsigned i = 0; i < 256; ++i) for (unsigned i = 0; i < 256; ++i)
{ {
@ -61,11 +62,14 @@ namespace Stockfish::Eval::NNUE::Layers {
#if defined(USE_VNNI) && !defined(USE_AVXVNNI) #if defined(USE_VNNI) && !defined(USE_AVXVNNI)
#define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256()) #define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256())
#else #else
#define vec_nnz(a) _mm256_movemask_ps(_mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256()))) #define vec_nnz(a) \
_mm256_movemask_ps( \
_mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256())))
#endif #endif
#elif defined(USE_SSSE3) #elif defined(USE_SSSE3)
using vec_t = __m128i; using vec_t = __m128i;
#define vec_nnz(a) _mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128()))) #define vec_nnz(a) \
_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128())))
#endif #endif
using vec128_t = __m128i; using vec128_t = __m128i;
#define vec128_zero _mm_setzero_si128() #define vec128_zero _mm_setzero_si128()
@ -107,7 +111,8 @@ namespace Stockfish::Eval::NNUE::Layers {
for (IndexType j = 0; j < OutputsPerChunk; ++j) for (IndexType j = 0; j < OutputsPerChunk; ++j)
{ {
const auto lookup = (nnz >> (j * 8)) & 0xFF; const auto lookup = (nnz >> (j * 8)) & 0xFF;
const auto offsets = vec128_load(reinterpret_cast<const vec128_t*>(&lookup_indices[lookup])); const auto offsets =
vec128_load(reinterpret_cast<const vec128_t*>(&lookup_indices[lookup]));
vec128_storeu(reinterpret_cast<vec128_t*>(out + count), vec128_add(base, offsets)); vec128_storeu(reinterpret_cast<vec128_t*>(out + count), vec128_add(base, offsets));
count += popcount(lookup); count += popcount(lookup);
base = vec128_add(base, increment); base = vec128_add(base, increment);
@ -135,7 +140,8 @@ namespace Stockfish::Eval::NNUE::Layers {
static constexpr IndexType InputDimensions = InDims; static constexpr IndexType InputDimensions = InDims;
static constexpr IndexType OutputDimensions = OutDims; static constexpr IndexType OutputDimensions = OutDims;
static_assert(OutputDimensions % 16 == 0, "Only implemented for OutputDimensions divisible by 16."); static_assert(OutputDimensions % 16 == 0,
"Only implemented for OutputDimensions divisible by 16.");
static constexpr IndexType PaddedInputDimensions = static constexpr IndexType PaddedInputDimensions =
ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth); ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);
@ -159,16 +165,12 @@ namespace Stockfish::Eval::NNUE::Layers {
return hashValue; return hashValue;
} }
static constexpr IndexType get_weight_index_scrambled(IndexType i) static constexpr IndexType get_weight_index_scrambled(IndexType i) {
{ return (i / ChunkSize) % (PaddedInputDimensions / ChunkSize) * OutputDimensions * ChunkSize
return + i / PaddedInputDimensions * ChunkSize + i % ChunkSize;
(i / ChunkSize) % (PaddedInputDimensions / ChunkSize) * OutputDimensions * ChunkSize +
i / PaddedInputDimensions * ChunkSize +
i % ChunkSize;
} }
static constexpr IndexType get_weight_index(IndexType i) static constexpr IndexType get_weight_index(IndexType i) {
{
#if (USE_SSSE3 | (USE_NEON >= 8)) #if (USE_SSSE3 | (USE_NEON >= 8))
return get_weight_index_scrambled(i); return get_weight_index_scrambled(i);
#else #else
@ -195,8 +197,7 @@ namespace Stockfish::Eval::NNUE::Layers {
return !stream.fail(); return !stream.fail();
} }
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if (USE_SSSE3 | (USE_NEON >= 8)) #if (USE_SSSE3 | (USE_NEON >= 8))
#if defined(USE_AVX512) #if defined(USE_AVX512)
@ -246,7 +247,8 @@ namespace Stockfish::Eval::NNUE::Layers {
{ {
const auto i = nnz[j]; const auto i = nnz[j];
const invec_t in = vec_set_32(input32[i]); const invec_t in = vec_set_32(input32[i]);
const auto col = reinterpret_cast<const invec_t*>(&weights[i * OutputDimensions * ChunkSize]); const auto col =
reinterpret_cast<const invec_t*>(&weights[i * OutputDimensions * ChunkSize]);
for (IndexType k = 0; k < NumRegs; ++k) for (IndexType k = 0; k < NumRegs; ++k)
vec_add_dpbusd_32(acc[k], in, col[k]); vec_add_dpbusd_32(acc[k], in, col[k]);
} }
@ -258,10 +260,8 @@ namespace Stockfish::Eval::NNUE::Layers {
#undef vec_add_dpbusd_32 #undef vec_add_dpbusd_32
#else #else
// Use dense implementation for the other architectures. // Use dense implementation for the other architectures.
affine_transform_non_ssse3< affine_transform_non_ssse3<InputDimensions, PaddedInputDimensions, OutputDimensions>(
InputDimensions, output, weights, biases, input);
PaddedInputDimensions,
OutputDimensions>(output, weights, biases, input);
#endif #endif
} }

View file

@ -53,54 +53,56 @@ namespace Stockfish::Eval::NNUE::Layers {
} }
// Read network parameters // Read network parameters
bool read_parameters(std::istream&) { bool read_parameters(std::istream&) { return true; }
return true;
}
// Write network parameters // Write network parameters
bool write_parameters(std::ostream&) const { bool write_parameters(std::ostream&) const { return true; }
return true;
}
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if defined(USE_AVX2) #if defined(USE_AVX2)
if constexpr (InputDimensions % SimdWidth == 0) { if constexpr (InputDimensions % SimdWidth == 0)
{
constexpr IndexType NumChunks = InputDimensions / SimdWidth; constexpr IndexType NumChunks = InputDimensions / SimdWidth;
const __m256i Zero = _mm256_setzero_si256(); const __m256i Zero = _mm256_setzero_si256();
const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
const auto in = reinterpret_cast<const __m256i*>(input); const auto in = reinterpret_cast<const __m256i*>(input);
const auto out = reinterpret_cast<__m256i*>(output); const auto out = reinterpret_cast<__m256i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) { for (IndexType i = 0; i < NumChunks; ++i)
const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32( {
_mm256_load_si256(&in[i * 4 + 0]), const __m256i words0 =
_mm256_load_si256(&in[i * 4 + 1])), WeightScaleBits); _mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 0]),
const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32( _mm256_load_si256(&in[i * 4 + 1])),
_mm256_load_si256(&in[i * 4 + 2]), WeightScaleBits);
_mm256_load_si256(&in[i * 4 + 3])), WeightScaleBits); const __m256i words1 =
_mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8( _mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 2]),
_mm256_packs_epi16(words0, words1), Zero), Offsets)); _mm256_load_si256(&in[i * 4 + 3])),
WeightScaleBits);
_mm256_store_si256(
&out[i], _mm256_permutevar8x32_epi32(
_mm256_max_epi8(_mm256_packs_epi16(words0, words1), Zero), Offsets));
} }
} else { }
else
{
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2); constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
const __m128i Zero = _mm_setzero_si128(); const __m128i Zero = _mm_setzero_si128();
const auto in = reinterpret_cast<const __m128i*>(input); const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output); const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) { for (IndexType i = 0; i < NumChunks; ++i)
const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32( {
_mm_load_si128(&in[i * 4 + 0]), const __m128i words0 = _mm_srai_epi16(
_mm_load_si128(&in[i * 4 + 1])), WeightScaleBits); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32( WeightScaleBits);
_mm_load_si128(&in[i * 4 + 2]), const __m128i words1 = _mm_srai_epi16(
_mm_load_si128(&in[i * 4 + 3])), WeightScaleBits); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
WeightScaleBits);
const __m128i packedbytes = _mm_packs_epi16(words0, words1); const __m128i packedbytes = _mm_packs_epi16(words0, words1);
_mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero)); _mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero));
} }
} }
constexpr IndexType Start = constexpr IndexType Start = InputDimensions % SimdWidth == 0
InputDimensions % SimdWidth == 0
? InputDimensions / SimdWidth * SimdWidth ? InputDimensions / SimdWidth * SimdWidth
: InputDimensions / (SimdWidth / 2) * (SimdWidth / 2); : InputDimensions / (SimdWidth / 2) * (SimdWidth / 2);
@ -115,13 +117,14 @@ namespace Stockfish::Eval::NNUE::Layers {
const auto in = reinterpret_cast<const __m128i*>(input); const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output); const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) { for (IndexType i = 0; i < NumChunks; ++i)
const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32( {
_mm_load_si128(&in[i * 4 + 0]), const __m128i words0 = _mm_srai_epi16(
_mm_load_si128(&in[i * 4 + 1])), WeightScaleBits); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32( WeightScaleBits);
_mm_load_si128(&in[i * 4 + 2]), const __m128i words1 = _mm_srai_epi16(
_mm_load_si128(&in[i * 4 + 3])), WeightScaleBits); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
WeightScaleBits);
const __m128i packedbytes = _mm_packs_epi16(words0, words1); const __m128i packedbytes = _mm_packs_epi16(words0, words1);
_mm_store_si128(&out[i], _mm_store_si128(&out[i],
@ -140,7 +143,8 @@ namespace Stockfish::Eval::NNUE::Layers {
const int8x8_t Zero = {0}; const int8x8_t Zero = {0};
const auto in = reinterpret_cast<const int32x4_t*>(input); const auto in = reinterpret_cast<const int32x4_t*>(input);
const auto out = reinterpret_cast<int8x8_t*>(output); const auto out = reinterpret_cast<int8x8_t*>(output);
for (IndexType i = 0; i < NumChunks; ++i) { for (IndexType i = 0; i < NumChunks; ++i)
{
int16x8_t shifted; int16x8_t shifted;
const auto pack = reinterpret_cast<int16x4_t*>(&shifted); const auto pack = reinterpret_cast<int16x4_t*>(&shifted);
pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits); pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits);
@ -152,9 +156,9 @@ namespace Stockfish::Eval::NNUE::Layers {
constexpr IndexType Start = 0; constexpr IndexType Start = 0;
#endif #endif
for (IndexType i = Start; i < InputDimensions; ++i) { for (IndexType i = Start; i < InputDimensions; ++i)
output[i] = static_cast<OutputType>( {
std::clamp(input[i] >> WeightScaleBits, 0, 127)); output[i] = static_cast<OutputType>(std::clamp(input[i] >> WeightScaleBits, 0, 127));
} }
} }
}; };

View file

@ -58,8 +58,8 @@ namespace Stockfish::Simd {
reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3]) reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3])
] ]
*/ */
[[maybe_unused]] static __m512i m512_hadd128x16_interleave( [[maybe_unused]] static __m512i
__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) { m512_hadd128x16_interleave(__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) {
__m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1); __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
__m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1); __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
@ -76,10 +76,7 @@ namespace Stockfish::Simd {
return _mm512_add_epi32(sum0123a, sum0123b); return _mm512_add_epi32(sum0123a, sum0123b);
} }
[[maybe_unused]] static void m512_add_dpbusd_epi32( [[maybe_unused]] static void m512_add_dpbusd_epi32(__m512i& acc, __m512i a, __m512i b) {
__m512i& acc,
__m512i a,
__m512i b) {
#if defined(USE_VNNI) #if defined(USE_VNNI)
acc = _mm512_dpbusd_epi32(acc, a, b); acc = _mm512_dpbusd_epi32(acc, a, b);
@ -90,10 +87,8 @@ namespace Stockfish::Simd {
#endif #endif
} }
[[maybe_unused]] static void m512_add_dpbusd_epi32x2( [[maybe_unused]] static void
__m512i& acc, m512_add_dpbusd_epi32x2(__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
__m512i a0, __m512i b0,
__m512i a1, __m512i b1) {
#if defined(USE_VNNI) #if defined(USE_VNNI)
acc = _mm512_dpbusd_epi32(acc, a0, b0); acc = _mm512_dpbusd_epi32(acc, a0, b0);
@ -118,10 +113,7 @@ namespace Stockfish::Simd {
return _mm_cvtsi128_si32(sum128) + bias; return _mm_cvtsi128_si32(sum128) + bias;
} }
[[maybe_unused]] static void m256_add_dpbusd_epi32( [[maybe_unused]] static void m256_add_dpbusd_epi32(__m256i& acc, __m256i a, __m256i b) {
__m256i& acc,
__m256i a,
__m256i b) {
#if defined(USE_VNNI) #if defined(USE_VNNI)
acc = _mm256_dpbusd_epi32(acc, a, b); acc = _mm256_dpbusd_epi32(acc, a, b);
@ -132,10 +124,8 @@ namespace Stockfish::Simd {
#endif #endif
} }
[[maybe_unused]] static void m256_add_dpbusd_epi32x2( [[maybe_unused]] static void
__m256i& acc, m256_add_dpbusd_epi32x2(__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
__m256i a0, __m256i b0,
__m256i a1, __m256i b1) {
#if defined(USE_VNNI) #if defined(USE_VNNI)
acc = _mm256_dpbusd_epi32(acc, a0, b0); acc = _mm256_dpbusd_epi32(acc, a0, b0);
@ -159,20 +149,15 @@ namespace Stockfish::Simd {
return _mm_cvtsi128_si32(sum) + bias; return _mm_cvtsi128_si32(sum) + bias;
} }
[[maybe_unused]] static void m128_add_dpbusd_epi32( [[maybe_unused]] static void m128_add_dpbusd_epi32(__m128i& acc, __m128i a, __m128i b) {
__m128i& acc,
__m128i a,
__m128i b) {
__m128i product0 = _mm_maddubs_epi16(a, b); __m128i product0 = _mm_maddubs_epi16(a, b);
product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
acc = _mm_add_epi32(acc, product0); acc = _mm_add_epi32(acc, product0);
} }
[[maybe_unused]] static void m128_add_dpbusd_epi32x2( [[maybe_unused]] static void
__m128i& acc, m128_add_dpbusd_epi32x2(__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
__m128i a0, __m128i b0,
__m128i a1, __m128i b1) {
__m128i product0 = _mm_maddubs_epi16(a0, b0); __m128i product0 = _mm_maddubs_epi16(a0, b0);
__m128i product1 = _mm_maddubs_epi16(a1, b1); __m128i product1 = _mm_maddubs_epi16(a1, b1);
@ -186,17 +171,14 @@ namespace Stockfish::Simd {
#if defined(USE_NEON_DOTPROD) #if defined(USE_NEON_DOTPROD)
[[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2( [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(
int32x4_t& acc, int32x4_t& acc, int8x16_t a0, int8x16_t b0, int8x16_t a1, int8x16_t b1) {
int8x16_t a0, int8x16_t b0,
int8x16_t a1, int8x16_t b1) {
acc = vdotq_s32(acc, a0, b0); acc = vdotq_s32(acc, a0, b0);
acc = vdotq_s32(acc, a1, b1); acc = vdotq_s32(acc, a1, b1);
} }
[[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32( [[maybe_unused]] static void
int32x4_t& acc, dotprod_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
int8x16_t a, int8x16_t b) {
acc = vdotq_s32(acc, a, b); acc = vdotq_s32(acc, a, b);
} }
@ -216,10 +198,8 @@ namespace Stockfish::Simd {
return neon_m128_reduce_add_epi32(sum) + bias; return neon_m128_reduce_add_epi32(sum) + bias;
} }
[[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2( [[maybe_unused]] static void
int32x4_t& acc, neon_m128_add_dpbusd_epi32x2(int32x4_t& acc, int8x8_t a0, int8x8_t b0, int8x8_t a1, int8x8_t b1) {
int8x8_t a0, int8x8_t b0,
int8x8_t a1, int8x8_t b1) {
int16x8_t product = vmull_s8(a0, b0); int16x8_t product = vmull_s8(a0, b0);
product = vmlal_s8(product, a1, b1); product = vmlal_s8(product, a1, b1);
@ -228,9 +208,7 @@ namespace Stockfish::Simd {
#endif #endif
#if USE_NEON >= 8 #if USE_NEON >= 8
[[maybe_unused]] static void neon_m128_add_dpbusd_epi32( [[maybe_unused]] static void neon_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
int32x4_t& acc,
int8x16_t a, int8x16_t b) {
int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
int16x8_t product1 = vmull_high_s8(a, b); int16x8_t product1 = vmull_high_s8(a, b);

View file

@ -53,18 +53,13 @@ namespace Stockfish::Eval::NNUE::Layers {
} }
// Read network parameters // Read network parameters
bool read_parameters(std::istream&) { bool read_parameters(std::istream&) { return true; }
return true;
}
// Write network parameters // Write network parameters
bool write_parameters(std::ostream&) const { bool write_parameters(std::ostream&) const { return true; }
return true;
}
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if defined(USE_SSE2) #if defined(USE_SSE2)
constexpr IndexType NumChunks = InputDimensions / 16; constexpr IndexType NumChunks = InputDimensions / 16;
@ -72,13 +67,12 @@ namespace Stockfish::Eval::NNUE::Layers {
static_assert(WeightScaleBits == 6); static_assert(WeightScaleBits == 6);
const auto in = reinterpret_cast<const __m128i*>(input); const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output); const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) { for (IndexType i = 0; i < NumChunks; ++i)
__m128i words0 = _mm_packs_epi32( {
_mm_load_si128(&in[i * 4 + 0]), __m128i words0 =
_mm_load_si128(&in[i * 4 + 1])); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1]));
__m128i words1 = _mm_packs_epi32( __m128i words1 =
_mm_load_si128(&in[i * 4 + 2]), _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3]));
_mm_load_si128(&in[i * 4 + 3]));
// We shift by WeightScaleBits * 2 = 12 and divide by 128 // We shift by WeightScaleBits * 2 = 12 and divide by 128
// which is an additional shift-right of 7, meaning 19 in total. // which is an additional shift-right of 7, meaning 19 in total.
@ -94,7 +88,8 @@ namespace Stockfish::Eval::NNUE::Layers {
constexpr IndexType Start = 0; constexpr IndexType Start = 0;
#endif #endif
for (IndexType i = Start; i < InputDimensions; ++i) { for (IndexType i = Start; i < InputDimensions; ++i)
{
output[i] = static_cast<OutputType>( output[i] = static_cast<OutputType>(
// Really should be /127 but we need to make it fast so we right shift // Really should be /127 but we need to make it fast so we right shift
// by an extra 7 bits instead. Needs to be accounted for in the trainer. // by an extra 7 bits instead. Needs to be accounted for in the trainer.

View file

@ -42,8 +42,7 @@ constexpr IndexType TransformedFeatureDimensions = 2560;
constexpr IndexType PSQTBuckets = 8; constexpr IndexType PSQTBuckets = 8;
constexpr IndexType LayerStacks = 8; constexpr IndexType LayerStacks = 8;
struct Network struct Network {
{
static constexpr int FC_0_OUTPUTS = 15; static constexpr int FC_0_OUTPUTS = 15;
static constexpr int FC_1_OUTPUTS = 32; static constexpr int FC_1_OUTPUTS = 32;
@ -71,37 +70,29 @@ struct Network
// Read network parameters // Read network parameters
bool read_parameters(std::istream& stream) { bool read_parameters(std::istream& stream) {
return fc_0.read_parameters(stream) return fc_0.read_parameters(stream) && ac_0.read_parameters(stream)
&& ac_0.read_parameters(stream) && fc_1.read_parameters(stream) && ac_1.read_parameters(stream)
&& fc_1.read_parameters(stream)
&& ac_1.read_parameters(stream)
&& fc_2.read_parameters(stream); && fc_2.read_parameters(stream);
} }
// Write network parameters // Write network parameters
bool write_parameters(std::ostream& stream) const { bool write_parameters(std::ostream& stream) const {
return fc_0.write_parameters(stream) return fc_0.write_parameters(stream) && ac_0.write_parameters(stream)
&& ac_0.write_parameters(stream) && fc_1.write_parameters(stream) && ac_1.write_parameters(stream)
&& fc_1.write_parameters(stream)
&& ac_1.write_parameters(stream)
&& fc_2.write_parameters(stream); && fc_2.write_parameters(stream);
} }
std::int32_t propagate(const TransformedFeatureType* transformedFeatures) std::int32_t propagate(const TransformedFeatureType* transformedFeatures) {
{ struct alignas(CacheLineSize) Buffer {
struct alignas(CacheLineSize) Buffer
{
alignas(CacheLineSize) decltype(fc_0)::OutputBuffer fc_0_out; alignas(CacheLineSize) decltype(fc_0)::OutputBuffer fc_0_out;
alignas(CacheLineSize) decltype(ac_sqr_0)::OutputType ac_sqr_0_out[ceil_to_multiple<IndexType>(FC_0_OUTPUTS * 2, 32)]; alignas(CacheLineSize) decltype(ac_sqr_0)::OutputType
ac_sqr_0_out[ceil_to_multiple<IndexType>(FC_0_OUTPUTS * 2, 32)];
alignas(CacheLineSize) decltype(ac_0)::OutputBuffer ac_0_out; alignas(CacheLineSize) decltype(ac_0)::OutputBuffer ac_0_out;
alignas(CacheLineSize) decltype(fc_1)::OutputBuffer fc_1_out; alignas(CacheLineSize) decltype(fc_1)::OutputBuffer fc_1_out;
alignas(CacheLineSize) decltype(ac_1)::OutputBuffer ac_1_out; alignas(CacheLineSize) decltype(ac_1)::OutputBuffer ac_1_out;
alignas(CacheLineSize) decltype(fc_2)::OutputBuffer fc_2_out; alignas(CacheLineSize) decltype(fc_2)::OutputBuffer fc_2_out;
Buffer() Buffer() { std::memset(this, 0, sizeof(*this)); }
{
std::memset(this, 0, sizeof(*this));
}
}; };
#if defined(__clang__) && (__APPLE__) #if defined(__clang__) && (__APPLE__)
@ -116,14 +107,16 @@ struct Network
fc_0.propagate(transformedFeatures, buffer.fc_0_out); fc_0.propagate(transformedFeatures, buffer.fc_0_out);
ac_sqr_0.propagate(buffer.fc_0_out, buffer.ac_sqr_0_out); ac_sqr_0.propagate(buffer.fc_0_out, buffer.ac_sqr_0_out);
ac_0.propagate(buffer.fc_0_out, buffer.ac_0_out); ac_0.propagate(buffer.fc_0_out, buffer.ac_0_out);
std::memcpy(buffer.ac_sqr_0_out + FC_0_OUTPUTS, buffer.ac_0_out, FC_0_OUTPUTS * sizeof(decltype(ac_0)::OutputType)); std::memcpy(buffer.ac_sqr_0_out + FC_0_OUTPUTS, buffer.ac_0_out,
FC_0_OUTPUTS * sizeof(decltype(ac_0)::OutputType));
fc_1.propagate(buffer.ac_sqr_0_out, buffer.fc_1_out); fc_1.propagate(buffer.ac_sqr_0_out, buffer.fc_1_out);
ac_1.propagate(buffer.fc_1_out, buffer.ac_1_out); ac_1.propagate(buffer.fc_1_out, buffer.ac_1_out);
fc_2.propagate(buffer.ac_1_out, buffer.fc_2_out); fc_2.propagate(buffer.ac_1_out, buffer.fc_2_out);
// buffer.fc_0_out[FC_0_OUTPUTS] is such that 1.0 is equal to 127*(1<<WeightScaleBits) in quantized form // buffer.fc_0_out[FC_0_OUTPUTS] is such that 1.0 is equal to 127*(1<<WeightScaleBits) in quantized form
// but we want 1.0 to be equal to 600*OutputScale // but we want 1.0 to be equal to 600*OutputScale
std::int32_t fwdOut = int(buffer.fc_0_out[FC_0_OUTPUTS]) * (600*OutputScale) / (127*(1<<WeightScaleBits)); std::int32_t fwdOut =
int(buffer.fc_0_out[FC_0_OUTPUTS]) * (600 * OutputScale) / (127 * (1 << WeightScaleBits));
std::int32_t outputValue = buffer.fc_2_out[0] + fwdOut; std::int32_t outputValue = buffer.fc_2_out[0] + fwdOut;
return outputValue; return outputValue;

View file

@ -203,12 +203,12 @@ namespace Stockfish::Eval::NNUE {
if ((byte & 0x80) == 0) if ((byte & 0x80) == 0)
{ {
out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0) ? result out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0)
? result
: result | ~((1 << shift) - 1); : result | ~((1 << shift) - 1);
break; break;
} }
} } while (shift < sizeof(IntType) * 8);
while (shift < sizeof(IntType) * 8);
} }
assert(bytes_left == 0); assert(bytes_left == 0);
@ -237,8 +237,7 @@ namespace Stockfish::Eval::NNUE {
byte = value & 0x7f; byte = value & 0x7f;
value >>= 7; value >>= 7;
++byte_count; ++byte_count;
} } while ((byte & 0x40) == 0 ? value != 0 : value != -1);
while ((byte & 0x40) == 0 ? value != 0 : value != -1);
} }
write_little_endian(stream, byte_count); write_little_endian(stream, byte_count);

View file

@ -125,7 +125,8 @@ namespace Stockfish::Eval::NNUE {
#define vec_add_16(a, b) vaddq_s16(a, b) #define vec_add_16(a, b) vaddq_s16(a, b)
#define vec_sub_16(a, b) vsubq_s16(a, b) #define vec_sub_16(a, b) vsubq_s16(a, b)
#define vec_mul_16(a, b) vmulq_s16(a, b) #define vec_mul_16(a, b) vmulq_s16(a, b)
#define vec_zero() vec_t{0} #define vec_zero() \
vec_t { 0 }
#define vec_set_16(a) vdupq_n_s16(a) #define vec_set_16(a) vdupq_n_s16(a)
#define vec_max_16(a, b) vmaxq_s16(a, b) #define vec_max_16(a, b) vmaxq_s16(a, b)
#define vec_min_16(a, b) vminq_s16(a, b) #define vec_min_16(a, b) vminq_s16(a, b)
@ -139,7 +140,8 @@ namespace Stockfish::Eval::NNUE {
#define vec_store_psqt(a, b) *(a) = (b) #define vec_store_psqt(a, b) *(a) = (b)
#define vec_add_psqt_32(a, b) vaddq_s32(a, b) #define vec_add_psqt_32(a, b) vaddq_s32(a, b)
#define vec_sub_psqt_32(a, b) vsubq_s32(a, b) #define vec_sub_psqt_32(a, b) vsubq_s32(a, b)
#define vec_zero_psqt() psqt_vec_t{0} #define vec_zero_psqt() \
psqt_vec_t { 0 }
#define NumRegistersSIMD 16 #define NumRegistersSIMD 16
#define MaxChunkSize 16 #define MaxChunkSize 16
@ -161,12 +163,8 @@ namespace Stockfish::Eval::NNUE {
#pragma GCC diagnostic ignored "-Wignored-attributes" #pragma GCC diagnostic ignored "-Wignored-attributes"
#endif #endif
template <typename SIMDRegisterType, template<typename SIMDRegisterType, typename LaneType, int NumLanes, int MaxRegisters>
typename LaneType, static constexpr int BestRegisterCount() {
int NumLanes,
int MaxRegisters>
static constexpr int BestRegisterCount()
{
#define RegisterSize sizeof(SIMDRegisterType) #define RegisterSize sizeof(SIMDRegisterType)
#define LaneSize sizeof(LaneType) #define LaneSize sizeof(LaneType)
@ -189,15 +187,16 @@ namespace Stockfish::Eval::NNUE {
return 1; return 1;
} }
static constexpr int NumRegs = BestRegisterCount<vec_t, WeightType, TransformedFeatureDimensions, NumRegistersSIMD>(); static constexpr int NumRegs =
static constexpr int NumPsqtRegs = BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>(); BestRegisterCount<vec_t, WeightType, TransformedFeatureDimensions, NumRegistersSIMD>();
static constexpr int NumPsqtRegs =
BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();
#if defined(__GNUC__) #if defined(__GNUC__)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
#endif #endif
// Input feature converter // Input feature converter
class FeatureTransformer { class FeatureTransformer {
@ -221,8 +220,7 @@ namespace Stockfish::Eval::NNUE {
static constexpr IndexType OutputDimensions = HalfDimensions; static constexpr IndexType OutputDimensions = HalfDimensions;
// Size of forward propagation buffer // Size of forward propagation buffer
static constexpr std::size_t BufferSize = static constexpr std::size_t BufferSize = OutputDimensions * sizeof(OutputType);
OutputDimensions * sizeof(OutputType);
// Hash value embedded in the evaluation file // Hash value embedded in the evaluation file
static constexpr std::uint32_t get_hash_value() { static constexpr std::uint32_t get_hash_value() {
@ -258,10 +256,9 @@ namespace Stockfish::Eval::NNUE {
const auto& accumulation = pos.state()->accumulator.accumulation; const auto& accumulation = pos.state()->accumulator.accumulation;
const auto& psqtAccumulation = pos.state()->accumulator.psqtAccumulation; const auto& psqtAccumulation = pos.state()->accumulator.psqtAccumulation;
const auto psqt = ( const auto psqt =
psqtAccumulation[perspectives[0]][bucket] (psqtAccumulation[perspectives[0]][bucket] - psqtAccumulation[perspectives[1]][bucket])
- psqtAccumulation[perspectives[1]][bucket] / 2;
) / 2;
for (IndexType p = 0; p < 2; ++p) for (IndexType p = 0; p < 2; ++p)
@ -278,7 +275,8 @@ namespace Stockfish::Eval::NNUE {
vec_t One = vec_set_16(127); vec_t One = vec_set_16(127);
const vec_t* in0 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][0])); const vec_t* in0 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][0]));
const vec_t* in1 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2])); const vec_t* in1 =
reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
vec_t* out = reinterpret_cast<vec_t*>(output + offset); vec_t* out = reinterpret_cast<vec_t*>(output + offset);
for (IndexType j = 0; j < NumOutputChunks; j += 1) for (IndexType j = 0; j < NumOutputChunks; j += 1)
@ -296,9 +294,11 @@ namespace Stockfish::Eval::NNUE {
#else #else
for (IndexType j = 0; j < HalfDimensions / 2; ++j) { for (IndexType j = 0; j < HalfDimensions / 2; ++j)
{
BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0]; BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];
BiasType sum1 = accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2]; BiasType sum1 =
accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
sum0 = std::clamp<BiasType>(sum0, 0, 127); sum0 = std::clamp<BiasType>(sum0, 0, 127);
sum1 = std::clamp<BiasType>(sum1, 0, 127); sum1 = std::clamp<BiasType>(sum1, 0, 127);
output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 128); output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 128);
@ -317,7 +317,8 @@ namespace Stockfish::Eval::NNUE {
private: private:
template<Color Perspective> template<Color Perspective>
[[nodiscard]] std::pair<StateInfo*, StateInfo*> try_find_computed_accumulator(const Position& pos) const { [[nodiscard]] std::pair<StateInfo*, StateInfo*>
try_find_computed_accumulator(const Position& pos) const {
// Look for a usable accumulator of an earlier position. We keep track // Look for a usable accumulator of an earlier position. We keep track
// of the estimated gain in terms of features to be added/subtracted. // of the estimated gain in terms of features to be added/subtracted.
StateInfo *st = pos.state(), *next = nullptr; StateInfo *st = pos.state(), *next = nullptr;
@ -340,7 +341,9 @@ namespace Stockfish::Eval::NNUE {
// by repeatedly applying ->previous from states_to_update[i+1] or states_to_update[i] == nullptr. // by repeatedly applying ->previous from states_to_update[i+1] or states_to_update[i] == nullptr.
// computed_st must be reachable by repeatedly applying ->previous on states_to_update[0], if not nullptr. // computed_st must be reachable by repeatedly applying ->previous on states_to_update[0], if not nullptr.
template<Color Perspective, size_t N> template<Color Perspective, size_t N>
void update_accumulator_incremental(const Position& pos, StateInfo* computed_st, StateInfo* states_to_update[N]) const { void update_accumulator_incremental(const Position& pos,
StateInfo* computed_st,
StateInfo* states_to_update[N]) const {
static_assert(N > 0); static_assert(N > 0);
assert(states_to_update[N - 1] == nullptr); assert(states_to_update[N - 1] == nullptr);
@ -366,7 +369,9 @@ namespace Stockfish::Eval::NNUE {
FeatureSet::IndexList removed[N - 1], added[N - 1]; FeatureSet::IndexList removed[N - 1], added[N - 1];
{ {
int i = N-2; // last potential state to update. Skip last element because it must be nullptr. int i =
N
- 2; // last potential state to update. Skip last element because it must be nullptr.
while (states_to_update[i] == nullptr) while (states_to_update[i] == nullptr)
--i; --i;
@ -379,8 +384,8 @@ namespace Stockfish::Eval::NNUE {
const StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1]; const StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1];
for (; st2 != end_state; st2 = st2->previous) for (; st2 != end_state; st2 = st2->previous)
FeatureSet::append_changed_indices<Perspective>( FeatureSet::append_changed_indices<Perspective>(ksq, st2->dirtyPiece,
ksq, st2->dirtyPiece, removed[i], added[i]); removed[i], added[i]);
} }
} }
@ -389,14 +394,13 @@ namespace Stockfish::Eval::NNUE {
// Now update the accumulators listed in states_to_update[], where the last element is a sentinel. // Now update the accumulators listed in states_to_update[], where the last element is a sentinel.
#ifdef VECTOR #ifdef VECTOR
if ( states_to_update[1] == nullptr if (states_to_update[1] == nullptr && (removed[0].size() == 1 || removed[0].size() == 2)
&& (removed[0].size() == 1 || removed[0].size() == 2)
&& added[0].size() == 1) && added[0].size() == 1)
{ {
assert(states_to_update[0]); assert(states_to_update[0]);
auto accIn = reinterpret_cast<const vec_t*>( auto accIn =
&st->accumulator.accumulation[Perspective][0]); reinterpret_cast<const vec_t*>(&st->accumulator.accumulation[Perspective][0]);
auto accOut = reinterpret_cast<vec_t*>( auto accOut = reinterpret_cast<vec_t*>(
&states_to_update[0]->accumulator.accumulation[Perspective][0]); &states_to_update[0]->accumulator.accumulation[Perspective][0]);
@ -407,7 +411,8 @@ namespace Stockfish::Eval::NNUE {
if (removed[0].size() == 1) if (removed[0].size() == 1)
{ {
for (IndexType k = 0; k < HalfDimensions * sizeof(std::int16_t) / sizeof(vec_t); ++k) for (IndexType k = 0; k < HalfDimensions * sizeof(std::int16_t) / sizeof(vec_t);
++k)
accOut[k] = vec_add_16(vec_sub_16(accIn[k], columnR0[k]), columnA[k]); accOut[k] = vec_add_16(vec_sub_16(accIn[k], columnR0[k]), columnA[k]);
} }
else else
@ -415,9 +420,9 @@ namespace Stockfish::Eval::NNUE {
const IndexType offsetR1 = HalfDimensions * removed[0][1]; const IndexType offsetR1 = HalfDimensions * removed[0][1];
auto columnR1 = reinterpret_cast<const vec_t*>(&weights[offsetR1]); auto columnR1 = reinterpret_cast<const vec_t*>(&weights[offsetR1]);
for (IndexType k = 0; k < HalfDimensions * sizeof(std::int16_t) / sizeof(vec_t); ++k) for (IndexType k = 0; k < HalfDimensions * sizeof(std::int16_t) / sizeof(vec_t);
accOut[k] = vec_sub_16( ++k)
vec_add_16(accIn[k], columnA[k]), accOut[k] = vec_sub_16(vec_add_16(accIn[k], columnA[k]),
vec_add_16(columnR0[k], columnR1[k])); vec_add_16(columnR0[k], columnR1[k]));
} }
@ -433,18 +438,20 @@ namespace Stockfish::Eval::NNUE {
if (removed[0].size() == 1) if (removed[0].size() == 1)
{ {
for (std::size_t k = 0; k < PSQTBuckets * sizeof(std::int32_t) / sizeof(psqt_vec_t); ++k) for (std::size_t k = 0; k < PSQTBuckets * sizeof(std::int32_t) / sizeof(psqt_vec_t);
accPsqtOut[k] = vec_add_psqt_32(vec_sub_psqt_32( ++k)
accPsqtIn[k], columnPsqtR0[k]), columnPsqtA[k]); accPsqtOut[k] = vec_add_psqt_32(vec_sub_psqt_32(accPsqtIn[k], columnPsqtR0[k]),
columnPsqtA[k]);
} }
else else
{ {
const IndexType offsetPsqtR1 = PSQTBuckets * removed[0][1]; const IndexType offsetPsqtR1 = PSQTBuckets * removed[0][1];
auto columnPsqtR1 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR1]); auto columnPsqtR1 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR1]);
for (std::size_t k = 0; k < PSQTBuckets * sizeof(std::int32_t) / sizeof(psqt_vec_t); ++k) for (std::size_t k = 0; k < PSQTBuckets * sizeof(std::int32_t) / sizeof(psqt_vec_t);
accPsqtOut[k] = vec_sub_psqt_32( ++k)
vec_add_psqt_32(accPsqtIn[k], columnPsqtA[k]), accPsqtOut[k] =
vec_sub_psqt_32(vec_add_psqt_32(accPsqtIn[k], columnPsqtA[k]),
vec_add_psqt_32(columnPsqtR0[k], columnPsqtR1[k])); vec_add_psqt_32(columnPsqtR0[k], columnPsqtR1[k]));
} }
} }
@ -516,7 +523,8 @@ namespace Stockfish::Eval::NNUE {
// Store accumulator // Store accumulator
auto accTilePsqtOut = reinterpret_cast<psqt_vec_t*>( auto accTilePsqtOut = reinterpret_cast<psqt_vec_t*>(
&states_to_update[i]->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); &states_to_update[i]
->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]);
for (std::size_t k = 0; k < NumPsqtRegs; ++k) for (std::size_t k = 0; k < NumPsqtRegs; ++k)
vec_store_psqt(&accTilePsqtOut[k], psqt[k]); vec_store_psqt(&accTilePsqtOut[k], psqt[k]);
} }
@ -530,7 +538,8 @@ namespace Stockfish::Eval::NNUE {
HalfDimensions * sizeof(BiasType)); HalfDimensions * sizeof(BiasType));
for (std::size_t k = 0; k < PSQTBuckets; ++k) for (std::size_t k = 0; k < PSQTBuckets; ++k)
states_to_update[i]->accumulator.psqtAccumulation[Perspective][k] = st->accumulator.psqtAccumulation[Perspective][k]; states_to_update[i]->accumulator.psqtAccumulation[Perspective][k] =
st->accumulator.psqtAccumulation[Perspective][k];
st = states_to_update[i]; st = states_to_update[i];
@ -543,7 +552,8 @@ namespace Stockfish::Eval::NNUE {
st->accumulator.accumulation[Perspective][j] -= weights[offset + j]; st->accumulator.accumulation[Perspective][j] -= weights[offset + j];
for (std::size_t k = 0; k < PSQTBuckets; ++k) for (std::size_t k = 0; k < PSQTBuckets; ++k)
st->accumulator.psqtAccumulation[Perspective][k] -= psqtWeights[index * PSQTBuckets + k]; st->accumulator.psqtAccumulation[Perspective][k] -=
psqtWeights[index * PSQTBuckets + k];
} }
// Difference calculation for the activated features // Difference calculation for the activated features
@ -555,7 +565,8 @@ namespace Stockfish::Eval::NNUE {
st->accumulator.accumulation[Perspective][j] += weights[offset + j]; st->accumulator.accumulation[Perspective][j] += weights[offset + j];
for (std::size_t k = 0; k < PSQTBuckets; ++k) for (std::size_t k = 0; k < PSQTBuckets; ++k)
st->accumulator.psqtAccumulation[Perspective][k] += psqtWeights[index * PSQTBuckets + k]; st->accumulator.psqtAccumulation[Perspective][k] +=
psqtWeights[index * PSQTBuckets + k];
} }
} }
#endif #endif
@ -581,8 +592,7 @@ namespace Stockfish::Eval::NNUE {
#ifdef VECTOR #ifdef VECTOR
for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j) for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
{ {
auto biasesTile = reinterpret_cast<const vec_t*>( auto biasesTile = reinterpret_cast<const vec_t*>(&biases[j * TileHeight]);
&biases[j * TileHeight]);
for (IndexType k = 0; k < NumRegs; ++k) for (IndexType k = 0; k < NumRegs; ++k)
acc[k] = biasesTile[k]; acc[k] = biasesTile[k];
@ -595,8 +605,8 @@ namespace Stockfish::Eval::NNUE {
acc[k] = vec_add_16(acc[k], column[k]); acc[k] = vec_add_16(acc[k], column[k]);
} }
auto accTile = reinterpret_cast<vec_t*>( auto accTile =
&accumulator.accumulation[Perspective][j * TileHeight]); reinterpret_cast<vec_t*>(&accumulator.accumulation[Perspective][j * TileHeight]);
for (unsigned k = 0; k < NumRegs; k++) for (unsigned k = 0; k < NumRegs; k++)
vec_store(&accTile[k], acc[k]); vec_store(&accTile[k], acc[k]);
} }
@ -636,7 +646,8 @@ namespace Stockfish::Eval::NNUE {
accumulator.accumulation[Perspective][j] += weights[offset + j]; accumulator.accumulation[Perspective][j] += weights[offset + j];
for (std::size_t k = 0; k < PSQTBuckets; ++k) for (std::size_t k = 0; k < PSQTBuckets; ++k)
accumulator.psqtAccumulation[Perspective][k] += psqtWeights[index * PSQTBuckets + k]; accumulator.psqtAccumulation[Perspective][k] +=
psqtWeights[index * PSQTBuckets + k];
} }
#endif #endif
} }
@ -682,8 +693,8 @@ namespace Stockfish::Eval::NNUE {
// 1. for the current position // 1. for the current position
// 2. the next accumulator after the computed one // 2. the next accumulator after the computed one
// The heuristic may change in the future. // The heuristic may change in the future.
StateInfo *states_to_update[3] = StateInfo* states_to_update[3] = {next, next == pos.state() ? nullptr : pos.state(),
{ next, next == pos.state() ? nullptr : pos.state(), nullptr }; nullptr};
update_accumulator_incremental<Perspective, 3>(pos, oldest_st, states_to_update); update_accumulator_incremental<Perspective, 3>(pos, oldest_st, states_to_update);
} }

View file

@ -76,15 +76,13 @@ std::ostream& operator<<(std::ostream& os, const Position& pos) {
} }
os << " a b c d e f g h\n" os << " a b c d e f g h\n"
<< "\nFen: " << pos.fen() << "\nKey: " << std::hex << std::uppercase << "\nFen: " << pos.fen() << "\nKey: " << std::hex << std::uppercase << std::setfill('0')
<< std::setfill('0') << std::setw(16) << pos.key() << std::setw(16) << pos.key() << std::setfill(' ') << std::dec << "\nCheckers: ";
<< std::setfill(' ') << std::dec << "\nCheckers: ";
for (Bitboard b = pos.checkers(); b;) for (Bitboard b = pos.checkers(); b;)
os << UCI::square(pop_lsb(b)) << " "; os << UCI::square(pop_lsb(b)) << " ";
if ( int(Tablebases::MaxCardinality) >= popcount(pos.pieces()) if (int(Tablebases::MaxCardinality) >= popcount(pos.pieces()) && !pos.can_castle(ANY_CASTLING))
&& !pos.can_castle(ANY_CASTLING))
{ {
StateInfo st; StateInfo st;
ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize); ASSERT_ALIGNED(&st, Eval::NNUE::CacheLineSize);
@ -220,7 +218,8 @@ Position& Position::set(const string& fenStr, bool isChess960, StateInfo* si, Th
else if (token == '/') else if (token == '/')
sq += 2 * SOUTH; sq += 2 * SOUTH;
else if ((idx = PieceToChar.find(token)) != string::npos) { else if ((idx = PieceToChar.find(token)) != string::npos)
{
put_piece(Piece(idx), sq); put_piece(Piece(idx), sq);
++sq; ++sq;
} }
@ -245,10 +244,12 @@ Position& Position::set(const string& fenStr, bool isChess960, StateInfo* si, Th
token = char(toupper(token)); token = char(toupper(token));
if (token == 'K') if (token == 'K')
for (rsq = relative_square(c, SQ_H1); piece_on(rsq) != rook; --rsq) {} for (rsq = relative_square(c, SQ_H1); piece_on(rsq) != rook; --rsq)
{}
else if (token == 'Q') else if (token == 'Q')
for (rsq = relative_square(c, SQ_A1); piece_on(rsq) != rook; ++rsq) {} for (rsq = relative_square(c, SQ_A1); piece_on(rsq) != rook; ++rsq)
{}
else if (token >= 'A' && token <= 'H') else if (token >= 'A' && token <= 'H')
rsq = make_square(File(token - 'A'), relative_rank(c, RANK_1)); rsq = make_square(File(token - 'A'), relative_rank(c, RANK_1));
@ -313,8 +314,7 @@ void Position::set_castling_right(Color c, Square rfrom) {
Square kto = relative_square(c, cr & KING_SIDE ? SQ_G1 : SQ_C1); Square kto = relative_square(c, cr & KING_SIDE ? SQ_G1 : SQ_C1);
Square rto = relative_square(c, cr & KING_SIDE ? SQ_F1 : SQ_D1); Square rto = relative_square(c, cr & KING_SIDE ? SQ_F1 : SQ_D1);
castlingPath[cr] = (between_bb(rfrom, rto) | between_bb(kfrom, kto)) castlingPath[cr] = (between_bb(rfrom, rto) | between_bb(kfrom, kto)) & ~(kfrom | rfrom);
& ~(kfrom | rfrom);
} }
@ -388,8 +388,8 @@ Position& Position::set(const string& code, Color c, StateInfo* si) {
std::transform(sides[c].begin(), sides[c].end(), sides[c].begin(), tolower); std::transform(sides[c].begin(), sides[c].end(), sides[c].begin(), tolower);
string fenStr = "8/" + sides[0] + char(8 - sides[0].length() + '0') + "/8/8/8/8/" string fenStr = "8/" + sides[0] + char(8 - sides[0].length() + '0') + "/8/8/8/8/" + sides[1]
+ sides[1] + char(8 - sides[1].length() + '0') + "/8 w - - 0 10"; + char(8 - sides[1].length() + '0') + "/8 w - - 0 10";
return set(fenStr, false, si, nullptr); return set(fenStr, false, si, nullptr);
} }
@ -438,8 +438,8 @@ string Position::fen() const {
if (!can_castle(ANY_CASTLING)) if (!can_castle(ANY_CASTLING))
ss << '-'; ss << '-';
ss << (ep_square() == SQ_NONE ? " - " : " " + UCI::square(ep_square()) + " ") ss << (ep_square() == SQ_NONE ? " - " : " " + UCI::square(ep_square()) + " ") << st->rule50
<< st->rule50 << " " << 1 + (gamePly - (sideToMove == BLACK)) / 2; << " " << 1 + (gamePly - (sideToMove == BLACK)) / 2;
return ss.str(); return ss.str();
} }
@ -456,7 +456,8 @@ void Position::update_slider_blockers(Color c) const {
// Snipers are sliders that attack 's' when a piece and other snipers are removed // Snipers are sliders that attack 's' when a piece and other snipers are removed
Bitboard snipers = ((attacks_bb<ROOK>(ksq) & pieces(QUEEN, ROOK)) Bitboard snipers = ((attacks_bb<ROOK>(ksq) & pieces(QUEEN, ROOK))
| (attacks_bb<BISHOP>(ksq) & pieces(QUEEN, BISHOP))) & pieces(~c); | (attacks_bb<BISHOP>(ksq) & pieces(QUEEN, BISHOP)))
& pieces(~c);
Bitboard occupancy = pieces() ^ snipers; Bitboard occupancy = pieces() ^ snipers;
while (snipers) while (snipers)
@ -544,8 +545,7 @@ bool Position::legal(Move m) const {
// A non-king move is legal if and only if it is not pinned or it // A non-king move is legal if and only if it is not pinned or it
// is moving along the ray towards or away from the king. // is moving along the ray towards or away from the king.
return !(blockers_for_king(us) & from) return !(blockers_for_king(us) & from) || aligned(from, to, square<KING>(us));
|| aligned(from, to, square<KING>(us));
} }
@ -589,9 +589,7 @@ bool Position::pseudo_legal(const Move m) const {
if (!(pawn_attacks_bb(us, from) & pieces(~us) & to) // Not a capture if (!(pawn_attacks_bb(us, from) & pieces(~us) & to) // Not a capture
&& !((from + pawn_push(us) == to) && empty(to)) // Not a single push && !((from + pawn_push(us) == to) && empty(to)) // Not a single push
&& !((from + 2 * pawn_push(us) == to) // Not a double push && !((from + 2 * pawn_push(us) == to) // Not a double push
&& (relative_rank(us, from) == RANK_2) && (relative_rank(us, from) == RANK_2) && empty(to) && empty(to - pawn_push(us))))
&& empty(to)
&& empty(to - pawn_push(us))))
return false; return false;
} }
else if (!(attacks_bb(type_of(pc), from, pieces()) & to)) else if (!(attacks_bb(type_of(pc), from, pieces()) & to))
@ -638,8 +636,7 @@ bool Position::gives_check(Move m) const {
// Is there a discovered check? // Is there a discovered check?
if (blockers_for_king(~sideToMove) & from) if (blockers_for_king(~sideToMove) & from)
return !aligned(from, to, square<KING>(~sideToMove)) return !aligned(from, to, square<KING>(~sideToMove)) || type_of(m) == CASTLING;
|| type_of(m) == CASTLING;
switch (type_of(m)) switch (type_of(m))
{ {
@ -653,13 +650,13 @@ bool Position::gives_check(Move m) const {
// of direct checks and ordinary discovered check, so the only case we // of direct checks and ordinary discovered check, so the only case we
// need to handle is the unusual case of a discovered check through // need to handle is the unusual case of a discovered check through
// the captured pawn. // the captured pawn.
case EN_PASSANT: case EN_PASSANT : {
{
Square capsq = make_square(file_of(to), rank_of(from)); Square capsq = make_square(file_of(to), rank_of(from));
Bitboard b = (pieces() ^ from ^ capsq) | to; Bitboard b = (pieces() ^ from ^ capsq) | to;
return (attacks_bb<ROOK>(square<KING>(~sideToMove), b) & pieces(sideToMove, QUEEN, ROOK)) return (attacks_bb<ROOK>(square<KING>(~sideToMove), b) & pieces(sideToMove, QUEEN, ROOK))
| (attacks_bb<BISHOP>(square<KING>(~sideToMove), b) & pieces(sideToMove, QUEEN, BISHOP)); | (attacks_bb<BISHOP>(square<KING>(~sideToMove), b)
& pieces(sideToMove, QUEEN, BISHOP));
} }
default : //CASTLING default : //CASTLING
{ {
@ -822,8 +819,8 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) {
// Update hash keys // Update hash keys
k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[promotion][to]; k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[promotion][to];
st->materialKey ^= Zobrist::psq[promotion][pieceCount[promotion]-1] st->materialKey ^=
^ Zobrist::psq[pc][pieceCount[pc]]; Zobrist::psq[promotion][pieceCount[promotion] - 1] ^ Zobrist::psq[pc][pieceCount[pc]];
// Update material // Update material
st->nonPawnMaterial[us] += PieceValue[promotion]; st->nonPawnMaterial[us] += PieceValue[promotion];
@ -959,7 +956,8 @@ void Position::do_castling(Color us, Square from, Square& to, Square& rfrom, Squ
// Remove both pieces first since squares could overlap in Chess960 // Remove both pieces first since squares could overlap in Chess960
remove_piece(Do ? from : to); remove_piece(Do ? from : to);
remove_piece(Do ? rfrom : rto); remove_piece(Do ? rfrom : rto);
board[Do ? from : to] = board[Do ? rfrom : rto] = NO_PIECE; // remove_piece does not do this for us board[Do ? from : to] = board[Do ? rfrom : rto] =
NO_PIECE; // remove_piece does not do this for us
put_piece(make_piece(us, KING), Do ? to : from); put_piece(make_piece(us, KING), Do ? to : from);
put_piece(make_piece(us, ROOK), Do ? rto : rfrom); put_piece(make_piece(us, ROOK), Do ? rto : rfrom);
} }
@ -1033,8 +1031,7 @@ Key Position::key_after(Move m) const {
k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[pc][from]; k ^= Zobrist::psq[pc][to] ^ Zobrist::psq[pc][from];
return (captured || type_of(pc) == PAWN) return (captured || type_of(pc) == PAWN) ? k : adjust_key50<true>(k);
? k : adjust_key50<true>(k);
} }
@ -1195,8 +1192,7 @@ bool Position::has_game_cycle(int ply) const {
stp = stp->previous->previous; stp = stp->previous->previous;
Key moveKey = originalKey ^ stp->key; Key moveKey = originalKey ^ stp->key;
if ( (j = H1(moveKey), cuckoo[j] == moveKey) if ((j = H1(moveKey), cuckoo[j] == moveKey) || (j = H2(moveKey), cuckoo[j] == moveKey))
|| (j = H2(moveKey), cuckoo[j] == moveKey))
{ {
Move move = cuckooMove[j]; Move move = cuckooMove[j];
Square s1 = from_sq(move); Square s1 = from_sq(move);
@ -1267,30 +1263,23 @@ bool Position::pos_is_ok() const {
constexpr bool Fast = true; // Quick (default) or full check? constexpr bool Fast = true; // Quick (default) or full check?
if ( (sideToMove != WHITE && sideToMove != BLACK) if ((sideToMove != WHITE && sideToMove != BLACK) || piece_on(square<KING>(WHITE)) != W_KING
|| piece_on(square<KING>(WHITE)) != W_KING
|| piece_on(square<KING>(BLACK)) != B_KING || piece_on(square<KING>(BLACK)) != B_KING
|| ( ep_square() != SQ_NONE || (ep_square() != SQ_NONE && relative_rank(sideToMove, ep_square()) != RANK_6))
&& relative_rank(sideToMove, ep_square()) != RANK_6))
assert(0 && "pos_is_ok: Default"); assert(0 && "pos_is_ok: Default");
if (Fast) if (Fast)
return true; return true;
if ( pieceCount[W_KING] != 1 if (pieceCount[W_KING] != 1 || pieceCount[B_KING] != 1
|| pieceCount[B_KING] != 1
|| attackers_to(square<KING>(~sideToMove)) & pieces(sideToMove)) || attackers_to(square<KING>(~sideToMove)) & pieces(sideToMove))
assert(0 && "pos_is_ok: Kings"); assert(0 && "pos_is_ok: Kings");
if ( (pieces(PAWN) & (Rank1BB | Rank8BB)) if ((pieces(PAWN) & (Rank1BB | Rank8BB)) || pieceCount[W_PAWN] > 8 || pieceCount[B_PAWN] > 8)
|| pieceCount[W_PAWN] > 8
|| pieceCount[B_PAWN] > 8)
assert(0 && "pos_is_ok: Pawns"); assert(0 && "pos_is_ok: Pawns");
if ( (pieces(WHITE) & pieces(BLACK)) if ((pieces(WHITE) & pieces(BLACK)) || (pieces(WHITE) | pieces(BLACK)) != pieces()
|| (pieces(WHITE) | pieces(BLACK)) != pieces() || popcount(pieces(WHITE)) > 16 || popcount(pieces(BLACK)) > 16)
|| popcount(pieces(WHITE)) > 16
|| popcount(pieces(BLACK)) > 16)
assert(0 && "pos_is_ok: Bitboards"); assert(0 && "pos_is_ok: Bitboards");
for (PieceType p1 = PAWN; p1 <= KING; ++p1) for (PieceType p1 = PAWN; p1 <= KING; ++p1)

View file

@ -89,15 +89,20 @@ public:
// Position representation // Position representation
Bitboard pieces(PieceType pt = ALL_PIECES) const; Bitboard pieces(PieceType pt = ALL_PIECES) const;
template<typename ...PieceTypes> Bitboard pieces(PieceType pt, PieceTypes... pts) const; template<typename... PieceTypes>
Bitboard pieces(PieceType pt, PieceTypes... pts) const;
Bitboard pieces(Color c) const; Bitboard pieces(Color c) const;
template<typename ...PieceTypes> Bitboard pieces(Color c, PieceTypes... pts) const; template<typename... PieceTypes>
Bitboard pieces(Color c, PieceTypes... pts) const;
Piece piece_on(Square s) const; Piece piece_on(Square s) const;
Square ep_square() const; Square ep_square() const;
bool empty(Square s) const; bool empty(Square s) const;
template<PieceType Pt> int count(Color c) const; template<PieceType Pt>
template<PieceType Pt> int count() const; int count(Color c) const;
template<PieceType Pt> Square square(Color c) const; template<PieceType Pt>
int count() const;
template<PieceType Pt>
Square square(Color c) const;
// Castling // Castling
CastlingRights castling_rights(Color c) const; CastlingRights castling_rights(Color c) const;
@ -115,7 +120,8 @@ public:
Bitboard attackers_to(Square s) const; Bitboard attackers_to(Square s) const;
Bitboard attackers_to(Square s, Bitboard occupied) const; Bitboard attackers_to(Square s, Bitboard occupied) const;
void update_slider_blockers(Color c) const; void update_slider_blockers(Color c) const;
template<PieceType Pt> Bitboard attacks_by(Color c) const; template<PieceType Pt>
Bitboard attacks_by(Color c) const;
// Properties of moves // Properties of moves
bool legal(Move m) const; bool legal(Move m) const;
@ -193,61 +199,50 @@ private:
std::ostream& operator<<(std::ostream& os, const Position& pos); std::ostream& operator<<(std::ostream& os, const Position& pos);
inline Color Position::side_to_move() const { inline Color Position::side_to_move() const { return sideToMove; }
return sideToMove;
}
inline Piece Position::piece_on(Square s) const { inline Piece Position::piece_on(Square s) const {
assert(is_ok(s)); assert(is_ok(s));
return board[s]; return board[s];
} }
inline bool Position::empty(Square s) const { inline bool Position::empty(Square s) const { return piece_on(s) == NO_PIECE; }
return piece_on(s) == NO_PIECE;
}
inline Piece Position::moved_piece(Move m) const { inline Piece Position::moved_piece(Move m) const { return piece_on(from_sq(m)); }
return piece_on(from_sq(m));
}
inline Bitboard Position::pieces(PieceType pt) const { inline Bitboard Position::pieces(PieceType pt) const { return byTypeBB[pt]; }
return byTypeBB[pt];
}
template<typename... PieceTypes> template<typename... PieceTypes>
inline Bitboard Position::pieces(PieceType pt, PieceTypes... pts) const { inline Bitboard Position::pieces(PieceType pt, PieceTypes... pts) const {
return pieces(pt) | pieces(pts...); return pieces(pt) | pieces(pts...);
} }
inline Bitboard Position::pieces(Color c) const { inline Bitboard Position::pieces(Color c) const { return byColorBB[c]; }
return byColorBB[c];
}
template<typename... PieceTypes> template<typename... PieceTypes>
inline Bitboard Position::pieces(Color c, PieceTypes... pts) const { inline Bitboard Position::pieces(Color c, PieceTypes... pts) const {
return pieces(c) & pieces(pts...); return pieces(c) & pieces(pts...);
} }
template<PieceType Pt> inline int Position::count(Color c) const { template<PieceType Pt>
inline int Position::count(Color c) const {
return pieceCount[make_piece(c, Pt)]; return pieceCount[make_piece(c, Pt)];
} }
template<PieceType Pt> inline int Position::count() const { template<PieceType Pt>
inline int Position::count() const {
return count<Pt>(WHITE) + count<Pt>(BLACK); return count<Pt>(WHITE) + count<Pt>(BLACK);
} }
template<PieceType Pt> inline Square Position::square(Color c) const { template<PieceType Pt>
inline Square Position::square(Color c) const {
assert(count<Pt>(c) == 1); assert(count<Pt>(c) == 1);
return lsb(pieces(c, Pt)); return lsb(pieces(c, Pt));
} }
inline Square Position::ep_square() const { inline Square Position::ep_square() const { return st->epSquare; }
return st->epSquare;
}
inline bool Position::can_castle(CastlingRights cr) const { inline bool Position::can_castle(CastlingRights cr) const { return st->castlingRights & cr; }
return st->castlingRights & cr;
}
inline CastlingRights Position::castling_rights(Color c) const { inline CastlingRights Position::castling_rights(Color c) const {
return c & CastlingRights(st->castlingRights); return c & CastlingRights(st->castlingRights);
@ -265,9 +260,7 @@ inline Square Position::castling_rook_square(CastlingRights cr) const {
return castlingRookSquare[cr]; return castlingRookSquare[cr];
} }
inline Bitboard Position::attackers_to(Square s) const { inline Bitboard Position::attackers_to(Square s) const { return attackers_to(s, pieces()); }
return attackers_to(s, pieces());
}
template<PieceType Pt> template<PieceType Pt>
inline Bitboard Position::attacks_by(Color c) const { inline Bitboard Position::attacks_by(Color c) const {
@ -285,61 +278,38 @@ inline Bitboard Position::attacks_by(Color c) const {
} }
} }
inline Bitboard Position::checkers() const { inline Bitboard Position::checkers() const { return st->checkersBB; }
return st->checkersBB;
}
inline Bitboard Position::blockers_for_king(Color c) const { inline Bitboard Position::blockers_for_king(Color c) const { return st->blockersForKing[c]; }
return st->blockersForKing[c];
}
inline Bitboard Position::pinners(Color c) const { inline Bitboard Position::pinners(Color c) const { return st->pinners[c]; }
return st->pinners[c];
}
inline Bitboard Position::check_squares(PieceType pt) const { inline Bitboard Position::check_squares(PieceType pt) const { return st->checkSquares[pt]; }
return st->checkSquares[pt];
}
inline Key Position::key() const { inline Key Position::key() const { return adjust_key50<false>(st->key); }
return adjust_key50<false>(st->key);
}
template<bool AfterMove> template<bool AfterMove>
inline Key Position::adjust_key50(Key k) const inline Key Position::adjust_key50(Key k) const {
{ return st->rule50 < 14 - AfterMove ? k : k ^ make_key((st->rule50 - (14 - AfterMove)) / 8);
return st->rule50 < 14 - AfterMove
? k : k ^ make_key((st->rule50 - (14 - AfterMove)) / 8);
} }
inline Key Position::material_key() const { inline Key Position::material_key() const { return st->materialKey; }
return st->materialKey;
}
inline Value Position::non_pawn_material(Color c) const { inline Value Position::non_pawn_material(Color c) const { return st->nonPawnMaterial[c]; }
return st->nonPawnMaterial[c];
}
inline Value Position::non_pawn_material() const { inline Value Position::non_pawn_material() const {
return non_pawn_material(WHITE) + non_pawn_material(BLACK); return non_pawn_material(WHITE) + non_pawn_material(BLACK);
} }
inline int Position::game_ply() const { inline int Position::game_ply() const { return gamePly; }
return gamePly;
}
inline int Position::rule50_count() const { inline int Position::rule50_count() const { return st->rule50; }
return st->rule50;
}
inline bool Position::is_chess960() const { inline bool Position::is_chess960() const { return chess960; }
return chess960;
}
inline bool Position::capture(Move m) const { inline bool Position::capture(Move m) const {
assert(is_ok(m)); assert(is_ok(m));
return (!empty(to_sq(m)) && type_of(m) != CASTLING) return (!empty(to_sq(m)) && type_of(m) != CASTLING) || type_of(m) == EN_PASSANT;
|| type_of(m) == EN_PASSANT;
} }
// Returns true if a move is generated from the capture stage, having also // Returns true if a move is generated from the capture stage, having also
@ -350,13 +320,9 @@ inline bool Position::capture_stage(Move m) const {
return capture(m) || promotion_type(m) == QUEEN; return capture(m) || promotion_type(m) == QUEEN;
} }
inline Piece Position::captured_piece() const { inline Piece Position::captured_piece() const { return st->capturedPiece; }
return st->capturedPiece;
}
inline Thread* Position::this_thread() const { inline Thread* Position::this_thread() const { return thisThread; }
return thisThread;
}
inline void Position::put_piece(Piece pc, Square s) { inline void Position::put_piece(Piece pc, Square s) {
@ -389,14 +355,9 @@ inline void Position::move_piece(Square from, Square to) {
board[to] = pc; board[to] = pc;
} }
inline void Position::do_move(Move m, StateInfo& newSt) { inline void Position::do_move(Move m, StateInfo& newSt) { do_move(m, newSt, gives_check(m)); }
do_move(m, newSt, gives_check(m));
}
inline StateInfo* Position::state() const { inline StateInfo* Position::state() const { return st; }
return st;
}
} // namespace Stockfish } // namespace Stockfish

View file

@ -69,7 +69,11 @@ using namespace Search;
namespace { namespace {
// Different node types, used as a template parameter // Different node types, used as a template parameter
enum NodeType { NonPV, PV, Root }; enum NodeType {
NonPV,
PV,
Root
};
// Futility margin // Futility margin
Value futility_margin(Depth d, bool noTtCutNode, bool improving) { Value futility_margin(Depth d, bool noTtCutNode, bool improving) {
@ -86,14 +90,11 @@ namespace {
} }
constexpr int futility_move_count(bool improving, Depth depth) { constexpr int futility_move_count(bool improving, Depth depth) {
return improving ? (3 + depth * depth) return improving ? (3 + depth * depth) : (3 + depth * depth) / 2;
: (3 + depth * depth) / 2;
} }
// History and stats update bonus, based on depth // History and stats update bonus, based on depth
int stat_bonus(Depth d) { int stat_bonus(Depth d) { return std::min(334 * d - 531, 1538); }
return std::min(334 * d - 531, 1538);
}
// Add a small random component to draw evaluations to avoid 3-fold blindness // Add a small random component to draw evaluations to avoid 3-fold blindness
Value value_draw(const Thread* thisThread) { Value value_draw(const Thread* thisThread) {
@ -135,8 +136,17 @@ namespace {
void update_pv(Move* pv, Move move, const Move* childPv); void update_pv(Move* pv, Move move, const Move* childPv);
void update_continuation_histories(Stack* ss, Piece pc, Square to, int bonus); void update_continuation_histories(Stack* ss, Piece pc, Square to, int bonus);
void update_quiet_stats(const Position& pos, Stack* ss, Move move, int bonus); void update_quiet_stats(const Position& pos, Stack* ss, Move move, int bonus);
void update_all_stats(const Position& pos, Stack* ss, Move bestMove, Value bestValue, Value beta, Square prevSq, void update_all_stats(const Position& pos,
Move* quietsSearched, int quietCount, Move* capturesSearched, int captureCount, Depth depth); Stack* ss,
Move bestMove,
Value bestValue,
Value beta,
Square prevSq,
Move* quietsSearched,
int quietCount,
Move* capturesSearched,
int captureCount,
Depth depth);
// perft() is our utility to verify move generation. All the leaf nodes up // perft() is our utility to verify move generation. All the leaf nodes up
// to the given depth are generated and counted, and the sum is returned. // to the given depth are generated and counted, and the sum is returned.
@ -213,8 +223,7 @@ void MainThread::search() {
{ {
rootMoves.emplace_back(MOVE_NONE); rootMoves.emplace_back(MOVE_NONE);
sync_cout << "info depth 0 score " sync_cout << "info depth 0 score "
<< UCI::value(rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW) << UCI::value(rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW) << sync_endl;
<< sync_endl;
} }
else else
{ {
@ -244,11 +253,10 @@ void MainThread::search() {
Time.availableNodes += Limits.inc[us] - Threads.nodes_searched(); Time.availableNodes += Limits.inc[us] - Threads.nodes_searched();
Thread* bestThread = this; Thread* bestThread = this;
Skill skill = Skill(Options["Skill Level"], Options["UCI_LimitStrength"] ? int(Options["UCI_Elo"]) : 0); Skill skill =
Skill(Options["Skill Level"], Options["UCI_LimitStrength"] ? int(Options["UCI_Elo"]) : 0);
if ( int(Options["MultiPV"]) == 1 if (int(Options["MultiPV"]) == 1 && !Limits.depth && !skill.enabled()
&& !Limits.depth
&& !skill.enabled()
&& rootMoves[0].pv[0] != MOVE_NONE) && rootMoves[0].pv[0] != MOVE_NONE)
bestThread = Threads.get_best_thread(); bestThread = Threads.get_best_thread();
@ -261,7 +269,8 @@ void MainThread::search() {
sync_cout << "bestmove " << UCI::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960()); sync_cout << "bestmove " << UCI::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960());
if (bestThread->rootMoves[0].pv.size() > 1 || bestThread->rootMoves[0].extract_ponder_from_tt(rootPos)) if (bestThread->rootMoves[0].pv.size() > 1
|| bestThread->rootMoves[0].extract_ponder_from_tt(rootPos))
std::cout << " ponder " << UCI::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960()); std::cout << " ponder " << UCI::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960());
std::cout << sync_endl; std::cout << sync_endl;
@ -290,7 +299,8 @@ void Thread::search() {
std::memset(ss - 7, 0, 10 * sizeof(Stack)); std::memset(ss - 7, 0, 10 * sizeof(Stack));
for (int i = 7; i > 0; --i) for (int i = 7; i > 0; --i)
{ {
(ss-i)->continuationHistory = &this->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel (ss - i)->continuationHistory =
&this->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel
(ss - i)->staticEval = VALUE_NONE; (ss - i)->staticEval = VALUE_NONE;
} }
@ -324,8 +334,7 @@ void Thread::search() {
int searchAgainCounter = 0; int searchAgainCounter = 0;
// Iterative deepening loop until requested to stop or the target depth is reached // Iterative deepening loop until requested to stop or the target depth is reached
while ( ++rootDepth < MAX_PLY while (++rootDepth < MAX_PLY && !Threads.stop
&& !Threads.stop
&& !(Limits.depth && mainThread && rootDepth > Limits.depth)) && !(Limits.depth && mainThread && rootDepth > Limits.depth))
{ {
// Age out PV variability metric // Age out PV variability metric
@ -376,7 +385,8 @@ void Thread::search() {
{ {
// Adjust the effective depth searched, but ensure at least one effective increment for every // Adjust the effective depth searched, but ensure at least one effective increment for every
// four searchAgain steps (see issue #2717). // four searchAgain steps (see issue #2717).
Depth adjustedDepth = std::max(1, rootDepth - failedHighCnt - 3 * (searchAgainCounter + 1) / 4); Depth adjustedDepth =
std::max(1, rootDepth - failedHighCnt - 3 * (searchAgainCounter + 1) / 4);
bestValue = Stockfish::search<Root>(rootPos, ss, alpha, beta, adjustedDepth, false); bestValue = Stockfish::search<Root>(rootPos, ss, alpha, beta, adjustedDepth, false);
// Bring the best move to the front. It is critical that sorting // Bring the best move to the front. It is critical that sorting
@ -395,9 +405,7 @@ void Thread::search() {
// When failing high/low give some update (without cluttering // When failing high/low give some update (without cluttering
// the UI) before a re-search. // the UI) before a re-search.
if ( mainThread if (mainThread && multiPV == 1 && (bestValue <= alpha || bestValue >= beta)
&& multiPV == 1
&& (bestValue <= alpha || bestValue >= beta)
&& Time.elapsed() > 3000) && Time.elapsed() > 3000)
sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl; sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl;
@ -428,8 +436,7 @@ void Thread::search() {
// Sort the PV lines searched so far and update the GUI // Sort the PV lines searched so far and update the GUI
std::stable_sort(rootMoves.begin() + pvFirst, rootMoves.begin() + pvIdx + 1); std::stable_sort(rootMoves.begin() + pvFirst, rootMoves.begin() + pvIdx + 1);
if ( mainThread if (mainThread && (Threads.stop || pvIdx + 1 == multiPV || Time.elapsed() > 3000))
&& (Threads.stop || pvIdx + 1 == multiPV || Time.elapsed() > 3000))
sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl; sync_cout << UCI::pv(rootPos, rootDepth) << sync_endl;
} }
@ -443,8 +450,7 @@ void Thread::search() {
} }
// Have we found a "mate in x"? // Have we found a "mate in x"?
if ( Limits.mate if (Limits.mate && bestValue >= VALUE_MATE_IN_MAX_PLY
&& bestValue >= VALUE_MATE_IN_MAX_PLY
&& VALUE_MATE - bestValue <= 2 * Limits.mate) && VALUE_MATE - bestValue <= 2 * Limits.mate)
Threads.stop = true; Threads.stop = true;
@ -463,12 +469,11 @@ void Thread::search() {
} }
// Do we have time for the next iteration? Can we stop searching now? // Do we have time for the next iteration? Can we stop searching now?
if ( Limits.use_time_management() if (Limits.use_time_management() && !Threads.stop && !mainThread->stopOnPonderhit)
&& !Threads.stop
&& !mainThread->stopOnPonderhit)
{ {
double fallingEval = (69 + 13 * (mainThread->bestPreviousAverageScore - bestValue) double fallingEval = (69 + 13 * (mainThread->bestPreviousAverageScore - bestValue)
+ 6 * (mainThread->iterValue[iterIdx] - bestValue)) / 619.6; + 6 * (mainThread->iterValue[iterIdx] - bestValue))
/ 619.6;
fallingEval = std::clamp(fallingEval, 0.5, 1.5); fallingEval = std::clamp(fallingEval, 0.5, 1.5);
// If the bestMove is stable over several iterations, reduce time accordingly // If the bestMove is stable over several iterations, reduce time accordingly
@ -492,8 +497,7 @@ void Thread::search() {
else else
Threads.stop = true; Threads.stop = true;
} }
else if ( !mainThread->ponder else if (!mainThread->ponder && Time.elapsed() > totalTime * 0.50)
&& Time.elapsed() > totalTime * 0.50)
Threads.increaseDepth = false; Threads.increaseDepth = false;
else else
Threads.increaseDepth = true; Threads.increaseDepth = true;
@ -531,9 +535,7 @@ namespace {
// Check if we have an upcoming move that draws by repetition, or // Check if we have an upcoming move that draws by repetition, or
// if the opponent had an alternative move earlier to this position. // if the opponent had an alternative move earlier to this position.
if ( !rootNode if (!rootNode && alpha < VALUE_DRAW && pos.has_game_cycle(ss->ply))
&& alpha < VALUE_DRAW
&& pos.has_game_cycle(ss->ply))
{ {
alpha = value_draw(pos.this_thread()); alpha = value_draw(pos.this_thread());
if (alpha >= beta) if (alpha >= beta)
@ -573,15 +575,13 @@ namespace {
static_cast<MainThread*>(thisThread)->check_time(); static_cast<MainThread*>(thisThread)->check_time();
// Used to send selDepth info to GUI (selDepth counts from 1, ply from 0) // Used to send selDepth info to GUI (selDepth counts from 1, ply from 0)
if ( PvNode if (PvNode && thisThread->selDepth < ss->ply + 1)
&& thisThread->selDepth < ss->ply + 1)
thisThread->selDepth = ss->ply + 1; thisThread->selDepth = ss->ply + 1;
if (!rootNode) if (!rootNode)
{ {
// Step 2. Check for aborted search and immediate draw // Step 2. Check for aborted search and immediate draw
if ( Threads.stop.load(std::memory_order_relaxed) if (Threads.stop.load(std::memory_order_relaxed) || pos.is_draw(ss->ply)
|| pos.is_draw(ss->ply)
|| ss->ply >= MAX_PLY) || ss->ply >= MAX_PLY)
return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos) return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos)
: value_draw(pos.this_thread()); : value_draw(pos.this_thread());
@ -615,7 +615,8 @@ namespace {
tte = TT.probe(posKey, ss->ttHit); tte = TT.probe(posKey, ss->ttHit);
ttValue = ss->ttHit ? value_from_tt(tte->value(), ss->ply, pos.rule50_count()) : VALUE_NONE; ttValue = ss->ttHit ? value_from_tt(tte->value(), ss->ply, pos.rule50_count()) : VALUE_NONE;
ttMove = rootNode ? thisThread->rootMoves[thisThread->pvIdx].pv[0] ttMove = rootNode ? thisThread->rootMoves[thisThread->pvIdx].pv[0]
: ss->ttHit ? tte->move() : MOVE_NONE; : ss->ttHit ? tte->move()
: MOVE_NONE;
ttCapture = ttMove && pos.capture_stage(ttMove); ttCapture = ttMove && pos.capture_stage(ttMove);
// At this point, if excluded, skip straight to step 6, static eval. However, // At this point, if excluded, skip straight to step 6, static eval. However,
@ -624,9 +625,7 @@ namespace {
ss->ttPv = PvNode || (ss->ttHit && tte->is_pv()); ss->ttPv = PvNode || (ss->ttHit && tte->is_pv());
// At non-PV nodes we check for an early TT cutoff // At non-PV nodes we check for an early TT cutoff
if ( !PvNode if (!PvNode && !excludedMove && tte->depth() > depth
&& !excludedMove
&& tte->depth() > depth
&& ttValue != VALUE_NONE // Possible in case of TT access race or if !ttHit && ttValue != VALUE_NONE // Possible in case of TT access race or if !ttHit
&& (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER))) && (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER)))
{ {
@ -640,10 +639,9 @@ namespace {
update_quiet_stats(pos, ss, ttMove, stat_bonus(depth)); update_quiet_stats(pos, ss, ttMove, stat_bonus(depth));
// Extra penalty for early quiet moves of the previous ply (~0 Elo on STC, ~2 Elo on LTC) // Extra penalty for early quiet moves of the previous ply (~0 Elo on STC, ~2 Elo on LTC)
if ( prevSq != SQ_NONE if (prevSq != SQ_NONE && (ss - 1)->moveCount <= 2 && !priorCapture)
&& (ss-1)->moveCount <= 2 update_continuation_histories(ss - 1, pos.piece_on(prevSq), prevSq,
&& !priorCapture) -stat_bonus(depth + 1));
update_continuation_histories(ss-1, pos.piece_on(prevSq), prevSq, -stat_bonus(depth + 1));
} }
// Penalty for a quiet ttMove that fails low (~1 Elo) // Penalty for a quiet ttMove that fails low (~1 Elo)
else if (!ttCapture) else if (!ttCapture)
@ -666,8 +664,7 @@ namespace {
int piecesCount = pos.count<ALL_PIECES>(); int piecesCount = pos.count<ALL_PIECES>();
if (piecesCount <= TB::Cardinality if (piecesCount <= TB::Cardinality
&& (piecesCount < TB::Cardinality || depth >= TB::ProbeDepth) && (piecesCount < TB::Cardinality || depth >= TB::ProbeDepth) && pos.rule50_count() == 0
&& pos.rule50_count() == 0
&& !pos.can_castle(ANY_CASTLING)) && !pos.can_castle(ANY_CASTLING))
{ {
TB::ProbeState err; TB::ProbeState err;
@ -689,14 +686,13 @@ namespace {
: VALUE_DRAW + 2 * wdl * drawScore; : VALUE_DRAW + 2 * wdl * drawScore;
Bound b = wdl < -drawScore ? BOUND_UPPER Bound b = wdl < -drawScore ? BOUND_UPPER
: wdl > drawScore ? BOUND_LOWER : BOUND_EXACT; : wdl > drawScore ? BOUND_LOWER
: BOUND_EXACT;
if ( b == BOUND_EXACT if (b == BOUND_EXACT || (b == BOUND_LOWER ? value >= beta : value <= alpha))
|| (b == BOUND_LOWER ? value >= beta : value <= alpha))
{ {
tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, b, tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, b,
std::min(MAX_PLY - 1, depth + 6), std::min(MAX_PLY - 1, depth + 6), MOVE_NONE, VALUE_NONE);
MOVE_NONE, VALUE_NONE);
return value; return value;
} }
@ -738,8 +734,7 @@ namespace {
Eval::NNUE::hint_common_parent_position(pos); Eval::NNUE::hint_common_parent_position(pos);
// ttValue can be used as a better position evaluation (~7 Elo) // ttValue can be used as a better position evaluation (~7 Elo)
if ( ttValue != VALUE_NONE if (ttValue != VALUE_NONE && (tte->bound() & (ttValue > eval ? BOUND_LOWER : BOUND_UPPER)))
&& (tte->bound() & (ttValue > eval ? BOUND_LOWER : BOUND_UPPER)))
eval = ttValue; eval = ttValue;
} }
else else
@ -750,9 +745,7 @@ namespace {
} }
// Use static evaluation difference to improve quiet move ordering (~4 Elo) // Use static evaluation difference to improve quiet move ordering (~4 Elo)
if ( is_ok((ss-1)->currentMove) if (is_ok((ss - 1)->currentMove) && !(ss - 1)->inCheck && !priorCapture)
&& !(ss-1)->inCheck
&& !priorCapture)
{ {
int bonus = std::clamp(-18 * int((ss - 1)->staticEval + ss->staticEval), -1812, 1812); int bonus = std::clamp(-18 * int((ss - 1)->staticEval + ss->staticEval), -1812, 1812);
thisThread->mainHistory[~us][from_to((ss - 1)->currentMove)] << bonus; thisThread->mainHistory[~us][from_to((ss - 1)->currentMove)] << bonus;
@ -780,25 +773,18 @@ namespace {
// Step 8. Futility pruning: child node (~40 Elo) // Step 8. Futility pruning: child node (~40 Elo)
// The depth condition is important for mate finding. // The depth condition is important for mate finding.
if ( !ss->ttPv if (!ss->ttPv && depth < 9
&& depth < 9 && eval - futility_margin(depth, cutNode && !ss->ttHit, improving)
&& eval - futility_margin(depth, cutNode && !ss->ttHit, improving) - (ss-1)->statScore / 321 >= beta - (ss - 1)->statScore / 321
&& eval >= beta >= beta
&& eval < 29462 // smaller than TB wins && eval >= beta && eval < 29462 // smaller than TB wins
&& !( !ttCapture && !(!ttCapture && ttMove))
&& ttMove))
return eval; return eval;
// Step 9. Null move search with verification search (~35 Elo) // Step 9. Null move search with verification search (~35 Elo)
if ( !PvNode if (!PvNode && (ss - 1)->currentMove != MOVE_NULL && (ss - 1)->statScore < 17257 && eval >= beta
&& (ss-1)->currentMove != MOVE_NULL && eval >= ss->staticEval && ss->staticEval >= beta - 24 * depth + 281 && !excludedMove
&& (ss-1)->statScore < 17257 && pos.non_pawn_material(us) && ss->ply >= thisThread->nmpMinPly
&& eval >= beta
&& eval >= ss->staticEval
&& ss->staticEval >= beta - 24 * depth + 281
&& !excludedMove
&& pos.non_pawn_material(us)
&& ss->ply >= thisThread->nmpMinPly
&& beta > VALUE_TB_LOSS_IN_MAX_PLY) && beta > VALUE_TB_LOSS_IN_MAX_PLY)
{ {
assert(eval - beta >= 0); assert(eval - beta >= 0);
@ -839,16 +825,13 @@ namespace {
// Step 10. If the position doesn't have a ttMove, decrease depth by 2 // Step 10. If the position doesn't have a ttMove, decrease depth by 2
// (or by 4 if the TT entry for the current position was hit and the stored depth is greater than or equal to the current depth). // (or by 4 if the TT entry for the current position was hit and the stored depth is greater than or equal to the current depth).
// Use qsearch if depth is equal or below zero (~9 Elo) // Use qsearch if depth is equal or below zero (~9 Elo)
if ( PvNode if (PvNode && !ttMove)
&& !ttMove)
depth -= 2 + 2 * (ss->ttHit && tte->depth() >= depth); depth -= 2 + 2 * (ss->ttHit && tte->depth() >= depth);
if (depth <= 0) if (depth <= 0)
return qsearch<PV>(pos, ss, alpha, beta); return qsearch<PV>(pos, ss, alpha, beta);
if ( cutNode if (cutNode && depth >= 8 && !ttMove)
&& depth >= 8
&& !ttMove)
depth -= 2; depth -= 2;
probCutBeta = beta + 168 - 70 * improving; probCutBeta = beta + 168 - 70 * improving;
@ -856,16 +839,14 @@ namespace {
// Step 11. ProbCut (~10 Elo) // Step 11. ProbCut (~10 Elo)
// If we have a good enough capture (or queen promotion) and a reduced search returns a value // If we have a good enough capture (or queen promotion) and a reduced search returns a value
// much above beta, we can (almost) safely prune the previous move. // much above beta, we can (almost) safely prune the previous move.
if ( !PvNode if (
&& depth > 3 !PvNode && depth > 3
&& abs(beta) < VALUE_TB_WIN_IN_MAX_PLY && abs(beta) < VALUE_TB_WIN_IN_MAX_PLY
// If value from transposition table is lower than probCutBeta, don't attempt probCut // If value from transposition table is lower than probCutBeta, don't attempt probCut
// there and in further interactions with transposition table cutoff depth is set to depth - 3 // there and in further interactions with transposition table cutoff depth is set to depth - 3
// because probCut search has depth set to depth - 4 but we also do a move before it // because probCut search has depth set to depth - 4 but we also do a move before it
// So effective depth is equal to depth - 3 // So effective depth is equal to depth - 3
&& !( tte->depth() >= depth - 3 && !(tte->depth() >= depth - 3 && ttValue != VALUE_NONE && ttValue < probCutBeta))
&& ttValue != VALUE_NONE
&& ttValue < probCutBeta))
{ {
assert(probCutBeta < VALUE_INFINITE); assert(probCutBeta < VALUE_INFINITE);
@ -877,10 +858,9 @@ namespace {
assert(pos.capture_stage(move)); assert(pos.capture_stage(move));
ss->currentMove = move; ss->currentMove = move;
ss->continuationHistory = &thisThread->continuationHistory[ss->inCheck] ss->continuationHistory =
[true] &thisThread
[pos.moved_piece(move)] ->continuationHistory[ss->inCheck][true][pos.moved_piece(move)][to_sq(move)];
[to_sq(move)];
pos.do_move(move, st); pos.do_move(move, st);
@ -889,14 +869,16 @@ namespace {
// If the qsearch held, perform the regular search // If the qsearch held, perform the regular search
if (value >= probCutBeta) if (value >= probCutBeta)
value = -search<NonPV>(pos, ss+1, -probCutBeta, -probCutBeta+1, depth - 4, !cutNode); value = -search<NonPV>(pos, ss + 1, -probCutBeta, -probCutBeta + 1, depth - 4,
!cutNode);
pos.undo_move(move); pos.undo_move(move);
if (value >= probCutBeta) if (value >= probCutBeta)
{ {
// Save ProbCut data into transposition table // Save ProbCut data into transposition table
tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, BOUND_LOWER, depth - 3, move, ss->staticEval); tte->save(posKey, value_to_tt(value, ss->ply), ss->ttPv, BOUND_LOWER, depth - 3,
move, ss->staticEval);
return value - (probCutBeta - beta); return value - (probCutBeta - beta);
} }
} }
@ -908,27 +890,23 @@ moves_loop: // When in check, search starts here
// Step 12. A small Probcut idea, when we are in check (~4 Elo) // Step 12. A small Probcut idea, when we are in check (~4 Elo)
probCutBeta = beta + 416; probCutBeta = beta + 416;
if ( ss->inCheck if (ss->inCheck && !PvNode && ttCapture && (tte->bound() & BOUND_LOWER)
&& !PvNode && tte->depth() >= depth - 4 && ttValue >= probCutBeta
&& ttCapture && abs(ttValue) < VALUE_TB_WIN_IN_MAX_PLY && abs(beta) < VALUE_TB_WIN_IN_MAX_PLY)
&& (tte->bound() & BOUND_LOWER)
&& tte->depth() >= depth - 4
&& ttValue >= probCutBeta
&& abs(ttValue) < VALUE_TB_WIN_IN_MAX_PLY
&& abs(beta) < VALUE_TB_WIN_IN_MAX_PLY)
return probCutBeta; return probCutBeta;
const PieceToHistory* contHist[] = { (ss-1)->continuationHistory, (ss-2)->continuationHistory, const PieceToHistory* contHist[] = {(ss - 1)->continuationHistory,
(ss-3)->continuationHistory, (ss-4)->continuationHistory, (ss - 2)->continuationHistory,
nullptr , (ss-6)->continuationHistory }; (ss - 3)->continuationHistory,
(ss - 4)->continuationHistory,
nullptr,
(ss - 6)->continuationHistory};
Move countermove = prevSq != SQ_NONE ? thisThread->counterMoves[pos.piece_on(prevSq)][prevSq] : MOVE_NONE; Move countermove =
prevSq != SQ_NONE ? thisThread->counterMoves[pos.piece_on(prevSq)][prevSq] : MOVE_NONE;
MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, &captureHistory, contHist,
&captureHistory, countermove, ss->killers);
contHist,
countermove,
ss->killers);
value = bestValue; value = bestValue;
moveCountPruning = singularQuietLMR = false; moveCountPruning = singularQuietLMR = false;
@ -936,10 +914,7 @@ moves_loop: // When in check, search starts here
// Indicate PvNodes that will probably fail low if the node was searched // Indicate PvNodes that will probably fail low if the node was searched
// at a depth equal to or greater than the current depth, and the result // at a depth equal to or greater than the current depth, and the result
// of this search was a fail low. // of this search was a fail low.
bool likelyFailLow = PvNode bool likelyFailLow = PvNode && ttMove && (tte->bound() & BOUND_UPPER) && tte->depth() >= depth;
&& ttMove
&& (tte->bound() & BOUND_UPPER)
&& tte->depth() >= depth;
// Step 13. Loop through all pseudo-legal moves until no moves remain // Step 13. Loop through all pseudo-legal moves until no moves remain
// or a beta cutoff occurs. // or a beta cutoff occurs.
@ -957,16 +932,17 @@ moves_loop: // When in check, search starts here
// At root obey the "searchmoves" option and skip moves not listed in Root // At root obey the "searchmoves" option and skip moves not listed in Root
// Move List. In MultiPV mode we also skip PV moves that have been already // Move List. In MultiPV mode we also skip PV moves that have been already
// searched and those of lower "TB rank" if we are in a TB root position. // searched and those of lower "TB rank" if we are in a TB root position.
if (rootNode && !std::count(thisThread->rootMoves.begin() + thisThread->pvIdx, if (rootNode
&& !std::count(thisThread->rootMoves.begin() + thisThread->pvIdx,
thisThread->rootMoves.begin() + thisThread->pvLast, move)) thisThread->rootMoves.begin() + thisThread->pvLast, move))
continue; continue;
ss->moveCount = ++moveCount; ss->moveCount = ++moveCount;
if (rootNode && thisThread == Threads.main() && Time.elapsed() > 3000) if (rootNode && thisThread == Threads.main() && Time.elapsed() > 3000)
sync_cout << "info depth " << depth sync_cout << "info depth " << depth << " currmove "
<< " currmove " << UCI::move(move, pos.is_chess960()) << UCI::move(move, pos.is_chess960()) << " currmovenumber "
<< " currmovenumber " << moveCount + thisThread->pvIdx << sync_endl; << moveCount + thisThread->pvIdx << sync_endl;
if (PvNode) if (PvNode)
(ss + 1)->pv = nullptr; (ss + 1)->pv = nullptr;
@ -984,9 +960,7 @@ moves_loop: // When in check, search starts here
// Step 14. Pruning at shallow depth (~120 Elo). // Step 14. Pruning at shallow depth (~120 Elo).
// Depth conditions are important for mate finding. // Depth conditions are important for mate finding.
if ( !rootNode if (!rootNode && pos.non_pawn_material(us) && bestValue > VALUE_TB_LOSS_IN_MAX_PLY)
&& pos.non_pawn_material(us)
&& bestValue > VALUE_TB_LOSS_IN_MAX_PLY)
{ {
// Skip quiet moves if movecount exceeds our FutilityMoveCount threshold (~8 Elo) // Skip quiet moves if movecount exceeds our FutilityMoveCount threshold (~8 Elo)
if (!moveCountPruning) if (!moveCountPruning)
@ -995,15 +969,15 @@ moves_loop: // When in check, search starts here
// Reduced depth of the next LMR search // Reduced depth of the next LMR search
int lmrDepth = newDepth - r; int lmrDepth = newDepth - r;
if ( capture if (capture || givesCheck)
|| givesCheck)
{ {
// Futility pruning for captures (~2 Elo) // Futility pruning for captures (~2 Elo)
if ( !givesCheck if (!givesCheck && lmrDepth < 7 && !ss->inCheck
&& lmrDepth < 7
&& !ss->inCheck
&& ss->staticEval + 188 + 206 * lmrDepth + PieceValue[pos.piece_on(to_sq(move))] && ss->staticEval + 188 + 206 * lmrDepth + PieceValue[pos.piece_on(to_sq(move))]
+ captureHistory[movedPiece][to_sq(move)][type_of(pos.piece_on(to_sq(move)))] / 7 < alpha) + captureHistory[movedPiece][to_sq(move)]
[type_of(pos.piece_on(to_sq(move)))]
/ 7
< alpha)
continue; continue;
// SEE based pruning for captures and checks (~11 Elo) // SEE based pruning for captures and checks (~11 Elo)
@ -1017,8 +991,7 @@ moves_loop: // When in check, search starts here
+ (*contHist[3])[movedPiece][to_sq(move)]; + (*contHist[3])[movedPiece][to_sq(move)];
// Continuation history based pruning (~2 Elo) // Continuation history based pruning (~2 Elo)
if ( lmrDepth < 6 if (lmrDepth < 6 && history < -3232 * depth)
&& history < -3232 * depth)
continue; continue;
history += 2 * thisThread->mainHistory[us][from_to(move)]; history += 2 * thisThread->mainHistory[us][from_to(move)];
@ -1027,9 +1000,7 @@ moves_loop: // When in check, search starts here
lmrDepth = std::max(lmrDepth, -2); lmrDepth = std::max(lmrDepth, -2);
// Futility pruning: parent node (~13 Elo) // Futility pruning: parent node (~13 Elo)
if ( !ss->inCheck if (!ss->inCheck && lmrDepth < 13 && ss->staticEval + 115 + 122 * lmrDepth <= alpha)
&& lmrDepth < 13
&& ss->staticEval + 115 + 122 * lmrDepth <= alpha)
continue; continue;
lmrDepth = std::max(lmrDepth, 0); lmrDepth = std::max(lmrDepth, 0);
@ -1054,17 +1025,16 @@ moves_loop: // When in check, search starts here
// so changing them requires tests at this type of time controls. // so changing them requires tests at this type of time controls.
if (!rootNode if (!rootNode
&& depth >= 4 - (thisThread->completedDepth > 24) + 2 * (PvNode && tte->is_pv()) && depth >= 4 - (thisThread->completedDepth > 24) + 2 * (PvNode && tte->is_pv())
&& move == ttMove && move == ttMove && !excludedMove // Avoid recursive singular search
&& !excludedMove // Avoid recursive singular search && abs(ttValue) < VALUE_TB_WIN_IN_MAX_PLY && (tte->bound() & BOUND_LOWER)
&& abs(ttValue) < VALUE_TB_WIN_IN_MAX_PLY
&& (tte->bound() & BOUND_LOWER)
&& tte->depth() >= depth - 3) && tte->depth() >= depth - 3)
{ {
Value singularBeta = ttValue - (64 + 57 * (ss->ttPv && !PvNode)) * depth / 64; Value singularBeta = ttValue - (64 + 57 * (ss->ttPv && !PvNode)) * depth / 64;
Depth singularDepth = (depth - 1) / 2; Depth singularDepth = (depth - 1) / 2;
ss->excludedMove = move; ss->excludedMove = move;
value = search<NonPV>(pos, ss, singularBeta - 1, singularBeta, singularDepth, cutNode); value =
search<NonPV>(pos, ss, singularBeta - 1, singularBeta, singularDepth, cutNode);
ss->excludedMove = MOVE_NONE; ss->excludedMove = MOVE_NONE;
if (value < singularBeta) if (value < singularBeta)
@ -1073,9 +1043,7 @@ moves_loop: // When in check, search starts here
singularQuietLMR = !ttCapture; singularQuietLMR = !ttCapture;
// Avoid search explosion by limiting the number of double extensions // Avoid search explosion by limiting the number of double extensions
if ( !PvNode if (!PvNode && value < singularBeta - 18 && ss->doubleExtensions <= 11)
&& value < singularBeta - 18
&& ss->doubleExtensions <= 11)
{ {
extension = 2; extension = 2;
depth += depth < 15; depth += depth < 15;
@ -1104,14 +1072,11 @@ moves_loop: // When in check, search starts here
} }
// Check extensions (~1 Elo) // Check extensions (~1 Elo)
else if ( givesCheck else if (givesCheck && depth > 9)
&& depth > 9)
extension = 1; extension = 1;
// Quiet ttMove extensions (~1 Elo) // Quiet ttMove extensions (~1 Elo)
else if ( PvNode else if (PvNode && move == ttMove && move == ss->killers[0]
&& move == ttMove
&& move == ss->killers[0]
&& (*contHist[0])[movedPiece][to_sq(move)] >= 4194) && (*contHist[0])[movedPiece][to_sq(move)] >= 4194)
extension = 1; extension = 1;
} }
@ -1125,17 +1090,14 @@ moves_loop: // When in check, search starts here
// Update the current move (this must be done after singular extension search) // Update the current move (this must be done after singular extension search)
ss->currentMove = move; ss->currentMove = move;
ss->continuationHistory = &thisThread->continuationHistory[ss->inCheck] ss->continuationHistory =
[capture] &thisThread->continuationHistory[ss->inCheck][capture][movedPiece][to_sq(move)];
[movedPiece]
[to_sq(move)];
// Step 16. Make the move // Step 16. Make the move
pos.do_move(move, st, givesCheck); pos.do_move(move, st, givesCheck);
// Decrease reduction if position is or has been on the PV (~4 Elo) // Decrease reduction if position is or has been on the PV (~4 Elo)
if ( ss->ttPv if (ss->ttPv && !likelyFailLow)
&& !likelyFailLow)
r -= cutNode && tte->depth() >= depth ? 3 : 2; r -= cutNode && tte->depth() >= depth ? 3 : 2;
// Decrease reduction if opponent's move count is high (~1 Elo) // Decrease reduction if opponent's move count is high (~1 Elo)
@ -1159,8 +1121,7 @@ moves_loop: // When in check, search starts here
r--; r--;
// Increase reduction on repetition (~1 Elo) // Increase reduction on repetition (~1 Elo)
if ( move == (ss-4)->currentMove if (move == (ss - 4)->currentMove && pos.has_repeated())
&& pos.has_repeated())
r += 2; r += 2;
// Increase reduction if next ply has a lot of fail high (~5 Elo) // Increase reduction if next ply has a lot of fail high (~5 Elo)
@ -1174,8 +1135,7 @@ moves_loop: // When in check, search starts here
ss->statScore = 2 * thisThread->mainHistory[us][from_to(move)] ss->statScore = 2 * thisThread->mainHistory[us][from_to(move)]
+ (*contHist[0])[movedPiece][to_sq(move)] + (*contHist[0])[movedPiece][to_sq(move)]
+ (*contHist[1])[movedPiece][to_sq(move)] + (*contHist[1])[movedPiece][to_sq(move)]
+ (*contHist[3])[movedPiece][to_sq(move)] + (*contHist[3])[movedPiece][to_sq(move)] - 3848;
- 3848;
// Decrease/increase reduction for moves with a good/bad history (~25 Elo) // Decrease/increase reduction for moves with a good/bad history (~25 Elo)
r -= ss->statScore / (10216 + 3855 * (depth > 5 && depth < 23)); r -= ss->statScore / (10216 + 3855 * (depth > 5 && depth < 23));
@ -1184,11 +1144,8 @@ moves_loop: // When in check, search starts here
// We use various heuristics for the sons of a node after the first son has // We use various heuristics for the sons of a node after the first son has
// been searched. In general, we would like to reduce them, but there are many // been searched. In general, we would like to reduce them, but there are many
// cases where we extend a son if it has good chances to be "interesting". // cases where we extend a son if it has good chances to be "interesting".
if ( depth >= 2 if (depth >= 2 && moveCount > 1 + (PvNode && ss->ply <= 1)
&& moveCount > 1 + (PvNode && ss->ply <= 1) && (!ss->ttPv || !capture || (cutNode && (ss - 1)->moveCount > 1)))
&& ( !ss->ttPv
|| !capture
|| (cutNode && (ss-1)->moveCount > 1)))
{ {
// In general we want to cap the LMR depth search at newDepth, but when // In general we want to cap the LMR depth search at newDepth, but when
// reduction is negative, we allow this move a limited search extension // reduction is negative, we allow this move a limited search extension
@ -1198,8 +1155,7 @@ moves_loop: // When in check, search starts here
value = -search<NonPV>(pos, ss + 1, -(alpha + 1), -alpha, d, true); value = -search<NonPV>(pos, ss + 1, -(alpha + 1), -alpha, d, true);
// Do a full-depth search when reduced LMR search fails high // Do a full-depth search when reduced LMR search fails high
if ( value > alpha if (value > alpha && d < newDepth)
&& d < newDepth)
{ {
// Adjust full-depth search based on LMR results - if the result // Adjust full-depth search based on LMR results - if the result
// was good enough search deeper, if it was bad enough search shallower. // was good enough search deeper, if it was bad enough search shallower.
@ -1226,8 +1182,7 @@ moves_loop: // When in check, search starts here
else if (!PvNode || moveCount > 1) else if (!PvNode || moveCount > 1)
{ {
// Increase reduction for cut nodes and not ttMove (~1 Elo) // Increase reduction for cut nodes and not ttMove (~1 Elo)
if ( !ttMove if (!ttMove && cutNode)
&& cutNode)
r += 2; r += 2;
// Note that if expected reduction is high, we reduce search depth by 1 here // Note that if expected reduction is high, we reduce search depth by 1 here
@ -1236,8 +1191,7 @@ moves_loop: // When in check, search starts here
// For PV nodes only, do a full PV search on the first move or after a fail high, // For PV nodes only, do a full PV search on the first move or after a fail high,
// otherwise let the parent node fail low with value <= alpha and try another move. // otherwise let the parent node fail low with value <= alpha and try another move.
if ( PvNode if (PvNode && (moveCount == 1 || value > alpha))
&& (moveCount == 1 || value > alpha))
{ {
(ss + 1)->pv = pv; (ss + 1)->pv = pv;
(ss + 1)->pv[0] = MOVE_NONE; (ss + 1)->pv[0] = MOVE_NONE;
@ -1259,10 +1213,11 @@ moves_loop: // When in check, search starts here
if (rootNode) if (rootNode)
{ {
RootMove& rm = *std::find(thisThread->rootMoves.begin(), RootMove& rm =
thisThread->rootMoves.end(), move); *std::find(thisThread->rootMoves.begin(), thisThread->rootMoves.end(), move);
rm.averageScore = rm.averageScore != -VALUE_INFINITE ? (2 * value + rm.averageScore) / 3 : value; rm.averageScore =
rm.averageScore != -VALUE_INFINITE ? (2 * value + rm.averageScore) / 3 : value;
// PV move or new best move? // PV move or new best move?
if (moveCount == 1 || value > alpha) if (moveCount == 1 || value > alpha)
@ -1292,8 +1247,7 @@ moves_loop: // When in check, search starts here
// We record how often the best move has been changed in each iteration. // We record how often the best move has been changed in each iteration.
// This information is used for time management. In MultiPV mode, // This information is used for time management. In MultiPV mode,
// we must take care to only do this for the first PV line. // we must take care to only do this for the first PV line.
if ( moveCount > 1 if (moveCount > 1 && !thisThread->pvIdx)
&& !thisThread->pvIdx)
++thisThread->bestMoveChanges; ++thisThread->bestMoveChanges;
} }
else else
@ -1323,10 +1277,7 @@ moves_loop: // When in check, search starts here
else else
{ {
// Reduce other moves if we have found at least one score improvement (~2 Elo) // Reduce other moves if we have found at least one score improvement (~2 Elo)
if ( depth > 2 if (depth > 2 && depth < 12 && beta < 13828 && value > -11369)
&& depth < 12
&& beta < 13828
&& value > -11369)
depth -= 2; depth -= 2;
assert(depth > 0); assert(depth > 0);
@ -1355,21 +1306,22 @@ moves_loop: // When in check, search starts here
assert(moveCount || !ss->inCheck || excludedMove || !MoveList<LEGAL>(pos).size()); assert(moveCount || !ss->inCheck || excludedMove || !MoveList<LEGAL>(pos).size());
if (!moveCount) if (!moveCount)
bestValue = excludedMove ? alpha : bestValue = excludedMove ? alpha : ss->inCheck ? mated_in(ss->ply) : VALUE_DRAW;
ss->inCheck ? mated_in(ss->ply)
: VALUE_DRAW;
// If there is a move that produces search value greater than alpha we update the stats of searched moves // If there is a move that produces search value greater than alpha we update the stats of searched moves
else if (bestMove) else if (bestMove)
update_all_stats(pos, ss, bestMove, bestValue, beta, prevSq, update_all_stats(pos, ss, bestMove, bestValue, beta, prevSq, quietsSearched, quietCount,
quietsSearched, quietCount, capturesSearched, captureCount, depth); capturesSearched, captureCount, depth);
// Bonus for prior countermove that caused the fail low // Bonus for prior countermove that caused the fail low
else if (!priorCapture && prevSq != SQ_NONE) else if (!priorCapture && prevSq != SQ_NONE)
{ {
int bonus = (depth > 6) + (PvNode || cutNode) + (bestValue < alpha - 653) + ((ss-1)->moveCount > 11); int bonus = (depth > 6) + (PvNode || cutNode) + (bestValue < alpha - 653)
update_continuation_histories(ss-1, pos.piece_on(prevSq), prevSq, stat_bonus(depth) * bonus); + ((ss - 1)->moveCount > 11);
thisThread->mainHistory[~us][from_to((ss-1)->currentMove)] << stat_bonus(depth) * bonus / 2; update_continuation_histories(ss - 1, pos.piece_on(prevSq), prevSq,
stat_bonus(depth) * bonus);
thisThread->mainHistory[~us][from_to((ss - 1)->currentMove)]
<< stat_bonus(depth) * bonus / 2;
} }
if (PvNode) if (PvNode)
@ -1383,8 +1335,9 @@ moves_loop: // When in check, search starts here
// Write gathered information in transposition table // Write gathered information in transposition table
if (!excludedMove && !(rootNode && thisThread->pvIdx)) if (!excludedMove && !(rootNode && thisThread->pvIdx))
tte->save(posKey, value_to_tt(bestValue, ss->ply), ss->ttPv, tte->save(posKey, value_to_tt(bestValue, ss->ply), ss->ttPv,
bestValue >= beta ? BOUND_LOWER : bestValue >= beta ? BOUND_LOWER
PvNode && bestMove ? BOUND_EXACT : BOUND_UPPER, : PvNode && bestMove ? BOUND_EXACT
: BOUND_UPPER,
depth, bestMove, ss->staticEval); depth, bestMove, ss->staticEval);
assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE); assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE);
@ -1408,8 +1361,7 @@ moves_loop: // When in check, search starts here
// Check if we have an upcoming move that draws by repetition, or // Check if we have an upcoming move that draws by repetition, or
// if the opponent had an alternative move earlier to this position. // if the opponent had an alternative move earlier to this position.
if ( alpha < VALUE_DRAW if (alpha < VALUE_DRAW && pos.has_game_cycle(ss->ply))
&& pos.has_game_cycle(ss->ply))
{ {
alpha = value_draw(pos.this_thread()); alpha = value_draw(pos.this_thread());
if (alpha >= beta) if (alpha >= beta)
@ -1442,8 +1394,7 @@ moves_loop: // When in check, search starts here
moveCount = 0; moveCount = 0;
// Step 2. Check for an immediate draw or maximum ply reached // Step 2. Check for an immediate draw or maximum ply reached
if ( pos.is_draw(ss->ply) if (pos.is_draw(ss->ply) || ss->ply >= MAX_PLY)
|| ss->ply >= MAX_PLY)
return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos) : VALUE_DRAW; return (ss->ply >= MAX_PLY && !ss->inCheck) ? evaluate(pos) : VALUE_DRAW;
assert(0 <= ss->ply && ss->ply < MAX_PLY); assert(0 <= ss->ply && ss->ply < MAX_PLY);
@ -1451,8 +1402,7 @@ moves_loop: // When in check, search starts here
// Decide whether or not to include checks: this fixes also the type of // Decide whether or not to include checks: this fixes also the type of
// TT entry depth that we are going to use. Note that in qsearch we use // TT entry depth that we are going to use. Note that in qsearch we use
// only two types of depth in TT: DEPTH_QS_CHECKS or DEPTH_QS_NO_CHECKS. // only two types of depth in TT: DEPTH_QS_CHECKS or DEPTH_QS_NO_CHECKS.
ttDepth = ss->inCheck || depth >= DEPTH_QS_CHECKS ? DEPTH_QS_CHECKS ttDepth = ss->inCheck || depth >= DEPTH_QS_CHECKS ? DEPTH_QS_CHECKS : DEPTH_QS_NO_CHECKS;
: DEPTH_QS_NO_CHECKS;
// Step 3. Transposition table lookup // Step 3. Transposition table lookup
posKey = pos.key(); posKey = pos.key();
@ -1462,8 +1412,7 @@ moves_loop: // When in check, search starts here
pvHit = ss->ttHit && tte->is_pv(); pvHit = ss->ttHit && tte->is_pv();
// At non-PV nodes we check for an early TT cutoff // At non-PV nodes we check for an early TT cutoff
if ( !PvNode if (!PvNode && tte->depth() >= ttDepth
&& tte->depth() >= ttDepth
&& ttValue != VALUE_NONE // Only in case of TT access race or if !ttHit && ttValue != VALUE_NONE // Only in case of TT access race or if !ttHit
&& (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER))) && (tte->bound() & (ttValue >= beta ? BOUND_LOWER : BOUND_UPPER)))
return ttValue; return ttValue;
@ -1486,15 +1435,15 @@ moves_loop: // When in check, search starts here
} }
else else
// In case of null move search use previous static eval with a different sign // In case of null move search use previous static eval with a different sign
ss->staticEval = bestValue = (ss-1)->currentMove != MOVE_NULL ? evaluate(pos) ss->staticEval = bestValue =
: -(ss-1)->staticEval; (ss - 1)->currentMove != MOVE_NULL ? evaluate(pos) : -(ss - 1)->staticEval;
// Stand pat. Return immediately if static value is at least beta // Stand pat. Return immediately if static value is at least beta
if (bestValue >= beta) if (bestValue >= beta)
{ {
if (!ss->ttHit) if (!ss->ttHit)
tte->save(posKey, value_to_tt(bestValue, ss->ply), false, BOUND_LOWER, tte->save(posKey, value_to_tt(bestValue, ss->ply), false, BOUND_LOWER, DEPTH_NONE,
DEPTH_NONE, MOVE_NONE, ss->staticEval); MOVE_NONE, ss->staticEval);
return bestValue; return bestValue;
} }
@ -1505,17 +1454,16 @@ moves_loop: // When in check, search starts here
futilityBase = std::min(ss->staticEval, bestValue) + 200; futilityBase = std::min(ss->staticEval, bestValue) + 200;
} }
const PieceToHistory* contHist[] = {(ss-1)->continuationHistory, (ss-2)->continuationHistory}; const PieceToHistory* contHist[] = {(ss - 1)->continuationHistory,
(ss - 2)->continuationHistory};
// Initialize a MovePicker object for the current position, and prepare // Initialize a MovePicker object for the current position, and prepare
// to search the moves. Because the depth is <= 0 here, only captures, // to search the moves. Because the depth is <= 0 here, only captures,
// queen promotions, and other checks (only if depth >= DEPTH_QS_CHECKS) // queen promotions, and other checks (only if depth >= DEPTH_QS_CHECKS)
// will be generated. // will be generated.
Square prevSq = is_ok((ss - 1)->currentMove) ? to_sq((ss - 1)->currentMove) : SQ_NONE; Square prevSq = is_ok((ss - 1)->currentMove) ? to_sq((ss - 1)->currentMove) : SQ_NONE;
MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, MovePicker mp(pos, ttMove, depth, &thisThread->mainHistory, &thisThread->captureHistory,
&thisThread->captureHistory, contHist, prevSq);
contHist,
prevSq);
int quietCheckEvasions = 0; int quietCheckEvasions = 0;
@ -1535,13 +1483,10 @@ moves_loop: // When in check, search starts here
moveCount++; moveCount++;
// Step 6. Pruning // Step 6. Pruning
if ( bestValue > VALUE_TB_LOSS_IN_MAX_PLY if (bestValue > VALUE_TB_LOSS_IN_MAX_PLY && pos.non_pawn_material(us))
&& pos.non_pawn_material(us))
{ {
// Futility pruning and moveCount pruning (~10 Elo) // Futility pruning and moveCount pruning (~10 Elo)
if ( !givesCheck if (!givesCheck && to_sq(move) != prevSq && futilityBase > VALUE_TB_LOSS_IN_MAX_PLY
&& to_sq(move) != prevSq
&& futilityBase > VALUE_TB_LOSS_IN_MAX_PLY
&& type_of(move) != PROMOTION) && type_of(move) != PROMOTION)
{ {
if (moveCount > 2) if (moveCount > 2)
@ -1559,8 +1504,7 @@ moves_loop: // When in check, search starts here
// If static eval is much lower than alpha and move is not winning material // If static eval is much lower than alpha and move is not winning material
// we can prune this move. // we can prune this move.
if ( futilityBase <= alpha if (futilityBase <= alpha && !pos.see_ge(move, VALUE_ZERO + 1))
&& !pos.see_ge(move, VALUE_ZERO + 1))
{ {
bestValue = std::max(bestValue, futilityBase); bestValue = std::max(bestValue, futilityBase);
continue; continue;
@ -1582,8 +1526,7 @@ moves_loop: // When in check, search starts here
break; break;
// Continuation history based pruning (~3 Elo) // Continuation history based pruning (~3 Elo)
if ( !capture if (!capture && (*contHist[0])[pos.moved_piece(move)][to_sq(move)] < 0
&& (*contHist[0])[pos.moved_piece(move)][to_sq(move)] < 0
&& (*contHist[1])[pos.moved_piece(move)][to_sq(move)] < 0) && (*contHist[1])[pos.moved_piece(move)][to_sq(move)] < 0)
continue; continue;
@ -1597,10 +1540,9 @@ moves_loop: // When in check, search starts here
// Update the current move // Update the current move
ss->currentMove = move; ss->currentMove = move;
ss->continuationHistory = &thisThread->continuationHistory[ss->inCheck] ss->continuationHistory =
[capture] &thisThread
[pos.moved_piece(move)] ->continuationHistory[ss->inCheck][capture][pos.moved_piece(move)][to_sq(move)];
[to_sq(move)];
quietCheckEvasions += !capture && ss->inCheck; quietCheckEvasions += !capture && ss->inCheck;
@ -1643,8 +1585,7 @@ moves_loop: // When in check, search starts here
// Save gathered info in transposition table // Save gathered info in transposition table
tte->save(posKey, value_to_tt(bestValue, ss->ply), pvHit, tte->save(posKey, value_to_tt(bestValue, ss->ply), pvHit,
bestValue >= beta ? BOUND_LOWER : BOUND_UPPER, bestValue >= beta ? BOUND_LOWER : BOUND_UPPER, ttDepth, bestMove, ss->staticEval);
ttDepth, bestMove, ss->staticEval);
assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE); assert(bestValue > -VALUE_INFINITE && bestValue < VALUE_INFINITE);
@ -1660,8 +1601,7 @@ moves_loop: // When in check, search starts here
assert(v != VALUE_NONE); assert(v != VALUE_NONE);
return v >= VALUE_TB_WIN_IN_MAX_PLY ? v + ply return v >= VALUE_TB_WIN_IN_MAX_PLY ? v + ply : v <= VALUE_TB_LOSS_IN_MAX_PLY ? v - ply : v;
: v <= VALUE_TB_LOSS_IN_MAX_PLY ? v - ply : v;
} }
@ -1708,8 +1648,17 @@ moves_loop: // When in check, search starts here
// update_all_stats() updates stats at the end of search() when a bestMove is found // update_all_stats() updates stats at the end of search() when a bestMove is found
void update_all_stats(const Position& pos, Stack* ss, Move bestMove, Value bestValue, Value beta, Square prevSq, void update_all_stats(const Position& pos,
Move* quietsSearched, int quietCount, Move* capturesSearched, int captureCount, Depth depth) { Stack* ss,
Move bestMove,
Value bestValue,
Value beta,
Square prevSq,
Move* quietsSearched,
int quietCount,
Move* capturesSearched,
int captureCount,
Depth depth) {
Color us = pos.side_to_move(); Color us = pos.side_to_move();
Thread* thisThread = pos.this_thread(); Thread* thisThread = pos.this_thread();
@ -1731,7 +1680,8 @@ moves_loop: // When in check, search starts here
for (int i = 0; i < quietCount; ++i) for (int i = 0; i < quietCount; ++i)
{ {
thisThread->mainHistory[us][from_to(quietsSearched[i])] << -bestMoveBonus; thisThread->mainHistory[us][from_to(quietsSearched[i])] << -bestMoveBonus;
update_continuation_histories(ss, pos.moved_piece(quietsSearched[i]), to_sq(quietsSearched[i]), -bestMoveBonus); update_continuation_histories(ss, pos.moved_piece(quietsSearched[i]),
to_sq(quietsSearched[i]), -bestMoveBonus);
} }
} }
else else
@ -1744,7 +1694,8 @@ moves_loop: // When in check, search starts here
// Extra penalty for a quiet early move that was not a TT move or // Extra penalty for a quiet early move that was not a TT move or
// main killer move in previous ply when it gets refuted. // main killer move in previous ply when it gets refuted.
if (prevSq != SQ_NONE if (prevSq != SQ_NONE
&& ((ss-1)->moveCount == 1 + (ss-1)->ttHit || ((ss-1)->currentMove == (ss-1)->killers[0])) && ((ss - 1)->moveCount == 1 + (ss - 1)->ttHit
|| ((ss - 1)->currentMove == (ss - 1)->killers[0]))
&& !pos.captured_piece()) && !pos.captured_piece())
update_continuation_histories(ss - 1, pos.piece_on(prevSq), prevSq, -quietMoveBonus); update_continuation_histories(ss - 1, pos.piece_on(prevSq), prevSq, -quietMoveBonus);
@ -1819,7 +1770,8 @@ moves_loop: // When in check, search starts here
{ {
// This is our magic formula // This is our magic formula
int push = int((weakness * int(topScore - rootMoves[i].score) int push = int((weakness * int(topScore - rootMoves[i].score)
+ delta * (rng.rand<unsigned>() % int(weakness))) / 128); + delta * (rng.rand<unsigned>() % int(weakness)))
/ 128);
if (rootMoves[i].score + push >= maxScore) if (rootMoves[i].score + push >= maxScore)
{ {
@ -1900,23 +1852,19 @@ string UCI::pv(const Position& pos, Depth depth) {
ss << "\n"; ss << "\n";
ss << "info" ss << "info"
<< " depth " << d << " depth " << d << " seldepth " << rootMoves[i].selDepth << " multipv " << i + 1
<< " seldepth " << rootMoves[i].selDepth
<< " multipv " << i + 1
<< " score " << UCI::value(v); << " score " << UCI::value(v);
if (Options["UCI_ShowWDL"]) if (Options["UCI_ShowWDL"])
ss << UCI::wdl(v, pos.game_ply()); ss << UCI::wdl(v, pos.game_ply());
if (i == pvIdx && !tb && updated) // tablebase- and previous-scores are exact if (i == pvIdx && !tb && updated) // tablebase- and previous-scores are exact
ss << (rootMoves[i].scoreLowerbound ? " lowerbound" : (rootMoves[i].scoreUpperbound ? " upperbound" : "")); ss << (rootMoves[i].scoreLowerbound
? " lowerbound"
: (rootMoves[i].scoreUpperbound ? " upperbound" : ""));
ss << " nodes " << nodesSearched ss << " nodes " << nodesSearched << " nps " << nodesSearched * 1000 / elapsed
<< " nps " << nodesSearched * 1000 / elapsed << " hashfull " << TT.hashfull() << " tbhits " << tbHits << " time " << elapsed << " pv";
<< " hashfull " << TT.hashfull()
<< " tbhits " << tbHits
<< " time " << elapsed
<< " pv";
for (Move m : rootMoves[i].pv) for (Move m : rootMoves[i].pv)
ss << " " << UCI::move(m, pos.is_chess960()); ss << " " << UCI::move(m, pos.is_chess960());

View file

@ -61,12 +61,12 @@ struct Stack {
struct RootMove { struct RootMove {
explicit RootMove(Move m) : pv(1, m) {} explicit RootMove(Move m) :
pv(1, m) {}
bool extract_ponder_from_tt(Position& pos); bool extract_ponder_from_tt(Position& pos);
bool operator==(const Move& m) const { return pv[0] == m; } bool operator==(const Move& m) const { return pv[0] == m; }
bool operator<(const RootMove& m) const { // Sort in descending order bool operator<(const RootMove& m) const { // Sort in descending order
return m.score != score ? m.score < score return m.score != score ? m.score < score : m.previousScore < previousScore;
: m.previousScore < previousScore;
} }
Value score = -VALUE_INFINITE; Value score = -VALUE_INFINITE;
@ -95,9 +95,7 @@ struct LimitsType {
nodes = 0; nodes = 0;
} }
bool use_time_management() const { bool use_time_management() const { return time[WHITE] || time[BLACK]; }
return time[WHITE] || time[BLACK];
}
std::vector<Move> searchmoves; std::vector<Move> searchmoves;
TimePoint time[COLOR_NB], inc[COLOR_NB], npmsec, movetime, startTime; TimePoint time[COLOR_NB], inc[COLOR_NB], npmsec, movetime, startTime;

View file

@ -65,13 +65,27 @@ namespace Stockfish {
namespace { namespace {
constexpr int TBPIECES = 7; // Max number of supported pieces constexpr int TBPIECES = 7; // Max number of supported pieces
constexpr int MAX_DTZ = 1 << 18; // Max DTZ supported, large enough to deal with the syzygy TB limit. constexpr int MAX_DTZ =
1 << 18; // Max DTZ supported, large enough to deal with the syzygy TB limit.
enum { BigEndian, LittleEndian }; enum {
enum TBType { WDL, DTZ }; // Used as template parameter BigEndian,
LittleEndian
};
enum TBType {
WDL,
DTZ
}; // Used as template parameter
// Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables
enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, Wide = 16, SingleValue = 128 }; enum TBFlag {
STM = 1,
Mapped = 2,
WinPlies = 4,
LossPlies = 8,
Wide = 16,
SingleValue = 128
};
inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); } inline WDLScore operator-(WDLScore d) { return WDLScore(-int(d)); }
inline Square operator^(Square s, int i) { return Square(int(s) ^ i); } inline Square operator^(Square s, int i) { return Square(int(s) ^ i); }
@ -91,27 +105,22 @@ int LeadPawnsSize[6][4]; // [leadPawnsCnt][FILE_A..FILE_D]
bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; } bool pawns_comp(Square i, Square j) { return MapPawns[i] < MapPawns[j]; }
int off_A1H8(Square sq) { return int(rank_of(sq)) - file_of(sq); } int off_A1H8(Square sq) { return int(rank_of(sq)) - file_of(sq); }
constexpr Value WDL_to_value[] = { constexpr Value WDL_to_value[] = {-VALUE_MATE + MAX_PLY + 1, VALUE_DRAW - 2, VALUE_DRAW,
-VALUE_MATE + MAX_PLY + 1, VALUE_DRAW + 2, VALUE_MATE - MAX_PLY - 1};
VALUE_DRAW - 2,
VALUE_DRAW,
VALUE_DRAW + 2,
VALUE_MATE - MAX_PLY - 1
};
template<typename T, int Half = sizeof(T) / 2, int End = sizeof(T) - 1> template<typename T, int Half = sizeof(T) / 2, int End = sizeof(T) - 1>
inline void swap_endian(T& x) inline void swap_endian(T& x) {
{
static_assert(std::is_unsigned_v<T>, "Argument of swap_endian not unsigned"); static_assert(std::is_unsigned_v<T>, "Argument of swap_endian not unsigned");
uint8_t tmp, *c = (uint8_t*) &x; uint8_t tmp, *c = (uint8_t*) &x;
for (int i = 0; i < Half; ++i) for (int i = 0; i < Half; ++i)
tmp = c[i], c[i] = c[End - i], c[End - i] = tmp; tmp = c[i], c[i] = c[End - i], c[End - i] = tmp;
} }
template<> inline void swap_endian<uint8_t>(uint8_t&) {} template<>
inline void swap_endian<uint8_t>(uint8_t&) {}
template<typename T, int LE> T number(void* addr) template<typename T, int LE>
{ T number(void* addr) {
T v; T v;
if (uintptr_t(addr) & (alignof(T) - 1)) // Unaligned pointer (very rare) if (uintptr_t(addr) & (alignof(T) - 1)) // Unaligned pointer (very rare)
@ -128,14 +137,16 @@ template<typename T, int LE> T number(void* addr)
// like captures and pawn moves but we can easily recover the correct dtz of the // like captures and pawn moves but we can easily recover the correct dtz of the
// previous move if we know the position's WDL score. // previous move if we know the position's WDL score.
int dtz_before_zeroing(WDLScore wdl) { int dtz_before_zeroing(WDLScore wdl) {
return wdl == WDLWin ? 1 : return wdl == WDLWin ? 1
wdl == WDLCursedWin ? 101 : : wdl == WDLCursedWin ? 101
wdl == WDLBlessedLoss ? -101 : : wdl == WDLBlessedLoss ? -101
wdl == WDLLoss ? -1 : 0; : wdl == WDLLoss ? -1
: 0;
} }
// Return the sign of a number (-1, 0, 1) // Return the sign of a number (-1, 0, 1)
template <typename T> int sign_of(T val) { template<typename T>
int sign_of(T val) {
return (T(0) < val) - (val < T(0)); return (T(0) < val) - (val < T(0));
} }
@ -150,15 +161,19 @@ static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes");
using Sym = uint16_t; // Huffman symbol using Sym = uint16_t; // Huffman symbol
struct LR { struct LR {
enum Side { Left, Right }; enum Side {
Left,
Right
};
uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12 uint8_t lr[3]; // The first 12 bits is the left-hand symbol, the second 12
// bits is the right-hand symbol. If the symbol has length 1, // bits is the right-hand symbol. If the symbol has length 1,
// then the left-hand symbol is the stored value. // then the left-hand symbol is the stored value.
template<Side S> template<Side S>
Sym get() { Sym get() {
return S == Left ? ((lr[1] & 0xF) << 8) | lr[0] : return S == Left ? ((lr[1] & 0xF) << 8) | lr[0]
S == Right ? (lr[2] << 4) | (lr[1] >> 4) : (assert(false), Sym(-1)); : S == Right ? (lr[2] << 4) | (lr[1] >> 4)
: (assert(false), Sym(-1));
} }
}; };
@ -275,8 +290,7 @@ public:
#endif #endif
uint8_t* data = (uint8_t*) *baseAddress; uint8_t* data = (uint8_t*) *baseAddress;
constexpr uint8_t Magics[][4] = { { 0xD7, 0x66, 0x0C, 0xA5 }, constexpr uint8_t Magics[][4] = {{0xD7, 0x66, 0x0C, 0xA5}, {0x71, 0xE8, 0x23, 0x5D}};
{ 0x71, 0xE8, 0x23, 0x5D } };
if (memcmp(data, Magics[type == WDL], 4)) if (memcmp(data, Magics[type == WDL], 4))
{ {
@ -318,8 +332,10 @@ struct PairsData {
SparseEntry* sparseIndex; // Partial indices into blockLength[] SparseEntry* sparseIndex; // Partial indices into blockLength[]
size_t sparseIndexSize; // Size of SparseIndex[] table size_t sparseIndexSize; // Size of SparseIndex[] table
uint8_t* data; // Start of Huffman compressed data uint8_t* data; // Start of Huffman compressed data
std::vector<uint64_t> base64; // base64[l - min_sym_len] is the 64bit-padded lowest symbol of length l std::vector<uint64_t>
std::vector<uint8_t> symlen; // Number of values (-1) represented by a given Huffman symbol: 1..256 base64; // base64[l - min_sym_len] is the 64bit-padded lowest symbol of length l
std::vector<uint8_t>
symlen; // Number of values (-1) represented by a given Huffman symbol: 1..256
Piece pieces[TBPIECES]; // Position pieces: the order of pieces defines the groups Piece pieces[TBPIECES]; // Position pieces: the order of pieces defines the groups
uint64_t groupIdx[TBPIECES + 1]; // Start index used for the encoding of the group's pieces uint64_t groupIdx[TBPIECES + 1]; // Start index used for the encoding of the group's pieces
int groupLen[TBPIECES + 1]; // Number of pieces in a given group: KRKN -> (3, 1) int groupLen[TBPIECES + 1]; // Number of pieces in a given group: KRKN -> (3, 1)
@ -348,11 +364,11 @@ struct TBTable {
uint8_t pawnCount[2]; // [Lead color / other color] uint8_t pawnCount[2]; // [Lead color / other color]
PairsData items[Sides][4]; // [wtm / btm][FILE_A..FILE_D or 0] PairsData items[Sides][4]; // [wtm / btm][FILE_A..FILE_D or 0]
PairsData* get(int stm, int f) { PairsData* get(int stm, int f) { return &items[stm % Sides][hasPawns ? f : 0]; }
return &items[stm % Sides][hasPawns ? f : 0];
}
TBTable() : ready(false), baseAddress(nullptr) {} TBTable() :
ready(false),
baseAddress(nullptr) {}
explicit TBTable(const std::string& code); explicit TBTable(const std::string& code);
explicit TBTable(const TBTable<WDL>& wdl); explicit TBTable(const TBTable<WDL>& wdl);
@ -363,7 +379,8 @@ struct TBTable {
}; };
template<> template<>
TBTable<WDL>::TBTable(const std::string& code) : TBTable() { TBTable<WDL>::TBTable(const std::string& code) :
TBTable() {
StateInfo st; StateInfo st;
Position pos; Position pos;
@ -381,8 +398,7 @@ TBTable<WDL>::TBTable(const std::string& code) : TBTable() {
// Set the leading color. In case both sides have pawns the leading color // Set the leading color. In case both sides have pawns the leading color
// is the side with fewer pawns because this leads to better compression. // is the side with fewer pawns because this leads to better compression.
bool c = !pos.count<PAWN>(BLACK) bool c = !pos.count<PAWN>(BLACK)
|| ( pos.count<PAWN>(WHITE) || (pos.count<PAWN>(WHITE) && pos.count<PAWN>(BLACK) >= pos.count<PAWN>(WHITE));
&& pos.count<PAWN>(BLACK) >= pos.count<PAWN>(WHITE));
pawnCount[0] = pos.count<PAWN>(c ? WHITE : BLACK); pawnCount[0] = pos.count<PAWN>(c ? WHITE : BLACK);
pawnCount[1] = pos.count<PAWN>(c ? BLACK : WHITE); pawnCount[1] = pos.count<PAWN>(c ? BLACK : WHITE);
@ -391,7 +407,8 @@ TBTable<WDL>::TBTable(const std::string& code) : TBTable() {
} }
template<> template<>
TBTable<DTZ>::TBTable(const TBTable<WDL>& wdl) : TBTable() { TBTable<DTZ>::TBTable(const TBTable<WDL>& wdl) :
TBTable() {
// Use the corresponding WDL table to avoid recalculating all from scratch // Use the corresponding WDL table to avoid recalculating all from scratch
key = wdl.key; key = wdl.key;
@ -408,8 +425,7 @@ TBTable<DTZ>::TBTable(const TBTable<WDL>& wdl) : TBTable() {
// at init time, accessed at probe time. // at init time, accessed at probe time.
class TBTables { class TBTables {
struct Entry struct Entry {
{
Key key; Key key;
TBTable<WDL>* wdl; TBTable<WDL>* wdl;
TBTable<DTZ>* dtz; TBTable<DTZ>* dtz;
@ -433,9 +449,11 @@ class TBTables {
Entry entry{key, wdl, dtz}; Entry entry{key, wdl, dtz};
// Ensure last element is empty to avoid overflow when looking up // Ensure last element is empty to avoid overflow when looking up
for (uint32_t bucket = homeBucket; bucket < Size + Overflow - 1; ++bucket) { for (uint32_t bucket = homeBucket; bucket < Size + Overflow - 1; ++bucket)
{
Key otherKey = hashTable[bucket].key; Key otherKey = hashTable[bucket].key;
if (otherKey == key || !hashTable[bucket].get<WDL>()) { if (otherKey == key || !hashTable[bucket].get<WDL>())
{
hashTable[bucket] = entry; hashTable[bucket] = entry;
return; return;
} }
@ -443,7 +461,8 @@ class TBTables {
// Robin Hood hashing: If we've probed for longer than this element, // Robin Hood hashing: If we've probed for longer than this element,
// insert here and search for a new spot for the other element instead. // insert here and search for a new spot for the other element instead.
uint32_t otherHomeBucket = uint32_t(otherKey) & (Size - 1); uint32_t otherHomeBucket = uint32_t(otherKey) & (Size - 1);
if (otherHomeBucket > homeBucket) { if (otherHomeBucket > homeBucket)
{
std::swap(entry, hashTable[bucket]); std::swap(entry, hashTable[bucket]);
key = otherKey; key = otherKey;
homeBucket = otherHomeBucket; homeBucket = otherHomeBucket;
@ -456,7 +475,8 @@ class TBTables {
public: public:
template<TBType Type> template<TBType Type>
TBTable<Type>* get(Key key) { TBTable<Type>* get(Key key) {
for (const Entry* entry = &hashTable[uint32_t(key) & (Size - 1)]; ; ++entry) { for (const Entry* entry = &hashTable[uint32_t(key) & (Size - 1)];; ++entry)
{
if (entry->key == key || !entry->get<Type>()) if (entry->key == key || !entry->get<Type>())
return entry->get<Type>(); return entry->get<Type>();
} }
@ -565,7 +585,8 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
// Read the first 64 bits in our block, this is a (truncated) sequence of // Read the first 64 bits in our block, this is a (truncated) sequence of
// unknown number of symbols of unknown length but we know the first one // unknown number of symbols of unknown length but we know the first one
// is at the beginning of this 64-bit sequence. // is at the beginning of this 64-bit sequence.
uint64_t buf64 = number<uint64_t, BigEndian>(ptr); ptr += 2; uint64_t buf64 = number<uint64_t, BigEndian>(ptr);
ptr += 2;
int buf64Size = 64; int buf64Size = 64;
Sym sym; Sym sym;
@ -598,7 +619,8 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
buf64 <<= len; // Consume the just processed symbol buf64 <<= len; // Consume the just processed symbol
buf64Size -= len; buf64Size -= len;
if (buf64Size <= 32) { // Refill the buffer if (buf64Size <= 32)
{ // Refill the buffer
buf64Size += 32; buf64Size += 32;
buf64 |= uint64_t(number<uint32_t, BigEndian>(ptr++)) << (64 - buf64Size); buf64 |= uint64_t(number<uint32_t, BigEndian>(ptr++)) << (64 - buf64Size);
} }
@ -618,7 +640,8 @@ int decompress_pairs(PairsData* d, uint64_t idx) {
// the left side because in Recursive Pairing child symbols are adjacent. // the left side because in Recursive Pairing child symbols are adjacent.
if (offset < d->symlen[left] + 1) if (offset < d->symlen[left] + 1)
sym = left; sym = left;
else { else
{
offset -= d->symlen[left] + 1; offset -= d->symlen[left] + 1;
sym = d->btree[sym].get<LR::Right>(); sym = d->btree[sym].get<LR::Right>();
} }
@ -632,8 +655,7 @@ bool check_dtz_stm(TBTable<WDL>*, int, File) { return true; }
bool check_dtz_stm(TBTable<DTZ>* entry, int stm, File f) { bool check_dtz_stm(TBTable<DTZ>* entry, int stm, File f) {
auto flags = entry->get(stm, f)->flags; auto flags = entry->get(stm, f)->flags;
return (flags & TBFlag::STM) == stm return (flags & TBFlag::STM) == stm || ((entry->key == entry->key2) && !entry->hasPawns);
|| ((entry->key == entry->key2) && !entry->hasPawns);
} }
// DTZ scores are sorted by frequency of occurrence and then assigned the // DTZ scores are sorted by frequency of occurrence and then assigned the
@ -650,7 +672,8 @@ int map_score(TBTable<DTZ>* entry, File f, int value, WDLScore wdl) {
uint8_t* map = entry->map; uint8_t* map = entry->map;
uint16_t* idx = entry->get(0, f)->map_idx; uint16_t* idx = entry->get(0, f)->map_idx;
if (flags & TBFlag::Mapped) { if (flags & TBFlag::Mapped)
{
if (flags & TBFlag::Wide) if (flags & TBFlag::Wide)
value = ((uint16_t*) map)[idx[WDLMap[wdl + 2]] + value]; value = ((uint16_t*) map)[idx[WDLMap[wdl + 2]] + value];
else else
@ -660,8 +683,7 @@ int map_score(TBTable<DTZ>* entry, File f, int value, WDLScore wdl) {
// DTZ tables store distance to zero in number of moves or plies. We // DTZ tables store distance to zero in number of moves or plies. We
// want to return plies, so we have to convert to plies when needed. // want to return plies, so we have to convert to plies when needed.
if ((wdl == WDLWin && !(flags & TBFlag::WinPlies)) if ((wdl == WDLWin && !(flags & TBFlag::WinPlies))
|| (wdl == WDLLoss && !(flags & TBFlag::LossPlies)) || (wdl == WDLLoss && !(flags & TBFlag::LossPlies)) || wdl == WDLCursedWin
|| wdl == WDLCursedWin
|| wdl == WDLBlessedLoss) || wdl == WDLBlessedLoss)
value *= 2; value *= 2;
@ -704,7 +726,8 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
// For pawns, TB files store 4 separate tables according if leading pawn is on // For pawns, TB files store 4 separate tables according if leading pawn is on
// file a, b, c or d after reordering. The leading pawn is the one with maximum // file a, b, c or d after reordering. The leading pawn is the one with maximum
// MapPawns[] value, that is the one most toward the edges and with lowest rank. // MapPawns[] value, that is the one most toward the edges and with lowest rank.
if (entry->hasPawns) { if (entry->hasPawns)
{
// In all the 4 tables, pawns are at the beginning of the piece sequence and // In all the 4 tables, pawns are at the beginning of the piece sequence and
// their color is the reference one. So we just pick the first one. // their color is the reference one. So we just pick the first one.
@ -733,7 +756,8 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
// Now we are ready to get all the position pieces (but the lead pawns) and // Now we are ready to get all the position pieces (but the lead pawns) and
// directly map them to the correct color and square. // directly map them to the correct color and square.
b = pos.pieces() ^ leadPawns; b = pos.pieces() ^ leadPawns;
do { do
{
Square s = pop_lsb(b); Square s = pop_lsb(b);
squares[size] = s ^ flipSquares; squares[size] = s ^ flipSquares;
pieces[size++] = Piece(pos.piece_on(s) ^ flipColor); pieces[size++] = Piece(pos.piece_on(s) ^ flipColor);
@ -762,7 +786,8 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
// Encode leading pawns starting with the one with minimum MapPawns[] and // Encode leading pawns starting with the one with minimum MapPawns[] and
// proceeding in ascending order. // proceeding in ascending order.
if (entry->hasPawns) { if (entry->hasPawns)
{
idx = LeadPawnIdx[leadPawnsCnt][squares[0]]; idx = LeadPawnIdx[leadPawnsCnt][squares[0]];
std::stable_sort(squares + 1, squares + leadPawnsCnt, pawns_comp); std::stable_sort(squares + 1, squares + leadPawnsCnt, pawns_comp);
@ -781,7 +806,8 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
// Look for the first piece of the leading group not on the A1-D4 diagonal // Look for the first piece of the leading group not on the A1-D4 diagonal
// and ensure it is mapped below the diagonal. // and ensure it is mapped below the diagonal.
for (int i = 0; i < d->groupLen[0]; ++i) { for (int i = 0; i < d->groupLen[0]; ++i)
{
if (!off_A1H8(squares[i])) if (!off_A1H8(squares[i]))
continue; continue;
@ -818,7 +844,8 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
// //
// In case we have at least 3 unique pieces (including kings) we encode them // In case we have at least 3 unique pieces (including kings) we encode them
// together. // together.
if (entry->hasUniquePieces) { if (entry->hasUniquePieces)
{
int adjust1 = squares[1] > squares[0]; int adjust1 = squares[1] > squares[0];
int adjust2 = (squares[2] > squares[0]) + (squares[2] > squares[1]); int adjust2 = (squares[2] > squares[0]) + (squares[2] > squares[1]);
@ -827,32 +854,26 @@ Ret do_probe_table(const Position& pos, T* entry, WDLScore wdl, ProbeState* resu
// triangle to 0...5. There are 63 squares for second piece and and 62 // triangle to 0...5. There are 63 squares for second piece and and 62
// (mapped to 0...61) for the third. // (mapped to 0...61) for the third.
if (off_A1H8(squares[0])) if (off_A1H8(squares[0]))
idx = ( MapA1D1D4[squares[0]] * 63 idx = (MapA1D1D4[squares[0]] * 63 + (squares[1] - adjust1)) * 62 + squares[2] - adjust2;
+ (squares[1] - adjust1)) * 62
+ squares[2] - adjust2;
// First piece is on a1-h8 diagonal, second below: map this occurrence to // First piece is on a1-h8 diagonal, second below: map this occurrence to
// 6 to differentiate from the above case, rank_of() maps a1-d4 diagonal // 6 to differentiate from the above case, rank_of() maps a1-d4 diagonal
// to 0...3 and finally MapB1H1H7[] maps the b1-h1-h7 triangle to 0..27. // to 0...3 and finally MapB1H1H7[] maps the b1-h1-h7 triangle to 0..27.
else if (off_A1H8(squares[1])) else if (off_A1H8(squares[1]))
idx = ( 6 * 63 + rank_of(squares[0]) * 28 idx = (6 * 63 + rank_of(squares[0]) * 28 + MapB1H1H7[squares[1]]) * 62 + squares[2]
+ MapB1H1H7[squares[1]]) * 62 - adjust2;
+ squares[2] - adjust2;
// First two pieces are on a1-h8 diagonal, third below // First two pieces are on a1-h8 diagonal, third below
else if (off_A1H8(squares[2])) else if (off_A1H8(squares[2]))
idx = 6 * 63 * 62 + 4 * 28 * 62 idx = 6 * 63 * 62 + 4 * 28 * 62 + rank_of(squares[0]) * 7 * 28
+ rank_of(squares[0]) * 7 * 28 + (rank_of(squares[1]) - adjust1) * 28 + MapB1H1H7[squares[2]];
+ (rank_of(squares[1]) - adjust1) * 28
+ MapB1H1H7[squares[2]];
// All 3 pieces on the diagonal a1-h8 // All 3 pieces on the diagonal a1-h8
else else
idx = 6 * 63 * 62 + 4 * 28 * 62 + 4 * 7 * 28 idx = 6 * 63 * 62 + 4 * 28 * 62 + 4 * 7 * 28 + rank_of(squares[0]) * 7 * 6
+ rank_of(squares[0]) * 7 * 6 + (rank_of(squares[1]) - adjust1) * 6 + (rank_of(squares[2]) - adjust2);
+ (rank_of(squares[1]) - adjust1) * 6 }
+ (rank_of(squares[2]) - adjust2); else
} else
// We don't have at least 3 unique pieces, like in KRRvKBB, just map // We don't have at least 3 unique pieces, like in KRRvKBB, just map
// the kings. // the kings.
idx = MapKK[MapA1D1D4[squares[0]]][squares[1]]; idx = MapKK[MapA1D1D4[squares[0]]][squares[1]];
@ -933,8 +954,7 @@ void set_groups(T& e, PairsData* d, int order[], File f) {
if (k == order[0]) // Leading pawns or pieces if (k == order[0]) // Leading pawns or pieces
{ {
d->groupIdx[0] = idx; d->groupIdx[0] = idx;
idx *= e.hasPawns ? LeadPawnsSize[d->groupLen[0]][f] idx *= e.hasPawns ? LeadPawnsSize[d->groupLen[0]][f] : e.hasUniquePieces ? 31332 : 462;
: e.hasUniquePieces ? 31332 : 462;
} }
else if (k == order[1]) // Remaining pawns else if (k == order[1]) // Remaining pawns
{ {
@ -977,7 +997,8 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
d->flags = *data++; d->flags = *data++;
if (d->flags & TBFlag::SingleValue) { if (d->flags & TBFlag::SingleValue)
{
d->blocksNum = d->blockLengthSize = 0; d->blocksNum = d->blockLengthSize = 0;
d->span = d->sparseIndexSize = 0; // Broken MSVC zero-init d->span = d->sparseIndexSize = 0; // Broken MSVC zero-init
d->minSymLen = *data++; // Here we store the single value d->minSymLen = *data++; // Here we store the single value
@ -992,7 +1013,8 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
d->span = 1ULL << *data++; d->span = 1ULL << *data++;
d->sparseIndexSize = size_t((tbSize + d->span - 1) / d->span); // Round up d->sparseIndexSize = size_t((tbSize + d->span - 1) / d->span); // Round up
auto padding = number<uint8_t, LittleEndian>(data++); auto padding = number<uint8_t, LittleEndian>(data++);
d->blocksNum = number<uint32_t, LittleEndian>(data); data += sizeof(uint32_t); d->blocksNum = number<uint32_t, LittleEndian>(data);
data += sizeof(uint32_t);
d->blockLengthSize = d->blocksNum + padding; // Padded to ensure SparseIndex[] d->blockLengthSize = d->blocksNum + padding; // Padded to ensure SparseIndex[]
// does not point out of range. // does not point out of range.
d->maxSymLen = *data++; d->maxSymLen = *data++;
@ -1012,9 +1034,11 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
// avoiding unsigned overflow warnings. // avoiding unsigned overflow warnings.
int base64_size = static_cast<int>(d->base64.size()); int base64_size = static_cast<int>(d->base64.size());
for (int i = base64_size - 2; i >= 0; --i) { for (int i = base64_size - 2; i >= 0; --i)
{
d->base64[i] = (d->base64[i + 1] + number<Sym, LittleEndian>(&d->lowestSym[i]) d->base64[i] = (d->base64[i + 1] + number<Sym, LittleEndian>(&d->lowestSym[i])
- number<Sym, LittleEndian>(&d->lowestSym[i + 1])) / 2; - number<Sym, LittleEndian>(&d->lowestSym[i + 1]))
/ 2;
assert(d->base64[i] * 2 >= d->base64[i + 1]); assert(d->base64[i] * 2 >= d->base64[i + 1]);
} }
@ -1027,7 +1051,8 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data) {
d->base64[i] <<= 64 - i - d->minSymLen; // Right-padding to 64 bits d->base64[i] <<= 64 - i - d->minSymLen; // Right-padding to 64 bits
data += base64_size * sizeof(Sym); data += base64_size * sizeof(Sym);
d->symlen.resize(number<uint16_t, LittleEndian>(data)); data += sizeof(uint16_t); d->symlen.resize(number<uint16_t, LittleEndian>(data));
data += sizeof(uint16_t);
d->btree = (LR*) data; d->btree = (LR*) data;
// The compression scheme used is "Recursive Pairing", that replaces the most // The compression scheme used is "Recursive Pairing", that replaces the most
@ -1050,18 +1075,24 @@ uint8_t* set_dtz_map(TBTable<DTZ>& e, uint8_t* data, File maxFile) {
e.map = data; e.map = data;
for (File f = FILE_A; f <= maxFile; ++f) { for (File f = FILE_A; f <= maxFile; ++f)
{
auto flags = e.get(0, f)->flags; auto flags = e.get(0, f)->flags;
if (flags & TBFlag::Mapped) { if (flags & TBFlag::Mapped)
if (flags & TBFlag::Wide) { {
if (flags & TBFlag::Wide)
{
data += uintptr_t(data) & 1; // Word alignment, we may have a mixed table data += uintptr_t(data) & 1; // Word alignment, we may have a mixed table
for (int i = 0; i < 4; ++i) { // Sequence like 3,x,x,x,1,x,0,2,x,x for (int i = 0; i < 4; ++i)
{ // Sequence like 3,x,x,x,1,x,0,2,x,x
e.get(0, f)->map_idx[i] = uint16_t((uint16_t*) data - (uint16_t*) e.map + 1); e.get(0, f)->map_idx[i] = uint16_t((uint16_t*) data - (uint16_t*) e.map + 1);
data += 2 * number<uint16_t, LittleEndian>(data) + 2; data += 2 * number<uint16_t, LittleEndian>(data) + 2;
} }
} }
else { else
for (int i = 0; i < 4; ++i) { {
for (int i = 0; i < 4; ++i)
{
e.get(0, f)->map_idx[i] = uint16_t(data - e.map + 1); e.get(0, f)->map_idx[i] = uint16_t(data - e.map + 1);
data += *data + 1; data += *data + 1;
} }
@ -1079,7 +1110,10 @@ void set(T& e, uint8_t* data) {
PairsData* d; PairsData* d;
enum { Split = 1, HasPawns = 2 }; enum {
Split = 1,
HasPawns = 2
};
assert(e.hasPawns == bool(*data & HasPawns)); assert(e.hasPawns == bool(*data & HasPawns));
assert((e.key != e.key2) == bool(*data & Split)); assert((e.key != e.key2) == bool(*data & Split));
@ -1093,7 +1127,8 @@ void set(T& e, uint8_t* data) {
assert(!pp || e.pawnCount[0]); assert(!pp || e.pawnCount[0]);
for (File f = FILE_A; f <= maxFile; ++f) { for (File f = FILE_A; f <= maxFile; ++f)
{
for (int i = 0; i < sides; i++) for (int i = 0; i < sides; i++)
*e.get(i, f) = PairsData(); *e.get(i, f) = PairsData();
@ -1119,19 +1154,22 @@ void set(T& e, uint8_t* data) {
data = set_dtz_map(e, data, maxFile); data = set_dtz_map(e, data, maxFile);
for (File f = FILE_A; f <= maxFile; ++f) for (File f = FILE_A; f <= maxFile; ++f)
for (int i = 0; i < sides; i++) { for (int i = 0; i < sides; i++)
{
(d = e.get(i, f))->sparseIndex = (SparseEntry*) data; (d = e.get(i, f))->sparseIndex = (SparseEntry*) data;
data += d->sparseIndexSize * sizeof(SparseEntry); data += d->sparseIndexSize * sizeof(SparseEntry);
} }
for (File f = FILE_A; f <= maxFile; ++f) for (File f = FILE_A; f <= maxFile; ++f)
for (int i = 0; i < sides; i++) { for (int i = 0; i < sides; i++)
{
(d = e.get(i, f))->blockLength = (uint16_t*) data; (d = e.get(i, f))->blockLength = (uint16_t*) data;
data += d->blockLengthSize * sizeof(uint16_t); data += d->blockLengthSize * sizeof(uint16_t);
} }
for (File f = FILE_A; f <= maxFile; ++f) for (File f = FILE_A; f <= maxFile; ++f)
for (int i = 0; i < sides; i++) { for (int i = 0; i < sides; i++)
{
data = (uint8_t*) ((uintptr_t(data) + 0x3F) & ~0x3F); // 64 byte alignment data = (uint8_t*) ((uintptr_t(data) + 0x3F) & ~0x3F); // 64 byte alignment
(d = e.get(i, f))->data = data; (d = e.get(i, f))->data = data;
data += d->blocksNum * d->sizeofBlock; data += d->blocksNum * d->sizeofBlock;
@ -1159,13 +1197,14 @@ void* mapped(TBTable<Type>& e, const Position& pos) {
// Pieces strings in decreasing order for each color, like ("KPP","KR") // Pieces strings in decreasing order for each color, like ("KPP","KR")
std::string fname, w, b; std::string fname, w, b;
for (PieceType pt = KING; pt >= PAWN; --pt) { for (PieceType pt = KING; pt >= PAWN; --pt)
{
w += std::string(popcount(pos.pieces(WHITE, pt)), PieceToChar[pt]); w += std::string(popcount(pos.pieces(WHITE, pt)), PieceToChar[pt]);
b += std::string(popcount(pos.pieces(BLACK, pt)), PieceToChar[pt]); b += std::string(popcount(pos.pieces(BLACK, pt)), PieceToChar[pt]);
} }
fname = (e.key == pos.material_key() ? w + 'v' + b : b + 'v' + w) fname =
+ (Type == WDL ? ".rtbw" : ".rtbz"); (e.key == pos.material_key() ? w + 'v' + b : b + 'v' + w) + (Type == WDL ? ".rtbw" : ".rtbz");
uint8_t* data = TBFile(fname).map(&e.baseAddress, &e.mapping, Type); uint8_t* data = TBFile(fname).map(&e.baseAddress, &e.mapping, Type);
@ -1214,8 +1253,7 @@ WDLScore search(Position& pos, ProbeState* result) {
for (const Move move : moveList) for (const Move move : moveList)
{ {
if ( !pos.capture(move) if (!pos.capture(move) && (!CheckZeroingMoves || type_of(pos.moved_piece(move)) != PAWN))
&& (!CheckZeroingMoves || type_of(pos.moved_piece(move)) != PAWN))
continue; continue;
moveCount++; moveCount++;
@ -1259,8 +1297,7 @@ WDLScore search(Position& pos, ProbeState* result) {
// DTZ stores a "don't care" value if bestValue is a win // DTZ stores a "don't care" value if bestValue is a win
if (bestValue >= value) if (bestValue >= value)
return *result = ( bestValue > WDLDraw return *result = (bestValue > WDLDraw || noMoreMoves ? ZEROING_BEST_MOVE : OK), bestValue;
|| noMoreMoves ? ZEROING_BEST_MOVE : OK), bestValue;
return *result = OK, value; return *result = OK, value;
} }
@ -1333,8 +1370,8 @@ void Tablebases::init(const std::string& paths) {
for (int n = 1; n < 64; n++) // Squares for (int n = 1; n < 64; n++) // Squares
for (int k = 0; k < 6 && k <= n; ++k) // Pieces for (int k = 0; k < 6 && k <= n; ++k) // Pieces
Binomial[k][n] = (k > 0 ? Binomial[k - 1][n - 1] : 0) Binomial[k][n] =
+ (k < n ? Binomial[k ][n - 1] : 0); (k > 0 ? Binomial[k - 1][n - 1] : 0) + (k < n ? Binomial[k][n - 1] : 0);
// MapPawns[s] encodes squares a2-h7 to 0..47. This is the number of possible // MapPawns[s] encodes squares a2-h7 to 0..47. This is the number of possible
// available squares when the leading one is in 's'. Moreover the pawn with // available squares when the leading one is in 's'. Moreover the pawn with
@ -1375,20 +1412,24 @@ void Tablebases::init(const std::string& paths) {
} }
// Add entries in TB tables if the corresponding ".rtbw" file exists // Add entries in TB tables if the corresponding ".rtbw" file exists
for (PieceType p1 = PAWN; p1 < KING; ++p1) { for (PieceType p1 = PAWN; p1 < KING; ++p1)
{
TBTables.add({KING, p1, KING}); TBTables.add({KING, p1, KING});
for (PieceType p2 = PAWN; p2 <= p1; ++p2) { for (PieceType p2 = PAWN; p2 <= p1; ++p2)
{
TBTables.add({KING, p1, p2, KING}); TBTables.add({KING, p1, p2, KING});
TBTables.add({KING, p1, KING, p2}); TBTables.add({KING, p1, KING, p2});
for (PieceType p3 = PAWN; p3 < KING; ++p3) for (PieceType p3 = PAWN; p3 < KING; ++p3)
TBTables.add({KING, p1, p2, KING, p3}); TBTables.add({KING, p1, p2, KING, p3});
for (PieceType p3 = PAWN; p3 <= p2; ++p3) { for (PieceType p3 = PAWN; p3 <= p2; ++p3)
{
TBTables.add({KING, p1, p2, p3, KING}); TBTables.add({KING, p1, p2, p3, KING});
for (PieceType p4 = PAWN; p4 <= p3; ++p4) { for (PieceType p4 = PAWN; p4 <= p3; ++p4)
{
TBTables.add({KING, p1, p2, p3, p4, KING}); TBTables.add({KING, p1, p2, p3, p4, KING});
for (PieceType p5 = PAWN; p5 <= p4; ++p5) for (PieceType p5 = PAWN; p5 <= p4; ++p5)
@ -1398,7 +1439,8 @@ void Tablebases::init(const std::string& paths) {
TBTables.add({KING, p1, p2, p3, p4, KING, p5}); TBTables.add({KING, p1, p2, p3, p4, KING, p5});
} }
for (PieceType p4 = PAWN; p4 < KING; ++p4) { for (PieceType p4 = PAWN; p4 < KING; ++p4)
{
TBTables.add({KING, p1, p2, p3, KING, p4}); TBTables.add({KING, p1, p2, p3, KING, p4});
for (PieceType p5 = PAWN; p5 <= p4; ++p5) for (PieceType p5 = PAWN; p5 <= p4; ++p5)
@ -1491,8 +1533,7 @@ int Tablebases::probe_dtz(Position& pos, ProbeState* result) {
// otherwise we will get the dtz of the next move sequence. Search the // otherwise we will get the dtz of the next move sequence. Search the
// position after the move to get the score sign (because even in a // position after the move to get the score sign (because even in a
// winning position we could make a losing capture or go for a draw). // winning position we could make a losing capture or go for a draw).
dtz = zeroing ? -dtz_before_zeroing(search<false>(pos, result)) dtz = zeroing ? -dtz_before_zeroing(search<false>(pos, result)) : -probe_dtz(pos, result);
: -probe_dtz(pos, result);
// If the move mates, force minDTZ to 1 // If the move mates, force minDTZ to 1
if (dtz == 1 && pos.checkers() && MoveList<LEGAL>(pos).size() == 0) if (dtz == 1 && pos.checkers() && MoveList<LEGAL>(pos).size() == 0)
@ -1557,14 +1598,11 @@ bool Tablebases::root_probe(Position& pos, Search::RootMoves& rootMoves) {
{ {
// Otherwise, take dtz for the new position and correct by 1 ply // Otherwise, take dtz for the new position and correct by 1 ply
dtz = -probe_dtz(pos, &result); dtz = -probe_dtz(pos, &result);
dtz = dtz > 0 ? dtz + 1 dtz = dtz > 0 ? dtz + 1 : dtz < 0 ? dtz - 1 : dtz;
: dtz < 0 ? dtz - 1 : dtz;
} }
// Make sure that a mating move is assigned a dtz value of 1 // Make sure that a mating move is assigned a dtz value of 1
if ( pos.checkers() if (pos.checkers() && dtz == 2 && MoveList<LEGAL>(pos).size() == 0)
&& dtz == 2
&& MoveList<LEGAL>(pos).size() == 0)
dtz = 1; dtz = 1;
pos.undo_move(m.pv[0]); pos.undo_move(m.pv[0]);
@ -1625,8 +1663,7 @@ bool Tablebases::root_probe_wdl(Position& pos, Search::RootMoves& rootMoves) {
m.tbRank = WDL_to_rank[wdl + 2]; m.tbRank = WDL_to_rank[wdl + 2];
if (!rule50) if (!rule50)
wdl = wdl > WDLDraw ? WDLWin wdl = wdl > WDLDraw ? WDLWin : wdl < WDLDraw ? WDLLoss : WDLDraw;
: wdl < WDLDraw ? WDLLoss : WDLDraw;
m.tbScore = WDL_to_value[wdl + 2]; m.tbScore = WDL_to_value[wdl + 2];
} }

View file

@ -43,7 +43,9 @@ ThreadPool Threads; // Global object
// Thread constructor launches the thread and waits until it goes to sleep // Thread constructor launches the thread and waits until it goes to sleep
// in idle_loop(). Note that 'searching' and 'exit' should be already set. // in idle_loop(). Note that 'searching' and 'exit' should be already set.
Thread::Thread(size_t n) : idx(n), stdThread(&Thread::idle_loop, this) { Thread::Thread(size_t n) :
idx(n),
stdThread(&Thread::idle_loop, this) {
wait_for_search_finished(); wait_for_search_finished();
} }
@ -175,8 +177,10 @@ void ThreadPool::clear() {
// ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and // ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and
// returns immediately. Main thread will wake up other threads and start the search. // returns immediately. Main thread will wake up other threads and start the search.
void ThreadPool::start_thinking(Position& pos, StateListPtr& states, void ThreadPool::start_thinking(Position& pos,
const Search::LimitsType& limits, bool ponderMode) { StateListPtr& states,
const Search::LimitsType& limits,
bool ponderMode) {
main()->wait_for_search_finished(); main()->wait_for_search_finished();
@ -249,7 +253,8 @@ Thread* ThreadPool::get_best_thread() const {
&& (votes[th->rootMoves[0].pv[0]] > votes[bestThread->rootMoves[0].pv[0]] && (votes[th->rootMoves[0].pv[0]] > votes[bestThread->rootMoves[0].pv[0]]
|| (votes[th->rootMoves[0].pv[0]] == votes[bestThread->rootMoves[0].pv[0]] || (votes[th->rootMoves[0].pv[0]] == votes[bestThread->rootMoves[0].pv[0]]
&& thread_value(th) * int(th->rootMoves[0].pv.size() > 2) && thread_value(th) * int(th->rootMoves[0].pv.size() > 2)
> thread_value(bestThread) * int(bestThread->rootMoves[0].pv.size() > 2))))) > thread_value(bestThread)
* int(bestThread->rootMoves[0].pv.size() > 2)))))
bestThread = th; bestThread = th;
return bestThread; return bestThread;

View file

@ -36,8 +36,7 @@ namespace Stockfish {
static const size_t TH_STACK_SIZE = 8 * 1024 * 1024; static const size_t TH_STACK_SIZE = 8 * 1024 * 1024;
template<class T, class P = std::pair<T*, void (T::*)()>> template<class T, class P = std::pair<T*, void (T::*)()>>
void* start_routine(void* ptr) void* start_routine(void* ptr) {
{
P* p = reinterpret_cast<P*>(ptr); P* p = reinterpret_cast<P*>(ptr);
(p->first->*(p->second))(); // Call member function pointer (p->first->*(p->second))(); // Call member function pointer
delete p; delete p;

View file

@ -69,8 +69,8 @@ void TimeManagement::init(Search::LimitsType& limits, Color us, int ply) {
int mtg = limits.movestogo ? std::min(limits.movestogo, 50) : 50; int mtg = limits.movestogo ? std::min(limits.movestogo, 50) : 50;
// Make sure timeLeft is > 0 since we may use it as a divisor // Make sure timeLeft is > 0 since we may use it as a divisor
TimePoint timeLeft = std::max(TimePoint(1), TimePoint timeLeft = std::max(TimePoint(1), limits.time[us] + limits.inc[us] * (mtg - 1)
limits.time[us] + limits.inc[us] * (mtg - 1) - moveOverhead * (2 + mtg)); - moveOverhead * (2 + mtg));
// Use extra time with larger increments // Use extra time with larger increments
double optExtra = std::clamp(1.0 + 12.0 * limits.inc[us] / limits.time[us], 1.0, 1.12); double optExtra = std::clamp(1.0 + 12.0 * limits.inc[us] / limits.time[us], 1.0, 1.12);
@ -93,14 +93,14 @@ void TimeManagement::init(Search::LimitsType& limits, Color us, int ply) {
// x moves in y seconds (+ z increment) // x moves in y seconds (+ z increment)
else else
{ {
optScale = std::min((0.88 + ply / 116.4) / mtg, optScale = std::min((0.88 + ply / 116.4) / mtg, 0.88 * limits.time[us] / double(timeLeft));
0.88 * limits.time[us] / double(timeLeft));
maxScale = std::min(6.3, 1.5 + 0.11 * mtg); maxScale = std::min(6.3, 1.5 + 0.11 * mtg);
} }
// Never use more than 80% of the available time for this move // Never use more than 80% of the available time for this move
optimumTime = TimePoint(optScale * timeLeft); optimumTime = TimePoint(optScale * timeLeft);
maximumTime = TimePoint(std::min(0.8 * limits.time[us] - moveOverhead, maxScale * optimumTime)) - 10; maximumTime =
TimePoint(std::min(0.8 * limits.time[us] - moveOverhead, maxScale * optimumTime)) - 10;
if (Options["Ponder"]) if (Options["Ponder"])
optimumTime += optimumTime / 4; optimumTime += optimumTime / 4;

View file

@ -36,8 +36,9 @@ public:
void init(Search::LimitsType& limits, Color us, int ply); void init(Search::LimitsType& limits, Color us, int ply);
TimePoint optimum() const { return optimumTime; } TimePoint optimum() const { return optimumTime; }
TimePoint maximum() const { return maximumTime; } TimePoint maximum() const { return maximumTime; }
TimePoint elapsed() const { return Search::Limits.npmsec ? TimePoint elapsed() const {
TimePoint(Threads.nodes_searched()) : now() - startTime; } return Search::Limits.npmsec ? TimePoint(Threads.nodes_searched()) : now() - startTime;
}
int64_t availableNodes; // When in 'nodes as time' mode int64_t availableNodes; // When in 'nodes as time' mode

View file

@ -43,9 +43,7 @@ void TTEntry::save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev)
move16 = uint16_t(m); move16 = uint16_t(m);
// Overwrite less valuable entries (cheapest checks first) // Overwrite less valuable entries (cheapest checks first)
if ( b == BOUND_EXACT if (b == BOUND_EXACT || uint16_t(k) != key16 || d - DEPTH_OFFSET + 2 * pv > depth8 - 4)
|| uint16_t(k) != key16
|| d - DEPTH_OFFSET + 2 * pv > depth8 - 4)
{ {
assert(d > DEPTH_OFFSET); assert(d > DEPTH_OFFSET);
assert(d < 256 + DEPTH_OFFSET); assert(d < 256 + DEPTH_OFFSET);
@ -74,8 +72,7 @@ void TranspositionTable::resize(size_t mbSize) {
table = static_cast<Cluster*>(aligned_large_pages_alloc(clusterCount * sizeof(Cluster))); table = static_cast<Cluster*>(aligned_large_pages_alloc(clusterCount * sizeof(Cluster)));
if (!table) if (!table)
{ {
std::cerr << "Failed to allocate " << mbSize std::cerr << "Failed to allocate " << mbSize << "MB for transposition table." << std::endl;
<< "MB for transposition table." << std::endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
@ -93,7 +90,6 @@ void TranspositionTable::clear() {
for (size_t idx = 0; idx < size_t(Options["Threads"]); ++idx) for (size_t idx = 0; idx < size_t(Options["Threads"]); ++idx)
{ {
threads.emplace_back([this, idx]() { threads.emplace_back([this, idx]() {
// Thread binding gives faster search on systems with a first-touch policy // Thread binding gives faster search on systems with a first-touch policy
if (Options["Threads"] > 8) if (Options["Threads"] > 8)
WinProcGroup::bindThisThread(idx); WinProcGroup::bindThisThread(idx);
@ -101,8 +97,8 @@ void TranspositionTable::clear() {
// Each thread will zero its part of the hash table // Each thread will zero its part of the hash table
const size_t stride = size_t(clusterCount / Options["Threads"]), const size_t stride = size_t(clusterCount / Options["Threads"]),
start = size_t(stride * idx), start = size_t(stride * idx),
len = idx != size_t(Options["Threads"]) - 1 ? len =
stride : clusterCount - start; idx != size_t(Options["Threads"]) - 1 ? stride : clusterCount - start;
std::memset(&table[start], 0, len * sizeof(Cluster)); std::memset(&table[start], 0, len * sizeof(Cluster));
}); });
@ -128,7 +124,8 @@ TTEntry* TranspositionTable::probe(const Key key, bool& found) const {
for (int i = 0; i < ClusterSize; ++i) for (int i = 0; i < ClusterSize; ++i)
if (tte[i].key16 == key16 || !tte[i].depth8) if (tte[i].key16 == key16 || !tte[i].depth8)
{ {
tte[i].genBound8 = uint8_t(generation8 | (tte[i].genBound8 & (GENERATION_DELTA - 1))); // Refresh tte[i].genBound8 =
uint8_t(generation8 | (tte[i].genBound8 & (GENERATION_DELTA - 1))); // Refresh
return found = bool(tte[i].depth8), &tte[i]; return found = bool(tte[i].depth8), &tte[i];
} }
@ -141,8 +138,10 @@ TTEntry* TranspositionTable::probe(const Key key, bool& found) const {
// is needed to keep the unrelated lowest n bits from affecting // is needed to keep the unrelated lowest n bits from affecting
// the result) to calculate the entry age correctly even after // the result) to calculate the entry age correctly even after
// generation8 overflows into the next cycle. // generation8 overflows into the next cycle.
if ( replace->depth8 - ((GENERATION_CYCLE + generation8 - replace->genBound8) & GENERATION_MASK) if (replace->depth8
> tte[i].depth8 - ((GENERATION_CYCLE + generation8 - tte[i].genBound8) & GENERATION_MASK)) - ((GENERATION_CYCLE + generation8 - replace->genBound8) & GENERATION_MASK)
> tte[i].depth8
- ((GENERATION_CYCLE + generation8 - tte[i].genBound8) & GENERATION_MASK))
replace = &tte[i]; replace = &tte[i];
return found = false, replace; return found = false, replace;
@ -157,7 +156,8 @@ int TranspositionTable::hashfull() const {
int cnt = 0; int cnt = 0;
for (int i = 0; i < 1000; ++i) for (int i = 0; i < 1000; ++i)
for (int j = 0; j < ClusterSize; ++j) for (int j = 0; j < ClusterSize; ++j)
cnt += table[i].entry[j].depth8 && (table[i].entry[j].genBound8 & GENERATION_MASK) == generation8; cnt += table[i].entry[j].depth8
&& (table[i].entry[j].genBound8 & GENERATION_MASK) == generation8;
return cnt / ClusterSize; return cnt / ClusterSize;
} }

View file

@ -79,9 +79,11 @@ class TranspositionTable {
// Constants used to refresh the hash table periodically // Constants used to refresh the hash table periodically
static constexpr unsigned GENERATION_BITS = 3; // nb of bits reserved for other things static constexpr unsigned GENERATION_BITS = 3; // nb of bits reserved for other things
static constexpr int GENERATION_DELTA = (1 << GENERATION_BITS); // increment for generation field static constexpr int GENERATION_DELTA =
(1 << GENERATION_BITS); // increment for generation field
static constexpr int GENERATION_CYCLE = 255 + (1 << GENERATION_BITS); // cycle length static constexpr int GENERATION_CYCLE = 255 + (1 << GENERATION_BITS); // cycle length
static constexpr int GENERATION_MASK = (0xFF << GENERATION_BITS) & 0xFF; // mask to pull out generation number static constexpr int GENERATION_MASK =
(0xFF << GENERATION_BITS) & 0xFF; // mask to pull out generation number
public: public:
~TranspositionTable() { aligned_large_pages_free(table); } ~TranspositionTable() { aligned_large_pages_free(table); }

View file

@ -42,7 +42,8 @@ string Tune::next(string& names, bool pop) {
string name; string name;
do { do
{
string token = names.substr(0, names.find(',')); string token = names.substr(0, names.find(','));
if (pop) if (pop)
@ -51,8 +52,7 @@ string Tune::next(string& names, bool pop) {
std::stringstream ws(token); std::stringstream ws(token);
name += (ws >> token, token); // Remove trailing whitespace name += (ws >> token, token); // Remove trailing whitespace
} while ( std::count(name.begin(), name.end(), '(') } while (std::count(name.begin(), name.end(), '(') - std::count(name.begin(), name.end(), ')'));
- std::count(name.begin(), name.end(), ')'));
return name; return name;
} }
@ -76,31 +76,40 @@ static void make_option(const string& n, int v, const SetRange& r) {
LastOption = &Options[n]; LastOption = &Options[n];
// Print formatted parameters, ready to be copy-pasted in Fishtest // Print formatted parameters, ready to be copy-pasted in Fishtest
std::cout << n << "," std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << ","
<< v << ","
<< r(v).first << "," << r(v).second << ","
<< (r(v).second - r(v).first) / 20.0 << "," << (r(v).second - r(v).first) / 20.0 << ","
<< "0.0020" << "0.0020" << std::endl;
<< std::endl;
} }
template<> void Tune::Entry<int>::init_option() { make_option(name, value, range); } template<>
void Tune::Entry<int>::init_option() {
make_option(name, value, range);
}
template<> void Tune::Entry<int>::read_option() { template<>
void Tune::Entry<int>::read_option() {
if (Options.count(name)) if (Options.count(name))
value = int(Options[name]); value = int(Options[name]);
} }
template<> void Tune::Entry<Value>::init_option() { make_option(name, value, range); } template<>
void Tune::Entry<Value>::init_option() {
make_option(name, value, range);
}
template<> void Tune::Entry<Value>::read_option() { template<>
void Tune::Entry<Value>::read_option() {
if (Options.count(name)) if (Options.count(name))
value = Value(int(Options[name])); value = Value(int(Options[name]));
} }
// Instead of a variable here we have a PostUpdate function: just call it // Instead of a variable here we have a PostUpdate function: just call it
template<> void Tune::Entry<Tune::PostUpdate>::init_option() {} template<>
template<> void Tune::Entry<Tune::PostUpdate>::read_option() { value(); } void Tune::Entry<Tune::PostUpdate>::init_option() {}
template<>
void Tune::Entry<Tune::PostUpdate>::read_option() {
value();
}
} // namespace Stockfish } // namespace Stockfish
@ -117,9 +126,7 @@ template<> void Tune::Entry<Tune::PostUpdate>::read_option() { value(); }
namespace Stockfish { namespace Stockfish {
void Tune::read_results() { void Tune::read_results() { /* ...insert your values here... */
/* ...insert your values here... */
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -33,13 +33,14 @@ using Range = std::pair<int, int>; // Option's min-max values
using RangeFun = Range(int); using RangeFun = Range(int);
// Default Range function, to calculate Option's min-max values // Default Range function, to calculate Option's min-max values
inline Range default_range(int v) { inline Range default_range(int v) { return v > 0 ? Range(0, 2 * v) : Range(2 * v, 0); }
return v > 0 ? Range(0, 2 * v) : Range(2 * v, 0);
}
struct SetRange { struct SetRange {
explicit SetRange(RangeFun f) : fun(f) {} explicit SetRange(RangeFun f) :
SetRange(int min, int max) : fun(nullptr), range(min, max) {} fun(f) {}
SetRange(int min, int max) :
fun(nullptr),
range(min, max) {}
Range operator()(int v) const { return fun ? fun(v) : range; } Range operator()(int v) const { return fun ? fun(v) : range; }
RangeFun* fun; RangeFun* fun;
@ -83,7 +84,10 @@ class Tune {
void operator=(const Tune&) = delete; void operator=(const Tune&) = delete;
void read_results(); void read_results();
static Tune& instance() { static Tune t; return t; } // Singleton static Tune& instance() {
static Tune t;
return t;
} // Singleton
// Use polymorphism to accommodate Entry of different types in the same vector // Use polymorphism to accommodate Entry of different types in the same vector
struct EntryBase { struct EntryBase {
@ -97,11 +101,14 @@ class Tune {
static_assert(!std::is_const_v<T>, "Parameter cannot be const!"); static_assert(!std::is_const_v<T>, "Parameter cannot be const!");
static_assert( std::is_same_v<T, int> static_assert(std::is_same_v<T, int> || std::is_same_v<T, Value>
|| std::is_same_v<T, Value> || std::is_same_v<T, PostUpdate>,
|| std::is_same_v<T, PostUpdate>, "Parameter type not supported!"); "Parameter type not supported!");
Entry(const std::string& n, T& v, const SetRange& r) : name(n), value(v), range(r) {} Entry(const std::string& n, T& v, const SetRange& r) :
name(n),
value(v),
range(r) {}
void operator=(const Entry&) = delete; // Because 'value' is a reference void operator=(const Entry&) = delete; // Because 'value' is a reference
void init_option() override; void init_option() override;
void read_option() override; void read_option() override;
@ -143,10 +150,18 @@ class Tune {
public: public:
template<typename... Args> template<typename... Args>
static int add(const std::string& names, Args&&... args) { static int add(const std::string& names, Args&&... args) {
return instance().add(SetDefaultRange, names.substr(1, names.size() - 2), args...); // Remove trailing parenthesis return instance().add(SetDefaultRange, names.substr(1, names.size() - 2),
args...); // Remove trailing parenthesis
}
static void init() {
for (auto& e : instance().list)
e->init_option();
read_options();
} // Deferred, due to UCI::Options access
static void read_options() {
for (auto& e : instance().list)
e->read_option();
} }
static void init() { for (auto& e : instance().list) e->init_option(); read_options(); } // Deferred, due to UCI::Options access
static void read_options() { for (auto& e : instance().list) e->read_option(); }
static bool update_on_last; static bool update_on_last;
}; };

View file

@ -55,7 +55,8 @@
// _WIN32 Building on Windows (any) // _WIN32 Building on Windows (any)
// _WIN64 Building on Windows 64 bit // _WIN64 Building on Windows 64 bit
#if defined(__GNUC__ ) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2)) && defined(_WIN32) && !defined(__clang__) #if defined(__GNUC__) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2)) \
&& defined(_WIN32) && !defined(__clang__)
#define ALIGNAS_ON_STACK_VARIABLES_BROKEN #define ALIGNAS_ON_STACK_VARIABLES_BROKEN
#endif #endif
@ -132,7 +133,9 @@ enum MoveType {
}; };
enum Color { enum Color {
WHITE, BLACK, COLOR_NB = 2 WHITE,
BLACK,
COLOR_NB = 2
}; };
enum CastlingRights { enum CastlingRights {
@ -180,6 +183,7 @@ enum Value : int {
QueenValue = 2538, QueenValue = 2538,
}; };
// clang-format off
enum PieceType { enum PieceType {
NO_PIECE_TYPE, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING, NO_PIECE_TYPE, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING,
ALL_PIECES = 0, ALL_PIECES = 0,
@ -192,8 +196,10 @@ enum Piece {
B_PAWN = PAWN + 8, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING, B_PAWN = PAWN + 8, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING,
PIECE_NB = 16 PIECE_NB = 16
}; };
// clang-format on
constexpr Value PieceValue[PIECE_NB] = { VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO, constexpr Value PieceValue[PIECE_NB] = {
VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO,
VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO}; VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO};
using Depth = int; using Depth = int;
@ -208,6 +214,7 @@ enum : int {
DEPTH_OFFSET = -7 // value used only for TT entry occupancy check DEPTH_OFFSET = -7 // value used only for TT entry occupancy check
}; };
// clang-format off
enum Square : int { enum Square : int {
SQ_A1, SQ_B1, SQ_C1, SQ_D1, SQ_E1, SQ_F1, SQ_G1, SQ_H1, SQ_A1, SQ_B1, SQ_C1, SQ_D1, SQ_E1, SQ_F1, SQ_G1, SQ_H1,
SQ_A2, SQ_B2, SQ_C2, SQ_D2, SQ_E2, SQ_F2, SQ_G2, SQ_H2, SQ_A2, SQ_B2, SQ_C2, SQ_D2, SQ_E2, SQ_F2, SQ_G2, SQ_H2,
@ -222,6 +229,7 @@ enum Square : int {
SQUARE_ZERO = 0, SQUARE_ZERO = 0,
SQUARE_NB = 64 SQUARE_NB = 64
}; };
// clang-format on
enum Direction : int { enum Direction : int {
NORTH = 8, NORTH = 8,
@ -236,11 +244,27 @@ enum Direction : int {
}; };
enum File : int { enum File : int {
FILE_A, FILE_B, FILE_C, FILE_D, FILE_E, FILE_F, FILE_G, FILE_H, FILE_NB FILE_A,
FILE_B,
FILE_C,
FILE_D,
FILE_E,
FILE_F,
FILE_G,
FILE_H,
FILE_NB
}; };
enum Rank : int { enum Rank : int {
RANK_1, RANK_2, RANK_3, RANK_4, RANK_5, RANK_6, RANK_7, RANK_8, RANK_NB RANK_1,
RANK_2,
RANK_3,
RANK_4,
RANK_5,
RANK_6,
RANK_7,
RANK_8,
RANK_NB
}; };
// Keep track of what a move changes on the board (used by NNUE) // Keep track of what a move changes on the board (used by NNUE)
@ -317,62 +341,36 @@ constexpr CastlingRights operator&(Color c, CastlingRights cr) {
return CastlingRights((c == WHITE ? WHITE_CASTLING : BLACK_CASTLING) & cr); return CastlingRights((c == WHITE ? WHITE_CASTLING : BLACK_CASTLING) & cr);
} }
constexpr Value mate_in(int ply) { constexpr Value mate_in(int ply) { return VALUE_MATE - ply; }
return VALUE_MATE - ply;
}
constexpr Value mated_in(int ply) { constexpr Value mated_in(int ply) { return -VALUE_MATE + ply; }
return -VALUE_MATE + ply;
}
constexpr Square make_square(File f, Rank r) { constexpr Square make_square(File f, Rank r) { return Square((r << 3) + f); }
return Square((r << 3) + f);
}
constexpr Piece make_piece(Color c, PieceType pt) { constexpr Piece make_piece(Color c, PieceType pt) { return Piece((c << 3) + pt); }
return Piece((c << 3) + pt);
}
constexpr PieceType type_of(Piece pc) { constexpr PieceType type_of(Piece pc) { return PieceType(pc & 7); }
return PieceType(pc & 7);
}
inline Color color_of(Piece pc) { inline Color color_of(Piece pc) {
assert(pc != NO_PIECE); assert(pc != NO_PIECE);
return Color(pc >> 3); return Color(pc >> 3);
} }
constexpr bool is_ok(Move m) { constexpr bool is_ok(Move m) { return m != MOVE_NONE && m != MOVE_NULL; }
return m != MOVE_NONE && m != MOVE_NULL;
}
constexpr bool is_ok(Square s) { constexpr bool is_ok(Square s) { return s >= SQ_A1 && s <= SQ_H8; }
return s >= SQ_A1 && s <= SQ_H8;
}
constexpr File file_of(Square s) { constexpr File file_of(Square s) { return File(s & 7); }
return File(s & 7);
}
constexpr Rank rank_of(Square s) { constexpr Rank rank_of(Square s) { return Rank(s >> 3); }
return Rank(s >> 3);
}
constexpr Square relative_square(Color c, Square s) { constexpr Square relative_square(Color c, Square s) { return Square(s ^ (c * 56)); }
return Square(s ^ (c * 56));
}
constexpr Rank relative_rank(Color c, Rank r) { constexpr Rank relative_rank(Color c, Rank r) { return Rank(r ^ (c * 7)); }
return Rank(r ^ (c * 7));
}
constexpr Rank relative_rank(Color c, Square s) { constexpr Rank relative_rank(Color c, Square s) { return relative_rank(c, rank_of(s)); }
return relative_rank(c, rank_of(s));
}
constexpr Direction pawn_push(Color c) { constexpr Direction pawn_push(Color c) { return c == WHITE ? NORTH : SOUTH; }
return c == WHITE ? NORTH : SOUTH;
}
constexpr Square from_sq(Move m) { constexpr Square from_sq(Move m) {
assert(is_ok(m)); assert(is_ok(m));
@ -384,21 +382,13 @@ constexpr Square to_sq(Move m) {
return Square(m & 0x3F); return Square(m & 0x3F);
} }
constexpr int from_to(Move m) { constexpr int from_to(Move m) { return m & 0xFFF; }
return m & 0xFFF;
}
constexpr MoveType type_of(Move m) { constexpr MoveType type_of(Move m) { return MoveType(m & (3 << 14)); }
return MoveType(m & (3 << 14));
}
constexpr PieceType promotion_type(Move m) { constexpr PieceType promotion_type(Move m) { return PieceType(((m >> 12) & 3) + KNIGHT); }
return PieceType(((m >> 12) & 3) + KNIGHT);
}
constexpr Move make_move(Square from, Square to) { constexpr Move make_move(Square from, Square to) { return Move((from << 6) + to); }
return Move((from << 6) + to);
}
template<MoveType T> template<MoveType T>
constexpr Move make(Square from, Square to, PieceType pt = KNIGHT) { constexpr Move make(Square from, Square to, PieceType pt = KNIGHT) {

View file

@ -141,18 +141,30 @@ namespace {
while (is >> token) while (is >> token)
limits.searchmoves.push_back(UCI::to_move(pos, token)); limits.searchmoves.push_back(UCI::to_move(pos, token));
else if (token == "wtime") is >> limits.time[WHITE]; else if (token == "wtime")
else if (token == "btime") is >> limits.time[BLACK]; is >> limits.time[WHITE];
else if (token == "winc") is >> limits.inc[WHITE]; else if (token == "btime")
else if (token == "binc") is >> limits.inc[BLACK]; is >> limits.time[BLACK];
else if (token == "movestogo") is >> limits.movestogo; else if (token == "winc")
else if (token == "depth") is >> limits.depth; is >> limits.inc[WHITE];
else if (token == "nodes") is >> limits.nodes; else if (token == "binc")
else if (token == "movetime") is >> limits.movetime; is >> limits.inc[BLACK];
else if (token == "mate") is >> limits.mate; else if (token == "movestogo")
else if (token == "perft") is >> limits.perft; is >> limits.movestogo;
else if (token == "infinite") limits.infinite = 1; else if (token == "depth")
else if (token == "ponder") ponderMode = true; is >> limits.depth;
else if (token == "nodes")
is >> limits.nodes;
else if (token == "movetime")
is >> limits.movetime;
else if (token == "mate")
is >> limits.mate;
else if (token == "perft")
is >> limits.perft;
else if (token == "infinite")
limits.infinite = 1;
else if (token == "ponder")
ponderMode = true;
Threads.start_thinking(pos, states, limits, ponderMode); Threads.start_thinking(pos, states, limits, ponderMode);
} }
@ -168,7 +180,8 @@ namespace {
uint64_t num, nodes = 0, cnt = 1; uint64_t num, nodes = 0, cnt = 1;
std::vector<std::string> list = setup_bench(pos, args); std::vector<std::string> list = setup_bench(pos, args);
num = count_if(list.begin(), list.end(), [](const std::string& s) { return s.find("go ") == 0 || s.find("eval") == 0; }); num = count_if(list.begin(), list.end(),
[](const std::string& s) { return s.find("go ") == 0 || s.find("eval") == 0; });
TimePoint elapsed = now(); TimePoint elapsed = now();
@ -179,7 +192,8 @@ namespace {
if (token == "go" || token == "eval") if (token == "go" || token == "eval")
{ {
std::cerr << "\nPosition: " << cnt++ << '/' << num << " (" << pos.fen() << ")" << std::endl; std::cerr << "\nPosition: " << cnt++ << '/' << num << " (" << pos.fen() << ")"
<< std::endl;
if (token == "go") if (token == "go")
{ {
go(pos, is, states); go(pos, is, states);
@ -189,9 +203,15 @@ namespace {
else else
trace_eval(pos); trace_eval(pos);
} }
else if (token == "setoption") setoption(is); else if (token == "setoption")
else if (token == "position") position(pos, is, states); setoption(is);
else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take a while else if (token == "position")
position(pos, is, states);
else if (token == "ucinewgame")
{
Search::clear();
elapsed = now();
} // Search::clear() may take a while
} }
elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero'
@ -199,8 +219,7 @@ namespace {
dbg_print(); dbg_print();
std::cerr << "\n===========================" std::cerr << "\n==========================="
<< "\nTotal time (ms) : " << elapsed << "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes
<< "\nNodes searched : " << nodes
<< "\nNodes/second : " << 1000 * nodes / elapsed << std::endl; << "\nNodes/second : " << 1000 * nodes / elapsed << std::endl;
} }
@ -250,8 +269,10 @@ void UCI::loop(int argc, char* argv[]) {
for (int i = 1; i < argc; ++i) for (int i = 1; i < argc; ++i)
cmd += std::string(argv[i]) + " "; cmd += std::string(argv[i]) + " ";
do { do
if (argc == 1 && !getline(std::cin, cmd)) // Wait for an input or an end-of-file (EOF) indication {
if (argc == 1
&& !getline(std::cin, cmd)) // Wait for an input or an end-of-file (EOF) indication
cmd = "quit"; cmd = "quit";
std::istringstream is(cmd); std::istringstream is(cmd);
@ -259,8 +280,7 @@ void UCI::loop(int argc, char* argv[]) {
token.clear(); // Avoid a stale if getline() returns nothing or a blank line token.clear(); // Avoid a stale if getline() returns nothing or a blank line
is >> std::skipws >> token; is >> std::skipws >> token;
if ( token == "quit" if (token == "quit" || token == "stop")
|| token == "stop")
Threads.stop = true; Threads.stop = true;
// The GUI sends 'ponderhit' to tell that the user has played the expected move. // The GUI sends 'ponderhit' to tell that the user has played the expected move.
@ -271,23 +291,32 @@ void UCI::loop(int argc, char* argv[]) {
Threads.main()->ponder = false; // Switch to the normal search Threads.main()->ponder = false; // Switch to the normal search
else if (token == "uci") else if (token == "uci")
sync_cout << "id name " << engine_info(true) sync_cout << "id name " << engine_info(true) << "\n"
<< "\n" << Options << Options << "\nuciok" << sync_endl;
<< "\nuciok" << sync_endl;
else if (token == "setoption") setoption(is); else if (token == "setoption")
else if (token == "go") go(pos, is, states); setoption(is);
else if (token == "position") position(pos, is, states); else if (token == "go")
else if (token == "ucinewgame") Search::clear(); go(pos, is, states);
else if (token == "isready") sync_cout << "readyok" << sync_endl; else if (token == "position")
position(pos, is, states);
else if (token == "ucinewgame")
Search::clear();
else if (token == "isready")
sync_cout << "readyok" << sync_endl;
// Add custom non-UCI commands, mainly for debugging purposes. // Add custom non-UCI commands, mainly for debugging purposes.
// These commands must not be used during a search! // These commands must not be used during a search!
else if (token == "flip") pos.flip(); else if (token == "flip")
else if (token == "bench") bench(pos, is, states); pos.flip();
else if (token == "d") sync_cout << pos << sync_endl; else if (token == "bench")
else if (token == "eval") trace_eval(pos); bench(pos, is, states);
else if (token == "compiler") sync_cout << compiler_info() << sync_endl; else if (token == "d")
sync_cout << pos << sync_endl;
else if (token == "eval")
trace_eval(pos);
else if (token == "compiler")
sync_cout << compiler_info() << sync_endl;
else if (token == "export_net") else if (token == "export_net")
{ {
std::optional<std::string> filename; std::optional<std::string> filename;
@ -297,14 +326,17 @@ void UCI::loop(int argc, char* argv[]) {
Eval::NNUE::save_eval(filename); Eval::NNUE::save_eval(filename);
} }
else if (token == "--help" || token == "help" || token == "--license" || token == "license") else if (token == "--help" || token == "help" || token == "--license" || token == "license")
sync_cout << "\nStockfish is a powerful chess engine for playing and analyzing." sync_cout
<< "\nStockfish is a powerful chess engine for playing and analyzing."
"\nIt is released as free software licensed under the GNU GPLv3 License." "\nIt is released as free software licensed under the GNU GPLv3 License."
"\nStockfish is normally used with a graphical user interface (GUI) and implements" "\nStockfish is normally used with a graphical user interface (GUI) and implements"
"\nthe Universal Chess Interface (UCI) protocol to communicate with a GUI, an API, etc." "\nthe Universal Chess Interface (UCI) protocol to communicate with a GUI, an API, etc."
"\nFor any further information, visit https://github.com/official-stockfish/Stockfish#readme" "\nFor any further information, visit https://github.com/official-stockfish/Stockfish#readme"
"\nor read the corresponding README.md and Copying.txt files distributed along with this program.\n" << sync_endl; "\nor read the corresponding README.md and Copying.txt files distributed along with this program.\n"
<< sync_endl;
else if (!token.empty() && token[0] != '#') else if (!token.empty() && token[0] != '#')
sync_cout << "Unknown command: '" << cmd << "'. Type help for more information." << sync_endl; sync_cout << "Unknown command: '" << cmd << "'. Type help for more information."
<< sync_endl;
} while (token != "quit" && argc == 1); // The command-line arguments are one-shot } while (token != "quit" && argc == 1); // The command-line arguments are one-shot
} }
@ -312,10 +344,7 @@ void UCI::loop(int argc, char* argv[]) {
// Turns a Value to an integer centipawn number, // Turns a Value to an integer centipawn number,
// without treatment of mate and similar special scores. // without treatment of mate and similar special scores.
int UCI::to_cp(Value v) { int UCI::to_cp(Value v) { return 100 * v / UCI::NormalizeToPawnValue; }
return 100 * v / UCI::NormalizeToPawnValue;
}
// UCI::value() converts a Value to a string by adhering to the UCI protocol specification: // UCI::value() converts a Value to a string by adhering to the UCI protocol specification:
// //

View file

@ -105,9 +105,8 @@ std::ostream& operator<<(std::ostream& os, const OptionsMap& om) {
os << " default " << o.defaultValue; os << " default " << o.defaultValue;
if (o.type == "spin") if (o.type == "spin")
os << " default " << int(stof(o.defaultValue)) os << " default " << int(stof(o.defaultValue)) << " min " << o.min << " max "
<< " min " << o.min << o.max;
<< " max " << o.max;
break; break;
} }
@ -118,20 +117,44 @@ std::ostream& operator<<(std::ostream& os, const OptionsMap& om) {
// Option class constructors and conversion operators // Option class constructors and conversion operators
Option::Option(const char* v, OnChange f) : type("string"), min(0), max(0), on_change(f) Option::Option(const char* v, OnChange f) :
{ defaultValue = currentValue = v; } type("string"),
min(0),
max(0),
on_change(f) {
defaultValue = currentValue = v;
}
Option::Option(bool v, OnChange f) : type("check"), min(0), max(0), on_change(f) Option::Option(bool v, OnChange f) :
{ defaultValue = currentValue = (v ? "true" : "false"); } type("check"),
min(0),
max(0),
on_change(f) {
defaultValue = currentValue = (v ? "true" : "false");
}
Option::Option(OnChange f) : type("button"), min(0), max(0), on_change(f) Option::Option(OnChange f) :
{} type("button"),
min(0),
max(0),
on_change(f) {}
Option::Option(double v, int minv, int maxv, OnChange f) : type("spin"), min(minv), max(maxv), on_change(f) Option::Option(double v, int minv, int maxv, OnChange f) :
{ defaultValue = currentValue = std::to_string(v); } type("spin"),
min(minv),
max(maxv),
on_change(f) {
defaultValue = currentValue = std::to_string(v);
}
Option::Option(const char* v, const char* cur, OnChange f) : type("combo"), min(0), max(0), on_change(f) Option::Option(const char* v, const char* cur, OnChange f) :
{ defaultValue = v; currentValue = cur; } type("combo"),
min(0),
max(0),
on_change(f) {
defaultValue = v;
currentValue = cur;
}
Option::operator int() const { Option::operator int() const {
assert(type == "check" || type == "spin"); assert(type == "check" || type == "spin");
@ -145,8 +168,7 @@ Option::operator std::string() const {
bool Option::operator==(const char* s) const { bool Option::operator==(const char* s) const {
assert(type == "combo"); assert(type == "combo");
return !CaseInsensitiveLess()(currentValue, s) return !CaseInsensitiveLess()(currentValue, s) && !CaseInsensitiveLess()(s, currentValue);
&& !CaseInsensitiveLess()(s, currentValue);
} }