1
0
Fork 0
mirror of https://github.com/sockspls/badfish synced 2025-06-28 00:19:50 +00:00

add clang-format

This introduces clang-format to enforce a consistent code style for Stockfish.

Having a documented and consistent style across the code will make contributing easier
for new developers, and will make larger changes to the codebase easier to make.

To facilitate formatting, this PR includes a Makefile target (`make format`) to format the code,
this requires clang-format (version 17 currently) to be installed locally.

Installing clang-format is straightforward on most OS and distros
(e.g. with https://apt.llvm.org/, brew install clang-format, etc), as this is part of quite commonly
used suite of tools and compilers (llvm / clang).

Additionally, a CI action is present that will verify if the code requires formatting,
and comment on the PR as needed. Initially, correct formatting is not required, it will be
done by maintainers as part of the merge or in later commits, but obviously this is encouraged.

fixes https://github.com/official-stockfish/Stockfish/issues/3608
closes https://github.com/official-stockfish/Stockfish/pull/4790

Co-Authored-By: Joost VandeVondele <Joost.VandeVondele@gmail.com>
This commit is contained in:
Disservin 2023-10-21 11:40:56 +02:00 committed by Joost VandeVondele
parent 8366ec48ae
commit 2d0237db3f
49 changed files with 6403 additions and 6197 deletions

44
.clang-format Normal file
View file

@ -0,0 +1,44 @@
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: Consecutive
AlignConsecutiveDeclarations: Consecutive
AlignEscapedNewlines: DontAlign
AlignOperands: AlignAfterOperator
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortCaseLabelsOnASingleLine: false
AllowShortEnumsOnASingleLine: false
AllowShortIfStatementsOnASingleLine: false
AlwaysBreakTemplateDeclarations: Yes
BasedOnStyle: WebKit
BitFieldColonSpacing: After
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeBraces: Custom
BraceWrapping:
AfterFunction: false
AfterClass: false
AfterControlStatement: true
BeforeElse: true
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakStringLiterals: false
ColumnLimit: 100
ContinuationIndentWidth: 2
Cpp11BracedListStyle: true
IndentGotoLabels: false
IndentPPDirectives: BeforeHash
IndentWidth: 4
MaxEmptyLinesToKeep: 2
NamespaceIndentation: None
PackConstructorInitializers: Never
ReflowComments: false
SortIncludes: false
SortUsingDeclarations: false
SpaceAfterCStyleCast: true
SpaceAfterTemplateKeyword: false
SpaceBeforeCaseColon: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeInheritanceColon: false
SpaceInEmptyBlock: false
SpacesBeforeTrailingComments: 2

View file

@ -0,0 +1,51 @@
# This workflow will run clang-format and comment on the PR.
# Because of security reasons, it is crucial that this workflow
# executes no shell script nor runs make.
# Read this before editing: https://securitylab.github.com/research/github-actions-preventing-pwn-requests/
name: Stockfish
on:
pull_request_target:
branches:
- 'master'
paths:
- '**.cpp'
- '**.h'
jobs:
Stockfish:
name: clang-format check
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Run clang-format style check
uses: jidicula/clang-format-action@f62da5e3d3a2d88ff364771d9d938773a618ab5e
id: clang-format
continue-on-error: true
with:
clang-format-version: '17'
exclude-regex: 'incbin'
- name: Comment on PR
if: steps.clang-format.outcome == 'failure'
uses: thollander/actions-comment-pull-request@1d3973dc4b8e1399c0620d3f2b1aa5e795465308
with:
message: |
clang-format 17 needs to be run on this PR.
If you do not have clang-format installed, the maintainer will run it when merging.
For the exact version please see https://packages.ubuntu.com/mantic/clang-format-17.
_(execution **${{ github.run_id }}** / attempt **${{ github.run_attempt }}**)_
comment_tag: execution
- name: Comment on PR
if: steps.clang-format.outcome != 'failure'
uses: thollander/actions-comment-pull-request@1d3973dc4b8e1399c0620d3f2b1aa5e795465308
with:
message: |
_(execution **${{ github.run_id }}** / attempt **${{ github.run_attempt }}**)_
create_if_not_exists: false
comment_tag: execution
mode: delete

View file

@ -57,8 +57,9 @@ discussion._
## Code Style ## Code Style
We do not have a strict code style. But it is best to stick to the existing Changes to Stockfish C++ code should respect our coding style defined by
style of the file you are editing. [.clang-format](.clang-format). You can format your changes by running
`make format`. This requires clang-format version 17 to be installed on your system.
## Community and Communication ## Community and Communication

View file

@ -57,6 +57,14 @@ SRCS = benchmark.cpp bitboard.cpp evaluate.cpp main.cpp \
search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \ search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \
nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp
HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \
nnue/evaluate_nnue.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \
nnue/layers/affine_transform_sparse_input.h nnue/layers/clipped_relu.h nnue/layers/simd.h \
nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \
nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \
search.h syzygy/tbprobe.h thread.h thread_win32_osx.h timeman.h \
tt.h tune.h types.h uci.h
OBJS = $(notdir $(SRCS:.cpp=.o)) OBJS = $(notdir $(SRCS:.cpp=.o))
VPATH = syzygy:nnue:nnue/features VPATH = syzygy:nnue:nnue/features
@ -145,6 +153,12 @@ dotprod = no
arm_version = 0 arm_version = 0
STRIP = strip STRIP = strip
ifneq ($(shell command -v clang-format-17),)
CLANG-FORMAT = clang-format-17
else
CLANG-FORMAT = clang-format
endif
### 2.2 Architecture specific ### 2.2 Architecture specific
ifeq ($(findstring x86,$(ARCH)),x86) ifeq ($(findstring x86,$(ARCH)),x86)
@ -936,6 +950,9 @@ net: netvariables
fi; \ fi; \
fi; \ fi; \
format:
$(CLANG-FORMAT) -i $(SRCS) $(HEADERS) -style=file:../.clang-format
# default target # default target
default: default:
help help

View file

@ -27,6 +27,7 @@
namespace { namespace {
// clang-format off
const std::vector<std::string> Defaults = { const std::vector<std::string> Defaults = {
"setoption name UCI_Chess960 value false", "setoption name UCI_Chess960 value false",
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
@ -90,8 +91,9 @@ const std::vector<std::string> Defaults = {
"nqbnrkrb/pppppppp/8/8/8/8/PPPPPPPP/NQBNRKRB w KQkq - 0 1", "nqbnrkrb/pppppppp/8/8/8/8/PPPPPPPP/NQBNRKRB w KQkq - 0 1",
"setoption name UCI_Chess960 value false" "setoption name UCI_Chess960 value false"
}; };
// clang-format on
} // namespace } // namespace
namespace Stockfish { namespace Stockfish {
@ -109,56 +111,56 @@ namespace Stockfish {
std::vector<std::string> setup_bench(const Position& current, std::istream& is) { std::vector<std::string> setup_bench(const Position& current, std::istream& is) {
std::vector<std::string> fens, list; std::vector<std::string> fens, list;
std::string go, token; std::string go, token;
// Assign default values to missing arguments // Assign default values to missing arguments
std::string ttSize = (is >> token) ? token : "16"; std::string ttSize = (is >> token) ? token : "16";
std::string threads = (is >> token) ? token : "1"; std::string threads = (is >> token) ? token : "1";
std::string limit = (is >> token) ? token : "13"; std::string limit = (is >> token) ? token : "13";
std::string fenFile = (is >> token) ? token : "default"; std::string fenFile = (is >> token) ? token : "default";
std::string limitType = (is >> token) ? token : "depth"; std::string limitType = (is >> token) ? token : "depth";
go = limitType == "eval" ? "eval" : "go " + limitType + " " + limit; go = limitType == "eval" ? "eval" : "go " + limitType + " " + limit;
if (fenFile == "default") if (fenFile == "default")
fens = Defaults; fens = Defaults;
else if (fenFile == "current") else if (fenFile == "current")
fens.push_back(current.fen()); fens.push_back(current.fen());
else else
{ {
std::string fen; std::string fen;
std::ifstream file(fenFile); std::ifstream file(fenFile);
if (!file.is_open()) if (!file.is_open())
{ {
std::cerr << "Unable to open file " << fenFile << std::endl; std::cerr << "Unable to open file " << fenFile << std::endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
while (getline(file, fen)) while (getline(file, fen))
if (!fen.empty()) if (!fen.empty())
fens.push_back(fen); fens.push_back(fen);
file.close(); file.close();
} }
list.emplace_back("setoption name Threads value " + threads); list.emplace_back("setoption name Threads value " + threads);
list.emplace_back("setoption name Hash value " + ttSize); list.emplace_back("setoption name Hash value " + ttSize);
list.emplace_back("ucinewgame"); list.emplace_back("ucinewgame");
for (const std::string& fen : fens) for (const std::string& fen : fens)
if (fen.find("setoption") != std::string::npos) if (fen.find("setoption") != std::string::npos)
list.emplace_back(fen); list.emplace_back(fen);
else else
{ {
list.emplace_back("position fen " + fen); list.emplace_back("position fen " + fen);
list.emplace_back(go); list.emplace_back(go);
} }
return list; return list;
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -29,6 +29,6 @@ class Position;
std::vector<std::string> setup_bench(const Position&, std::istream&); std::vector<std::string> setup_bench(const Position&, std::istream&);
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef BENCHMARK_H_INCLUDED #endif // #ifndef BENCHMARK_H_INCLUDED

View file

@ -39,10 +39,10 @@ Magic BishopMagics[SQUARE_NB];
namespace { namespace {
Bitboard RookTable[0x19000]; // To store rook attacks Bitboard RookTable[0x19000]; // To store rook attacks
Bitboard BishopTable[0x1480]; // To store bishop attacks Bitboard BishopTable[0x1480]; // To store bishop attacks
void init_magics(PieceType pt, Bitboard table[], Magic magics[]); void init_magics(PieceType pt, Bitboard table[], Magic magics[]);
} }
@ -60,18 +60,18 @@ inline Bitboard safe_destination(Square s, int step) {
std::string Bitboards::pretty(Bitboard b) { std::string Bitboards::pretty(Bitboard b) {
std::string s = "+---+---+---+---+---+---+---+---+\n"; std::string s = "+---+---+---+---+---+---+---+---+\n";
for (Rank r = RANK_8; r >= RANK_1; --r) for (Rank r = RANK_8; r >= RANK_1; --r)
{ {
for (File f = FILE_A; f <= FILE_H; ++f) for (File f = FILE_A; f <= FILE_H; ++f)
s += b & make_square(f, r) ? "| X " : "| "; s += b & make_square(f, r) ? "| X " : "| ";
s += "| " + std::to_string(1 + r) + "\n+---+---+---+---+---+---+---+---+\n"; s += "| " + std::to_string(1 + r) + "\n+---+---+---+---+---+---+---+---+\n";
} }
s += " a b c d e f g h\n"; s += " a b c d e f g h\n";
return s; return s;
} }
@ -80,49 +80,50 @@ std::string Bitboards::pretty(Bitboard b) {
void Bitboards::init() { void Bitboards::init() {
for (unsigned i = 0; i < (1 << 16); ++i) for (unsigned i = 0; i < (1 << 16); ++i)
PopCnt16[i] = uint8_t(std::bitset<16>(i).count()); PopCnt16[i] = uint8_t(std::bitset<16>(i).count());
for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1)
for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2)
SquareDistance[s1][s2] = std::max(distance<File>(s1, s2), distance<Rank>(s1, s2)); SquareDistance[s1][s2] = std::max(distance<File>(s1, s2), distance<Rank>(s1, s2));
init_magics(ROOK, RookTable, RookMagics); init_magics(ROOK, RookTable, RookMagics);
init_magics(BISHOP, BishopTable, BishopMagics); init_magics(BISHOP, BishopTable, BishopMagics);
for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1) for (Square s1 = SQ_A1; s1 <= SQ_H8; ++s1)
{ {
PawnAttacks[WHITE][s1] = pawn_attacks_bb<WHITE>(square_bb(s1)); PawnAttacks[WHITE][s1] = pawn_attacks_bb<WHITE>(square_bb(s1));
PawnAttacks[BLACK][s1] = pawn_attacks_bb<BLACK>(square_bb(s1)); PawnAttacks[BLACK][s1] = pawn_attacks_bb<BLACK>(square_bb(s1));
for (int step : {-9, -8, -7, -1, 1, 7, 8, 9} ) for (int step : {-9, -8, -7, -1, 1, 7, 8, 9})
PseudoAttacks[KING][s1] |= safe_destination(s1, step); PseudoAttacks[KING][s1] |= safe_destination(s1, step);
for (int step : {-17, -15, -10, -6, 6, 10, 15, 17} ) for (int step : {-17, -15, -10, -6, 6, 10, 15, 17})
PseudoAttacks[KNIGHT][s1] |= safe_destination(s1, step); PseudoAttacks[KNIGHT][s1] |= safe_destination(s1, step);
PseudoAttacks[QUEEN][s1] = PseudoAttacks[BISHOP][s1] = attacks_bb<BISHOP>(s1, 0); PseudoAttacks[QUEEN][s1] = PseudoAttacks[BISHOP][s1] = attacks_bb<BISHOP>(s1, 0);
PseudoAttacks[QUEEN][s1] |= PseudoAttacks[ ROOK][s1] = attacks_bb< ROOK>(s1, 0); PseudoAttacks[QUEEN][s1] |= PseudoAttacks[ROOK][s1] = attacks_bb<ROOK>(s1, 0);
for (PieceType pt : { BISHOP, ROOK }) for (PieceType pt : {BISHOP, ROOK})
for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2) for (Square s2 = SQ_A1; s2 <= SQ_H8; ++s2)
{ {
if (PseudoAttacks[pt][s1] & s2) if (PseudoAttacks[pt][s1] & s2)
{ {
LineBB[s1][s2] = (attacks_bb(pt, s1, 0) & attacks_bb(pt, s2, 0)) | s1 | s2; LineBB[s1][s2] = (attacks_bb(pt, s1, 0) & attacks_bb(pt, s2, 0)) | s1 | s2;
BetweenBB[s1][s2] = (attacks_bb(pt, s1, square_bb(s2)) & attacks_bb(pt, s2, square_bb(s1))); BetweenBB[s1][s2] =
} (attacks_bb(pt, s1, square_bb(s2)) & attacks_bb(pt, s2, square_bb(s1)));
BetweenBB[s1][s2] |= s2; }
} BetweenBB[s1][s2] |= s2;
} }
}
} }
namespace { namespace {
Bitboard sliding_attack(PieceType pt, Square sq, Bitboard occupied) { Bitboard sliding_attack(PieceType pt, Square sq, Bitboard occupied) {
Bitboard attacks = 0; Bitboard attacks = 0;
Direction RookDirections[4] = {NORTH, SOUTH, EAST, WEST}; Direction RookDirections[4] = {NORTH, SOUTH, EAST, WEST};
Direction BishopDirections[4] = {NORTH_EAST, SOUTH_EAST, SOUTH_WEST, NORTH_WEST}; Direction BishopDirections[4] = {NORTH_EAST, SOUTH_EAST, SOUTH_WEST, NORTH_WEST};
for (Direction d : (pt == ROOK ? RookDirections : BishopDirections)) for (Direction d : (pt == ROOK ? RookDirections : BishopDirections))
@ -133,22 +134,22 @@ namespace {
} }
return attacks; return attacks;
} }
// init_magics() computes all rook and bishop attacks at startup. Magic // init_magics() computes all rook and bishop attacks at startup. Magic
// bitboards are used to look up attacks of sliding pieces. As a reference see // bitboards are used to look up attacks of sliding pieces. As a reference see
// www.chessprogramming.org/Magic_Bitboards. In particular, here we use the so // www.chessprogramming.org/Magic_Bitboards. In particular, here we use the so
// called "fancy" approach. // called "fancy" approach.
void init_magics(PieceType pt, Bitboard table[], Magic magics[]) { void init_magics(PieceType pt, Bitboard table[], Magic magics[]) {
// Optimal PRNG seeds to pick the correct magics in the shortest time // Optimal PRNG seeds to pick the correct magics in the shortest time
int seeds[][RANK_NB] = { { 8977, 44560, 54343, 38998, 5731, 95205, 104912, 17020 }, int seeds[][RANK_NB] = {{8977, 44560, 54343, 38998, 5731, 95205, 104912, 17020},
{ 728, 10316, 55013, 32803, 12281, 15100, 16645, 255 } }; {728, 10316, 55013, 32803, 12281, 15100, 16645, 255}};
Bitboard occupancy[4096], reference[4096], edges, b; Bitboard occupancy[4096], reference[4096], edges, b;
int epoch[4096] = {}, cnt = 0, size = 0; int epoch[4096] = {}, cnt = 0, size = 0;
for (Square s = SQ_A1; s <= SQ_H8; ++s) for (Square s = SQ_A1; s <= SQ_H8; ++s)
{ {
@ -161,8 +162,8 @@ namespace {
// the number of 1s of the mask. Hence we deduce the size of the shift to // the number of 1s of the mask. Hence we deduce the size of the shift to
// apply to the 64 or 32 bits word to get the index. // apply to the 64 or 32 bits word to get the index.
Magic& m = magics[s]; Magic& m = magics[s];
m.mask = sliding_attack(pt, s, 0) & ~edges; m.mask = sliding_attack(pt, s, 0) & ~edges;
m.shift = (Is64Bit ? 64 : 32) - popcount(m.mask); m.shift = (Is64Bit ? 64 : 32) - popcount(m.mask);
// Set the offset for the attacks table of the square. We have individual // Set the offset for the attacks table of the square. We have individual
// table sizes for each square with "Fancy Magic Bitboards". // table sizes for each square with "Fancy Magic Bitboards".
@ -171,7 +172,8 @@ namespace {
// Use Carry-Rippler trick to enumerate all subsets of masks[s] and // Use Carry-Rippler trick to enumerate all subsets of masks[s] and
// store the corresponding sliding attack bitboard in reference[]. // store the corresponding sliding attack bitboard in reference[].
b = size = 0; b = size = 0;
do { do
{
occupancy[size] = b; occupancy[size] = b;
reference[size] = sliding_attack(pt, s, b); reference[size] = sliding_attack(pt, s, b);
@ -189,9 +191,9 @@ namespace {
// Find a magic for square 's' picking up an (almost) random number // Find a magic for square 's' picking up an (almost) random number
// until we find the one that passes the verification test. // until we find the one that passes the verification test.
for (int i = 0; i < size; ) for (int i = 0; i < size;)
{ {
for (m.magic = 0; popcount((m.magic * m.mask) >> 56) < 6; ) for (m.magic = 0; popcount((m.magic * m.mask) >> 56) < 6;)
m.magic = rng.sparse_rand<Bitboard>(); m.magic = rng.sparse_rand<Bitboard>();
// A good magic must map every possible occupancy to an index that // A good magic must map every possible occupancy to an index that
@ -206,7 +208,7 @@ namespace {
if (epoch[idx] < cnt) if (epoch[idx] < cnt)
{ {
epoch[idx] = cnt; epoch[idx] = cnt;
m.attacks[idx] = reference[i]; m.attacks[idx] = reference[i];
} }
else if (m.attacks[idx] != reference[i]) else if (m.attacks[idx] != reference[i])
@ -214,7 +216,7 @@ namespace {
} }
} }
} }
} }
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -32,10 +32,10 @@ namespace Stockfish {
namespace Bitboards { namespace Bitboards {
void init(); void init();
std::string pretty(Bitboard b); std::string pretty(Bitboard b);
} // namespace Stockfish::Bitboards } // namespace Stockfish::Bitboards
constexpr Bitboard FileABB = 0x0101010101010101ULL; constexpr Bitboard FileABB = 0x0101010101010101ULL;
constexpr Bitboard FileBBB = FileABB << 1; constexpr Bitboard FileBBB = FileABB << 1;
@ -66,85 +66,80 @@ extern Bitboard PawnAttacks[COLOR_NB][SQUARE_NB];
// Magic holds all magic bitboards relevant data for a single square // Magic holds all magic bitboards relevant data for a single square
struct Magic { struct Magic {
Bitboard mask; Bitboard mask;
Bitboard magic; Bitboard magic;
Bitboard* attacks; Bitboard* attacks;
unsigned shift; unsigned shift;
// Compute the attack's index using the 'magic bitboards' approach // Compute the attack's index using the 'magic bitboards' approach
unsigned index(Bitboard occupied) const { unsigned index(Bitboard occupied) const {
if (HasPext) if (HasPext)
return unsigned(pext(occupied, mask)); return unsigned(pext(occupied, mask));
if (Is64Bit) if (Is64Bit)
return unsigned(((occupied & mask) * magic) >> shift); return unsigned(((occupied & mask) * magic) >> shift);
unsigned lo = unsigned(occupied) & unsigned(mask); unsigned lo = unsigned(occupied) & unsigned(mask);
unsigned hi = unsigned(occupied >> 32) & unsigned(mask >> 32); unsigned hi = unsigned(occupied >> 32) & unsigned(mask >> 32);
return (lo * unsigned(magic) ^ hi * unsigned(magic >> 32)) >> shift; return (lo * unsigned(magic) ^ hi * unsigned(magic >> 32)) >> shift;
} }
}; };
extern Magic RookMagics[SQUARE_NB]; extern Magic RookMagics[SQUARE_NB];
extern Magic BishopMagics[SQUARE_NB]; extern Magic BishopMagics[SQUARE_NB];
inline Bitboard square_bb(Square s) { inline Bitboard square_bb(Square s) {
assert(is_ok(s)); assert(is_ok(s));
return (1ULL << s); return (1ULL << s);
} }
// Overloads of bitwise operators between a Bitboard and a Square for testing // Overloads of bitwise operators between a Bitboard and a Square for testing
// whether a given bit is set in a bitboard, and for setting and clearing bits. // whether a given bit is set in a bitboard, and for setting and clearing bits.
inline Bitboard operator&( Bitboard b, Square s) { return b & square_bb(s); } inline Bitboard operator&(Bitboard b, Square s) { return b & square_bb(s); }
inline Bitboard operator|( Bitboard b, Square s) { return b | square_bb(s); } inline Bitboard operator|(Bitboard b, Square s) { return b | square_bb(s); }
inline Bitboard operator^( Bitboard b, Square s) { return b ^ square_bb(s); } inline Bitboard operator^(Bitboard b, Square s) { return b ^ square_bb(s); }
inline Bitboard& operator|=(Bitboard& b, Square s) { return b |= square_bb(s); } inline Bitboard& operator|=(Bitboard& b, Square s) { return b |= square_bb(s); }
inline Bitboard& operator^=(Bitboard& b, Square s) { return b ^= square_bb(s); } inline Bitboard& operator^=(Bitboard& b, Square s) { return b ^= square_bb(s); }
inline Bitboard operator&(Square s, Bitboard b) { return b & s; } inline Bitboard operator&(Square s, Bitboard b) { return b & s; }
inline Bitboard operator|(Square s, Bitboard b) { return b | s; } inline Bitboard operator|(Square s, Bitboard b) { return b | s; }
inline Bitboard operator^(Square s, Bitboard b) { return b ^ s; } inline Bitboard operator^(Square s, Bitboard b) { return b ^ s; }
inline Bitboard operator|(Square s1, Square s2) { return square_bb(s1) | s2; } inline Bitboard operator|(Square s1, Square s2) { return square_bb(s1) | s2; }
constexpr bool more_than_one(Bitboard b) { constexpr bool more_than_one(Bitboard b) { return b & (b - 1); }
return b & (b - 1);
}
// rank_bb() and file_bb() return a bitboard representing all the squares on // rank_bb() and file_bb() return a bitboard representing all the squares on
// the given file or rank. // the given file or rank.
constexpr Bitboard rank_bb(Rank r) { constexpr Bitboard rank_bb(Rank r) { return Rank1BB << (8 * r); }
return Rank1BB << (8 * r);
}
constexpr Bitboard rank_bb(Square s) { constexpr Bitboard rank_bb(Square s) { return rank_bb(rank_of(s)); }
return rank_bb(rank_of(s));
}
constexpr Bitboard file_bb(File f) { constexpr Bitboard file_bb(File f) { return FileABB << f; }
return FileABB << f;
}
constexpr Bitboard file_bb(Square s) { constexpr Bitboard file_bb(Square s) { return file_bb(file_of(s)); }
return file_bb(file_of(s));
}
// shift() moves a bitboard one or two steps as specified by the direction D // shift() moves a bitboard one or two steps as specified by the direction D
template<Direction D> template<Direction D>
constexpr Bitboard shift(Bitboard b) { constexpr Bitboard shift(Bitboard b) {
return D == NORTH ? b << 8 : D == SOUTH ? b >> 8 return D == NORTH ? b << 8
: D == NORTH+NORTH? b <<16 : D == SOUTH+SOUTH? b >>16 : D == SOUTH ? b >> 8
: D == EAST ? (b & ~FileHBB) << 1 : D == WEST ? (b & ~FileABB) >> 1 : D == NORTH + NORTH ? b << 16
: D == NORTH_EAST ? (b & ~FileHBB) << 9 : D == NORTH_WEST ? (b & ~FileABB) << 7 : D == SOUTH + SOUTH ? b >> 16
: D == SOUTH_EAST ? (b & ~FileHBB) >> 7 : D == SOUTH_WEST ? (b & ~FileABB) >> 9 : D == EAST ? (b & ~FileHBB) << 1
: 0; : D == WEST ? (b & ~FileABB) >> 1
: D == NORTH_EAST ? (b & ~FileHBB) << 9
: D == NORTH_WEST ? (b & ~FileABB) << 7
: D == SOUTH_EAST ? (b & ~FileHBB) >> 7
: D == SOUTH_WEST ? (b & ~FileABB) >> 9
: 0;
} }
@ -153,14 +148,14 @@ constexpr Bitboard shift(Bitboard b) {
template<Color C> template<Color C>
constexpr Bitboard pawn_attacks_bb(Bitboard b) { constexpr Bitboard pawn_attacks_bb(Bitboard b) {
return C == WHITE ? shift<NORTH_WEST>(b) | shift<NORTH_EAST>(b) return C == WHITE ? shift<NORTH_WEST>(b) | shift<NORTH_EAST>(b)
: shift<SOUTH_WEST>(b) | shift<SOUTH_EAST>(b); : shift<SOUTH_WEST>(b) | shift<SOUTH_EAST>(b);
} }
inline Bitboard pawn_attacks_bb(Color c, Square s) { inline Bitboard pawn_attacks_bb(Color c, Square s) {
assert(is_ok(s)); assert(is_ok(s));
return PawnAttacks[c][s]; return PawnAttacks[c][s];
} }
// line_bb() returns a bitboard representing an entire line (from board edge // line_bb() returns a bitboard representing an entire line (from board edge
@ -170,9 +165,9 @@ inline Bitboard pawn_attacks_bb(Color c, Square s) {
inline Bitboard line_bb(Square s1, Square s2) { inline Bitboard line_bb(Square s1, Square s2) {
assert(is_ok(s1) && is_ok(s2)); assert(is_ok(s1) && is_ok(s2));
return LineBB[s1][s2]; return LineBB[s1][s2];
} }
@ -186,26 +181,34 @@ inline Bitboard line_bb(Square s1, Square s2) {
inline Bitboard between_bb(Square s1, Square s2) { inline Bitboard between_bb(Square s1, Square s2) {
assert(is_ok(s1) && is_ok(s2)); assert(is_ok(s1) && is_ok(s2));
return BetweenBB[s1][s2]; return BetweenBB[s1][s2];
} }
// aligned() returns true if the squares s1, s2 and s3 are aligned either on a // aligned() returns true if the squares s1, s2 and s3 are aligned either on a
// straight or on a diagonal line. // straight or on a diagonal line.
inline bool aligned(Square s1, Square s2, Square s3) { inline bool aligned(Square s1, Square s2, Square s3) { return line_bb(s1, s2) & s3; }
return line_bb(s1, s2) & s3;
}
// distance() functions return the distance between x and y, defined as the // distance() functions return the distance between x and y, defined as the
// number of steps for a king in x to reach y. // number of steps for a king in x to reach y.
template<typename T1 = Square> inline int distance(Square x, Square y); template<typename T1 = Square>
template<> inline int distance<File>(Square x, Square y) { return std::abs(file_of(x) - file_of(y)); } inline int distance(Square x, Square y);
template<> inline int distance<Rank>(Square x, Square y) { return std::abs(rank_of(x) - rank_of(y)); } template<>
template<> inline int distance<Square>(Square x, Square y) { return SquareDistance[x][y]; } inline int distance<File>(Square x, Square y) {
return std::abs(file_of(x) - file_of(y));
}
template<>
inline int distance<Rank>(Square x, Square y) {
return std::abs(rank_of(x) - rank_of(y));
}
template<>
inline int distance<Square>(Square x, Square y) {
return SquareDistance[x][y];
}
inline int edge_distance(File f) { return std::min(f, File(FILE_H - f)); } inline int edge_distance(File f) { return std::min(f, File(FILE_H - f)); }
@ -215,9 +218,9 @@ inline int edge_distance(File f) { return std::min(f, File(FILE_H - f)); }
template<PieceType Pt> template<PieceType Pt>
inline Bitboard attacks_bb(Square s) { inline Bitboard attacks_bb(Square s) {
assert((Pt != PAWN) && (is_ok(s))); assert((Pt != PAWN) && (is_ok(s)));
return PseudoAttacks[Pt][s]; return PseudoAttacks[Pt][s];
} }
@ -228,28 +231,36 @@ inline Bitboard attacks_bb(Square s) {
template<PieceType Pt> template<PieceType Pt>
inline Bitboard attacks_bb(Square s, Bitboard occupied) { inline Bitboard attacks_bb(Square s, Bitboard occupied) {
assert((Pt != PAWN) && (is_ok(s))); assert((Pt != PAWN) && (is_ok(s)));
switch (Pt) switch (Pt)
{ {
case BISHOP: return BishopMagics[s].attacks[BishopMagics[s].index(occupied)]; case BISHOP :
case ROOK : return RookMagics[s].attacks[ RookMagics[s].index(occupied)]; return BishopMagics[s].attacks[BishopMagics[s].index(occupied)];
case QUEEN : return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied); case ROOK :
default : return PseudoAttacks[Pt][s]; return RookMagics[s].attacks[RookMagics[s].index(occupied)];
} case QUEEN :
return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied);
default :
return PseudoAttacks[Pt][s];
}
} }
inline Bitboard attacks_bb(PieceType pt, Square s, Bitboard occupied) { inline Bitboard attacks_bb(PieceType pt, Square s, Bitboard occupied) {
assert((pt != PAWN) && (is_ok(s))); assert((pt != PAWN) && (is_ok(s)));
switch (pt) switch (pt)
{ {
case BISHOP: return attacks_bb<BISHOP>(s, occupied); case BISHOP :
case ROOK : return attacks_bb< ROOK>(s, occupied); return attacks_bb<BISHOP>(s, occupied);
case QUEEN : return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied); case ROOK :
default : return PseudoAttacks[pt][s]; return attacks_bb<ROOK>(s, occupied);
} case QUEEN :
return attacks_bb<BISHOP>(s, occupied) | attacks_bb<ROOK>(s, occupied);
default :
return PseudoAttacks[pt][s];
}
} }
@ -259,16 +270,19 @@ inline int popcount(Bitboard b) {
#ifndef USE_POPCNT #ifndef USE_POPCNT
union { Bitboard bb; uint16_t u[4]; } v = { b }; union {
return PopCnt16[v.u[0]] + PopCnt16[v.u[1]] + PopCnt16[v.u[2]] + PopCnt16[v.u[3]]; Bitboard bb;
uint16_t u[4];
} v = {b};
return PopCnt16[v.u[0]] + PopCnt16[v.u[1]] + PopCnt16[v.u[2]] + PopCnt16[v.u[3]];
#elif defined(_MSC_VER) #elif defined(_MSC_VER)
return int(_mm_popcnt_u64(b)); return int(_mm_popcnt_u64(b));
#else // Assumed gcc or compatible compiler #else // Assumed gcc or compatible compiler
return __builtin_popcountll(b); return __builtin_popcountll(b);
#endif #endif
} }
@ -279,66 +293,72 @@ inline int popcount(Bitboard b) {
#if defined(__GNUC__) // GCC, Clang, ICX #if defined(__GNUC__) // GCC, Clang, ICX
inline Square lsb(Bitboard b) { inline Square lsb(Bitboard b) {
assert(b); assert(b);
return Square(__builtin_ctzll(b)); return Square(__builtin_ctzll(b));
} }
inline Square msb(Bitboard b) { inline Square msb(Bitboard b) {
assert(b); assert(b);
return Square(63 ^ __builtin_clzll(b)); return Square(63 ^ __builtin_clzll(b));
} }
#elif defined(_MSC_VER) // MSVC #elif defined(_MSC_VER) // MSVC
#ifdef _WIN64 // MSVC, WIN64 #ifdef _WIN64 // MSVC, WIN64
inline Square lsb(Bitboard b) { inline Square lsb(Bitboard b) {
assert(b); assert(b);
unsigned long idx; unsigned long idx;
_BitScanForward64(&idx, b); _BitScanForward64(&idx, b);
return (Square) idx; return (Square) idx;
} }
inline Square msb(Bitboard b) { inline Square msb(Bitboard b) {
assert(b); assert(b);
unsigned long idx; unsigned long idx;
_BitScanReverse64(&idx, b); _BitScanReverse64(&idx, b);
return (Square) idx; return (Square) idx;
} }
#else // MSVC, WIN32 #else // MSVC, WIN32
inline Square lsb(Bitboard b) { inline Square lsb(Bitboard b) {
assert(b); assert(b);
unsigned long idx; unsigned long idx;
if (b & 0xffffffff) { if (b & 0xffffffff)
_BitScanForward(&idx, int32_t(b)); {
return Square(idx); _BitScanForward(&idx, int32_t(b));
} else { return Square(idx);
_BitScanForward(&idx, int32_t(b >> 32)); }
return Square(idx + 32); else
} {
_BitScanForward(&idx, int32_t(b >> 32));
return Square(idx + 32);
}
} }
inline Square msb(Bitboard b) { inline Square msb(Bitboard b) {
assert(b); assert(b);
unsigned long idx; unsigned long idx;
if (b >> 32) { if (b >> 32)
_BitScanReverse(&idx, int32_t(b >> 32)); {
return Square(idx + 32); _BitScanReverse(&idx, int32_t(b >> 32));
} else { return Square(idx + 32);
_BitScanReverse(&idx, int32_t(b)); }
return Square(idx); else
} {
_BitScanReverse(&idx, int32_t(b));
return Square(idx);
}
} }
#endif #endif
#else // Compiler is neither GCC nor MSVC compatible #else // Compiler is neither GCC nor MSVC compatible
#error "Compiler not supported." #error "Compiler not supported."
#endif #endif
@ -346,19 +366,19 @@ inline Square msb(Bitboard b) {
// square of a non-zero bitboard. It is equivalent to square_bb(lsb(bb)). // square of a non-zero bitboard. It is equivalent to square_bb(lsb(bb)).
inline Bitboard least_significant_square_bb(Bitboard b) { inline Bitboard least_significant_square_bb(Bitboard b) {
assert(b); assert(b);
return b & -b; return b & -b;
} }
// pop_lsb() finds and clears the least significant bit in a non-zero bitboard // pop_lsb() finds and clears the least significant bit in a non-zero bitboard
inline Square pop_lsb(Bitboard& b) { inline Square pop_lsb(Bitboard& b) {
assert(b); assert(b);
const Square s = lsb(b); const Square s = lsb(b);
b &= b - 1; b &= b - 1;
return s; return s;
} }
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef BITBOARD_H_INCLUDED #endif // #ifndef BITBOARD_H_INCLUDED

View file

@ -43,11 +43,11 @@
// const unsigned int gEmbeddedNNUESize; // the size of the embedded file // const unsigned int gEmbeddedNNUESize; // the size of the embedded file
// Note that this does not work in Microsoft Visual Studio. // Note that this does not work in Microsoft Visual Studio.
#if !defined(_MSC_VER) && !defined(NNUE_EMBEDDING_OFF) #if !defined(_MSC_VER) && !defined(NNUE_EMBEDDING_OFF)
INCBIN(EmbeddedNNUE, EvalFileDefaultName); INCBIN(EmbeddedNNUE, EvalFileDefaultName);
#else #else
const unsigned char gEmbeddedNNUEData[1] = {0x0}; const unsigned char gEmbeddedNNUEData[1] = {0x0};
const unsigned char *const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1]; const unsigned char* const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1];
const unsigned int gEmbeddedNNUESize = 1; const unsigned int gEmbeddedNNUESize = 1;
#endif #endif
@ -55,27 +55,28 @@ namespace Stockfish {
namespace Eval { namespace Eval {
std::string currentEvalFileName = "None"; std::string currentEvalFileName = "None";
// NNUE::init() tries to load a NNUE network at startup time, or when the engine // NNUE::init() tries to load a NNUE network at startup time, or when the engine
// receives a UCI command "setoption name EvalFile value nn-[a-z0-9]{12}.nnue" // receives a UCI command "setoption name EvalFile value nn-[a-z0-9]{12}.nnue"
// The name of the NNUE network is always retrieved from the EvalFile option. // The name of the NNUE network is always retrieved from the EvalFile option.
// We search the given network in three locations: internally (the default // We search the given network in three locations: internally (the default
// network may be embedded in the binary), in the active working directory and // network may be embedded in the binary), in the active working directory and
// in the engine directory. Distro packagers may define the DEFAULT_NNUE_DIRECTORY // in the engine directory. Distro packagers may define the DEFAULT_NNUE_DIRECTORY
// variable to have the engine search in a special directory in their distro. // variable to have the engine search in a special directory in their distro.
void NNUE::init() { void NNUE::init() {
std::string eval_file = std::string(Options["EvalFile"]); std::string eval_file = std::string(Options["EvalFile"]);
if (eval_file.empty()) if (eval_file.empty())
eval_file = EvalFileDefaultName; eval_file = EvalFileDefaultName;
#if defined(DEFAULT_NNUE_DIRECTORY) #if defined(DEFAULT_NNUE_DIRECTORY)
std::vector<std::string> dirs = { "<internal>" , "" , CommandLine::binaryDirectory , stringify(DEFAULT_NNUE_DIRECTORY) }; std::vector<std::string> dirs = {"<internal>", "", CommandLine::binaryDirectory,
#else stringify(DEFAULT_NNUE_DIRECTORY)};
std::vector<std::string> dirs = { "<internal>" , "" , CommandLine::binaryDirectory }; #else
#endif std::vector<std::string> dirs = {"<internal>", "", CommandLine::binaryDirectory};
#endif
for (const std::string& directory : dirs) for (const std::string& directory : dirs)
if (currentEvalFileName != eval_file) if (currentEvalFileName != eval_file)
@ -90,23 +91,28 @@ namespace Eval {
if (directory == "<internal>" && eval_file == EvalFileDefaultName) if (directory == "<internal>" && eval_file == EvalFileDefaultName)
{ {
// C++ way to prepare a buffer for a memory stream // C++ way to prepare a buffer for a memory stream
class MemoryBuffer : public std::basic_streambuf<char> { class MemoryBuffer: public std::basic_streambuf<char> {
public: MemoryBuffer(char* p, size_t n) { setg(p, p, p + n); setp(p, p + n); } public:
MemoryBuffer(char* p, size_t n) {
setg(p, p, p + n);
setp(p, p + n);
}
}; };
MemoryBuffer buffer(const_cast<char*>(reinterpret_cast<const char*>(gEmbeddedNNUEData)), MemoryBuffer buffer(
size_t(gEmbeddedNNUESize)); const_cast<char*>(reinterpret_cast<const char*>(gEmbeddedNNUEData)),
(void) gEmbeddedNNUEEnd; // Silence warning on unused variable size_t(gEmbeddedNNUESize));
(void) gEmbeddedNNUEEnd; // Silence warning on unused variable
std::istream stream(&buffer); std::istream stream(&buffer);
if (NNUE::load_eval(eval_file, stream)) if (NNUE::load_eval(eval_file, stream))
currentEvalFileName = eval_file; currentEvalFileName = eval_file;
} }
} }
} }
// NNUE::verify() verifies that the last net used was loaded successfully // NNUE::verify() verifies that the last net used was loaded successfully
void NNUE::verify() { void NNUE::verify() {
std::string eval_file = std::string(Options["EvalFile"]); std::string eval_file = std::string(Options["EvalFile"]);
if (eval_file.empty()) if (eval_file.empty())
@ -115,10 +121,14 @@ namespace Eval {
if (currentEvalFileName != eval_file) if (currentEvalFileName != eval_file)
{ {
std::string msg1 = "Network evaluation parameters compatible with the engine must be available."; std::string msg1 =
"Network evaluation parameters compatible with the engine must be available.";
std::string msg2 = "The network file " + eval_file + " was not loaded successfully."; std::string msg2 = "The network file " + eval_file + " was not loaded successfully.";
std::string msg3 = "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file."; std::string msg3 =
std::string msg4 = "The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/" + std::string(EvalFileDefaultName); "The UCI option EvalFile might need to specify the full path, including the directory name, to the network file.";
std::string msg4 =
"The default net can be downloaded from: https://tests.stockfishchess.org/api/nn/"
+ std::string(EvalFileDefaultName);
std::string msg5 = "The engine will be terminated now."; std::string msg5 = "The engine will be terminated now.";
sync_cout << "info string ERROR: " << msg1 << sync_endl; sync_cout << "info string ERROR: " << msg1 << sync_endl;
@ -131,7 +141,7 @@ namespace Eval {
} }
sync_cout << "info string NNUE evaluation using " << eval_file << sync_endl; sync_cout << "info string NNUE evaluation using " << eval_file << sync_endl;
} }
} }
@ -140,8 +150,8 @@ namespace Eval {
// an approximation of the material advantage on the board in terms of pawns. // an approximation of the material advantage on the board in terms of pawns.
Value Eval::simple_eval(const Position& pos, Color c) { Value Eval::simple_eval(const Position& pos, Color c) {
return PawnValue * (pos.count<PAWN>(c) - pos.count<PAWN>(~c)) return PawnValue * (pos.count<PAWN>(c) - pos.count<PAWN>(~c))
+ (pos.non_pawn_material(c) - pos.non_pawn_material(~c)); + (pos.non_pawn_material(c) - pos.non_pawn_material(~c));
} }
@ -150,43 +160,41 @@ Value Eval::simple_eval(const Position& pos, Color c) {
Value Eval::evaluate(const Position& pos) { Value Eval::evaluate(const Position& pos) {
assert(!pos.checkers()); assert(!pos.checkers());
Value v; Value v;
Color stm = pos.side_to_move(); Color stm = pos.side_to_move();
int shuffling = pos.rule50_count(); int shuffling = pos.rule50_count();
int simpleEval = simple_eval(pos, stm) + (int(pos.key() & 7) - 3); int simpleEval = simple_eval(pos, stm) + (int(pos.key() & 7) - 3);
bool lazy = abs(simpleEval) >= RookValue + KnightValue bool lazy = abs(simpleEval) >= RookValue + KnightValue + 16 * shuffling * shuffling
+ 16 * shuffling * shuffling + abs(pos.this_thread()->bestValue)
+ abs(pos.this_thread()->bestValue) + abs(pos.this_thread()->rootSimpleEval);
+ abs(pos.this_thread()->rootSimpleEval);
if (lazy) if (lazy)
v = Value(simpleEval); v = Value(simpleEval);
else else
{ {
int nnueComplexity; int nnueComplexity;
Value nnue = NNUE::evaluate(pos, true, &nnueComplexity); Value nnue = NNUE::evaluate(pos, true, &nnueComplexity);
Value optimism = pos.this_thread()->optimism[stm]; Value optimism = pos.this_thread()->optimism[stm];
// Blend optimism and eval with nnue complexity and material imbalance // Blend optimism and eval with nnue complexity and material imbalance
optimism += optimism * (nnueComplexity + abs(simpleEval - nnue)) / 512; optimism += optimism * (nnueComplexity + abs(simpleEval - nnue)) / 512;
nnue -= nnue * (nnueComplexity + abs(simpleEval - nnue)) / 32768; nnue -= nnue * (nnueComplexity + abs(simpleEval - nnue)) / 32768;
int npm = pos.non_pawn_material() / 64; int npm = pos.non_pawn_material() / 64;
v = ( nnue * (915 + npm + 9 * pos.count<PAWN>()) v = (nnue * (915 + npm + 9 * pos.count<PAWN>()) + optimism * (154 + npm)) / 1024;
+ optimism * (154 + npm )) / 1024; }
}
// Damp down the evaluation linearly when shuffling // Damp down the evaluation linearly when shuffling
v = v * (200 - shuffling) / 214; v = v * (200 - shuffling) / 214;
// Guarantee evaluation does not hit the tablebase range // Guarantee evaluation does not hit the tablebase range
v = std::clamp(v, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1); v = std::clamp(v, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1);
return v; return v;
} }
// trace() is like evaluate(), but instead of returning a value, it returns // trace() is like evaluate(), but instead of returning a value, it returns
@ -196,33 +204,33 @@ Value Eval::evaluate(const Position& pos) {
std::string Eval::trace(Position& pos) { std::string Eval::trace(Position& pos) {
if (pos.checkers()) if (pos.checkers())
return "Final evaluation: none (in check)"; return "Final evaluation: none (in check)";
// Reset any global variable used in eval // Reset any global variable used in eval
pos.this_thread()->bestValue = VALUE_ZERO; pos.this_thread()->bestValue = VALUE_ZERO;
pos.this_thread()->rootSimpleEval = VALUE_ZERO; pos.this_thread()->rootSimpleEval = VALUE_ZERO;
pos.this_thread()->optimism[WHITE] = VALUE_ZERO; pos.this_thread()->optimism[WHITE] = VALUE_ZERO;
pos.this_thread()->optimism[BLACK] = VALUE_ZERO; pos.this_thread()->optimism[BLACK] = VALUE_ZERO;
std::stringstream ss; std::stringstream ss;
ss << std::showpoint << std::noshowpos << std::fixed << std::setprecision(2); ss << std::showpoint << std::noshowpos << std::fixed << std::setprecision(2);
ss << '\n' << NNUE::trace(pos) << '\n'; ss << '\n' << NNUE::trace(pos) << '\n';
ss << std::showpoint << std::showpos << std::fixed << std::setprecision(2) << std::setw(15); ss << std::showpoint << std::showpos << std::fixed << std::setprecision(2) << std::setw(15);
Value v; Value v;
v = NNUE::evaluate(pos, false); v = NNUE::evaluate(pos, false);
v = pos.side_to_move() == WHITE ? v : -v; v = pos.side_to_move() == WHITE ? v : -v;
ss << "NNUE evaluation " << 0.01 * UCI::to_cp(v) << " (white side)\n"; ss << "NNUE evaluation " << 0.01 * UCI::to_cp(v) << " (white side)\n";
v = evaluate(pos); v = evaluate(pos);
v = pos.side_to_move() == WHITE ? v : -v; v = pos.side_to_move() == WHITE ? v : -v;
ss << "Final evaluation " << 0.01 * UCI::to_cp(v) << " (white side)"; ss << "Final evaluation " << 0.01 * UCI::to_cp(v) << " (white side)";
ss << " [with scaled NNUE, ...]"; ss << " [with scaled NNUE, ...]";
ss << "\n"; ss << "\n";
return ss.str(); return ss.str();
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -29,27 +29,27 @@ class Position;
namespace Eval { namespace Eval {
std::string trace(Position& pos); std::string trace(Position& pos);
Value simple_eval(const Position& pos, Color c); Value simple_eval(const Position& pos, Color c);
Value evaluate(const Position& pos); Value evaluate(const Position& pos);
extern std::string currentEvalFileName; extern std::string currentEvalFileName;
// The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue // The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue
// for the build process (profile-build and fishtest) to work. Do not change the // for the build process (profile-build and fishtest) to work. Do not change the
// name of the macro, as it is used in the Makefile. // name of the macro, as it is used in the Makefile.
#define EvalFileDefaultName "nn-0000000000a0.nnue" #define EvalFileDefaultName "nn-0000000000a0.nnue"
namespace NNUE { namespace NNUE {
void init(); void init();
void verify(); void verify();
} // namespace NNUE } // namespace NNUE
} // namespace Eval } // namespace Eval
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef EVALUATE_H_INCLUDED #endif // #ifndef EVALUATE_H_INCLUDED

View file

@ -33,19 +33,19 @@ using namespace Stockfish;
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
std::cout << engine_info() << std::endl; std::cout << engine_info() << std::endl;
CommandLine::init(argc, argv); CommandLine::init(argc, argv);
UCI::init(Options); UCI::init(Options);
Tune::init(); Tune::init();
Bitboards::init(); Bitboards::init();
Position::init(); Position::init();
Threads.set(size_t(Options["Threads"])); Threads.set(size_t(Options["Threads"]));
Search::clear(); // After threads are up Search::clear(); // After threads are up
Eval::NNUE::init(); Eval::NNUE::init();
UCI::loop(argc, argv); UCI::loop(argc, argv);
Threads.set(0); Threads.set(0);
return 0; return 0;
} }

File diff suppressed because it is too large Load diff

View file

@ -33,12 +33,13 @@ namespace Stockfish {
std::string engine_info(bool to_uci = false); std::string engine_info(bool to_uci = false);
std::string compiler_info(); std::string compiler_info();
void prefetch(void* addr); void prefetch(void* addr);
void start_logger(const std::string& fname); void start_logger(const std::string& fname);
void* std_aligned_alloc(size_t alignment, size_t size); void* std_aligned_alloc(size_t alignment, size_t size);
void std_aligned_free(void* ptr); void std_aligned_free(void* ptr);
void* aligned_large_pages_alloc(size_t size); // memory aligned by page size, min alignment: 4096 bytes void* aligned_large_pages_alloc(
void aligned_large_pages_free(void* mem); // nop if mem == nullptr size_t size); // memory aligned by page size, min alignment: 4096 bytes
void aligned_large_pages_free(void* mem); // nop if mem == nullptr
void dbg_hit_on(bool cond, int slot = 0); void dbg_hit_on(bool cond, int slot = 0);
void dbg_mean_of(int64_t value, int slot = 0); void dbg_mean_of(int64_t value, int slot = 0);
@ -46,15 +47,19 @@ void dbg_stdev_of(int64_t value, int slot = 0);
void dbg_correl_of(int64_t value1, int64_t value2, int slot = 0); void dbg_correl_of(int64_t value1, int64_t value2, int slot = 0);
void dbg_print(); void dbg_print();
using TimePoint = std::chrono::milliseconds::rep; // A value in milliseconds using TimePoint = std::chrono::milliseconds::rep; // A value in milliseconds
static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits"); static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits");
inline TimePoint now() { inline TimePoint now() {
return std::chrono::duration_cast<std::chrono::milliseconds> return std::chrono::duration_cast<std::chrono::milliseconds>(
(std::chrono::steady_clock::now().time_since_epoch()).count(); std::chrono::steady_clock::now().time_since_epoch())
.count();
} }
enum SyncCout { IO_LOCK, IO_UNLOCK }; enum SyncCout {
IO_LOCK,
IO_UNLOCK
};
std::ostream& operator<<(std::ostream&, SyncCout); std::ostream& operator<<(std::ostream&, SyncCout);
#define sync_cout std::cout << IO_LOCK #define sync_cout std::cout << IO_LOCK
@ -64,34 +69,37 @@ std::ostream& operator<<(std::ostream&, SyncCout);
// align_ptr_up() : get the first aligned element of an array. // align_ptr_up() : get the first aligned element of an array.
// ptr must point to an array of size at least `sizeof(T) * N + alignment` bytes, // ptr must point to an array of size at least `sizeof(T) * N + alignment` bytes,
// where N is the number of elements in the array. // where N is the number of elements in the array.
template <uintptr_t Alignment, typename T> template<uintptr_t Alignment, typename T>
T* align_ptr_up(T* ptr) T* align_ptr_up(T* ptr) {
{ static_assert(alignof(T) < Alignment);
static_assert(alignof(T) < Alignment);
const uintptr_t ptrint = reinterpret_cast<uintptr_t>(reinterpret_cast<char*>(ptr)); const uintptr_t ptrint = reinterpret_cast<uintptr_t>(reinterpret_cast<char*>(ptr));
return reinterpret_cast<T*>(reinterpret_cast<char*>((ptrint + (Alignment - 1)) / Alignment * Alignment)); return reinterpret_cast<T*>(
reinterpret_cast<char*>((ptrint + (Alignment - 1)) / Alignment * Alignment));
} }
// IsLittleEndian : true if and only if the binary is compiled on a little-endian machine // IsLittleEndian : true if and only if the binary is compiled on a little-endian machine
static inline const union { uint32_t i; char c[4]; } Le = { 0x01020304 }; static inline const union {
uint32_t i;
char c[4];
} Le = {0x01020304};
static inline const bool IsLittleEndian = (Le.c[0] == 4); static inline const bool IsLittleEndian = (Le.c[0] == 4);
template <typename T, std::size_t MaxSize> template<typename T, std::size_t MaxSize>
class ValueList { class ValueList {
public: public:
std::size_t size() const { return size_; } std::size_t size() const { return size_; }
void push_back(const T& value) { values_[size_++] = value; } void push_back(const T& value) { values_[size_++] = value; }
const T* begin() const { return values_; } const T* begin() const { return values_; }
const T* end() const { return values_ + size_; } const T* end() const { return values_ + size_; }
const T& operator[](int index) const { return values_[index]; } const T& operator[](int index) const { return values_[index]; }
private: private:
T values_[MaxSize]; T values_[MaxSize];
std::size_t size_ = 0; std::size_t size_ = 0;
}; };
@ -112,23 +120,31 @@ private:
class PRNG { class PRNG {
uint64_t s; uint64_t s;
uint64_t rand64() { uint64_t rand64() {
s ^= s >> 12, s ^= s << 25, s ^= s >> 27; s ^= s >> 12, s ^= s << 25, s ^= s >> 27;
return s * 2685821657736338717LL; return s * 2685821657736338717LL;
} }
public: public:
PRNG(uint64_t seed) : s(seed) { assert(seed); } PRNG(uint64_t seed) :
s(seed) {
assert(seed);
}
template<typename T> T rand() { return T(rand64()); } template<typename T>
T rand() {
return T(rand64());
}
// Special generator used to fast init magic numbers. // Special generator used to fast init magic numbers.
// Output values only have 1/8th of their bits set on average. // Output values only have 1/8th of their bits set on average.
template<typename T> T sparse_rand() template<typename T>
{ return T(rand64() & rand64() & rand64()); } T sparse_rand() {
return T(rand64() & rand64() & rand64());
}
}; };
inline uint64_t mul_hi64(uint64_t a, uint64_t b) { inline uint64_t mul_hi64(uint64_t a, uint64_t b) {
@ -152,16 +168,16 @@ inline uint64_t mul_hi64(uint64_t a, uint64_t b) {
// Peter Österlund. // Peter Österlund.
namespace WinProcGroup { namespace WinProcGroup {
void bindThisThread(size_t idx); void bindThisThread(size_t idx);
} }
namespace CommandLine { namespace CommandLine {
void init(int argc, char* argv[]); void init(int argc, char* argv[]);
extern std::string binaryDirectory; // path of the executable directory extern std::string binaryDirectory; // path of the executable directory
extern std::string workingDirectory; // path of the working directory extern std::string workingDirectory; // path of the working directory
} }
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef MISC_H_INCLUDED #endif // #ifndef MISC_H_INCLUDED

View file

@ -28,8 +28,8 @@ namespace Stockfish {
namespace { namespace {
template<GenType Type, Direction D, bool Enemy> template<GenType Type, Direction D, bool Enemy>
ExtMove* make_promotions(ExtMove* moveList, [[maybe_unused]] Square to) { ExtMove* make_promotions(ExtMove* moveList, [[maybe_unused]] Square to) {
if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS) if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS)
{ {
@ -50,33 +50,32 @@ namespace {
} }
return moveList; return moveList;
} }
template<Color Us, GenType Type> template<Color Us, GenType Type>
ExtMove* generate_pawn_moves(const Position& pos, ExtMove* moveList, Bitboard target) { ExtMove* generate_pawn_moves(const Position& pos, ExtMove* moveList, Bitboard target) {
constexpr Color Them = ~Us; constexpr Color Them = ~Us;
constexpr Bitboard TRank7BB = (Us == WHITE ? Rank7BB : Rank2BB); constexpr Bitboard TRank7BB = (Us == WHITE ? Rank7BB : Rank2BB);
constexpr Bitboard TRank3BB = (Us == WHITE ? Rank3BB : Rank6BB); constexpr Bitboard TRank3BB = (Us == WHITE ? Rank3BB : Rank6BB);
constexpr Direction Up = pawn_push(Us); constexpr Direction Up = pawn_push(Us);
constexpr Direction UpRight = (Us == WHITE ? NORTH_EAST : SOUTH_WEST); constexpr Direction UpRight = (Us == WHITE ? NORTH_EAST : SOUTH_WEST);
constexpr Direction UpLeft = (Us == WHITE ? NORTH_WEST : SOUTH_EAST); constexpr Direction UpLeft = (Us == WHITE ? NORTH_WEST : SOUTH_EAST);
const Bitboard emptySquares = ~pos.pieces(); const Bitboard emptySquares = ~pos.pieces();
const Bitboard enemies = Type == EVASIONS ? pos.checkers() const Bitboard enemies = Type == EVASIONS ? pos.checkers() : pos.pieces(Them);
: pos.pieces(Them);
Bitboard pawnsOn7 = pos.pieces(Us, PAWN) & TRank7BB; Bitboard pawnsOn7 = pos.pieces(Us, PAWN) & TRank7BB;
Bitboard pawnsNotOn7 = pos.pieces(Us, PAWN) & ~TRank7BB; Bitboard pawnsNotOn7 = pos.pieces(Us, PAWN) & ~TRank7BB;
// Single and double pawn pushes, no promotions // Single and double pawn pushes, no promotions
if constexpr (Type != CAPTURES) if constexpr (Type != CAPTURES)
{ {
Bitboard b1 = shift<Up>(pawnsNotOn7) & emptySquares; Bitboard b1 = shift<Up>(pawnsNotOn7) & emptySquares;
Bitboard b2 = shift<Up>(b1 & TRank3BB) & emptySquares; Bitboard b2 = shift<Up>(b1 & TRank3BB) & emptySquares;
if constexpr (Type == EVASIONS) // Consider only blocking squares if constexpr (Type == EVASIONS) // Consider only blocking squares
{ {
b1 &= target; b1 &= target;
b2 &= target; b2 &= target;
@ -87,21 +86,21 @@ namespace {
// To make a quiet check, you either make a direct check by pushing a pawn // To make a quiet check, you either make a direct check by pushing a pawn
// or push a blocker pawn that is not on the same file as the enemy king. // or push a blocker pawn that is not on the same file as the enemy king.
// Discovered check promotion has been already generated amongst the captures. // Discovered check promotion has been already generated amongst the captures.
Square ksq = pos.square<KING>(Them); Square ksq = pos.square<KING>(Them);
Bitboard dcCandidatePawns = pos.blockers_for_king(Them) & ~file_bb(ksq); Bitboard dcCandidatePawns = pos.blockers_for_king(Them) & ~file_bb(ksq);
b1 &= pawn_attacks_bb(Them, ksq) | shift< Up>(dcCandidatePawns); b1 &= pawn_attacks_bb(Them, ksq) | shift<Up>(dcCandidatePawns);
b2 &= pawn_attacks_bb(Them, ksq) | shift<Up+Up>(dcCandidatePawns); b2 &= pawn_attacks_bb(Them, ksq) | shift<Up + Up>(dcCandidatePawns);
} }
while (b1) while (b1)
{ {
Square to = pop_lsb(b1); Square to = pop_lsb(b1);
*moveList++ = make_move(to - Up, to); *moveList++ = make_move(to - Up, to);
} }
while (b2) while (b2)
{ {
Square to = pop_lsb(b2); Square to = pop_lsb(b2);
*moveList++ = make_move(to - Up - Up, to); *moveList++ = make_move(to - Up - Up, to);
} }
} }
@ -110,8 +109,8 @@ namespace {
if (pawnsOn7) if (pawnsOn7)
{ {
Bitboard b1 = shift<UpRight>(pawnsOn7) & enemies; Bitboard b1 = shift<UpRight>(pawnsOn7) & enemies;
Bitboard b2 = shift<UpLeft >(pawnsOn7) & enemies; Bitboard b2 = shift<UpLeft>(pawnsOn7) & enemies;
Bitboard b3 = shift<Up >(pawnsOn7) & emptySquares; Bitboard b3 = shift<Up>(pawnsOn7) & emptySquares;
if constexpr (Type == EVASIONS) if constexpr (Type == EVASIONS)
b3 &= target; b3 &= target;
@ -123,24 +122,24 @@ namespace {
moveList = make_promotions<Type, UpLeft, true>(moveList, pop_lsb(b2)); moveList = make_promotions<Type, UpLeft, true>(moveList, pop_lsb(b2));
while (b3) while (b3)
moveList = make_promotions<Type, Up, false>(moveList, pop_lsb(b3)); moveList = make_promotions<Type, Up, false>(moveList, pop_lsb(b3));
} }
// Standard and en passant captures // Standard and en passant captures
if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS) if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS)
{ {
Bitboard b1 = shift<UpRight>(pawnsNotOn7) & enemies; Bitboard b1 = shift<UpRight>(pawnsNotOn7) & enemies;
Bitboard b2 = shift<UpLeft >(pawnsNotOn7) & enemies; Bitboard b2 = shift<UpLeft>(pawnsNotOn7) & enemies;
while (b1) while (b1)
{ {
Square to = pop_lsb(b1); Square to = pop_lsb(b1);
*moveList++ = make_move(to - UpRight, to); *moveList++ = make_move(to - UpRight, to);
} }
while (b2) while (b2)
{ {
Square to = pop_lsb(b2); Square to = pop_lsb(b2);
*moveList++ = make_move(to - UpLeft, to); *moveList++ = make_move(to - UpLeft, to);
} }
@ -162,11 +161,11 @@ namespace {
} }
return moveList; return moveList;
} }
template<Color Us, PieceType Pt, bool Checks> template<Color Us, PieceType Pt, bool Checks>
ExtMove* generate_moves(const Position& pos, ExtMove* moveList, Bitboard target) { ExtMove* generate_moves(const Position& pos, ExtMove* moveList, Bitboard target) {
static_assert(Pt != KING && Pt != PAWN, "Unsupported piece type in generate_moves()"); static_assert(Pt != KING && Pt != PAWN, "Unsupported piece type in generate_moves()");
@ -174,8 +173,8 @@ namespace {
while (bb) while (bb)
{ {
Square from = pop_lsb(bb); Square from = pop_lsb(bb);
Bitboard b = attacks_bb<Pt>(from, pos.pieces()) & target; Bitboard b = attacks_bb<Pt>(from, pos.pieces()) & target;
// To check, you either move freely a blocker or make a direct check. // To check, you either move freely a blocker or make a direct check.
if (Checks && (Pt == QUEEN || !(pos.blockers_for_king(~Us) & from))) if (Checks && (Pt == QUEEN || !(pos.blockers_for_king(~Us) & from)))
@ -186,31 +185,31 @@ namespace {
} }
return moveList; return moveList;
} }
template<Color Us, GenType Type> template<Color Us, GenType Type>
ExtMove* generate_all(const Position& pos, ExtMove* moveList) { ExtMove* generate_all(const Position& pos, ExtMove* moveList) {
static_assert(Type != LEGAL, "Unsupported type in generate_all()"); static_assert(Type != LEGAL, "Unsupported type in generate_all()");
constexpr bool Checks = Type == QUIET_CHECKS; // Reduce template instantiations constexpr bool Checks = Type == QUIET_CHECKS; // Reduce template instantiations
const Square ksq = pos.square<KING>(Us); const Square ksq = pos.square<KING>(Us);
Bitboard target; Bitboard target;
// Skip generating non-king moves when in double check // Skip generating non-king moves when in double check
if (Type != EVASIONS || !more_than_one(pos.checkers())) if (Type != EVASIONS || !more_than_one(pos.checkers()))
{ {
target = Type == EVASIONS ? between_bb(ksq, lsb(pos.checkers())) target = Type == EVASIONS ? between_bb(ksq, lsb(pos.checkers()))
: Type == NON_EVASIONS ? ~pos.pieces( Us) : Type == NON_EVASIONS ? ~pos.pieces(Us)
: Type == CAPTURES ? pos.pieces(~Us) : Type == CAPTURES ? pos.pieces(~Us)
: ~pos.pieces( ); // QUIETS || QUIET_CHECKS : ~pos.pieces(); // QUIETS || QUIET_CHECKS
moveList = generate_pawn_moves<Us, Type>(pos, moveList, target); moveList = generate_pawn_moves<Us, Type>(pos, moveList, target);
moveList = generate_moves<Us, KNIGHT, Checks>(pos, moveList, target); moveList = generate_moves<Us, KNIGHT, Checks>(pos, moveList, target);
moveList = generate_moves<Us, BISHOP, Checks>(pos, moveList, target); moveList = generate_moves<Us, BISHOP, Checks>(pos, moveList, target);
moveList = generate_moves<Us, ROOK, Checks>(pos, moveList, target); moveList = generate_moves<Us, ROOK, Checks>(pos, moveList, target);
moveList = generate_moves<Us, QUEEN, Checks>(pos, moveList, target); moveList = generate_moves<Us, QUEEN, Checks>(pos, moveList, target);
} }
if (!Checks || pos.blockers_for_king(~Us) & ksq) if (!Checks || pos.blockers_for_king(~Us) & ksq)
@ -223,15 +222,15 @@ namespace {
*moveList++ = make_move(ksq, pop_lsb(b)); *moveList++ = make_move(ksq, pop_lsb(b));
if ((Type == QUIETS || Type == NON_EVASIONS) && pos.can_castle(Us & ANY_CASTLING)) if ((Type == QUIETS || Type == NON_EVASIONS) && pos.can_castle(Us & ANY_CASTLING))
for (CastlingRights cr : { Us & KING_SIDE, Us & QUEEN_SIDE } ) for (CastlingRights cr : {Us & KING_SIDE, Us & QUEEN_SIDE})
if (!pos.castling_impeded(cr) && pos.can_castle(cr)) if (!pos.castling_impeded(cr) && pos.can_castle(cr))
*moveList++ = make<CASTLING>(ksq, pos.castling_rook_square(cr)); *moveList++ = make<CASTLING>(ksq, pos.castling_rook_square(cr));
} }
return moveList; return moveList;
} }
} // namespace } // namespace
// <CAPTURES> Generates all pseudo-legal captures plus queen promotions // <CAPTURES> Generates all pseudo-legal captures plus queen promotions
@ -246,13 +245,13 @@ namespace {
template<GenType Type> template<GenType Type>
ExtMove* generate(const Position& pos, ExtMove* moveList) { ExtMove* generate(const Position& pos, ExtMove* moveList) {
static_assert(Type != LEGAL, "Unsupported type in generate()"); static_assert(Type != LEGAL, "Unsupported type in generate()");
assert((Type == EVASIONS) == bool(pos.checkers())); assert((Type == EVASIONS) == bool(pos.checkers()));
Color us = pos.side_to_move(); Color us = pos.side_to_move();
return us == WHITE ? generate_all<WHITE, Type>(pos, moveList) return us == WHITE ? generate_all<WHITE, Type>(pos, moveList)
: generate_all<BLACK, Type>(pos, moveList); : generate_all<BLACK, Type>(pos, moveList);
} }
// Explicit template instantiations // Explicit template instantiations
@ -268,21 +267,21 @@ template ExtMove* generate<NON_EVASIONS>(const Position&, ExtMove*);
template<> template<>
ExtMove* generate<LEGAL>(const Position& pos, ExtMove* moveList) { ExtMove* generate<LEGAL>(const Position& pos, ExtMove* moveList) {
Color us = pos.side_to_move(); Color us = pos.side_to_move();
Bitboard pinned = pos.blockers_for_king(us) & pos.pieces(us); Bitboard pinned = pos.blockers_for_king(us) & pos.pieces(us);
Square ksq = pos.square<KING>(us); Square ksq = pos.square<KING>(us);
ExtMove* cur = moveList; ExtMove* cur = moveList;
moveList = pos.checkers() ? generate<EVASIONS >(pos, moveList) moveList =
: generate<NON_EVASIONS>(pos, moveList); pos.checkers() ? generate<EVASIONS>(pos, moveList) : generate<NON_EVASIONS>(pos, moveList);
while (cur != moveList) while (cur != moveList)
if ( ((pinned & from_sq(*cur)) || from_sq(*cur) == ksq || type_of(*cur) == EN_PASSANT) if (((pinned & from_sq(*cur)) || from_sq(*cur) == ksq || type_of(*cur) == EN_PASSANT)
&& !pos.legal(*cur)) && !pos.legal(*cur))
*cur = (--moveList)->move; *cur = (--moveList)->move;
else else
++cur; ++cur;
return moveList; return moveList;
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -19,7 +19,7 @@
#ifndef MOVEGEN_H_INCLUDED #ifndef MOVEGEN_H_INCLUDED
#define MOVEGEN_H_INCLUDED #define MOVEGEN_H_INCLUDED
#include <algorithm> // IWYU pragma: keep #include <algorithm> // IWYU pragma: keep
#include <cstddef> #include <cstddef>
#include "types.h" #include "types.h"
@ -29,29 +29,27 @@ namespace Stockfish {
class Position; class Position;
enum GenType { enum GenType {
CAPTURES, CAPTURES,
QUIETS, QUIETS,
QUIET_CHECKS, QUIET_CHECKS,
EVASIONS, EVASIONS,
NON_EVASIONS, NON_EVASIONS,
LEGAL LEGAL
}; };
struct ExtMove { struct ExtMove {
Move move; Move move;
int value; int value;
operator Move() const { return move; } operator Move() const { return move; }
void operator=(Move m) { move = m; } void operator=(Move m) { move = m; }
// Inhibit unwanted implicit conversions to Move // Inhibit unwanted implicit conversions to Move
// with an ambiguity that yields to a compile error. // with an ambiguity that yields to a compile error.
operator float() const = delete; operator float() const = delete;
}; };
inline bool operator<(const ExtMove& f, const ExtMove& s) { inline bool operator<(const ExtMove& f, const ExtMove& s) { return f.value < s.value; }
return f.value < s.value;
}
template<GenType> template<GenType>
ExtMove* generate(const Position& pos, ExtMove* moveList); ExtMove* generate(const Position& pos, ExtMove* moveList);
@ -62,18 +60,17 @@ ExtMove* generate(const Position& pos, ExtMove* moveList);
template<GenType T> template<GenType T>
struct MoveList { struct MoveList {
explicit MoveList(const Position& pos) : last(generate<T>(pos, moveList)) {} explicit MoveList(const Position& pos) :
const ExtMove* begin() const { return moveList; } last(generate<T>(pos, moveList)) {}
const ExtMove* end() const { return last; } const ExtMove* begin() const { return moveList; }
size_t size() const { return last - moveList; } const ExtMove* end() const { return last; }
bool contains(Move move) const { size_t size() const { return last - moveList; }
return std::find(begin(), end(), move) != end(); bool contains(Move move) const { return std::find(begin(), end(), move) != end(); }
}
private: private:
ExtMove moveList[MAX_MOVES], *last; ExtMove moveList[MAX_MOVES], *last;
}; };
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef MOVEGEN_H_INCLUDED #endif // #ifndef MOVEGEN_H_INCLUDED

View file

@ -30,29 +30,50 @@ namespace Stockfish {
namespace { namespace {
enum Stages { enum Stages {
MAIN_TT, CAPTURE_INIT, GOOD_CAPTURE, REFUTATION, QUIET_INIT, QUIET, BAD_CAPTURE, // generate main search moves
EVASION_TT, EVASION_INIT, EVASION, MAIN_TT,
PROBCUT_TT, PROBCUT_INIT, PROBCUT, CAPTURE_INIT,
QSEARCH_TT, QCAPTURE_INIT, QCAPTURE, QCHECK_INIT, QCHECK GOOD_CAPTURE,
}; REFUTATION,
QUIET_INIT,
QUIET,
BAD_CAPTURE,
// partial_insertion_sort() sorts moves in descending order up to and including // generate evasion moves
// a given limit. The order of moves smaller than the limit is left unspecified. EVASION_TT,
void partial_insertion_sort(ExtMove* begin, ExtMove* end, int limit) { EVASION_INIT,
EVASION,
// generate probcut moves
PROBCUT_TT,
PROBCUT_INIT,
PROBCUT,
// generate qsearch moves
QSEARCH_TT,
QCAPTURE_INIT,
QCAPTURE,
QCHECK_INIT,
QCHECK
};
// partial_insertion_sort() sorts moves in descending order up to and including
// a given limit. The order of moves smaller than the limit is left unspecified.
void partial_insertion_sort(ExtMove* begin, ExtMove* end, int limit) {
for (ExtMove *sortedEnd = begin, *p = begin + 1; p < end; ++p) for (ExtMove *sortedEnd = begin, *p = begin + 1; p < end; ++p)
if (p->value >= limit) if (p->value >= limit)
{ {
ExtMove tmp = *p, *q; ExtMove tmp = *p, *q;
*p = *++sortedEnd; *p = *++sortedEnd;
for (q = sortedEnd; q != begin && *(q - 1) < tmp; --q) for (q = sortedEnd; q != begin && *(q - 1) < tmp; --q)
*q = *(q - 1); *q = *(q - 1);
*q = tmp; *q = tmp;
} }
} }
} // namespace } // namespace
// Constructors of the MovePicker class. As arguments, we pass information // Constructors of the MovePicker class. As arguments, we pass information
@ -62,44 +83,57 @@ namespace {
// move ordering is at the current node. // move ordering is at the current node.
// MovePicker constructor for the main search // MovePicker constructor for the main search
MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, MovePicker::MovePicker(const Position& p,
const CapturePieceToHistory* cph, Move ttm,
const PieceToHistory** ch, Depth d,
Move cm, const ButterflyHistory* mh,
const Move* killers) const CapturePieceToHistory* cph,
: pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), const PieceToHistory** ch,
ttMove(ttm), refutations{{killers[0], 0}, {killers[1], 0}, {cm, 0}}, depth(d) Move cm,
{ const Move* killers) :
assert(d > 0); pos(p),
mainHistory(mh),
captureHistory(cph),
continuationHistory(ch),
ttMove(ttm),
refutations{{killers[0], 0}, {killers[1], 0}, {cm, 0}},
depth(d) {
assert(d > 0);
stage = (pos.checkers() ? EVASION_TT : MAIN_TT) + stage = (pos.checkers() ? EVASION_TT : MAIN_TT) + !(ttm && pos.pseudo_legal(ttm));
!(ttm && pos.pseudo_legal(ttm));
} }
// MovePicker constructor for quiescence search // MovePicker constructor for quiescence search
MovePicker::MovePicker(const Position& p, Move ttm, Depth d, const ButterflyHistory* mh, MovePicker::MovePicker(const Position& p,
const CapturePieceToHistory* cph, Move ttm,
const PieceToHistory** ch, Depth d,
Square rs) const ButterflyHistory* mh,
: pos(p), mainHistory(mh), captureHistory(cph), continuationHistory(ch), ttMove(ttm), recaptureSquare(rs), depth(d) const CapturePieceToHistory* cph,
{ const PieceToHistory** ch,
assert(d <= 0); Square rs) :
pos(p),
mainHistory(mh),
captureHistory(cph),
continuationHistory(ch),
ttMove(ttm),
recaptureSquare(rs),
depth(d) {
assert(d <= 0);
stage = (pos.checkers() ? EVASION_TT : QSEARCH_TT) + stage = (pos.checkers() ? EVASION_TT : QSEARCH_TT) + !(ttm && pos.pseudo_legal(ttm));
!( ttm
&& pos.pseudo_legal(ttm));
} }
// MovePicker constructor for ProbCut: we generate captures with SEE greater // MovePicker constructor for ProbCut: we generate captures with SEE greater
// than or equal to the given threshold. // than or equal to the given threshold.
MovePicker::MovePicker(const Position& p, Move ttm, Value th, const CapturePieceToHistory* cph) MovePicker::MovePicker(const Position& p, Move ttm, Value th, const CapturePieceToHistory* cph) :
: pos(p), captureHistory(cph), ttMove(ttm), threshold(th) pos(p),
{ captureHistory(cph),
assert(!pos.checkers()); ttMove(ttm),
threshold(th) {
assert(!pos.checkers());
stage = PROBCUT_TT + !(ttm && pos.capture_stage(ttm) stage = PROBCUT_TT
&& pos.pseudo_legal(ttm) + !(ttm && pos.capture_stage(ttm) && pos.pseudo_legal(ttm) && pos.see_ge(ttm, threshold));
&& pos.see_ge(ttm, threshold));
} }
// MovePicker::score() assigns a numerical value to each move in a list, used // MovePicker::score() assigns a numerical value to each move in a list, used
@ -108,76 +142,78 @@ MovePicker::MovePicker(const Position& p, Move ttm, Value th, const CapturePiece
template<GenType Type> template<GenType Type>
void MovePicker::score() { void MovePicker::score() {
static_assert(Type == CAPTURES || Type == QUIETS || Type == EVASIONS, "Wrong type"); static_assert(Type == CAPTURES || Type == QUIETS || Type == EVASIONS, "Wrong type");
[[maybe_unused]] Bitboard threatenedByPawn, threatenedByMinor, threatenedByRook, threatenedPieces; [[maybe_unused]] Bitboard threatenedByPawn, threatenedByMinor, threatenedByRook,
if constexpr (Type == QUIETS) threatenedPieces;
{ if constexpr (Type == QUIETS)
Color us = pos.side_to_move(); {
Color us = pos.side_to_move();
threatenedByPawn = pos.attacks_by<PAWN>(~us); threatenedByPawn = pos.attacks_by<PAWN>(~us);
threatenedByMinor = pos.attacks_by<KNIGHT>(~us) | pos.attacks_by<BISHOP>(~us) | threatenedByPawn; threatenedByMinor =
threatenedByRook = pos.attacks_by<ROOK>(~us) | threatenedByMinor; pos.attacks_by<KNIGHT>(~us) | pos.attacks_by<BISHOP>(~us) | threatenedByPawn;
threatenedByRook = pos.attacks_by<ROOK>(~us) | threatenedByMinor;
// Pieces threatened by pieces of lesser material value // Pieces threatened by pieces of lesser material value
threatenedPieces = (pos.pieces(us, QUEEN) & threatenedByRook) threatenedPieces = (pos.pieces(us, QUEEN) & threatenedByRook)
| (pos.pieces(us, ROOK) & threatenedByMinor) | (pos.pieces(us, ROOK) & threatenedByMinor)
| (pos.pieces(us, KNIGHT, BISHOP) & threatenedByPawn); | (pos.pieces(us, KNIGHT, BISHOP) & threatenedByPawn);
} }
for (auto& m : *this) for (auto& m : *this)
if constexpr (Type == CAPTURES) if constexpr (Type == CAPTURES)
m.value = (7 * int(PieceValue[pos.piece_on(to_sq(m))]) m.value =
+ (*captureHistory)[pos.moved_piece(m)][to_sq(m)][type_of(pos.piece_on(to_sq(m)))]) / 16; (7 * int(PieceValue[pos.piece_on(to_sq(m))])
+ (*captureHistory)[pos.moved_piece(m)][to_sq(m)][type_of(pos.piece_on(to_sq(m)))])
/ 16;
else if constexpr (Type == QUIETS) else if constexpr (Type == QUIETS)
{ {
Piece pc = pos.moved_piece(m); Piece pc = pos.moved_piece(m);
PieceType pt = type_of(pos.moved_piece(m)); PieceType pt = type_of(pos.moved_piece(m));
Square from = from_sq(m); Square from = from_sq(m);
Square to = to_sq(m); Square to = to_sq(m);
// histories // histories
m.value = 2 * (*mainHistory)[pos.side_to_move()][from_to(m)]; m.value = 2 * (*mainHistory)[pos.side_to_move()][from_to(m)];
m.value += 2 * (*continuationHistory[0])[pc][to]; m.value += 2 * (*continuationHistory[0])[pc][to];
m.value += (*continuationHistory[1])[pc][to]; m.value += (*continuationHistory[1])[pc][to];
m.value += (*continuationHistory[2])[pc][to] / 4; m.value += (*continuationHistory[2])[pc][to] / 4;
m.value += (*continuationHistory[3])[pc][to]; m.value += (*continuationHistory[3])[pc][to];
m.value += (*continuationHistory[5])[pc][to]; m.value += (*continuationHistory[5])[pc][to];
// bonus for checks // bonus for checks
m.value += bool(pos.check_squares(pt) & to) * 16384; m.value += bool(pos.check_squares(pt) & to) * 16384;
// bonus for escaping from capture // bonus for escaping from capture
m.value += threatenedPieces & from ? m.value += threatenedPieces & from ? (pt == QUEEN && !(to & threatenedByRook) ? 50000
(pt == QUEEN && !(to & threatenedByRook) ? 50000 : pt == ROOK && !(to & threatenedByMinor) ? 25000
: pt == ROOK && !(to & threatenedByMinor) ? 25000 : !(to & threatenedByPawn) ? 15000
: !(to & threatenedByPawn) ? 15000 : 0)
: 0 ) : 0;
: 0 ;
// malus for putting piece en prise // malus for putting piece en prise
m.value -= !(threatenedPieces & from) ? m.value -= !(threatenedPieces & from)
(pt == QUEEN ? bool(to & threatenedByRook) * 50000 ? (pt == QUEEN ? bool(to & threatenedByRook) * 50000
+ bool(to & threatenedByMinor) * 10000 + bool(to & threatenedByMinor) * 10000
+ bool(to & threatenedByPawn) * 20000 + bool(to & threatenedByPawn) * 20000
: pt == ROOK ? bool(to & threatenedByMinor) * 25000 : pt == ROOK ? bool(to & threatenedByMinor) * 25000
+ bool(to & threatenedByPawn) * 10000 + bool(to & threatenedByPawn) * 10000
: pt != PAWN ? bool(to & threatenedByPawn) * 15000 : pt != PAWN ? bool(to & threatenedByPawn) * 15000
: 0 ) : 0)
: 0 ; : 0;
} }
else // Type == EVASIONS else // Type == EVASIONS
{ {
if (pos.capture_stage(m)) if (pos.capture_stage(m))
m.value = PieceValue[pos.piece_on(to_sq(m))] m.value = PieceValue[pos.piece_on(to_sq(m))] - Value(type_of(pos.moved_piece(m)))
- Value(type_of(pos.moved_piece(m))) + (1 << 28);
+ (1 << 28); else
else m.value = (*mainHistory)[pos.side_to_move()][from_to(m)]
m.value = (*mainHistory)[pos.side_to_move()][from_to(m)] + (*continuationHistory[0])[pos.moved_piece(m)][to_sq(m)];
+ (*continuationHistory[0])[pos.moved_piece(m)][to_sq(m)]; }
}
} }
// MovePicker::select() returns the next move satisfying a predicate function. // MovePicker::select() returns the next move satisfying a predicate function.
@ -185,17 +221,17 @@ void MovePicker::score() {
template<MovePicker::PickType T, typename Pred> template<MovePicker::PickType T, typename Pred>
Move MovePicker::select(Pred filter) { Move MovePicker::select(Pred filter) {
while (cur < endMoves) while (cur < endMoves)
{ {
if constexpr (T == Best) if constexpr (T == Best)
std::swap(*cur, *std::max_element(cur, endMoves)); std::swap(*cur, *std::max_element(cur, endMoves));
if (*cur != ttMove && filter()) if (*cur != ttMove && filter())
return *cur++; return *cur++;
cur++; cur++;
} }
return MOVE_NONE; return MOVE_NONE;
} }
// MovePicker::next_move() is the most important method of the MovePicker class. It // MovePicker::next_move() is the most important method of the MovePicker class. It
@ -204,122 +240,126 @@ Move MovePicker::select(Pred filter) {
Move MovePicker::next_move(bool skipQuiets) { Move MovePicker::next_move(bool skipQuiets) {
top: top:
switch (stage) { switch (stage)
{
case MAIN_TT: case MAIN_TT :
case EVASION_TT: case EVASION_TT :
case QSEARCH_TT: case QSEARCH_TT :
case PROBCUT_TT: case PROBCUT_TT :
++stage; ++stage;
return ttMove; return ttMove;
case CAPTURE_INIT: case CAPTURE_INIT :
case PROBCUT_INIT: case PROBCUT_INIT :
case QCAPTURE_INIT: case QCAPTURE_INIT :
cur = endBadCaptures = moves; cur = endBadCaptures = moves;
endMoves = generate<CAPTURES>(pos, cur); endMoves = generate<CAPTURES>(pos, cur);
score<CAPTURES>(); score<CAPTURES>();
partial_insertion_sort(cur, endMoves, std::numeric_limits<int>::min()); partial_insertion_sort(cur, endMoves, std::numeric_limits<int>::min());
++stage; ++stage;
goto top; goto top;
case GOOD_CAPTURE: case GOOD_CAPTURE :
if (select<Next>([&](){ if (select<Next>([&]() {
return pos.see_ge(*cur, Value(-cur->value)) ? return pos.see_ge(*cur, Value(-cur->value))
// Move losing capture to endBadCaptures to be tried later ?
true : (*endBadCaptures++ = *cur, false); })) // Move losing capture to endBadCaptures to be tried later
return *(cur - 1); true
: (*endBadCaptures++ = *cur, false);
}))
return *(cur - 1);
// Prepare the pointers to loop over the refutations array // Prepare the pointers to loop over the refutations array
cur = std::begin(refutations); cur = std::begin(refutations);
endMoves = std::end(refutations); endMoves = std::end(refutations);
// If the countermove is the same as a killer, skip it // If the countermove is the same as a killer, skip it
if ( refutations[0].move == refutations[2].move if (refutations[0].move == refutations[2].move
|| refutations[1].move == refutations[2].move) || refutations[1].move == refutations[2].move)
--endMoves; --endMoves;
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
case REFUTATION: case REFUTATION :
if (select<Next>([&](){ return *cur != MOVE_NONE if (select<Next>([&]() {
&& !pos.capture_stage(*cur) return *cur != MOVE_NONE && !pos.capture_stage(*cur) && pos.pseudo_legal(*cur);
&& pos.pseudo_legal(*cur); })) }))
return *(cur - 1); return *(cur - 1);
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
case QUIET_INIT: case QUIET_INIT :
if (!skipQuiets) if (!skipQuiets)
{ {
cur = endBadCaptures; cur = endBadCaptures;
endMoves = generate<QUIETS>(pos, cur); endMoves = generate<QUIETS>(pos, cur);
score<QUIETS>(); score<QUIETS>();
partial_insertion_sort(cur, endMoves, -3000 * depth); partial_insertion_sort(cur, endMoves, -3000 * depth);
} }
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
case QUIET: case QUIET :
if ( !skipQuiets if (!skipQuiets && select<Next>([&]() {
&& select<Next>([&](){return *cur != refutations[0].move return *cur != refutations[0].move && *cur != refutations[1].move
&& *cur != refutations[1].move && *cur != refutations[2].move;
&& *cur != refutations[2].move;})) }))
return *(cur - 1); return *(cur - 1);
// Prepare the pointers to loop over the bad captures // Prepare the pointers to loop over the bad captures
cur = moves; cur = moves;
endMoves = endBadCaptures; endMoves = endBadCaptures;
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
case BAD_CAPTURE: case BAD_CAPTURE :
return select<Next>([](){ return true; }); return select<Next>([]() { return true; });
case EVASION_INIT: case EVASION_INIT :
cur = moves; cur = moves;
endMoves = generate<EVASIONS>(pos, cur); endMoves = generate<EVASIONS>(pos, cur);
score<EVASIONS>(); score<EVASIONS>();
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
case EVASION: case EVASION :
return select<Best>([](){ return true; }); return select<Best>([]() { return true; });
case PROBCUT: case PROBCUT :
return select<Next>([&](){ return pos.see_ge(*cur, threshold); }); return select<Next>([&]() { return pos.see_ge(*cur, threshold); });
case QCAPTURE: case QCAPTURE :
if (select<Next>([&](){ return depth > DEPTH_QS_RECAPTURES if (select<Next>(
|| to_sq(*cur) == recaptureSquare; })) [&]() { return depth > DEPTH_QS_RECAPTURES || to_sq(*cur) == recaptureSquare; }))
return *(cur - 1); return *(cur - 1);
// If we did not find any move and we do not try checks, we have finished // If we did not find any move and we do not try checks, we have finished
if (depth != DEPTH_QS_CHECKS) if (depth != DEPTH_QS_CHECKS)
return MOVE_NONE; return MOVE_NONE;
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
case QCHECK_INIT: case QCHECK_INIT :
cur = moves; cur = moves;
endMoves = generate<QUIET_CHECKS>(pos, cur); endMoves = generate<QUIET_CHECKS>(pos, cur);
++stage; ++stage;
[[fallthrough]]; [[fallthrough]];
case QCHECK: case QCHECK :
return select<Next>([](){ return true; }); return select<Next>([]() { return true; });
} }
assert(false); assert(false);
return MOVE_NONE; // Silence warning return MOVE_NONE; // Silence warning
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -24,7 +24,7 @@
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <limits> #include <limits>
#include <type_traits> // IWYU pragma: keep #include <type_traits> // IWYU pragma: keep
#include "movegen.h" #include "movegen.h"
#include "types.h" #include "types.h"
@ -39,22 +39,22 @@ class Position;
template<typename T, int D> template<typename T, int D>
class StatsEntry { class StatsEntry {
T entry; T entry;
public: public:
void operator=(const T& v) { entry = v; } void operator=(const T& v) { entry = v; }
T* operator&() { return &entry; } T* operator&() { return &entry; }
T* operator->() { return &entry; } T* operator->() { return &entry; }
operator const T&() const { return entry; } operator const T&() const { return entry; }
void operator<<(int bonus) { void operator<<(int bonus) {
assert(abs(bonus) <= D); // Ensure range is [-D, D] assert(abs(bonus) <= D); // Ensure range is [-D, D]
static_assert(D <= std::numeric_limits<T>::max(), "D overflows T"); static_assert(D <= std::numeric_limits<T>::max(), "D overflows T");
entry += (bonus * D - entry * abs(bonus)) / (D * 5 / 4); entry += (bonus * D - entry * abs(bonus)) / (D * 5 / 4);
assert(abs(entry) <= D); assert(abs(entry) <= D);
} }
}; };
// Stats is a generic N-dimensional array used to store various statistics. // Stats is a generic N-dimensional array used to store various statistics.
@ -62,28 +62,32 @@ public:
// template parameter D limits the range of updates in [-D, D] when we update // template parameter D limits the range of updates in [-D, D] when we update
// values with the << operator, while the last parameters (Size and Sizes) // values with the << operator, while the last parameters (Size and Sizes)
// encode the dimensions of the array. // encode the dimensions of the array.
template <typename T, int D, int Size, int... Sizes> template<typename T, int D, int Size, int... Sizes>
struct Stats : public std::array<Stats<T, D, Sizes...>, Size> struct Stats: public std::array<Stats<T, D, Sizes...>, Size> {
{ using stats = Stats<T, D, Size, Sizes...>;
using stats = Stats<T, D, Size, Sizes...>;
void fill(const T& v) { void fill(const T& v) {
// For standard-layout 'this' points to the first struct member // For standard-layout 'this' points to the first struct member
assert(std::is_standard_layout_v<stats>); assert(std::is_standard_layout_v<stats>);
using entry = StatsEntry<T, D>; using entry = StatsEntry<T, D>;
entry* p = reinterpret_cast<entry*>(this); entry* p = reinterpret_cast<entry*>(this);
std::fill(p, p + sizeof(*this) / sizeof(entry), v); std::fill(p, p + sizeof(*this) / sizeof(entry), v);
} }
}; };
template <typename T, int D, int Size> template<typename T, int D, int Size>
struct Stats<T, D, Size> : public std::array<StatsEntry<T, D>, Size> {}; struct Stats<T, D, Size>: public std::array<StatsEntry<T, D>, Size> {};
// In stats table, D=0 means that the template parameter is not used // In stats table, D=0 means that the template parameter is not used
enum StatsParams { NOT_USED = 0 }; enum StatsParams {
enum StatsType { NoCaptures, Captures }; NOT_USED = 0
};
enum StatsType {
NoCaptures,
Captures
};
// ButterflyHistory records how often quiet moves have been successful or // ButterflyHistory records how often quiet moves have been successful or
// unsuccessful during the current search, and is used for reduction and move // unsuccessful during the current search, and is used for reduction and move
@ -117,42 +121,53 @@ using ContinuationHistory = Stats<PieceToHistory, NOT_USED, PIECE_NB, SQUARE_NB>
// likely to get a cut-off first. // likely to get a cut-off first.
class MovePicker { class MovePicker {
enum PickType { Next, Best }; enum PickType {
Next,
Best
};
public: public:
MovePicker(const MovePicker&) = delete; MovePicker(const MovePicker&) = delete;
MovePicker& operator=(const MovePicker&) = delete; MovePicker& operator=(const MovePicker&) = delete;
MovePicker(const Position&, Move, Depth, const ButterflyHistory*, MovePicker(const Position&,
const CapturePieceToHistory*, Move,
const PieceToHistory**, Depth,
Move, const ButterflyHistory*,
const Move*); const CapturePieceToHistory*,
MovePicker(const Position&, Move, Depth, const ButterflyHistory*, const PieceToHistory**,
const CapturePieceToHistory*, Move,
const PieceToHistory**, const Move*);
Square); MovePicker(const Position&,
MovePicker(const Position&, Move, Value, const CapturePieceToHistory*); Move,
Move next_move(bool skipQuiets = false); Depth,
const ButterflyHistory*,
const CapturePieceToHistory*,
const PieceToHistory**,
Square);
MovePicker(const Position&, Move, Value, const CapturePieceToHistory*);
Move next_move(bool skipQuiets = false);
private: private:
template<PickType T, typename Pred> Move select(Pred); template<PickType T, typename Pred>
template<GenType> void score(); Move select(Pred);
ExtMove* begin() { return cur; } template<GenType>
ExtMove* end() { return endMoves; } void score();
ExtMove* begin() { return cur; }
ExtMove* end() { return endMoves; }
const Position& pos; const Position& pos;
const ButterflyHistory* mainHistory; const ButterflyHistory* mainHistory;
const CapturePieceToHistory* captureHistory; const CapturePieceToHistory* captureHistory;
const PieceToHistory** continuationHistory; const PieceToHistory** continuationHistory;
Move ttMove; Move ttMove;
ExtMove refutations[3], *cur, *endMoves, *endBadCaptures; ExtMove refutations[3], *cur, *endMoves, *endBadCaptures;
int stage; int stage;
Square recaptureSquare; Square recaptureSquare;
Value threshold; Value threshold;
Depth depth; Depth depth;
ExtMove moves[MAX_MOVES]; ExtMove moves[MAX_MOVES];
}; };
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef MOVEPICK_H_INCLUDED #endif // #ifndef MOVEPICK_H_INCLUDED

View file

@ -39,136 +39,144 @@
namespace Stockfish::Eval::NNUE { namespace Stockfish::Eval::NNUE {
// Input feature converter // Input feature converter
LargePagePtr<FeatureTransformer> featureTransformer; LargePagePtr<FeatureTransformer> featureTransformer;
// Evaluation function // Evaluation function
AlignedPtr<Network> network[LayerStacks]; AlignedPtr<Network> network[LayerStacks];
// Evaluation function file name // Evaluation function file name
std::string fileName; std::string fileName;
std::string netDescription; std::string netDescription;
namespace Detail { namespace Detail {
// Initialize the evaluation function parameters // Initialize the evaluation function parameters
template <typename T> template<typename T>
void initialize(AlignedPtr<T>& pointer) { void initialize(AlignedPtr<T>& pointer) {
pointer.reset(reinterpret_cast<T*>(std_aligned_alloc(alignof(T), sizeof(T)))); pointer.reset(reinterpret_cast<T*>(std_aligned_alloc(alignof(T), sizeof(T))));
std::memset(pointer.get(), 0, sizeof(T)); std::memset(pointer.get(), 0, sizeof(T));
} }
template <typename T> template<typename T>
void initialize(LargePagePtr<T>& pointer) { void initialize(LargePagePtr<T>& pointer) {
static_assert(alignof(T) <= 4096, "aligned_large_pages_alloc() may fail for such a big alignment requirement of T"); static_assert(alignof(T) <= 4096,
"aligned_large_pages_alloc() may fail for such a big alignment requirement of T");
pointer.reset(reinterpret_cast<T*>(aligned_large_pages_alloc(sizeof(T)))); pointer.reset(reinterpret_cast<T*>(aligned_large_pages_alloc(sizeof(T))));
std::memset(pointer.get(), 0, sizeof(T)); std::memset(pointer.get(), 0, sizeof(T));
} }
// Read evaluation function parameters // Read evaluation function parameters
template <typename T> template<typename T>
bool read_parameters(std::istream& stream, T& reference) { bool read_parameters(std::istream& stream, T& reference) {
std::uint32_t header; std::uint32_t header;
header = read_little_endian<std::uint32_t>(stream); header = read_little_endian<std::uint32_t>(stream);
if (!stream || header != T::get_hash_value()) return false; if (!stream || header != T::get_hash_value())
return false;
return reference.read_parameters(stream); return reference.read_parameters(stream);
} }
// Write evaluation function parameters // Write evaluation function parameters
template <typename T> template<typename T>
bool write_parameters(std::ostream& stream, const T& reference) { bool write_parameters(std::ostream& stream, const T& reference) {
write_little_endian<std::uint32_t>(stream, T::get_hash_value()); write_little_endian<std::uint32_t>(stream, T::get_hash_value());
return reference.write_parameters(stream); return reference.write_parameters(stream);
} }
} // namespace Detail } // namespace Detail
// Initialize the evaluation function parameters // Initialize the evaluation function parameters
static void initialize() { static void initialize() {
Detail::initialize(featureTransformer); Detail::initialize(featureTransformer);
for (std::size_t i = 0; i < LayerStacks; ++i) for (std::size_t i = 0; i < LayerStacks; ++i)
Detail::initialize(network[i]); Detail::initialize(network[i]);
} }
// Read network header // Read network header
static bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc) static bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc) {
{
std::uint32_t version, size; std::uint32_t version, size;
version = read_little_endian<std::uint32_t>(stream); version = read_little_endian<std::uint32_t>(stream);
*hashValue = read_little_endian<std::uint32_t>(stream); *hashValue = read_little_endian<std::uint32_t>(stream);
size = read_little_endian<std::uint32_t>(stream); size = read_little_endian<std::uint32_t>(stream);
if (!stream || version != Version) return false; if (!stream || version != Version)
return false;
desc->resize(size); desc->resize(size);
stream.read(&(*desc)[0], size); stream.read(&(*desc)[0], size);
return !stream.fail(); return !stream.fail();
} }
// Write network header // Write network header
static bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc) static bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc) {
{
write_little_endian<std::uint32_t>(stream, Version); write_little_endian<std::uint32_t>(stream, Version);
write_little_endian<std::uint32_t>(stream, hashValue); write_little_endian<std::uint32_t>(stream, hashValue);
write_little_endian<std::uint32_t>(stream, (std::uint32_t)desc.size()); write_little_endian<std::uint32_t>(stream, (std::uint32_t) desc.size());
stream.write(&desc[0], desc.size()); stream.write(&desc[0], desc.size());
return !stream.fail(); return !stream.fail();
} }
// Read network parameters // Read network parameters
static bool read_parameters(std::istream& stream) { static bool read_parameters(std::istream& stream) {
std::uint32_t hashValue; std::uint32_t hashValue;
if (!read_header(stream, &hashValue, &netDescription)) return false; if (!read_header(stream, &hashValue, &netDescription))
if (hashValue != HashValue) return false; return false;
if (!Detail::read_parameters(stream, *featureTransformer)) return false; if (hashValue != HashValue)
return false;
if (!Detail::read_parameters(stream, *featureTransformer))
return false;
for (std::size_t i = 0; i < LayerStacks; ++i) for (std::size_t i = 0; i < LayerStacks; ++i)
if (!Detail::read_parameters(stream, *(network[i]))) return false; if (!Detail::read_parameters(stream, *(network[i])))
return false;
return stream && stream.peek() == std::ios::traits_type::eof(); return stream && stream.peek() == std::ios::traits_type::eof();
} }
// Write network parameters // Write network parameters
static bool write_parameters(std::ostream& stream) { static bool write_parameters(std::ostream& stream) {
if (!write_header(stream, HashValue, netDescription)) return false; if (!write_header(stream, HashValue, netDescription))
if (!Detail::write_parameters(stream, *featureTransformer)) return false; return false;
if (!Detail::write_parameters(stream, *featureTransformer))
return false;
for (std::size_t i = 0; i < LayerStacks; ++i) for (std::size_t i = 0; i < LayerStacks; ++i)
if (!Detail::write_parameters(stream, *(network[i]))) return false; if (!Detail::write_parameters(stream, *(network[i])))
return false;
return bool(stream); return bool(stream);
} }
void hint_common_parent_position(const Position& pos) { void hint_common_parent_position(const Position& pos) {
featureTransformer->hint_common_access(pos); featureTransformer->hint_common_access(pos);
} }
// Evaluation function. Perform differential calculation. // Evaluation function. Perform differential calculation.
Value evaluate(const Position& pos, bool adjusted, int* complexity) { Value evaluate(const Position& pos, bool adjusted, int* complexity) {
// We manually align the arrays on the stack because with gcc < 9.3 // We manually align the arrays on the stack because with gcc < 9.3
// overaligning stack variables with alignas() doesn't work correctly. // overaligning stack variables with alignas() doesn't work correctly.
constexpr uint64_t alignment = CacheLineSize; constexpr uint64_t alignment = CacheLineSize;
constexpr int delta = 24; constexpr int delta = 24;
#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN)
TransformedFeatureType transformedFeaturesUnaligned[ TransformedFeatureType
FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; transformedFeaturesUnaligned[FeatureTransformer::BufferSize
+ alignment / sizeof(TransformedFeatureType)];
auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]); auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]);
#else #else
alignas(alignment) alignas(alignment) TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
#endif #endif
ASSERT_ALIGNED(transformedFeatures, alignment); ASSERT_ALIGNED(transformedFeatures, alignment);
const int bucket = (pos.count<ALL_PIECES>() - 1) / 4; const int bucket = (pos.count<ALL_PIECES>() - 1) / 4;
const auto psqt = featureTransformer->transform(pos, transformedFeatures, bucket); const auto psqt = featureTransformer->transform(pos, transformedFeatures, bucket);
const auto positional = network[bucket]->propagate(transformedFeatures); const auto positional = network[bucket]->propagate(transformedFeatures);
if (complexity) if (complexity)
@ -176,158 +184,164 @@ namespace Stockfish::Eval::NNUE {
// Give more value to positional evaluation when adjusted flag is set // Give more value to positional evaluation when adjusted flag is set
if (adjusted) if (adjusted)
return static_cast<Value>(((1024 - delta) * psqt + (1024 + delta) * positional) / (1024 * OutputScale)); return static_cast<Value>(((1024 - delta) * psqt + (1024 + delta) * positional)
/ (1024 * OutputScale));
else else
return static_cast<Value>((psqt + positional) / OutputScale); return static_cast<Value>((psqt + positional) / OutputScale);
} }
struct NnueEvalTrace { struct NnueEvalTrace {
static_assert(LayerStacks == PSQTBuckets); static_assert(LayerStacks == PSQTBuckets);
Value psqt[LayerStacks]; Value psqt[LayerStacks];
Value positional[LayerStacks]; Value positional[LayerStacks];
std::size_t correctBucket; std::size_t correctBucket;
}; };
static NnueEvalTrace trace_evaluate(const Position& pos) { static NnueEvalTrace trace_evaluate(const Position& pos) {
// We manually align the arrays on the stack because with gcc < 9.3 // We manually align the arrays on the stack because with gcc < 9.3
// overaligning stack variables with alignas() doesn't work correctly. // overaligning stack variables with alignas() doesn't work correctly.
constexpr uint64_t alignment = CacheLineSize; constexpr uint64_t alignment = CacheLineSize;
#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN)
TransformedFeatureType transformedFeaturesUnaligned[ TransformedFeatureType
FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; transformedFeaturesUnaligned[FeatureTransformer::BufferSize
+ alignment / sizeof(TransformedFeatureType)];
auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]); auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]);
#else #else
alignas(alignment) alignas(alignment) TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
#endif #endif
ASSERT_ALIGNED(transformedFeatures, alignment); ASSERT_ALIGNED(transformedFeatures, alignment);
NnueEvalTrace t{}; NnueEvalTrace t{};
t.correctBucket = (pos.count<ALL_PIECES>() - 1) / 4; t.correctBucket = (pos.count<ALL_PIECES>() - 1) / 4;
for (IndexType bucket = 0; bucket < LayerStacks; ++bucket) { for (IndexType bucket = 0; bucket < LayerStacks; ++bucket)
const auto materialist = featureTransformer->transform(pos, transformedFeatures, bucket); {
const auto positional = network[bucket]->propagate(transformedFeatures); const auto materialist = featureTransformer->transform(pos, transformedFeatures, bucket);
const auto positional = network[bucket]->propagate(transformedFeatures);
t.psqt[bucket] = static_cast<Value>( materialist / OutputScale ); t.psqt[bucket] = static_cast<Value>(materialist / OutputScale);
t.positional[bucket] = static_cast<Value>( positional / OutputScale ); t.positional[bucket] = static_cast<Value>(positional / OutputScale);
} }
return t; return t;
} }
constexpr std::string_view PieceToChar(" PNBRQK pnbrqk"); constexpr std::string_view PieceToChar(" PNBRQK pnbrqk");
// format_cp_compact() converts a Value into (centi)pawns and writes it in a buffer. // format_cp_compact() converts a Value into (centi)pawns and writes it in a buffer.
// The buffer must have capacity for at least 5 chars. // The buffer must have capacity for at least 5 chars.
static void format_cp_compact(Value v, char* buffer) { static void format_cp_compact(Value v, char* buffer) {
buffer[0] = (v < 0 ? '-' : v > 0 ? '+' : ' '); buffer[0] = (v < 0 ? '-' : v > 0 ? '+' : ' ');
int cp = std::abs(UCI::to_cp(v)); int cp = std::abs(UCI::to_cp(v));
if (cp >= 10000) if (cp >= 10000)
{ {
buffer[1] = '0' + cp / 10000; cp %= 10000; buffer[1] = '0' + cp / 10000;
buffer[2] = '0' + cp / 1000; cp %= 1000; cp %= 10000;
buffer[2] = '0' + cp / 1000;
cp %= 1000;
buffer[3] = '0' + cp / 100; buffer[3] = '0' + cp / 100;
buffer[4] = ' '; buffer[4] = ' ';
} }
else if (cp >= 1000) else if (cp >= 1000)
{ {
buffer[1] = '0' + cp / 1000; cp %= 1000; buffer[1] = '0' + cp / 1000;
buffer[2] = '0' + cp / 100; cp %= 100; cp %= 1000;
buffer[2] = '0' + cp / 100;
cp %= 100;
buffer[3] = '.'; buffer[3] = '.';
buffer[4] = '0' + cp / 10; buffer[4] = '0' + cp / 10;
} }
else else
{ {
buffer[1] = '0' + cp / 100; cp %= 100; buffer[1] = '0' + cp / 100;
cp %= 100;
buffer[2] = '.'; buffer[2] = '.';
buffer[3] = '0' + cp / 10; cp %= 10; buffer[3] = '0' + cp / 10;
cp %= 10;
buffer[4] = '0' + cp / 1; buffer[4] = '0' + cp / 1;
} }
} }
// format_cp_aligned_dot() converts a Value into pawns, always keeping two decimals // format_cp_aligned_dot() converts a Value into pawns, always keeping two decimals
static void format_cp_aligned_dot(Value v, std::stringstream &stream) { static void format_cp_aligned_dot(Value v, std::stringstream& stream) {
const double pawns = std::abs(0.01 * UCI::to_cp(v)); const double pawns = std::abs(0.01 * UCI::to_cp(v));
stream << (v < 0 ? '-' : v > 0 ? '+' : ' ') stream << (v < 0 ? '-'
<< std::setiosflags(std::ios::fixed) : v > 0 ? '+'
<< std::setw(6) : ' ')
<< std::setprecision(2) << std::setiosflags(std::ios::fixed) << std::setw(6) << std::setprecision(2) << pawns;
<< pawns; }
}
// trace() returns a string with the value of each piece on a board, // trace() returns a string with the value of each piece on a board,
// and a table for (PSQT, Layers) values bucket by bucket. // and a table for (PSQT, Layers) values bucket by bucket.
std::string trace(Position& pos) { std::string trace(Position& pos) {
std::stringstream ss; std::stringstream ss;
char board[3*8+1][8*8+2]; char board[3 * 8 + 1][8 * 8 + 2];
std::memset(board, ' ', sizeof(board)); std::memset(board, ' ', sizeof(board));
for (int row = 0; row < 3*8+1; ++row) for (int row = 0; row < 3 * 8 + 1; ++row)
board[row][8*8+1] = '\0'; board[row][8 * 8 + 1] = '\0';
// A lambda to output one box of the board // A lambda to output one box of the board
auto writeSquare = [&board](File file, Rank rank, Piece pc, Value value) { auto writeSquare = [&board](File file, Rank rank, Piece pc, Value value) {
const int x = int(file) * 8;
const int x = int(file) * 8; const int y = (7 - int(rank)) * 3;
const int y = (7 - int(rank)) * 3; for (int i = 1; i < 8; ++i)
for (int i = 1; i < 8; ++i) board[y][x + i] = board[y + 3][x + i] = '-';
board[y][x+i] = board[y+3][x+i] = '-'; for (int i = 1; i < 3; ++i)
for (int i = 1; i < 3; ++i) board[y + i][x] = board[y + i][x + 8] = '|';
board[y+i][x] = board[y+i][x+8] = '|'; board[y][x] = board[y][x + 8] = board[y + 3][x + 8] = board[y + 3][x] = '+';
board[y][x] = board[y][x+8] = board[y+3][x+8] = board[y+3][x] = '+'; if (pc != NO_PIECE)
if (pc != NO_PIECE) board[y + 1][x + 4] = PieceToChar[pc];
board[y+1][x+4] = PieceToChar[pc]; if (value != VALUE_NONE)
if (value != VALUE_NONE) format_cp_compact(value, &board[y + 2][x + 2]);
format_cp_compact(value, &board[y+2][x+2]);
}; };
// We estimate the value of each piece by doing a differential evaluation from // We estimate the value of each piece by doing a differential evaluation from
// the current base eval, simulating the removal of the piece from its square. // the current base eval, simulating the removal of the piece from its square.
Value base = evaluate(pos); Value base = evaluate(pos);
base = pos.side_to_move() == WHITE ? base : -base; base = pos.side_to_move() == WHITE ? base : -base;
for (File f = FILE_A; f <= FILE_H; ++f) for (File f = FILE_A; f <= FILE_H; ++f)
for (Rank r = RANK_1; r <= RANK_8; ++r) for (Rank r = RANK_1; r <= RANK_8; ++r)
{
Square sq = make_square(f, r);
Piece pc = pos.piece_on(sq);
Value v = VALUE_NONE;
if (pc != NO_PIECE && type_of(pc) != KING)
{ {
auto st = pos.state(); Square sq = make_square(f, r);
Piece pc = pos.piece_on(sq);
Value v = VALUE_NONE;
pos.remove_piece(sq); if (pc != NO_PIECE && type_of(pc) != KING)
st->accumulator.computed[WHITE] = false; {
st->accumulator.computed[BLACK] = false; auto st = pos.state();
Value eval = evaluate(pos); pos.remove_piece(sq);
eval = pos.side_to_move() == WHITE ? eval : -eval; st->accumulator.computed[WHITE] = false;
v = base - eval; st->accumulator.computed[BLACK] = false;
pos.put_piece(pc, sq); Value eval = evaluate(pos);
st->accumulator.computed[WHITE] = false; eval = pos.side_to_move() == WHITE ? eval : -eval;
st->accumulator.computed[BLACK] = false; v = base - eval;
pos.put_piece(pc, sq);
st->accumulator.computed[WHITE] = false;
st->accumulator.computed[BLACK] = false;
}
writeSquare(f, r, pc, v);
} }
writeSquare(f, r, pc, v);
}
ss << " NNUE derived piece values:\n"; ss << " NNUE derived piece values:\n";
for (int row = 0; row < 3*8+1; ++row) for (int row = 0; row < 3 * 8 + 1; ++row)
ss << board[row] << '\n'; ss << board[row] << '\n';
ss << '\n'; ss << '\n';
@ -342,41 +356,47 @@ namespace Stockfish::Eval::NNUE {
for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket) for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket)
{ {
ss << "| " << bucket << " "; ss << "| " << bucket << " ";
ss << " | "; format_cp_aligned_dot(t.psqt[bucket], ss); ss << " " ss << " | ";
<< " | "; format_cp_aligned_dot(t.positional[bucket], ss); ss << " " format_cp_aligned_dot(t.psqt[bucket], ss);
<< " | "; format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss); ss << " " ss << " "
<< " |"; << " | ";
if (bucket == t.correctBucket) format_cp_aligned_dot(t.positional[bucket], ss);
ss << " <-- this bucket is used"; ss << " "
ss << '\n'; << " | ";
format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss);
ss << " "
<< " |";
if (bucket == t.correctBucket)
ss << " <-- this bucket is used";
ss << '\n';
} }
ss << "+------------+------------+------------+------------+\n"; ss << "+------------+------------+------------+------------+\n";
return ss.str(); return ss.str();
} }
// Load eval, from a file stream or a memory stream // Load eval, from a file stream or a memory stream
bool load_eval(std::string name, std::istream& stream) { bool load_eval(std::string name, std::istream& stream) {
initialize(); initialize();
fileName = name; fileName = name;
return read_parameters(stream); return read_parameters(stream);
} }
// Save eval, to a file stream or a memory stream // Save eval, to a file stream or a memory stream
bool save_eval(std::ostream& stream) { bool save_eval(std::ostream& stream) {
if (fileName.empty()) if (fileName.empty())
return false; return false;
return write_parameters(stream); return write_parameters(stream);
} }
// Save eval, to a file given by its name // Save eval, to a file given by its name
bool save_eval(const std::optional<std::string>& filename) { bool save_eval(const std::optional<std::string>& filename) {
std::string actualFilename; std::string actualFilename;
std::string msg; std::string msg;
@ -387,23 +407,23 @@ namespace Stockfish::Eval::NNUE {
{ {
if (currentEvalFileName != EvalFileDefaultName) if (currentEvalFileName != EvalFileDefaultName)
{ {
msg = "Failed to export a net. A non-embedded net can only be saved if the filename is specified"; msg =
"Failed to export a net. A non-embedded net can only be saved if the filename is specified";
sync_cout << msg << sync_endl; sync_cout << msg << sync_endl;
return false; return false;
} }
actualFilename = EvalFileDefaultName; actualFilename = EvalFileDefaultName;
} }
std::ofstream stream(actualFilename, std::ios_base::binary); std::ofstream stream(actualFilename, std::ios_base::binary);
bool saved = save_eval(stream); bool saved = save_eval(stream);
msg = saved ? "Network saved successfully to " + actualFilename msg = saved ? "Network saved successfully to " + actualFilename : "Failed to export a net";
: "Failed to export a net";
sync_cout << msg << sync_endl; sync_cout << msg << sync_endl;
return saved; return saved;
} }
} // namespace Stockfish::Eval::NNUE } // namespace Stockfish::Eval::NNUE

View file

@ -32,48 +32,48 @@
#include "nnue_feature_transformer.h" #include "nnue_feature_transformer.h"
namespace Stockfish { namespace Stockfish {
class Position; class Position;
enum Value : int; enum Value : int;
} }
namespace Stockfish::Eval::NNUE { namespace Stockfish::Eval::NNUE {
// Hash value of evaluation function structure // Hash value of evaluation function structure
constexpr std::uint32_t HashValue = constexpr std::uint32_t HashValue =
FeatureTransformer::get_hash_value() ^ Network::get_hash_value(); FeatureTransformer::get_hash_value() ^ Network::get_hash_value();
// Deleter for automating release of memory area // Deleter for automating release of memory area
template <typename T> template<typename T>
struct AlignedDeleter { struct AlignedDeleter {
void operator()(T* ptr) const { void operator()(T* ptr) const {
ptr->~T(); ptr->~T();
std_aligned_free(ptr); std_aligned_free(ptr);
} }
}; };
template <typename T> template<typename T>
struct LargePageDeleter { struct LargePageDeleter {
void operator()(T* ptr) const { void operator()(T* ptr) const {
ptr->~T(); ptr->~T();
aligned_large_pages_free(ptr); aligned_large_pages_free(ptr);
} }
}; };
template <typename T> template<typename T>
using AlignedPtr = std::unique_ptr<T, AlignedDeleter<T>>; using AlignedPtr = std::unique_ptr<T, AlignedDeleter<T>>;
template <typename T> template<typename T>
using LargePagePtr = std::unique_ptr<T, LargePageDeleter<T>>; using LargePagePtr = std::unique_ptr<T, LargePageDeleter<T>>;
std::string trace(Position& pos); std::string trace(Position& pos);
Value evaluate(const Position& pos, bool adjusted = false, int* complexity = nullptr); Value evaluate(const Position& pos, bool adjusted = false, int* complexity = nullptr);
void hint_common_parent_position(const Position& pos); void hint_common_parent_position(const Position& pos);
bool load_eval(std::string name, std::istream& stream); bool load_eval(std::string name, std::istream& stream);
bool save_eval(std::ostream& stream); bool save_eval(std::ostream& stream);
bool save_eval(const std::optional<std::string>& filename); bool save_eval(const std::optional<std::string>& filename);
} // namespace Stockfish::Eval::NNUE } // namespace Stockfish::Eval::NNUE
#endif // #ifndef NNUE_EVALUATE_NNUE_H_INCLUDED #endif // #ifndef NNUE_EVALUATE_NNUE_H_INCLUDED

View file

@ -27,61 +27,60 @@
namespace Stockfish::Eval::NNUE::Features { namespace Stockfish::Eval::NNUE::Features {
// Index of a feature for a given king position and another piece on some square // Index of a feature for a given king position and another piece on some square
template<Color Perspective> template<Color Perspective>
inline IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq) { inline IndexType HalfKAv2_hm::make_index(Square s, Piece pc, Square ksq) {
return IndexType((int(s) ^ OrientTBL[Perspective][ksq]) + PieceSquareIndex[Perspective][pc] + KingBuckets[Perspective][ksq]); return IndexType((int(s) ^ OrientTBL[Perspective][ksq]) + PieceSquareIndex[Perspective][pc]
} + KingBuckets[Perspective][ksq]);
}
// Get a list of indices for active features // Get a list of indices for active features
template<Color Perspective> template<Color Perspective>
void HalfKAv2_hm::append_active_indices( void HalfKAv2_hm::append_active_indices(const Position& pos, IndexList& active) {
const Position& pos, Square ksq = pos.square<KING>(Perspective);
IndexList& active Bitboard bb = pos.pieces();
) {
Square ksq = pos.square<KING>(Perspective);
Bitboard bb = pos.pieces();
while (bb) while (bb)
{ {
Square s = pop_lsb(bb); Square s = pop_lsb(bb);
active.push_back(make_index<Perspective>(s, pos.piece_on(s), ksq)); active.push_back(make_index<Perspective>(s, pos.piece_on(s), ksq));
} }
} }
// Explicit template instantiations // Explicit template instantiations
template void HalfKAv2_hm::append_active_indices<WHITE>(const Position& pos, IndexList& active); template void HalfKAv2_hm::append_active_indices<WHITE>(const Position& pos, IndexList& active);
template void HalfKAv2_hm::append_active_indices<BLACK>(const Position& pos, IndexList& active); template void HalfKAv2_hm::append_active_indices<BLACK>(const Position& pos, IndexList& active);
// append_changed_indices() : get a list of indices for recently changed features // append_changed_indices() : get a list of indices for recently changed features
template<Color Perspective> template<Color Perspective>
void HalfKAv2_hm::append_changed_indices( void HalfKAv2_hm::append_changed_indices(Square ksq,
Square ksq, const DirtyPiece& dp,
const DirtyPiece& dp, IndexList& removed,
IndexList& removed, IndexList& added) {
IndexList& added for (int i = 0; i < dp.dirty_num; ++i)
) { {
for (int i = 0; i < dp.dirty_num; ++i) { if (dp.from[i] != SQ_NONE)
if (dp.from[i] != SQ_NONE) removed.push_back(make_index<Perspective>(dp.from[i], dp.piece[i], ksq));
removed.push_back(make_index<Perspective>(dp.from[i], dp.piece[i], ksq)); if (dp.to[i] != SQ_NONE)
if (dp.to[i] != SQ_NONE) added.push_back(make_index<Perspective>(dp.to[i], dp.piece[i], ksq));
added.push_back(make_index<Perspective>(dp.to[i], dp.piece[i], ksq));
} }
} }
// Explicit template instantiations // Explicit template instantiations
template void HalfKAv2_hm::append_changed_indices<WHITE>(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added); template void HalfKAv2_hm::append_changed_indices<WHITE>(Square ksq,
template void HalfKAv2_hm::append_changed_indices<BLACK>(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added); const DirtyPiece& dp,
IndexList& removed,
IndexList& added);
template void HalfKAv2_hm::append_changed_indices<BLACK>(Square ksq,
const DirtyPiece& dp,
IndexList& removed,
IndexList& added);
int HalfKAv2_hm::update_cost(const StateInfo* st) { int HalfKAv2_hm::update_cost(const StateInfo* st) { return st->dirtyPiece.dirty_num; }
return st->dirtyPiece.dirty_num;
}
int HalfKAv2_hm::refresh_cost(const Position& pos) { int HalfKAv2_hm::refresh_cost(const Position& pos) { return pos.count<ALL_PIECES>(); }
return pos.count<ALL_PIECES>();
}
bool HalfKAv2_hm::requires_refresh(const StateInfo* st, Color perspective) { bool HalfKAv2_hm::requires_refresh(const StateInfo* st, Color perspective) {
return st->dirtyPiece.piece[0] == make_piece(perspective, KING); return st->dirtyPiece.piece[0] == make_piece(perspective, KING);
} }
} // namespace Stockfish::Eval::NNUE::Features } // namespace Stockfish::Eval::NNUE::Features

View file

@ -28,41 +28,40 @@
#include "../nnue_common.h" #include "../nnue_common.h"
namespace Stockfish { namespace Stockfish {
struct StateInfo; struct StateInfo;
class Position; class Position;
} }
namespace Stockfish::Eval::NNUE::Features { namespace Stockfish::Eval::NNUE::Features {
// Feature HalfKAv2_hm: Combination of the position of own king // Feature HalfKAv2_hm: Combination of the position of own king
// and the position of pieces. Position mirrored such that king always on e..h files. // and the position of pieces. Position mirrored such that king always on e..h files.
class HalfKAv2_hm { class HalfKAv2_hm {
// unique number for each piece type on each square // unique number for each piece type on each square
enum { enum {
PS_NONE = 0, PS_NONE = 0,
PS_W_PAWN = 0, PS_W_PAWN = 0,
PS_B_PAWN = 1 * SQUARE_NB, PS_B_PAWN = 1 * SQUARE_NB,
PS_W_KNIGHT = 2 * SQUARE_NB, PS_W_KNIGHT = 2 * SQUARE_NB,
PS_B_KNIGHT = 3 * SQUARE_NB, PS_B_KNIGHT = 3 * SQUARE_NB,
PS_W_BISHOP = 4 * SQUARE_NB, PS_W_BISHOP = 4 * SQUARE_NB,
PS_B_BISHOP = 5 * SQUARE_NB, PS_B_BISHOP = 5 * SQUARE_NB,
PS_W_ROOK = 6 * SQUARE_NB, PS_W_ROOK = 6 * SQUARE_NB,
PS_B_ROOK = 7 * SQUARE_NB, PS_B_ROOK = 7 * SQUARE_NB,
PS_W_QUEEN = 8 * SQUARE_NB, PS_W_QUEEN = 8 * SQUARE_NB,
PS_B_QUEEN = 9 * SQUARE_NB, PS_B_QUEEN = 9 * SQUARE_NB,
PS_KING = 10 * SQUARE_NB, PS_KING = 10 * SQUARE_NB,
PS_NB = 11 * SQUARE_NB PS_NB = 11 * SQUARE_NB
}; };
static constexpr IndexType PieceSquareIndex[COLOR_NB][PIECE_NB] = { static constexpr IndexType PieceSquareIndex[COLOR_NB][PIECE_NB] = {
// convention: W - us, B - them // convention: W - us, B - them
// viewed from other side, W and B are reversed // viewed from other side, W and B are reversed
{ PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE, {PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE,
PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE }, PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE},
{ PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE, {PS_NONE, PS_B_PAWN, PS_B_KNIGHT, PS_B_BISHOP, PS_B_ROOK, PS_B_QUEEN, PS_KING, PS_NONE,
PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE } PS_NONE, PS_W_PAWN, PS_W_KNIGHT, PS_W_BISHOP, PS_W_ROOK, PS_W_QUEEN, PS_KING, PS_NONE}};
};
// Index of a feature for a given king position and another piece on some square // Index of a feature for a given king position and another piece on some square
template<Color Perspective> template<Color Perspective>
@ -77,9 +76,10 @@ namespace Stockfish::Eval::NNUE::Features {
// Number of feature dimensions // Number of feature dimensions
static constexpr IndexType Dimensions = static constexpr IndexType Dimensions =
static_cast<IndexType>(SQUARE_NB) * static_cast<IndexType>(PS_NB) / 2; static_cast<IndexType>(SQUARE_NB) * static_cast<IndexType>(PS_NB) / 2;
#define B(v) (v * PS_NB) #define B(v) (v * PS_NB)
// clang-format off
static constexpr int KingBuckets[COLOR_NB][SQUARE_NB] = { static constexpr int KingBuckets[COLOR_NB][SQUARE_NB] = {
{ B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28), { B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28),
B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24), B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24),
@ -98,8 +98,9 @@ namespace Stockfish::Eval::NNUE::Features {
B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24), B(24), B(25), B(26), B(27), B(27), B(26), B(25), B(24),
B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28) } B(28), B(29), B(30), B(31), B(31), B(30), B(29), B(28) }
}; };
// clang-format on
#undef B #undef B
// clang-format off
// Orient a square according to perspective (rotates by 180 for black) // Orient a square according to perspective (rotates by 180 for black)
static constexpr int OrientTBL[COLOR_NB][SQUARE_NB] = { static constexpr int OrientTBL[COLOR_NB][SQUARE_NB] = {
{ SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1, { SQ_H1, SQ_H1, SQ_H1, SQ_H1, SQ_A1, SQ_A1, SQ_A1, SQ_A1,
@ -119,25 +120,20 @@ namespace Stockfish::Eval::NNUE::Features {
SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8, SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8,
SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8 } SQ_H8, SQ_H8, SQ_H8, SQ_H8, SQ_A8, SQ_A8, SQ_A8, SQ_A8 }
}; };
// clang-format on
// Maximum number of simultaneously active features. // Maximum number of simultaneously active features.
static constexpr IndexType MaxActiveDimensions = 32; static constexpr IndexType MaxActiveDimensions = 32;
using IndexList = ValueList<IndexType, MaxActiveDimensions>; using IndexList = ValueList<IndexType, MaxActiveDimensions>;
// Get a list of indices for active features // Get a list of indices for active features
template<Color Perspective> template<Color Perspective>
static void append_active_indices( static void append_active_indices(const Position& pos, IndexList& active);
const Position& pos,
IndexList& active);
// Get a list of indices for recently changed features // Get a list of indices for recently changed features
template<Color Perspective> template<Color Perspective>
static void append_changed_indices( static void
Square ksq, append_changed_indices(Square ksq, const DirtyPiece& dp, IndexList& removed, IndexList& added);
const DirtyPiece& dp,
IndexList& removed,
IndexList& added
);
// Returns the cost of updating one perspective, the most costly one. // Returns the cost of updating one perspective, the most costly one.
// Assumes no refresh needed. // Assumes no refresh needed.
@ -147,8 +143,8 @@ namespace Stockfish::Eval::NNUE::Features {
// Returns whether the change stored in this StateInfo means that // Returns whether the change stored in this StateInfo means that
// a full accumulator refresh is required. // a full accumulator refresh is required.
static bool requires_refresh(const StateInfo* st, Color perspective); static bool requires_refresh(const StateInfo* st, Color perspective);
}; };
} // namespace Stockfish::Eval::NNUE::Features } // namespace Stockfish::Eval::NNUE::Features
#endif // #ifndef NNUE_FEATURES_HALF_KA_V2_HM_H_INCLUDED #endif // #ifndef NNUE_FEATURES_HALF_KA_V2_HM_H_INCLUDED

View file

@ -42,95 +42,102 @@ namespace Stockfish::Eval::NNUE::Layers {
// Fallback implementation for older/other architectures. // Fallback implementation for older/other architectures.
// Requires the input to be padded to at least 16 values. // Requires the input to be padded to at least 16 values.
#if !defined(USE_SSSE3) #if !defined(USE_SSSE3)
template <IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions> template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, const std::int32_t* biases, const std::uint8_t* input) static void affine_transform_non_ssse3(std::int32_t* output,
{ const std::int8_t* weights,
# if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON) const std::int32_t* biases,
# if defined(USE_SSE2) const std::uint8_t* input) {
#if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON)
#if defined(USE_SSE2)
// At least a multiple of 16, with SSE2. // At least a multiple of 16, with SSE2.
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16; constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
const __m128i Zeros = _mm_setzero_si128(); const __m128i Zeros = _mm_setzero_si128();
const auto inputVector = reinterpret_cast<const __m128i*>(input); const auto inputVector = reinterpret_cast<const __m128i*>(input);
# elif defined(USE_NEON_DOTPROD) #elif defined(USE_NEON_DOTPROD)
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16; constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
const auto inputVector = reinterpret_cast<const int8x16_t*>(input); const auto inputVector = reinterpret_cast<const int8x16_t*>(input);
# elif defined(USE_NEON) #elif defined(USE_NEON)
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16; constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
const auto inputVector = reinterpret_cast<const int8x8_t*>(input); const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
# endif #endif
for (IndexType i = 0; i < OutputDimensions; ++i) { for (IndexType i = 0; i < OutputDimensions; ++i)
const IndexType offset = i * PaddedInputDimensions; {
const IndexType offset = i * PaddedInputDimensions;
# if defined(USE_SSE2) #if defined(USE_SSE2)
__m128i sumLo = _mm_cvtsi32_si128(biases[i]); __m128i sumLo = _mm_cvtsi32_si128(biases[i]);
__m128i sumHi = Zeros; __m128i sumHi = Zeros;
const auto row = reinterpret_cast<const __m128i*>(&weights[offset]); const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j) { for (IndexType j = 0; j < NumChunks; ++j)
__m128i row_j = _mm_load_si128(&row[j]); {
__m128i input_j = _mm_load_si128(&inputVector[j]); __m128i row_j = _mm_load_si128(&row[j]);
__m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8); __m128i input_j = _mm_load_si128(&inputVector[j]);
__m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8); __m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
__m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros); __m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
__m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros); __m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros);
__m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo); __m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros);
__m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi); __m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo);
sumLo = _mm_add_epi32(sumLo, productLo); __m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi);
sumHi = _mm_add_epi32(sumHi, productHi); sumLo = _mm_add_epi32(sumLo, productLo);
} sumHi = _mm_add_epi32(sumHi, productHi);
__m128i sum = _mm_add_epi32(sumLo, sumHi); }
__m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2)); __m128i sum = _mm_add_epi32(sumLo, sumHi);
sum = _mm_add_epi32(sum, sumHigh_64); __m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
__m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2)); sum = _mm_add_epi32(sum, sumHigh_64);
sum = _mm_add_epi32(sum, sum_second_32); __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2));
output[i] = _mm_cvtsi128_si32(sum); sum = _mm_add_epi32(sum, sum_second_32);
output[i] = _mm_cvtsi128_si32(sum);
# elif defined(USE_NEON_DOTPROD) #elif defined(USE_NEON_DOTPROD)
int32x4_t sum = {biases[i]}; int32x4_t sum = {biases[i]};
const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]); const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j) { for (IndexType j = 0; j < NumChunks; ++j)
sum = vdotq_s32(sum, inputVector[j], row[j]); {
} sum = vdotq_s32(sum, inputVector[j], row[j]);
output[i] = vaddvq_s32(sum); }
output[i] = vaddvq_s32(sum);
# elif defined(USE_NEON) #elif defined(USE_NEON)
int32x4_t sum = {biases[i]}; int32x4_t sum = {biases[i]};
const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]); const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
for (IndexType j = 0; j < NumChunks; ++j) { for (IndexType j = 0; j < NumChunks; ++j)
int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]); {
product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]); int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);
sum = vpadalq_s16(sum, product); product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);
} sum = vpadalq_s16(sum, product);
output[i] = sum[0] + sum[1] + sum[2] + sum[3]; }
output[i] = sum[0] + sum[1] + sum[2] + sum[3];
# endif #endif
} }
# else #else
std::memcpy(output, biases, sizeof(std::int32_t) * OutputDimensions); std::memcpy(output, biases, sizeof(std::int32_t) * OutputDimensions);
// Traverse weights in transpose order to take advantage of input sparsity // Traverse weights in transpose order to take advantage of input sparsity
for (IndexType i = 0; i < InputDimensions; ++i) for (IndexType i = 0; i < InputDimensions; ++i)
if (input[i]) { if (input[i])
const std::int8_t* w = &weights[i]; {
const int in = input[i]; const std::int8_t* w = &weights[i];
for (IndexType j = 0; j < OutputDimensions; ++j) const int in = input[i];
output[j] += w[j * PaddedInputDimensions] * in; for (IndexType j = 0; j < OutputDimensions; ++j)
} output[j] += w[j * PaddedInputDimensions] * in;
# endif }
} #endif
}
#endif #endif
template <IndexType InDims, IndexType OutDims> template<IndexType InDims, IndexType OutDims>
class AffineTransform { class AffineTransform {
public: public:
// Input/output type // Input/output type
using InputType = std::uint8_t; using InputType = std::uint8_t;
using OutputType = std::int32_t; using OutputType = std::int32_t;
// Number of input/output dimensions // Number of input/output dimensions
static constexpr IndexType InputDimensions = InDims; static constexpr IndexType InputDimensions = InDims;
static constexpr IndexType OutputDimensions = OutDims; static constexpr IndexType OutputDimensions = OutDims;
static constexpr IndexType PaddedInputDimensions = static constexpr IndexType PaddedInputDimensions =
@ -142,175 +149,168 @@ namespace Stockfish::Eval::NNUE::Layers {
// Hash value embedded in the evaluation file // Hash value embedded in the evaluation file
static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) {
std::uint32_t hashValue = 0xCC03DAE4u; std::uint32_t hashValue = 0xCC03DAE4u;
hashValue += OutputDimensions; hashValue += OutputDimensions;
hashValue ^= prevHash >> 1; hashValue ^= prevHash >> 1;
hashValue ^= prevHash << 31; hashValue ^= prevHash << 31;
return hashValue; return hashValue;
} }
static constexpr IndexType get_weight_index_scrambled(IndexType i) static constexpr IndexType get_weight_index_scrambled(IndexType i) {
{ return (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4
return + i / PaddedInputDimensions * 4 + i % 4;
(i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
i / PaddedInputDimensions * 4 +
i % 4;
} }
static constexpr IndexType get_weight_index(IndexType i) static constexpr IndexType get_weight_index(IndexType i) {
{ #if defined(USE_SSSE3)
#if defined (USE_SSSE3) return get_weight_index_scrambled(i);
return get_weight_index_scrambled(i);
#else #else
return i; return i;
#endif #endif
} }
// Read network parameters // Read network parameters
bool read_parameters(std::istream& stream) { bool read_parameters(std::istream& stream) {
read_little_endian<BiasType>(stream, biases, OutputDimensions); read_little_endian<BiasType>(stream, biases, OutputDimensions);
for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
weights[get_weight_index(i)] = read_little_endian<WeightType>(stream); weights[get_weight_index(i)] = read_little_endian<WeightType>(stream);
return !stream.fail(); return !stream.fail();
} }
// Write network parameters // Write network parameters
bool write_parameters(std::ostream& stream) const { bool write_parameters(std::ostream& stream) const {
write_little_endian<BiasType>(stream, biases, OutputDimensions); write_little_endian<BiasType>(stream, biases, OutputDimensions);
for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
write_little_endian<WeightType>(stream, weights[get_weight_index(i)]); write_little_endian<WeightType>(stream, weights[get_weight_index(i)]);
return !stream.fail(); return !stream.fail();
} }
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if defined (USE_SSSE3) #if defined(USE_SSSE3)
if constexpr (OutputDimensions > 1) if constexpr (OutputDimensions > 1)
{
#if defined (USE_AVX512)
using vec_t = __m512i;
#define vec_setzero _mm512_setzero_si512
#define vec_set_32 _mm512_set1_epi32
#define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2
#define vec_hadd Simd::m512_hadd
#elif defined (USE_AVX2)
using vec_t = __m256i;
#define vec_setzero _mm256_setzero_si256
#define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
#define vec_hadd Simd::m256_hadd
#elif defined (USE_SSSE3)
using vec_t = __m128i;
#define vec_setzero _mm_setzero_si128
#define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
#define vec_hadd Simd::m128_hadd
#endif
static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
static_assert(OutputDimensions % OutputSimdWidth == 0);
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / 4;
constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
const auto input32 = reinterpret_cast<const std::int32_t*>(input);
const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
vec_t acc[NumRegs];
for (IndexType k = 0; k < NumRegs; ++k)
acc[k] = biasvec[k];
for (IndexType i = 0; i < NumChunks; i += 2)
{ {
const vec_t in0 = vec_set_32(input32[i + 0]);
const vec_t in1 = vec_set_32(input32[i + 1]); #if defined(USE_AVX512)
const auto col0 = reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]); using vec_t = __m512i;
const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]); #define vec_setzero _mm512_setzero_si512
for (IndexType k = 0; k < NumRegs; ++k) #define vec_set_32 _mm512_set1_epi32
vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]); #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2
#define vec_hadd Simd::m512_hadd
#elif defined(USE_AVX2)
using vec_t = __m256i;
#define vec_setzero _mm256_setzero_si256
#define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
#define vec_hadd Simd::m256_hadd
#elif defined(USE_SSSE3)
using vec_t = __m128i;
#define vec_setzero _mm_setzero_si128
#define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
#define vec_hadd Simd::m128_hadd
#endif
static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
static_assert(OutputDimensions % OutputSimdWidth == 0);
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / 4;
constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
const auto input32 = reinterpret_cast<const std::int32_t*>(input);
const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
vec_t acc[NumRegs];
for (IndexType k = 0; k < NumRegs; ++k)
acc[k] = biasvec[k];
for (IndexType i = 0; i < NumChunks; i += 2)
{
const vec_t in0 = vec_set_32(input32[i + 0]);
const vec_t in1 = vec_set_32(input32[i + 1]);
const auto col0 =
reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]);
const auto col1 =
reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
for (IndexType k = 0; k < NumRegs; ++k)
vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]);
}
vec_t* outptr = reinterpret_cast<vec_t*>(output);
for (IndexType k = 0; k < NumRegs; ++k)
outptr[k] = acc[k];
#undef vec_setzero
#undef vec_set_32
#undef vec_add_dpbusd_32
#undef vec_add_dpbusd_32x2
#undef vec_hadd
} }
else if constexpr (OutputDimensions == 1)
vec_t* outptr = reinterpret_cast<vec_t*>(output);
for (IndexType k = 0; k < NumRegs; ++k)
outptr[k] = acc[k];
# undef vec_setzero
# undef vec_set_32
# undef vec_add_dpbusd_32
# undef vec_add_dpbusd_32x2
# undef vec_hadd
}
else if constexpr (OutputDimensions == 1)
{
// We cannot use AVX512 for the last layer because there's only 32 inputs and the buffer is not padded to 64 elements.
#if defined (USE_AVX2)
using vec_t = __m256i;
#define vec_setzero _mm256_setzero_si256
#define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
#define vec_hadd Simd::m256_hadd
#elif defined (USE_SSSE3)
using vec_t = __m128i;
#define vec_setzero _mm_setzero_si128
#define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
#define vec_hadd Simd::m128_hadd
#endif
const auto inputVector = reinterpret_cast<const vec_t*>(input);
static constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(InputType);
static_assert(PaddedInputDimensions % InputSimdWidth == 0);
constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;
vec_t sum0 = vec_setzero();
const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);
for (int j = 0; j < int(NumChunks); ++j)
{ {
const vec_t in = inputVector[j];
vec_add_dpbusd_32(sum0, in, row0[j]); // We cannot use AVX512 for the last layer because there's only 32 inputs and the buffer is not padded to 64 elements.
#if defined(USE_AVX2)
using vec_t = __m256i;
#define vec_setzero _mm256_setzero_si256
#define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
#define vec_hadd Simd::m256_hadd
#elif defined(USE_SSSE3)
using vec_t = __m128i;
#define vec_setzero _mm_setzero_si128
#define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
#define vec_hadd Simd::m128_hadd
#endif
const auto inputVector = reinterpret_cast<const vec_t*>(input);
static constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(InputType);
static_assert(PaddedInputDimensions % InputSimdWidth == 0);
constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;
vec_t sum0 = vec_setzero();
const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);
for (int j = 0; j < int(NumChunks); ++j)
{
const vec_t in = inputVector[j];
vec_add_dpbusd_32(sum0, in, row0[j]);
}
output[0] = vec_hadd(sum0, biases[0]);
#undef vec_setzero
#undef vec_set_32
#undef vec_add_dpbusd_32
#undef vec_add_dpbusd_32x2
#undef vec_hadd
} }
output[0] = vec_hadd(sum0, biases[0]);
# undef vec_setzero
# undef vec_set_32
# undef vec_add_dpbusd_32
# undef vec_add_dpbusd_32x2
# undef vec_hadd
}
#else #else
// Use old implementation for the other architectures. // Use old implementation for the other architectures.
affine_transform_non_ssse3< affine_transform_non_ssse3<InputDimensions, PaddedInputDimensions, OutputDimensions>(
InputDimensions, output, weights, biases, input);
PaddedInputDimensions,
OutputDimensions>(output, weights, biases, input);
#endif #endif
} }
private: private:
using BiasType = OutputType; using BiasType = OutputType;
using WeightType = std::int8_t; using WeightType = std::int8_t;
alignas(CacheLineSize) BiasType biases[OutputDimensions]; alignas(CacheLineSize) BiasType biases[OutputDimensions];
alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions]; alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions];
}; };
} // namespace Stockfish::Eval::NNUE::Layers } // namespace Stockfish::Eval::NNUE::Layers
#endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED #endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED

View file

@ -38,104 +38,110 @@
namespace Stockfish::Eval::NNUE::Layers { namespace Stockfish::Eval::NNUE::Layers {
#if (USE_SSSE3 | (USE_NEON >= 8)) #if (USE_SSSE3 | (USE_NEON >= 8))
alignas(CacheLineSize) static inline const std::array<std::array<std::uint16_t, 8>, 256> lookup_indices = [](){ alignas(CacheLineSize) static inline const
std::array<std::array<std::uint16_t, 8>, 256> v{}; std::array<std::array<std::uint16_t, 8>, 256> lookup_indices = []() {
for (unsigned i = 0; i < 256; ++i) std::array<std::array<std::uint16_t, 8>, 256> v{};
{ for (unsigned i = 0; i < 256; ++i)
std::uint64_t j = i, k = 0; {
while(j) std::uint64_t j = i, k = 0;
v[i][k++] = pop_lsb(j); while (j)
} v[i][k++] = pop_lsb(j);
return v; }
return v;
}(); }();
// Find indices of nonzero numbers in an int32_t array // Find indices of nonzero numbers in an int32_t array
template<const IndexType InputDimensions> template<const IndexType InputDimensions>
void find_nnz(const std::int32_t* input, std::uint16_t* out, IndexType& count_out) { void find_nnz(const std::int32_t* input, std::uint16_t* out, IndexType& count_out) {
#if defined (USE_SSSE3) #if defined(USE_SSSE3)
#if defined (USE_AVX512) #if defined(USE_AVX512)
using vec_t = __m512i; using vec_t = __m512i;
#define vec_nnz(a) _mm512_cmpgt_epi32_mask(a, _mm512_setzero_si512()) #define vec_nnz(a) _mm512_cmpgt_epi32_mask(a, _mm512_setzero_si512())
#elif defined (USE_AVX2) #elif defined(USE_AVX2)
using vec_t = __m256i; using vec_t = __m256i;
#if defined(USE_VNNI) && !defined(USE_AVXVNNI) #if defined(USE_VNNI) && !defined(USE_AVXVNNI)
#define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256()) #define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256())
#else #else
#define vec_nnz(a) _mm256_movemask_ps(_mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256()))) #define vec_nnz(a) \
_mm256_movemask_ps( \
_mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256())))
#endif
#elif defined(USE_SSSE3)
using vec_t = __m128i;
#define vec_nnz(a) \
_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128())))
#endif #endif
#elif defined (USE_SSSE3)
using vec_t = __m128i;
#define vec_nnz(a) _mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128())))
#endif
using vec128_t = __m128i; using vec128_t = __m128i;
#define vec128_zero _mm_setzero_si128() #define vec128_zero _mm_setzero_si128()
#define vec128_set_16(a) _mm_set1_epi16(a) #define vec128_set_16(a) _mm_set1_epi16(a)
#define vec128_load(a) _mm_load_si128(a) #define vec128_load(a) _mm_load_si128(a)
#define vec128_storeu(a, b) _mm_storeu_si128(a, b) #define vec128_storeu(a, b) _mm_storeu_si128(a, b)
#define vec128_add(a, b) _mm_add_epi16(a, b) #define vec128_add(a, b) _mm_add_epi16(a, b)
#elif defined (USE_NEON) #elif defined(USE_NEON)
using vec_t = uint32x4_t; using vec_t = uint32x4_t;
static const std::uint32_t Mask[4] = {1, 2, 4, 8}; static const std::uint32_t Mask[4] = {1, 2, 4, 8};
#define vec_nnz(a) vaddvq_u32(vandq_u32(vtstq_u32(a, a), vld1q_u32(Mask))) #define vec_nnz(a) vaddvq_u32(vandq_u32(vtstq_u32(a, a), vld1q_u32(Mask)))
using vec128_t = uint16x8_t; using vec128_t = uint16x8_t;
#define vec128_zero vdupq_n_u16(0) #define vec128_zero vdupq_n_u16(0)
#define vec128_set_16(a) vdupq_n_u16(a) #define vec128_set_16(a) vdupq_n_u16(a)
#define vec128_load(a) vld1q_u16(reinterpret_cast<const std::uint16_t*>(a)) #define vec128_load(a) vld1q_u16(reinterpret_cast<const std::uint16_t*>(a))
#define vec128_storeu(a, b) vst1q_u16(reinterpret_cast<std::uint16_t*>(a), b) #define vec128_storeu(a, b) vst1q_u16(reinterpret_cast<std::uint16_t*>(a), b)
#define vec128_add(a, b) vaddq_u16(a, b) #define vec128_add(a, b) vaddq_u16(a, b)
#endif #endif
constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(std::int32_t); constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(std::int32_t);
// Inputs are processed InputSimdWidth at a time and outputs are processed 8 at a time so we process in chunks of max(InputSimdWidth, 8) // Inputs are processed InputSimdWidth at a time and outputs are processed 8 at a time so we process in chunks of max(InputSimdWidth, 8)
constexpr IndexType ChunkSize = std::max<IndexType>(InputSimdWidth, 8); constexpr IndexType ChunkSize = std::max<IndexType>(InputSimdWidth, 8);
constexpr IndexType NumChunks = InputDimensions / ChunkSize; constexpr IndexType NumChunks = InputDimensions / ChunkSize;
constexpr IndexType InputsPerChunk = ChunkSize / InputSimdWidth; constexpr IndexType InputsPerChunk = ChunkSize / InputSimdWidth;
constexpr IndexType OutputsPerChunk = ChunkSize / 8; constexpr IndexType OutputsPerChunk = ChunkSize / 8;
const auto inputVector = reinterpret_cast<const vec_t*>(input); const auto inputVector = reinterpret_cast<const vec_t*>(input);
IndexType count = 0; IndexType count = 0;
vec128_t base = vec128_zero; vec128_t base = vec128_zero;
const vec128_t increment = vec128_set_16(8); const vec128_t increment = vec128_set_16(8);
for (IndexType i = 0; i < NumChunks; ++i) for (IndexType i = 0; i < NumChunks; ++i)
{ {
// bitmask of nonzero values in this chunk // bitmask of nonzero values in this chunk
unsigned nnz = 0; unsigned nnz = 0;
for (IndexType j = 0; j < InputsPerChunk; ++j) for (IndexType j = 0; j < InputsPerChunk; ++j)
{ {
const vec_t inputChunk = inputVector[i * InputsPerChunk + j]; const vec_t inputChunk = inputVector[i * InputsPerChunk + j];
nnz |= unsigned(vec_nnz(inputChunk)) << (j * InputSimdWidth); nnz |= unsigned(vec_nnz(inputChunk)) << (j * InputSimdWidth);
} }
for (IndexType j = 0; j < OutputsPerChunk; ++j) for (IndexType j = 0; j < OutputsPerChunk; ++j)
{ {
const auto lookup = (nnz >> (j * 8)) & 0xFF; const auto lookup = (nnz >> (j * 8)) & 0xFF;
const auto offsets = vec128_load(reinterpret_cast<const vec128_t*>(&lookup_indices[lookup])); const auto offsets =
vec128_storeu(reinterpret_cast<vec128_t*>(out + count), vec128_add(base, offsets)); vec128_load(reinterpret_cast<const vec128_t*>(&lookup_indices[lookup]));
count += popcount(lookup); vec128_storeu(reinterpret_cast<vec128_t*>(out + count), vec128_add(base, offsets));
base = vec128_add(base, increment); count += popcount(lookup);
} base = vec128_add(base, increment);
}
} }
count_out = count; count_out = count;
} }
# undef vec_nnz #undef vec_nnz
# undef vec128_zero #undef vec128_zero
# undef vec128_set_16 #undef vec128_set_16
# undef vec128_load #undef vec128_load
# undef vec128_storeu #undef vec128_storeu
# undef vec128_add #undef vec128_add
#endif #endif
// Sparse input implementation // Sparse input implementation
template <IndexType InDims, IndexType OutDims> template<IndexType InDims, IndexType OutDims>
class AffineTransformSparseInput { class AffineTransformSparseInput {
public: public:
// Input/output type // Input/output type
using InputType = std::uint8_t; using InputType = std::uint8_t;
using OutputType = std::int32_t; using OutputType = std::int32_t;
// Number of input/output dimensions // Number of input/output dimensions
static constexpr IndexType InputDimensions = InDims; static constexpr IndexType InputDimensions = InDims;
static constexpr IndexType OutputDimensions = OutDims; static constexpr IndexType OutputDimensions = OutDims;
static_assert(OutputDimensions % 16 == 0, "Only implemented for OutputDimensions divisible by 16."); static_assert(OutputDimensions % 16 == 0,
"Only implemented for OutputDimensions divisible by 16.");
static constexpr IndexType PaddedInputDimensions = static constexpr IndexType PaddedInputDimensions =
ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth); ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);
@ -152,127 +158,121 @@ namespace Stockfish::Eval::NNUE::Layers {
// Hash value embedded in the evaluation file // Hash value embedded in the evaluation file
static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) {
std::uint32_t hashValue = 0xCC03DAE4u; std::uint32_t hashValue = 0xCC03DAE4u;
hashValue += OutputDimensions; hashValue += OutputDimensions;
hashValue ^= prevHash >> 1; hashValue ^= prevHash >> 1;
hashValue ^= prevHash << 31; hashValue ^= prevHash << 31;
return hashValue; return hashValue;
} }
static constexpr IndexType get_weight_index_scrambled(IndexType i) static constexpr IndexType get_weight_index_scrambled(IndexType i) {
{ return (i / ChunkSize) % (PaddedInputDimensions / ChunkSize) * OutputDimensions * ChunkSize
return + i / PaddedInputDimensions * ChunkSize + i % ChunkSize;
(i / ChunkSize) % (PaddedInputDimensions / ChunkSize) * OutputDimensions * ChunkSize +
i / PaddedInputDimensions * ChunkSize +
i % ChunkSize;
} }
static constexpr IndexType get_weight_index(IndexType i) static constexpr IndexType get_weight_index(IndexType i) {
{
#if (USE_SSSE3 | (USE_NEON >= 8)) #if (USE_SSSE3 | (USE_NEON >= 8))
return get_weight_index_scrambled(i); return get_weight_index_scrambled(i);
#else #else
return i; return i;
#endif #endif
} }
// Read network parameters // Read network parameters
bool read_parameters(std::istream& stream) { bool read_parameters(std::istream& stream) {
read_little_endian<BiasType>(stream, biases, OutputDimensions); read_little_endian<BiasType>(stream, biases, OutputDimensions);
for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
weights[get_weight_index(i)] = read_little_endian<WeightType>(stream); weights[get_weight_index(i)] = read_little_endian<WeightType>(stream);
return !stream.fail(); return !stream.fail();
} }
// Write network parameters // Write network parameters
bool write_parameters(std::ostream& stream) const { bool write_parameters(std::ostream& stream) const {
write_little_endian<BiasType>(stream, biases, OutputDimensions); write_little_endian<BiasType>(stream, biases, OutputDimensions);
for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
write_little_endian<WeightType>(stream, weights[get_weight_index(i)]); write_little_endian<WeightType>(stream, weights[get_weight_index(i)]);
return !stream.fail(); return !stream.fail();
} }
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if (USE_SSSE3 | (USE_NEON >= 8)) #if (USE_SSSE3 | (USE_NEON >= 8))
#if defined (USE_AVX512) #if defined(USE_AVX512)
using invec_t = __m512i; using invec_t = __m512i;
using outvec_t = __m512i; using outvec_t = __m512i;
#define vec_set_32 _mm512_set1_epi32 #define vec_set_32 _mm512_set1_epi32
#define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32 #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
#elif defined (USE_AVX2) #elif defined(USE_AVX2)
using invec_t = __m256i; using invec_t = __m256i;
using outvec_t = __m256i; using outvec_t = __m256i;
#define vec_set_32 _mm256_set1_epi32 #define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#elif defined (USE_SSSE3) #elif defined(USE_SSSE3)
using invec_t = __m128i; using invec_t = __m128i;
using outvec_t = __m128i; using outvec_t = __m128i;
#define vec_set_32 _mm_set1_epi32 #define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#elif defined (USE_NEON_DOTPROD) #elif defined(USE_NEON_DOTPROD)
using invec_t = int8x16_t; using invec_t = int8x16_t;
using outvec_t = int32x4_t; using outvec_t = int32x4_t;
#define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a)) #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a))
#define vec_add_dpbusd_32 Simd::dotprod_m128_add_dpbusd_epi32 #define vec_add_dpbusd_32 Simd::dotprod_m128_add_dpbusd_epi32
#elif defined (USE_NEON) #elif defined(USE_NEON)
using invec_t = int8x16_t; using invec_t = int8x16_t;
using outvec_t = int32x4_t; using outvec_t = int32x4_t;
#define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a)) #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a))
#define vec_add_dpbusd_32 Simd::neon_m128_add_dpbusd_epi32 #define vec_add_dpbusd_32 Simd::neon_m128_add_dpbusd_epi32
#endif #endif
static constexpr IndexType OutputSimdWidth = sizeof(outvec_t) / sizeof(OutputType); static constexpr IndexType OutputSimdWidth = sizeof(outvec_t) / sizeof(OutputType);
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / ChunkSize; constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / ChunkSize;
constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth; constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
std::uint16_t nnz[NumChunks]; std::uint16_t nnz[NumChunks];
IndexType count; IndexType count;
const auto input32 = reinterpret_cast<const std::int32_t*>(input); const auto input32 = reinterpret_cast<const std::int32_t*>(input);
// Find indices of nonzero 32bit blocks // Find indices of nonzero 32bit blocks
find_nnz<NumChunks>(input32, nnz, count); find_nnz<NumChunks>(input32, nnz, count);
const outvec_t* biasvec = reinterpret_cast<const outvec_t*>(biases); const outvec_t* biasvec = reinterpret_cast<const outvec_t*>(biases);
outvec_t acc[NumRegs]; outvec_t acc[NumRegs];
for (IndexType k = 0; k < NumRegs; ++k)
acc[k] = biasvec[k];
for (IndexType j = 0; j < count; ++j)
{
const auto i = nnz[j];
const invec_t in = vec_set_32(input32[i]);
const auto col = reinterpret_cast<const invec_t*>(&weights[i * OutputDimensions * ChunkSize]);
for (IndexType k = 0; k < NumRegs; ++k) for (IndexType k = 0; k < NumRegs; ++k)
vec_add_dpbusd_32(acc[k], in, col[k]); acc[k] = biasvec[k];
}
outvec_t* outptr = reinterpret_cast<outvec_t*>(output); for (IndexType j = 0; j < count; ++j)
for (IndexType k = 0; k < NumRegs; ++k) {
outptr[k] = acc[k]; const auto i = nnz[j];
# undef vec_set_32 const invec_t in = vec_set_32(input32[i]);
# undef vec_add_dpbusd_32 const auto col =
reinterpret_cast<const invec_t*>(&weights[i * OutputDimensions * ChunkSize]);
for (IndexType k = 0; k < NumRegs; ++k)
vec_add_dpbusd_32(acc[k], in, col[k]);
}
outvec_t* outptr = reinterpret_cast<outvec_t*>(output);
for (IndexType k = 0; k < NumRegs; ++k)
outptr[k] = acc[k];
#undef vec_set_32
#undef vec_add_dpbusd_32
#else #else
// Use dense implementation for the other architectures. // Use dense implementation for the other architectures.
affine_transform_non_ssse3< affine_transform_non_ssse3<InputDimensions, PaddedInputDimensions, OutputDimensions>(
InputDimensions, output, weights, biases, input);
PaddedInputDimensions,
OutputDimensions>(output, weights, biases, input);
#endif #endif
} }
private: private:
using BiasType = OutputType; using BiasType = OutputType;
using WeightType = std::int8_t; using WeightType = std::int8_t;
alignas(CacheLineSize) BiasType biases[OutputDimensions]; alignas(CacheLineSize) BiasType biases[OutputDimensions];
alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions]; alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions];
}; };
} // namespace Stockfish::Eval::NNUE::Layers } // namespace Stockfish::Eval::NNUE::Layers
#endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_SPARSE_INPUT_H_INCLUDED #endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_SPARSE_INPUT_H_INCLUDED

View file

@ -29,136 +29,140 @@
namespace Stockfish::Eval::NNUE::Layers { namespace Stockfish::Eval::NNUE::Layers {
// Clipped ReLU // Clipped ReLU
template <IndexType InDims> template<IndexType InDims>
class ClippedReLU { class ClippedReLU {
public: public:
// Input/output type // Input/output type
using InputType = std::int32_t; using InputType = std::int32_t;
using OutputType = std::uint8_t; using OutputType = std::uint8_t;
// Number of input/output dimensions // Number of input/output dimensions
static constexpr IndexType InputDimensions = InDims; static constexpr IndexType InputDimensions = InDims;
static constexpr IndexType OutputDimensions = InputDimensions; static constexpr IndexType OutputDimensions = InputDimensions;
static constexpr IndexType PaddedOutputDimensions = static constexpr IndexType PaddedOutputDimensions =
ceil_to_multiple<IndexType>(OutputDimensions, 32); ceil_to_multiple<IndexType>(OutputDimensions, 32);
using OutputBuffer = OutputType[PaddedOutputDimensions]; using OutputBuffer = OutputType[PaddedOutputDimensions];
// Hash value embedded in the evaluation file // Hash value embedded in the evaluation file
static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) {
std::uint32_t hashValue = 0x538D24C7u; std::uint32_t hashValue = 0x538D24C7u;
hashValue += prevHash; hashValue += prevHash;
return hashValue; return hashValue;
} }
// Read network parameters // Read network parameters
bool read_parameters(std::istream&) { bool read_parameters(std::istream&) { return true; }
return true;
}
// Write network parameters // Write network parameters
bool write_parameters(std::ostream&) const { bool write_parameters(std::ostream&) const { return true; }
return true;
}
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if defined(USE_AVX2) #if defined(USE_AVX2)
if constexpr (InputDimensions % SimdWidth == 0) { if constexpr (InputDimensions % SimdWidth == 0)
{
constexpr IndexType NumChunks = InputDimensions / SimdWidth;
const __m256i Zero = _mm256_setzero_si256();
const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
const auto in = reinterpret_cast<const __m256i*>(input);
const auto out = reinterpret_cast<__m256i*>(output);
for (IndexType i = 0; i < NumChunks; ++i)
{
const __m256i words0 =
_mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 0]),
_mm256_load_si256(&in[i * 4 + 1])),
WeightScaleBits);
const __m256i words1 =
_mm256_srai_epi16(_mm256_packs_epi32(_mm256_load_si256(&in[i * 4 + 2]),
_mm256_load_si256(&in[i * 4 + 3])),
WeightScaleBits);
_mm256_store_si256(
&out[i], _mm256_permutevar8x32_epi32(
_mm256_max_epi8(_mm256_packs_epi16(words0, words1), Zero), Offsets));
}
}
else
{
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
const __m128i Zero = _mm_setzero_si128();
const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i)
{
const __m128i words0 = _mm_srai_epi16(
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
WeightScaleBits);
const __m128i words1 = _mm_srai_epi16(
_mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
WeightScaleBits);
const __m128i packedbytes = _mm_packs_epi16(words0, words1);
_mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero));
}
}
constexpr IndexType Start = InputDimensions % SimdWidth == 0
? InputDimensions / SimdWidth * SimdWidth
: InputDimensions / (SimdWidth / 2) * (SimdWidth / 2);
#elif defined(USE_SSE2)
constexpr IndexType NumChunks = InputDimensions / SimdWidth; constexpr IndexType NumChunks = InputDimensions / SimdWidth;
const __m256i Zero = _mm256_setzero_si256();
const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); #ifdef USE_SSE41
const auto in = reinterpret_cast<const __m256i*>(input);
const auto out = reinterpret_cast<__m256i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) {
const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32(
_mm256_load_si256(&in[i * 4 + 0]),
_mm256_load_si256(&in[i * 4 + 1])), WeightScaleBits);
const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32(
_mm256_load_si256(&in[i * 4 + 2]),
_mm256_load_si256(&in[i * 4 + 3])), WeightScaleBits);
_mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8(
_mm256_packs_epi16(words0, words1), Zero), Offsets));
}
} else {
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
const __m128i Zero = _mm_setzero_si128(); const __m128i Zero = _mm_setzero_si128();
const auto in = reinterpret_cast<const __m128i*>(input); #else
const __m128i k0x80s = _mm_set1_epi8(-128);
#endif
const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output); const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) { for (IndexType i = 0; i < NumChunks; ++i)
const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32( {
_mm_load_si128(&in[i * 4 + 0]), const __m128i words0 = _mm_srai_epi16(
_mm_load_si128(&in[i * 4 + 1])), WeightScaleBits); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1])),
const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32( WeightScaleBits);
_mm_load_si128(&in[i * 4 + 2]), const __m128i words1 = _mm_srai_epi16(
_mm_load_si128(&in[i * 4 + 3])), WeightScaleBits); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])),
const __m128i packedbytes = _mm_packs_epi16(words0, words1); WeightScaleBits);
_mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero)); const __m128i packedbytes = _mm_packs_epi16(words0, words1);
_mm_store_si128(&out[i],
#ifdef USE_SSE41
_mm_max_epi8(packedbytes, Zero)
#else
_mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
#endif
);
} }
} constexpr IndexType Start = NumChunks * SimdWidth;
constexpr IndexType Start =
InputDimensions % SimdWidth == 0
? InputDimensions / SimdWidth * SimdWidth
: InputDimensions / (SimdWidth / 2) * (SimdWidth / 2);
#elif defined(USE_SSE2) #elif defined(USE_NEON)
constexpr IndexType NumChunks = InputDimensions / SimdWidth; constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
const int8x8_t Zero = {0};
const auto in = reinterpret_cast<const int32x4_t*>(input);
const auto out = reinterpret_cast<int8x8_t*>(output);
for (IndexType i = 0; i < NumChunks; ++i)
{
int16x8_t shifted;
const auto pack = reinterpret_cast<int16x4_t*>(&shifted);
pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits);
pack[1] = vqshrn_n_s32(in[i * 2 + 1], WeightScaleBits);
out[i] = vmax_s8(vqmovn_s16(shifted), Zero);
}
constexpr IndexType Start = NumChunks * (SimdWidth / 2);
#else
constexpr IndexType Start = 0;
#endif
#ifdef USE_SSE41 for (IndexType i = Start; i < InputDimensions; ++i)
const __m128i Zero = _mm_setzero_si128(); {
#else output[i] = static_cast<OutputType>(std::clamp(input[i] >> WeightScaleBits, 0, 127));
const __m128i k0x80s = _mm_set1_epi8(-128); }
#endif
const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) {
const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32(
_mm_load_si128(&in[i * 4 + 0]),
_mm_load_si128(&in[i * 4 + 1])), WeightScaleBits);
const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32(
_mm_load_si128(&in[i * 4 + 2]),
_mm_load_si128(&in[i * 4 + 3])), WeightScaleBits);
const __m128i packedbytes = _mm_packs_epi16(words0, words1);
_mm_store_si128(&out[i],
#ifdef USE_SSE41
_mm_max_epi8(packedbytes, Zero)
#else
_mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
#endif
);
}
constexpr IndexType Start = NumChunks * SimdWidth;
#elif defined(USE_NEON)
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
const int8x8_t Zero = {0};
const auto in = reinterpret_cast<const int32x4_t*>(input);
const auto out = reinterpret_cast<int8x8_t*>(output);
for (IndexType i = 0; i < NumChunks; ++i) {
int16x8_t shifted;
const auto pack = reinterpret_cast<int16x4_t*>(&shifted);
pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits);
pack[1] = vqshrn_n_s32(in[i * 2 + 1], WeightScaleBits);
out[i] = vmax_s8(vqmovn_s16(shifted), Zero);
}
constexpr IndexType Start = NumChunks * (SimdWidth / 2);
#else
constexpr IndexType Start = 0;
#endif
for (IndexType i = Start; i < InputDimensions; ++i) {
output[i] = static_cast<OutputType>(
std::clamp(input[i] >> WeightScaleBits, 0, 127));
}
} }
}; };
} // namespace Stockfish::Eval::NNUE::Layers } // namespace Stockfish::Eval::NNUE::Layers
#endif // NNUE_LAYERS_CLIPPED_RELU_H_INCLUDED #endif // NNUE_LAYERS_CLIPPED_RELU_H_INCLUDED

View file

@ -20,30 +20,30 @@
#define STOCKFISH_SIMD_H_INCLUDED #define STOCKFISH_SIMD_H_INCLUDED
#if defined(USE_AVX2) #if defined(USE_AVX2)
# include <immintrin.h> #include <immintrin.h>
#elif defined(USE_SSE41) #elif defined(USE_SSE41)
# include <smmintrin.h> #include <smmintrin.h>
#elif defined(USE_SSSE3) #elif defined(USE_SSSE3)
# include <tmmintrin.h> #include <tmmintrin.h>
#elif defined(USE_SSE2) #elif defined(USE_SSE2)
# include <emmintrin.h> #include <emmintrin.h>
#elif defined(USE_NEON) #elif defined(USE_NEON)
# include <arm_neon.h> #include <arm_neon.h>
#endif #endif
namespace Stockfish::Simd { namespace Stockfish::Simd {
#if defined (USE_AVX512) #if defined(USE_AVX512)
[[maybe_unused]] static int m512_hadd(__m512i sum, int bias) { [[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
return _mm512_reduce_add_epi32(sum) + bias; return _mm512_reduce_add_epi32(sum) + bias;
} }
/* /*
Parameters: Parameters:
sum0 = [zmm0.i128[0], zmm0.i128[1], zmm0.i128[2], zmm0.i128[3]] sum0 = [zmm0.i128[0], zmm0.i128[1], zmm0.i128[2], zmm0.i128[3]]
sum1 = [zmm1.i128[0], zmm1.i128[1], zmm1.i128[2], zmm1.i128[3]] sum1 = [zmm1.i128[0], zmm1.i128[1], zmm1.i128[2], zmm1.i128[3]]
@ -58,186 +58,164 @@ namespace Stockfish::Simd {
reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3]) reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3])
] ]
*/ */
[[maybe_unused]] static __m512i m512_hadd128x16_interleave( [[maybe_unused]] static __m512i
__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) { m512_hadd128x16_interleave(__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) {
__m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1); __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
__m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1); __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
__m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3); __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
__m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3); __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
__m512i sum01 = _mm512_add_epi32(sum01a, sum01b); __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
__m512i sum23 = _mm512_add_epi32(sum23a, sum23b); __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
__m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23); __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
__m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23); __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
return _mm512_add_epi32(sum0123a, sum0123b); return _mm512_add_epi32(sum0123a, sum0123b);
} }
[[maybe_unused]] static void m512_add_dpbusd_epi32( [[maybe_unused]] static void m512_add_dpbusd_epi32(__m512i& acc, __m512i a, __m512i b) {
__m512i& acc,
__m512i a,
__m512i b) {
# if defined (USE_VNNI) #if defined(USE_VNNI)
acc = _mm512_dpbusd_epi32(acc, a, b); acc = _mm512_dpbusd_epi32(acc, a, b);
# else #else
__m512i product0 = _mm512_maddubs_epi16(a, b); __m512i product0 = _mm512_maddubs_epi16(a, b);
product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1)); product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
acc = _mm512_add_epi32(acc, product0); acc = _mm512_add_epi32(acc, product0);
# endif #endif
} }
[[maybe_unused]] static void m512_add_dpbusd_epi32x2( [[maybe_unused]] static void
__m512i& acc, m512_add_dpbusd_epi32x2(__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
__m512i a0, __m512i b0,
__m512i a1, __m512i b1) {
# if defined (USE_VNNI) #if defined(USE_VNNI)
acc = _mm512_dpbusd_epi32(acc, a0, b0); acc = _mm512_dpbusd_epi32(acc, a0, b0);
acc = _mm512_dpbusd_epi32(acc, a1, b1); acc = _mm512_dpbusd_epi32(acc, a1, b1);
# else #else
__m512i product0 = _mm512_maddubs_epi16(a0, b0); __m512i product0 = _mm512_maddubs_epi16(a0, b0);
__m512i product1 = _mm512_maddubs_epi16(a1, b1); __m512i product1 = _mm512_maddubs_epi16(a1, b1);
product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1)); product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1)); product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1));
acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1)); acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1));
# endif #endif
} }
#endif #endif
#if defined (USE_AVX2) #if defined(USE_AVX2)
[[maybe_unused]] static int m256_hadd(__m256i sum, int bias) { [[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC)); sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB)); sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
return _mm_cvtsi128_si32(sum128) + bias; return _mm_cvtsi128_si32(sum128) + bias;
} }
[[maybe_unused]] static void m256_add_dpbusd_epi32( [[maybe_unused]] static void m256_add_dpbusd_epi32(__m256i& acc, __m256i a, __m256i b) {
__m256i& acc,
__m256i a,
__m256i b) {
# if defined (USE_VNNI) #if defined(USE_VNNI)
acc = _mm256_dpbusd_epi32(acc, a, b); acc = _mm256_dpbusd_epi32(acc, a, b);
# else #else
__m256i product0 = _mm256_maddubs_epi16(a, b); __m256i product0 = _mm256_maddubs_epi16(a, b);
product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1)); product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
acc = _mm256_add_epi32(acc, product0); acc = _mm256_add_epi32(acc, product0);
# endif #endif
} }
[[maybe_unused]] static void m256_add_dpbusd_epi32x2( [[maybe_unused]] static void
__m256i& acc, m256_add_dpbusd_epi32x2(__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
__m256i a0, __m256i b0,
__m256i a1, __m256i b1) {
# if defined (USE_VNNI) #if defined(USE_VNNI)
acc = _mm256_dpbusd_epi32(acc, a0, b0); acc = _mm256_dpbusd_epi32(acc, a0, b0);
acc = _mm256_dpbusd_epi32(acc, a1, b1); acc = _mm256_dpbusd_epi32(acc, a1, b1);
# else #else
__m256i product0 = _mm256_maddubs_epi16(a0, b0); __m256i product0 = _mm256_maddubs_epi16(a0, b0);
__m256i product1 = _mm256_maddubs_epi16(a1, b1); __m256i product1 = _mm256_maddubs_epi16(a1, b1);
product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1)); product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1)); product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1));
acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1)); acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1));
# endif #endif
} }
#endif #endif
#if defined (USE_SSSE3) #if defined(USE_SSSE3)
[[maybe_unused]] static int m128_hadd(__m128i sum, int bias) { [[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
return _mm_cvtsi128_si32(sum) + bias; return _mm_cvtsi128_si32(sum) + bias;
} }
[[maybe_unused]] static void m128_add_dpbusd_epi32( [[maybe_unused]] static void m128_add_dpbusd_epi32(__m128i& acc, __m128i a, __m128i b) {
__m128i& acc,
__m128i a,
__m128i b) {
__m128i product0 = _mm_maddubs_epi16(a, b); __m128i product0 = _mm_maddubs_epi16(a, b);
product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
acc = _mm_add_epi32(acc, product0); acc = _mm_add_epi32(acc, product0);
} }
[[maybe_unused]] static void m128_add_dpbusd_epi32x2( [[maybe_unused]] static void
__m128i& acc, m128_add_dpbusd_epi32x2(__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
__m128i a0, __m128i b0,
__m128i a1, __m128i b1) {
__m128i product0 = _mm_maddubs_epi16(a0, b0); __m128i product0 = _mm_maddubs_epi16(a0, b0);
__m128i product1 = _mm_maddubs_epi16(a1, b1); __m128i product1 = _mm_maddubs_epi16(a1, b1);
product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1)); product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1)); product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1));
acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1)); acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1));
} }
#endif #endif
#if defined (USE_NEON_DOTPROD) #if defined(USE_NEON_DOTPROD)
[[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2( [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(
int32x4_t& acc, int32x4_t& acc, int8x16_t a0, int8x16_t b0, int8x16_t a1, int8x16_t b1) {
int8x16_t a0, int8x16_t b0,
int8x16_t a1, int8x16_t b1) {
acc = vdotq_s32(acc, a0, b0); acc = vdotq_s32(acc, a0, b0);
acc = vdotq_s32(acc, a1, b1); acc = vdotq_s32(acc, a1, b1);
} }
[[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32( [[maybe_unused]] static void
int32x4_t& acc, dotprod_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
int8x16_t a, int8x16_t b) {
acc = vdotq_s32(acc, a, b); acc = vdotq_s32(acc, a, b);
} }
#endif #endif
#if defined (USE_NEON) #if defined(USE_NEON)
[[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) { [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
# if USE_NEON >= 8 #if USE_NEON >= 8
return vaddvq_s32(s); return vaddvq_s32(s);
# else #else
return s[0] + s[1] + s[2] + s[3]; return s[0] + s[1] + s[2] + s[3];
# endif #endif
} }
[[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) { [[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
return neon_m128_reduce_add_epi32(sum) + bias; return neon_m128_reduce_add_epi32(sum) + bias;
} }
[[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2( [[maybe_unused]] static void
int32x4_t& acc, neon_m128_add_dpbusd_epi32x2(int32x4_t& acc, int8x8_t a0, int8x8_t b0, int8x8_t a1, int8x8_t b1) {
int8x8_t a0, int8x8_t b0,
int8x8_t a1, int8x8_t b1) {
int16x8_t product = vmull_s8(a0, b0); int16x8_t product = vmull_s8(a0, b0);
product = vmlal_s8(product, a1, b1); product = vmlal_s8(product, a1, b1);
acc = vpadalq_s16(acc, product); acc = vpadalq_s16(acc, product);
} }
#endif #endif
#if USE_NEON >= 8 #if USE_NEON >= 8
[[maybe_unused]] static void neon_m128_add_dpbusd_epi32( [[maybe_unused]] static void neon_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
int32x4_t& acc,
int8x16_t a, int8x16_t b) {
int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
int16x8_t product1 = vmull_high_s8(a, b); int16x8_t product1 = vmull_high_s8(a, b);
int16x8_t sum = vpaddq_s16(product0, product1); int16x8_t sum = vpaddq_s16(product0, product1);
acc = vpadalq_s16(acc, sum); acc = vpadalq_s16(acc, sum);
} }
#endif #endif
} }
#endif // STOCKFISH_SIMD_H_INCLUDED #endif // STOCKFISH_SIMD_H_INCLUDED

View file

@ -29,80 +29,75 @@
namespace Stockfish::Eval::NNUE::Layers { namespace Stockfish::Eval::NNUE::Layers {
// Clipped ReLU // Clipped ReLU
template <IndexType InDims> template<IndexType InDims>
class SqrClippedReLU { class SqrClippedReLU {
public: public:
// Input/output type // Input/output type
using InputType = std::int32_t; using InputType = std::int32_t;
using OutputType = std::uint8_t; using OutputType = std::uint8_t;
// Number of input/output dimensions // Number of input/output dimensions
static constexpr IndexType InputDimensions = InDims; static constexpr IndexType InputDimensions = InDims;
static constexpr IndexType OutputDimensions = InputDimensions; static constexpr IndexType OutputDimensions = InputDimensions;
static constexpr IndexType PaddedOutputDimensions = static constexpr IndexType PaddedOutputDimensions =
ceil_to_multiple<IndexType>(OutputDimensions, 32); ceil_to_multiple<IndexType>(OutputDimensions, 32);
using OutputBuffer = OutputType[PaddedOutputDimensions]; using OutputBuffer = OutputType[PaddedOutputDimensions];
// Hash value embedded in the evaluation file // Hash value embedded in the evaluation file
static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) { static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) {
std::uint32_t hashValue = 0x538D24C7u; std::uint32_t hashValue = 0x538D24C7u;
hashValue += prevHash; hashValue += prevHash;
return hashValue; return hashValue;
} }
// Read network parameters // Read network parameters
bool read_parameters(std::istream&) { bool read_parameters(std::istream&) { return true; }
return true;
}
// Write network parameters // Write network parameters
bool write_parameters(std::ostream&) const { bool write_parameters(std::ostream&) const { return true; }
return true;
}
// Forward propagation // Forward propagation
void propagate( void propagate(const InputType* input, OutputType* output) const {
const InputType* input, OutputType* output) const {
#if defined(USE_SSE2) #if defined(USE_SSE2)
constexpr IndexType NumChunks = InputDimensions / 16; constexpr IndexType NumChunks = InputDimensions / 16;
static_assert(WeightScaleBits == 6); static_assert(WeightScaleBits == 6);
const auto in = reinterpret_cast<const __m128i*>(input); const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output); const auto out = reinterpret_cast<__m128i*>(output);
for (IndexType i = 0; i < NumChunks; ++i) { for (IndexType i = 0; i < NumChunks; ++i)
__m128i words0 = _mm_packs_epi32( {
_mm_load_si128(&in[i * 4 + 0]), __m128i words0 =
_mm_load_si128(&in[i * 4 + 1])); _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1]));
__m128i words1 = _mm_packs_epi32( __m128i words1 =
_mm_load_si128(&in[i * 4 + 2]), _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3]));
_mm_load_si128(&in[i * 4 + 3]));
// We shift by WeightScaleBits * 2 = 12 and divide by 128 // We shift by WeightScaleBits * 2 = 12 and divide by 128
// which is an additional shift-right of 7, meaning 19 in total. // which is an additional shift-right of 7, meaning 19 in total.
// MulHi strips the lower 16 bits so we need to shift out 3 more to match. // MulHi strips the lower 16 bits so we need to shift out 3 more to match.
words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3); words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3); words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);
_mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)); _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
} }
constexpr IndexType Start = NumChunks * 16; constexpr IndexType Start = NumChunks * 16;
#else #else
constexpr IndexType Start = 0; constexpr IndexType Start = 0;
#endif #endif
for (IndexType i = Start; i < InputDimensions; ++i) { for (IndexType i = Start; i < InputDimensions; ++i)
output[i] = static_cast<OutputType>( {
// Really should be /127 but we need to make it fast so we right shift output[i] = static_cast<OutputType>(
// by an extra 7 bits instead. Needs to be accounted for in the trainer. // Really should be /127 but we need to make it fast so we right shift
std::min(127ll, ((long long)input[i] * input[i]) >> (2 * WeightScaleBits + 7))); // by an extra 7 bits instead. Needs to be accounted for in the trainer.
} std::min(127ll, ((long long) input[i] * input[i]) >> (2 * WeightScaleBits + 7)));
}
} }
}; };
} // namespace Stockfish::Eval::NNUE::Layers } // namespace Stockfish::Eval::NNUE::Layers
#endif // NNUE_LAYERS_SQR_CLIPPED_RELU_H_INCLUDED #endif // NNUE_LAYERS_SQR_CLIPPED_RELU_H_INCLUDED

View file

@ -28,13 +28,13 @@
namespace Stockfish::Eval::NNUE { namespace Stockfish::Eval::NNUE {
// Class that holds the result of affine transformation of input features // Class that holds the result of affine transformation of input features
struct alignas(CacheLineSize) Accumulator { struct alignas(CacheLineSize) Accumulator {
std::int16_t accumulation[2][TransformedFeatureDimensions]; std::int16_t accumulation[2][TransformedFeatureDimensions];
std::int32_t psqtAccumulation[2][PSQTBuckets]; std::int32_t psqtAccumulation[2][PSQTBuckets];
bool computed[2]; bool computed[2];
}; };
} // namespace Stockfish::Eval::NNUE } // namespace Stockfish::Eval::NNUE
#endif // NNUE_ACCUMULATOR_H_INCLUDED #endif // NNUE_ACCUMULATOR_H_INCLUDED

View file

@ -39,97 +39,90 @@ using FeatureSet = Features::HalfKAv2_hm;
// Number of input feature dimensions after conversion // Number of input feature dimensions after conversion
constexpr IndexType TransformedFeatureDimensions = 2560; constexpr IndexType TransformedFeatureDimensions = 2560;
constexpr IndexType PSQTBuckets = 8; constexpr IndexType PSQTBuckets = 8;
constexpr IndexType LayerStacks = 8; constexpr IndexType LayerStacks = 8;
struct Network struct Network {
{ static constexpr int FC_0_OUTPUTS = 15;
static constexpr int FC_0_OUTPUTS = 15; static constexpr int FC_1_OUTPUTS = 32;
static constexpr int FC_1_OUTPUTS = 32;
Layers::AffineTransformSparseInput<TransformedFeatureDimensions, FC_0_OUTPUTS + 1> fc_0; Layers::AffineTransformSparseInput<TransformedFeatureDimensions, FC_0_OUTPUTS + 1> fc_0;
Layers::SqrClippedReLU<FC_0_OUTPUTS + 1> ac_sqr_0; Layers::SqrClippedReLU<FC_0_OUTPUTS + 1> ac_sqr_0;
Layers::ClippedReLU<FC_0_OUTPUTS + 1> ac_0; Layers::ClippedReLU<FC_0_OUTPUTS + 1> ac_0;
Layers::AffineTransform<FC_0_OUTPUTS * 2, FC_1_OUTPUTS> fc_1; Layers::AffineTransform<FC_0_OUTPUTS * 2, FC_1_OUTPUTS> fc_1;
Layers::ClippedReLU<FC_1_OUTPUTS> ac_1; Layers::ClippedReLU<FC_1_OUTPUTS> ac_1;
Layers::AffineTransform<FC_1_OUTPUTS, 1> fc_2; Layers::AffineTransform<FC_1_OUTPUTS, 1> fc_2;
// Hash value embedded in the evaluation file // Hash value embedded in the evaluation file
static constexpr std::uint32_t get_hash_value() { static constexpr std::uint32_t get_hash_value() {
// input slice hash // input slice hash
std::uint32_t hashValue = 0xEC42E90Du; std::uint32_t hashValue = 0xEC42E90Du;
hashValue ^= TransformedFeatureDimensions * 2; hashValue ^= TransformedFeatureDimensions * 2;
hashValue = decltype(fc_0)::get_hash_value(hashValue); hashValue = decltype(fc_0)::get_hash_value(hashValue);
hashValue = decltype(ac_0)::get_hash_value(hashValue); hashValue = decltype(ac_0)::get_hash_value(hashValue);
hashValue = decltype(fc_1)::get_hash_value(hashValue); hashValue = decltype(fc_1)::get_hash_value(hashValue);
hashValue = decltype(ac_1)::get_hash_value(hashValue); hashValue = decltype(ac_1)::get_hash_value(hashValue);
hashValue = decltype(fc_2)::get_hash_value(hashValue); hashValue = decltype(fc_2)::get_hash_value(hashValue);
return hashValue; return hashValue;
} }
// Read network parameters // Read network parameters
bool read_parameters(std::istream& stream) { bool read_parameters(std::istream& stream) {
return fc_0.read_parameters(stream) return fc_0.read_parameters(stream) && ac_0.read_parameters(stream)
&& ac_0.read_parameters(stream) && fc_1.read_parameters(stream) && ac_1.read_parameters(stream)
&& fc_1.read_parameters(stream) && fc_2.read_parameters(stream);
&& ac_1.read_parameters(stream) }
&& fc_2.read_parameters(stream);
}
// Write network parameters // Write network parameters
bool write_parameters(std::ostream& stream) const { bool write_parameters(std::ostream& stream) const {
return fc_0.write_parameters(stream) return fc_0.write_parameters(stream) && ac_0.write_parameters(stream)
&& ac_0.write_parameters(stream) && fc_1.write_parameters(stream) && ac_1.write_parameters(stream)
&& fc_1.write_parameters(stream) && fc_2.write_parameters(stream);
&& ac_1.write_parameters(stream) }
&& fc_2.write_parameters(stream);
}
std::int32_t propagate(const TransformedFeatureType* transformedFeatures) std::int32_t propagate(const TransformedFeatureType* transformedFeatures) {
{ struct alignas(CacheLineSize) Buffer {
struct alignas(CacheLineSize) Buffer alignas(CacheLineSize) decltype(fc_0)::OutputBuffer fc_0_out;
{ alignas(CacheLineSize) decltype(ac_sqr_0)::OutputType
alignas(CacheLineSize) decltype(fc_0)::OutputBuffer fc_0_out; ac_sqr_0_out[ceil_to_multiple<IndexType>(FC_0_OUTPUTS * 2, 32)];
alignas(CacheLineSize) decltype(ac_sqr_0)::OutputType ac_sqr_0_out[ceil_to_multiple<IndexType>(FC_0_OUTPUTS * 2, 32)]; alignas(CacheLineSize) decltype(ac_0)::OutputBuffer ac_0_out;
alignas(CacheLineSize) decltype(ac_0)::OutputBuffer ac_0_out; alignas(CacheLineSize) decltype(fc_1)::OutputBuffer fc_1_out;
alignas(CacheLineSize) decltype(fc_1)::OutputBuffer fc_1_out; alignas(CacheLineSize) decltype(ac_1)::OutputBuffer ac_1_out;
alignas(CacheLineSize) decltype(ac_1)::OutputBuffer ac_1_out; alignas(CacheLineSize) decltype(fc_2)::OutputBuffer fc_2_out;
alignas(CacheLineSize) decltype(fc_2)::OutputBuffer fc_2_out;
Buffer() Buffer() { std::memset(this, 0, sizeof(*this)); }
{ };
std::memset(this, 0, sizeof(*this));
}
};
#if defined(__clang__) && (__APPLE__) #if defined(__clang__) && (__APPLE__)
// workaround for a bug reported with xcode 12 // workaround for a bug reported with xcode 12
static thread_local auto tlsBuffer = std::make_unique<Buffer>(); static thread_local auto tlsBuffer = std::make_unique<Buffer>();
// Access TLS only once, cache result. // Access TLS only once, cache result.
Buffer& buffer = *tlsBuffer; Buffer& buffer = *tlsBuffer;
#else #else
alignas(CacheLineSize) static thread_local Buffer buffer; alignas(CacheLineSize) static thread_local Buffer buffer;
#endif #endif
fc_0.propagate(transformedFeatures, buffer.fc_0_out); fc_0.propagate(transformedFeatures, buffer.fc_0_out);
ac_sqr_0.propagate(buffer.fc_0_out, buffer.ac_sqr_0_out); ac_sqr_0.propagate(buffer.fc_0_out, buffer.ac_sqr_0_out);
ac_0.propagate(buffer.fc_0_out, buffer.ac_0_out); ac_0.propagate(buffer.fc_0_out, buffer.ac_0_out);
std::memcpy(buffer.ac_sqr_0_out + FC_0_OUTPUTS, buffer.ac_0_out, FC_0_OUTPUTS * sizeof(decltype(ac_0)::OutputType)); std::memcpy(buffer.ac_sqr_0_out + FC_0_OUTPUTS, buffer.ac_0_out,
fc_1.propagate(buffer.ac_sqr_0_out, buffer.fc_1_out); FC_0_OUTPUTS * sizeof(decltype(ac_0)::OutputType));
ac_1.propagate(buffer.fc_1_out, buffer.ac_1_out); fc_1.propagate(buffer.ac_sqr_0_out, buffer.fc_1_out);
fc_2.propagate(buffer.ac_1_out, buffer.fc_2_out); ac_1.propagate(buffer.fc_1_out, buffer.ac_1_out);
fc_2.propagate(buffer.ac_1_out, buffer.fc_2_out);
// buffer.fc_0_out[FC_0_OUTPUTS] is such that 1.0 is equal to 127*(1<<WeightScaleBits) in quantized form // buffer.fc_0_out[FC_0_OUTPUTS] is such that 1.0 is equal to 127*(1<<WeightScaleBits) in quantized form
// but we want 1.0 to be equal to 600*OutputScale // but we want 1.0 to be equal to 600*OutputScale
std::int32_t fwdOut = int(buffer.fc_0_out[FC_0_OUTPUTS]) * (600*OutputScale) / (127*(1<<WeightScaleBits)); std::int32_t fwdOut =
std::int32_t outputValue = buffer.fc_2_out[0] + fwdOut; int(buffer.fc_0_out[FC_0_OUTPUTS]) * (600 * OutputScale) / (127 * (1 << WeightScaleBits));
std::int32_t outputValue = buffer.fc_2_out[0] + fwdOut;
return outputValue; return outputValue;
} }
}; };
} // namespace Stockfish::Eval::NNUE } // namespace Stockfish::Eval::NNUE
#endif // #ifndef NNUE_ARCHITECTURE_H_INCLUDED #endif // #ifndef NNUE_ARCHITECTURE_H_INCLUDED

View file

@ -31,255 +31,254 @@
#include "../misc.h" #include "../misc.h"
#if defined(USE_AVX2) #if defined(USE_AVX2)
#include <immintrin.h> #include <immintrin.h>
#elif defined(USE_SSE41) #elif defined(USE_SSE41)
#include <smmintrin.h> #include <smmintrin.h>
#elif defined(USE_SSSE3) #elif defined(USE_SSSE3)
#include <tmmintrin.h> #include <tmmintrin.h>
#elif defined(USE_SSE2) #elif defined(USE_SSE2)
#include <emmintrin.h> #include <emmintrin.h>
#elif defined(USE_NEON) #elif defined(USE_NEON)
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
namespace Stockfish::Eval::NNUE { namespace Stockfish::Eval::NNUE {
// Version of the evaluation file // Version of the evaluation file
constexpr std::uint32_t Version = 0x7AF32F20u; constexpr std::uint32_t Version = 0x7AF32F20u;
// Constant used in evaluation value calculation // Constant used in evaluation value calculation
constexpr int OutputScale = 16; constexpr int OutputScale = 16;
constexpr int WeightScaleBits = 6; constexpr int WeightScaleBits = 6;
// Size of cache line (in bytes) // Size of cache line (in bytes)
constexpr std::size_t CacheLineSize = 64; constexpr std::size_t CacheLineSize = 64;
constexpr const char Leb128MagicString[] = "COMPRESSED_LEB128"; constexpr const char Leb128MagicString[] = "COMPRESSED_LEB128";
constexpr const std::size_t Leb128MagicStringSize = sizeof(Leb128MagicString) - 1; constexpr const std::size_t Leb128MagicStringSize = sizeof(Leb128MagicString) - 1;
// SIMD width (in bytes) // SIMD width (in bytes)
#if defined(USE_AVX2) #if defined(USE_AVX2)
constexpr std::size_t SimdWidth = 32; constexpr std::size_t SimdWidth = 32;
#elif defined(USE_SSE2) #elif defined(USE_SSE2)
constexpr std::size_t SimdWidth = 16; constexpr std::size_t SimdWidth = 16;
#elif defined(USE_NEON) #elif defined(USE_NEON)
constexpr std::size_t SimdWidth = 16; constexpr std::size_t SimdWidth = 16;
#endif #endif
constexpr std::size_t MaxSimdWidth = 32; constexpr std::size_t MaxSimdWidth = 32;
// Type of input feature after conversion // Type of input feature after conversion
using TransformedFeatureType = std::uint8_t; using TransformedFeatureType = std::uint8_t;
using IndexType = std::uint32_t; using IndexType = std::uint32_t;
// Round n up to be a multiple of base // Round n up to be a multiple of base
template <typename IntType> template<typename IntType>
constexpr IntType ceil_to_multiple(IntType n, IntType base) { constexpr IntType ceil_to_multiple(IntType n, IntType base) {
return (n + base - 1) / base * base; return (n + base - 1) / base * base;
} }
// read_little_endian() is our utility to read an integer (signed or unsigned, any size) // read_little_endian() is our utility to read an integer (signed or unsigned, any size)
// from a stream in little-endian order. We swap the byte order after the read if // from a stream in little-endian order. We swap the byte order after the read if
// necessary to return a result with the byte ordering of the compiling machine. // necessary to return a result with the byte ordering of the compiling machine.
template <typename IntType> template<typename IntType>
inline IntType read_little_endian(std::istream& stream) { inline IntType read_little_endian(std::istream& stream) {
IntType result; IntType result;
if (IsLittleEndian) if (IsLittleEndian)
stream.read(reinterpret_cast<char*>(&result), sizeof(IntType)); stream.read(reinterpret_cast<char*>(&result), sizeof(IntType));
else else
{ {
std::uint8_t u[sizeof(IntType)]; std::uint8_t u[sizeof(IntType)];
std::make_unsigned_t<IntType> v = 0; std::make_unsigned_t<IntType> v = 0;
stream.read(reinterpret_cast<char*>(u), sizeof(IntType)); stream.read(reinterpret_cast<char*>(u), sizeof(IntType));
for (std::size_t i = 0; i < sizeof(IntType); ++i) for (std::size_t i = 0; i < sizeof(IntType); ++i)
v = (v << 8) | u[sizeof(IntType) - i - 1]; v = (v << 8) | u[sizeof(IntType) - i - 1];
std::memcpy(&result, &v, sizeof(IntType)); std::memcpy(&result, &v, sizeof(IntType));
} }
return result; return result;
} }
// write_little_endian() is our utility to write an integer (signed or unsigned, any size) // write_little_endian() is our utility to write an integer (signed or unsigned, any size)
// to a stream in little-endian order. We swap the byte order before the write if // to a stream in little-endian order. We swap the byte order before the write if
// necessary to always write in little endian order, independently of the byte // necessary to always write in little endian order, independently of the byte
// ordering of the compiling machine. // ordering of the compiling machine.
template <typename IntType> template<typename IntType>
inline void write_little_endian(std::ostream& stream, IntType value) { inline void write_little_endian(std::ostream& stream, IntType value) {
if (IsLittleEndian) if (IsLittleEndian)
stream.write(reinterpret_cast<const char*>(&value), sizeof(IntType)); stream.write(reinterpret_cast<const char*>(&value), sizeof(IntType));
else else
{ {
std::uint8_t u[sizeof(IntType)]; std::uint8_t u[sizeof(IntType)];
std::make_unsigned_t<IntType> v = value; std::make_unsigned_t<IntType> v = value;
std::size_t i = 0; std::size_t i = 0;
// if constexpr to silence the warning about shift by 8 // if constexpr to silence the warning about shift by 8
if constexpr (sizeof(IntType) > 1) if constexpr (sizeof(IntType) > 1)
{ {
for (; i + 1 < sizeof(IntType); ++i) for (; i + 1 < sizeof(IntType); ++i)
{ {
u[i] = (std::uint8_t)v; u[i] = (std::uint8_t) v;
v >>= 8; v >>= 8;
} }
} }
u[i] = (std::uint8_t)v; u[i] = (std::uint8_t) v;
stream.write(reinterpret_cast<char*>(u), sizeof(IntType)); stream.write(reinterpret_cast<char*>(u), sizeof(IntType));
} }
} }
// read_little_endian(s, out, N) : read integers in bulk from a little indian stream. // read_little_endian(s, out, N) : read integers in bulk from a little indian stream.
// This reads N integers from stream s and put them in array out. // This reads N integers from stream s and put them in array out.
template <typename IntType> template<typename IntType>
inline void read_little_endian(std::istream& stream, IntType* out, std::size_t count) { inline void read_little_endian(std::istream& stream, IntType* out, std::size_t count) {
if (IsLittleEndian) if (IsLittleEndian)
stream.read(reinterpret_cast<char*>(out), sizeof(IntType) * count); stream.read(reinterpret_cast<char*>(out), sizeof(IntType) * count);
else else
for (std::size_t i = 0; i < count; ++i) for (std::size_t i = 0; i < count; ++i)
out[i] = read_little_endian<IntType>(stream); out[i] = read_little_endian<IntType>(stream);
} }
// write_little_endian(s, values, N) : write integers in bulk to a little indian stream. // write_little_endian(s, values, N) : write integers in bulk to a little indian stream.
// This takes N integers from array values and writes them on stream s. // This takes N integers from array values and writes them on stream s.
template <typename IntType> template<typename IntType>
inline void write_little_endian(std::ostream& stream, const IntType* values, std::size_t count) { inline void write_little_endian(std::ostream& stream, const IntType* values, std::size_t count) {
if (IsLittleEndian) if (IsLittleEndian)
stream.write(reinterpret_cast<const char*>(values), sizeof(IntType) * count); stream.write(reinterpret_cast<const char*>(values), sizeof(IntType) * count);
else else
for (std::size_t i = 0; i < count; ++i) for (std::size_t i = 0; i < count; ++i)
write_little_endian<IntType>(stream, values[i]); write_little_endian<IntType>(stream, values[i]);
} }
// read_leb_128(s, out, N) : read N signed integers from the stream s, putting them in // read_leb_128(s, out, N) : read N signed integers from the stream s, putting them in
// the array out. The stream is assumed to be compressed using the signed LEB128 format. // the array out. The stream is assumed to be compressed using the signed LEB128 format.
// See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme.
template <typename IntType> template<typename IntType>
inline void read_leb_128(std::istream& stream, IntType* out, std::size_t count) { inline void read_leb_128(std::istream& stream, IntType* out, std::size_t count) {
// Check the presence of our LEB128 magic string // Check the presence of our LEB128 magic string
char leb128MagicString[Leb128MagicStringSize]; char leb128MagicString[Leb128MagicStringSize];
stream.read(leb128MagicString, Leb128MagicStringSize); stream.read(leb128MagicString, Leb128MagicStringSize);
assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0); assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0);
static_assert(std::is_signed_v<IntType>, "Not implemented for unsigned types"); static_assert(std::is_signed_v<IntType>, "Not implemented for unsigned types");
const std::uint32_t BUF_SIZE = 4096; const std::uint32_t BUF_SIZE = 4096;
std::uint8_t buf[BUF_SIZE]; std::uint8_t buf[BUF_SIZE];
auto bytes_left = read_little_endian<std::uint32_t>(stream); auto bytes_left = read_little_endian<std::uint32_t>(stream);
std::uint32_t buf_pos = BUF_SIZE; std::uint32_t buf_pos = BUF_SIZE;
for (std::size_t i = 0; i < count; ++i) for (std::size_t i = 0; i < count; ++i)
{ {
IntType result = 0; IntType result = 0;
size_t shift = 0; size_t shift = 0;
do do
{ {
if (buf_pos == BUF_SIZE) if (buf_pos == BUF_SIZE)
{ {
stream.read(reinterpret_cast<char*>(buf), std::min(bytes_left, BUF_SIZE)); stream.read(reinterpret_cast<char*>(buf), std::min(bytes_left, BUF_SIZE));
buf_pos = 0; buf_pos = 0;
} }
std::uint8_t byte = buf[buf_pos++]; std::uint8_t byte = buf[buf_pos++];
--bytes_left; --bytes_left;
result |= (byte & 0x7f) << shift; result |= (byte & 0x7f) << shift;
shift += 7; shift += 7;
if ((byte & 0x80) == 0) if ((byte & 0x80) == 0)
{ {
out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0) ? result out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0)
: result | ~((1 << shift) - 1); ? result
break; : result | ~((1 << shift) - 1);
} break;
} }
while (shift < sizeof(IntType) * 8); } while (shift < sizeof(IntType) * 8);
} }
assert(bytes_left == 0); assert(bytes_left == 0);
} }
// write_leb_128(s, values, N) : write signed integers to a stream with LEB128 compression. // write_leb_128(s, values, N) : write signed integers to a stream with LEB128 compression.
// This takes N integers from array values, compress them with the LEB128 algorithm and // This takes N integers from array values, compress them with the LEB128 algorithm and
// writes the result on the stream s. // writes the result on the stream s.
// See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme. // See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme.
template <typename IntType> template<typename IntType>
inline void write_leb_128(std::ostream& stream, const IntType* values, std::size_t count) { inline void write_leb_128(std::ostream& stream, const IntType* values, std::size_t count) {
// Write our LEB128 magic string // Write our LEB128 magic string
stream.write(Leb128MagicString, Leb128MagicStringSize); stream.write(Leb128MagicString, Leb128MagicStringSize);
static_assert(std::is_signed_v<IntType>, "Not implemented for unsigned types"); static_assert(std::is_signed_v<IntType>, "Not implemented for unsigned types");
std::uint32_t byte_count = 0; std::uint32_t byte_count = 0;
for (std::size_t i = 0; i < count; ++i) for (std::size_t i = 0; i < count; ++i)
{ {
IntType value = values[i]; IntType value = values[i];
std::uint8_t byte; std::uint8_t byte;
do do
{ {
byte = value & 0x7f; byte = value & 0x7f;
value >>= 7; value >>= 7;
++byte_count; ++byte_count;
} } while ((byte & 0x40) == 0 ? value != 0 : value != -1);
while ((byte & 0x40) == 0 ? value != 0 : value != -1); }
}
write_little_endian(stream, byte_count); write_little_endian(stream, byte_count);
const std::uint32_t BUF_SIZE = 4096; const std::uint32_t BUF_SIZE = 4096;
std::uint8_t buf[BUF_SIZE]; std::uint8_t buf[BUF_SIZE];
std::uint32_t buf_pos = 0; std::uint32_t buf_pos = 0;
auto flush = [&]() { auto flush = [&]() {
if (buf_pos > 0) if (buf_pos > 0)
{ {
stream.write(reinterpret_cast<char*>(buf), buf_pos); stream.write(reinterpret_cast<char*>(buf), buf_pos);
buf_pos = 0; buf_pos = 0;
} }
}; };
auto write = [&](std::uint8_t byte) { auto write = [&](std::uint8_t byte) {
buf[buf_pos++] = byte; buf[buf_pos++] = byte;
if (buf_pos == BUF_SIZE) if (buf_pos == BUF_SIZE)
flush(); flush();
}; };
for (std::size_t i = 0; i < count; ++i) for (std::size_t i = 0; i < count; ++i)
{ {
IntType value = values[i]; IntType value = values[i];
while (true) while (true)
{ {
std::uint8_t byte = value & 0x7f; std::uint8_t byte = value & 0x7f;
value >>= 7; value >>= 7;
if ((byte & 0x40) == 0 ? value == 0 : value == -1) if ((byte & 0x40) == 0 ? value == 0 : value == -1)
{ {
write(byte); write(byte);
break; break;
} }
write(byte | 0x80); write(byte | 0x80);
} }
} }
flush(); flush();
} }
} // namespace Stockfish::Eval::NNUE } // namespace Stockfish::Eval::NNUE
#endif // #ifndef NNUE_COMMON_H_INCLUDED #endif // #ifndef NNUE_COMMON_H_INCLUDED

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -37,27 +37,27 @@ namespace Stockfish {
struct StateInfo { struct StateInfo {
// Copied when making a move // Copied when making a move
Key materialKey; Key materialKey;
Value nonPawnMaterial[COLOR_NB]; Value nonPawnMaterial[COLOR_NB];
int castlingRights; int castlingRights;
int rule50; int rule50;
int pliesFromNull; int pliesFromNull;
Square epSquare; Square epSquare;
// Not copied when making a move (will be recomputed anyhow) // Not copied when making a move (will be recomputed anyhow)
Key key; Key key;
Bitboard checkersBB; Bitboard checkersBB;
StateInfo* previous; StateInfo* previous;
Bitboard blockersForKing[COLOR_NB]; Bitboard blockersForKing[COLOR_NB];
Bitboard pinners[COLOR_NB]; Bitboard pinners[COLOR_NB];
Bitboard checkSquares[PIECE_TYPE_NB]; Bitboard checkSquares[PIECE_TYPE_NB];
Piece capturedPiece; Piece capturedPiece;
int repetition; int repetition;
// Used by NNUE // Used by NNUE
Eval::NNUE::Accumulator accumulator; Eval::NNUE::Accumulator accumulator;
DirtyPiece dirtyPiece; DirtyPiece dirtyPiece;
}; };
@ -75,329 +75,290 @@ using StateListPtr = std::unique_ptr<std::deque<StateInfo>>;
class Thread; class Thread;
class Position { class Position {
public: public:
static void init(); static void init();
Position() = default; Position() = default;
Position(const Position&) = delete; Position(const Position&) = delete;
Position& operator=(const Position&) = delete; Position& operator=(const Position&) = delete;
// FEN string input/output // FEN string input/output
Position& set(const std::string& fenStr, bool isChess960, StateInfo* si, Thread* th); Position& set(const std::string& fenStr, bool isChess960, StateInfo* si, Thread* th);
Position& set(const std::string& code, Color c, StateInfo* si); Position& set(const std::string& code, Color c, StateInfo* si);
std::string fen() const; std::string fen() const;
// Position representation // Position representation
Bitboard pieces(PieceType pt = ALL_PIECES) const; Bitboard pieces(PieceType pt = ALL_PIECES) const;
template<typename ...PieceTypes> Bitboard pieces(PieceType pt, PieceTypes... pts) const; template<typename... PieceTypes>
Bitboard pieces(Color c) const; Bitboard pieces(PieceType pt, PieceTypes... pts) const;
template<typename ...PieceTypes> Bitboard pieces(Color c, PieceTypes... pts) const; Bitboard pieces(Color c) const;
Piece piece_on(Square s) const; template<typename... PieceTypes>
Square ep_square() const; Bitboard pieces(Color c, PieceTypes... pts) const;
bool empty(Square s) const; Piece piece_on(Square s) const;
template<PieceType Pt> int count(Color c) const; Square ep_square() const;
template<PieceType Pt> int count() const; bool empty(Square s) const;
template<PieceType Pt> Square square(Color c) const; template<PieceType Pt>
int count(Color c) const;
template<PieceType Pt>
int count() const;
template<PieceType Pt>
Square square(Color c) const;
// Castling // Castling
CastlingRights castling_rights(Color c) const; CastlingRights castling_rights(Color c) const;
bool can_castle(CastlingRights cr) const; bool can_castle(CastlingRights cr) const;
bool castling_impeded(CastlingRights cr) const; bool castling_impeded(CastlingRights cr) const;
Square castling_rook_square(CastlingRights cr) const; Square castling_rook_square(CastlingRights cr) const;
// Checking // Checking
Bitboard checkers() const; Bitboard checkers() const;
Bitboard blockers_for_king(Color c) const; Bitboard blockers_for_king(Color c) const;
Bitboard check_squares(PieceType pt) const; Bitboard check_squares(PieceType pt) const;
Bitboard pinners(Color c) const; Bitboard pinners(Color c) const;
// Attacks to/from a given square // Attacks to/from a given square
Bitboard attackers_to(Square s) const; Bitboard attackers_to(Square s) const;
Bitboard attackers_to(Square s, Bitboard occupied) const; Bitboard attackers_to(Square s, Bitboard occupied) const;
void update_slider_blockers(Color c) const; void update_slider_blockers(Color c) const;
template<PieceType Pt> Bitboard attacks_by(Color c) const; template<PieceType Pt>
Bitboard attacks_by(Color c) const;
// Properties of moves // Properties of moves
bool legal(Move m) const; bool legal(Move m) const;
bool pseudo_legal(const Move m) const; bool pseudo_legal(const Move m) const;
bool capture(Move m) const; bool capture(Move m) const;
bool capture_stage(Move m) const; bool capture_stage(Move m) const;
bool gives_check(Move m) const; bool gives_check(Move m) const;
Piece moved_piece(Move m) const; Piece moved_piece(Move m) const;
Piece captured_piece() const; Piece captured_piece() const;
// Doing and undoing moves // Doing and undoing moves
void do_move(Move m, StateInfo& newSt); void do_move(Move m, StateInfo& newSt);
void do_move(Move m, StateInfo& newSt, bool givesCheck); void do_move(Move m, StateInfo& newSt, bool givesCheck);
void undo_move(Move m); void undo_move(Move m);
void do_null_move(StateInfo& newSt); void do_null_move(StateInfo& newSt);
void undo_null_move(); void undo_null_move();
// Static Exchange Evaluation // Static Exchange Evaluation
bool see_ge(Move m, Value threshold = VALUE_ZERO) const; bool see_ge(Move m, Value threshold = VALUE_ZERO) const;
// Accessing hash keys // Accessing hash keys
Key key() const; Key key() const;
Key key_after(Move m) const; Key key_after(Move m) const;
Key material_key() const; Key material_key() const;
// Other properties of the position // Other properties of the position
Color side_to_move() const; Color side_to_move() const;
int game_ply() const; int game_ply() const;
bool is_chess960() const; bool is_chess960() const;
Thread* this_thread() const; Thread* this_thread() const;
bool is_draw(int ply) const; bool is_draw(int ply) const;
bool has_game_cycle(int ply) const; bool has_game_cycle(int ply) const;
bool has_repeated() const; bool has_repeated() const;
int rule50_count() const; int rule50_count() const;
Value non_pawn_material(Color c) const; Value non_pawn_material(Color c) const;
Value non_pawn_material() const; Value non_pawn_material() const;
// Position consistency check, for debugging // Position consistency check, for debugging
bool pos_is_ok() const; bool pos_is_ok() const;
void flip(); void flip();
// Used by NNUE // Used by NNUE
StateInfo* state() const; StateInfo* state() const;
void put_piece(Piece pc, Square s); void put_piece(Piece pc, Square s);
void remove_piece(Square s); void remove_piece(Square s);
private: private:
// Initialization helpers (used while setting up a position) // Initialization helpers (used while setting up a position)
void set_castling_right(Color c, Square rfrom); void set_castling_right(Color c, Square rfrom);
void set_state() const; void set_state() const;
void set_check_info() const; void set_check_info() const;
// Other helpers // Other helpers
void move_piece(Square from, Square to); void move_piece(Square from, Square to);
template<bool Do> template<bool Do>
void do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto); void do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto);
template<bool AfterMove> template<bool AfterMove>
Key adjust_key50(Key k) const; Key adjust_key50(Key k) const;
// Data members // Data members
Piece board[SQUARE_NB]; Piece board[SQUARE_NB];
Bitboard byTypeBB[PIECE_TYPE_NB]; Bitboard byTypeBB[PIECE_TYPE_NB];
Bitboard byColorBB[COLOR_NB]; Bitboard byColorBB[COLOR_NB];
int pieceCount[PIECE_NB]; int pieceCount[PIECE_NB];
int castlingRightsMask[SQUARE_NB]; int castlingRightsMask[SQUARE_NB];
Square castlingRookSquare[CASTLING_RIGHT_NB]; Square castlingRookSquare[CASTLING_RIGHT_NB];
Bitboard castlingPath[CASTLING_RIGHT_NB]; Bitboard castlingPath[CASTLING_RIGHT_NB];
Thread* thisThread; Thread* thisThread;
StateInfo* st; StateInfo* st;
int gamePly; int gamePly;
Color sideToMove; Color sideToMove;
bool chess960; bool chess960;
}; };
std::ostream& operator<<(std::ostream& os, const Position& pos); std::ostream& operator<<(std::ostream& os, const Position& pos);
inline Color Position::side_to_move() const { inline Color Position::side_to_move() const { return sideToMove; }
return sideToMove;
}
inline Piece Position::piece_on(Square s) const { inline Piece Position::piece_on(Square s) const {
assert(is_ok(s)); assert(is_ok(s));
return board[s]; return board[s];
} }
inline bool Position::empty(Square s) const { inline bool Position::empty(Square s) const { return piece_on(s) == NO_PIECE; }
return piece_on(s) == NO_PIECE;
}
inline Piece Position::moved_piece(Move m) const { inline Piece Position::moved_piece(Move m) const { return piece_on(from_sq(m)); }
return piece_on(from_sq(m));
}
inline Bitboard Position::pieces(PieceType pt) const { inline Bitboard Position::pieces(PieceType pt) const { return byTypeBB[pt]; }
return byTypeBB[pt];
}
template<typename ...PieceTypes> template<typename... PieceTypes>
inline Bitboard Position::pieces(PieceType pt, PieceTypes... pts) const { inline Bitboard Position::pieces(PieceType pt, PieceTypes... pts) const {
return pieces(pt) | pieces(pts...); return pieces(pt) | pieces(pts...);
} }
inline Bitboard Position::pieces(Color c) const { inline Bitboard Position::pieces(Color c) const { return byColorBB[c]; }
return byColorBB[c];
}
template<typename ...PieceTypes> template<typename... PieceTypes>
inline Bitboard Position::pieces(Color c, PieceTypes... pts) const { inline Bitboard Position::pieces(Color c, PieceTypes... pts) const {
return pieces(c) & pieces(pts...); return pieces(c) & pieces(pts...);
} }
template<PieceType Pt> inline int Position::count(Color c) const { template<PieceType Pt>
return pieceCount[make_piece(c, Pt)]; inline int Position::count(Color c) const {
return pieceCount[make_piece(c, Pt)];
} }
template<PieceType Pt> inline int Position::count() const { template<PieceType Pt>
return count<Pt>(WHITE) + count<Pt>(BLACK); inline int Position::count() const {
return count<Pt>(WHITE) + count<Pt>(BLACK);
} }
template<PieceType Pt> inline Square Position::square(Color c) const { template<PieceType Pt>
assert(count<Pt>(c) == 1); inline Square Position::square(Color c) const {
return lsb(pieces(c, Pt)); assert(count<Pt>(c) == 1);
return lsb(pieces(c, Pt));
} }
inline Square Position::ep_square() const { inline Square Position::ep_square() const { return st->epSquare; }
return st->epSquare;
}
inline bool Position::can_castle(CastlingRights cr) const { inline bool Position::can_castle(CastlingRights cr) const { return st->castlingRights & cr; }
return st->castlingRights & cr;
}
inline CastlingRights Position::castling_rights(Color c) const { inline CastlingRights Position::castling_rights(Color c) const {
return c & CastlingRights(st->castlingRights); return c & CastlingRights(st->castlingRights);
} }
inline bool Position::castling_impeded(CastlingRights cr) const { inline bool Position::castling_impeded(CastlingRights cr) const {
assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO); assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO);
return pieces() & castlingPath[cr]; return pieces() & castlingPath[cr];
} }
inline Square Position::castling_rook_square(CastlingRights cr) const { inline Square Position::castling_rook_square(CastlingRights cr) const {
assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO); assert(cr == WHITE_OO || cr == WHITE_OOO || cr == BLACK_OO || cr == BLACK_OOO);
return castlingRookSquare[cr]; return castlingRookSquare[cr];
} }
inline Bitboard Position::attackers_to(Square s) const { inline Bitboard Position::attackers_to(Square s) const { return attackers_to(s, pieces()); }
return attackers_to(s, pieces());
}
template<PieceType Pt> template<PieceType Pt>
inline Bitboard Position::attacks_by(Color c) const { inline Bitboard Position::attacks_by(Color c) const {
if constexpr (Pt == PAWN) if constexpr (Pt == PAWN)
return c == WHITE ? pawn_attacks_bb<WHITE>(pieces(WHITE, PAWN)) return c == WHITE ? pawn_attacks_bb<WHITE>(pieces(WHITE, PAWN))
: pawn_attacks_bb<BLACK>(pieces(BLACK, PAWN)); : pawn_attacks_bb<BLACK>(pieces(BLACK, PAWN));
else else
{ {
Bitboard threats = 0; Bitboard threats = 0;
Bitboard attackers = pieces(c, Pt); Bitboard attackers = pieces(c, Pt);
while (attackers) while (attackers)
threats |= attacks_bb<Pt>(pop_lsb(attackers), pieces()); threats |= attacks_bb<Pt>(pop_lsb(attackers), pieces());
return threats; return threats;
} }
} }
inline Bitboard Position::checkers() const { inline Bitboard Position::checkers() const { return st->checkersBB; }
return st->checkersBB;
}
inline Bitboard Position::blockers_for_king(Color c) const { inline Bitboard Position::blockers_for_king(Color c) const { return st->blockersForKing[c]; }
return st->blockersForKing[c];
}
inline Bitboard Position::pinners(Color c) const { inline Bitboard Position::pinners(Color c) const { return st->pinners[c]; }
return st->pinners[c];
}
inline Bitboard Position::check_squares(PieceType pt) const { inline Bitboard Position::check_squares(PieceType pt) const { return st->checkSquares[pt]; }
return st->checkSquares[pt];
}
inline Key Position::key() const { inline Key Position::key() const { return adjust_key50<false>(st->key); }
return adjust_key50<false>(st->key);
}
template<bool AfterMove> template<bool AfterMove>
inline Key Position::adjust_key50(Key k) const inline Key Position::adjust_key50(Key k) const {
{ return st->rule50 < 14 - AfterMove ? k : k ^ make_key((st->rule50 - (14 - AfterMove)) / 8);
return st->rule50 < 14 - AfterMove
? k : k ^ make_key((st->rule50 - (14 - AfterMove)) / 8);
} }
inline Key Position::material_key() const { inline Key Position::material_key() const { return st->materialKey; }
return st->materialKey;
}
inline Value Position::non_pawn_material(Color c) const { inline Value Position::non_pawn_material(Color c) const { return st->nonPawnMaterial[c]; }
return st->nonPawnMaterial[c];
}
inline Value Position::non_pawn_material() const { inline Value Position::non_pawn_material() const {
return non_pawn_material(WHITE) + non_pawn_material(BLACK); return non_pawn_material(WHITE) + non_pawn_material(BLACK);
} }
inline int Position::game_ply() const { inline int Position::game_ply() const { return gamePly; }
return gamePly;
}
inline int Position::rule50_count() const { inline int Position::rule50_count() const { return st->rule50; }
return st->rule50;
}
inline bool Position::is_chess960() const { inline bool Position::is_chess960() const { return chess960; }
return chess960;
}
inline bool Position::capture(Move m) const { inline bool Position::capture(Move m) const {
assert(is_ok(m)); assert(is_ok(m));
return (!empty(to_sq(m)) && type_of(m) != CASTLING) return (!empty(to_sq(m)) && type_of(m) != CASTLING) || type_of(m) == EN_PASSANT;
|| type_of(m) == EN_PASSANT;
} }
// Returns true if a move is generated from the capture stage, having also // Returns true if a move is generated from the capture stage, having also
// queen promotions covered, i.e. consistency with the capture stage move generation // queen promotions covered, i.e. consistency with the capture stage move generation
// is needed to avoid the generation of duplicate moves. // is needed to avoid the generation of duplicate moves.
inline bool Position::capture_stage(Move m) const { inline bool Position::capture_stage(Move m) const {
assert(is_ok(m)); assert(is_ok(m));
return capture(m) || promotion_type(m) == QUEEN; return capture(m) || promotion_type(m) == QUEEN;
} }
inline Piece Position::captured_piece() const { inline Piece Position::captured_piece() const { return st->capturedPiece; }
return st->capturedPiece;
}
inline Thread* Position::this_thread() const { inline Thread* Position::this_thread() const { return thisThread; }
return thisThread;
}
inline void Position::put_piece(Piece pc, Square s) { inline void Position::put_piece(Piece pc, Square s) {
board[s] = pc; board[s] = pc;
byTypeBB[ALL_PIECES] |= byTypeBB[type_of(pc)] |= s; byTypeBB[ALL_PIECES] |= byTypeBB[type_of(pc)] |= s;
byColorBB[color_of(pc)] |= s; byColorBB[color_of(pc)] |= s;
pieceCount[pc]++; pieceCount[pc]++;
pieceCount[make_piece(color_of(pc), ALL_PIECES)]++; pieceCount[make_piece(color_of(pc), ALL_PIECES)]++;
} }
inline void Position::remove_piece(Square s) { inline void Position::remove_piece(Square s) {
Piece pc = board[s]; Piece pc = board[s];
byTypeBB[ALL_PIECES] ^= s; byTypeBB[ALL_PIECES] ^= s;
byTypeBB[type_of(pc)] ^= s; byTypeBB[type_of(pc)] ^= s;
byColorBB[color_of(pc)] ^= s; byColorBB[color_of(pc)] ^= s;
board[s] = NO_PIECE; board[s] = NO_PIECE;
pieceCount[pc]--; pieceCount[pc]--;
pieceCount[make_piece(color_of(pc), ALL_PIECES)]--; pieceCount[make_piece(color_of(pc), ALL_PIECES)]--;
} }
inline void Position::move_piece(Square from, Square to) { inline void Position::move_piece(Square from, Square to) {
Piece pc = board[from]; Piece pc = board[from];
Bitboard fromTo = from | to; Bitboard fromTo = from | to;
byTypeBB[ALL_PIECES] ^= fromTo; byTypeBB[ALL_PIECES] ^= fromTo;
byTypeBB[type_of(pc)] ^= fromTo; byTypeBB[type_of(pc)] ^= fromTo;
byColorBB[color_of(pc)] ^= fromTo; byColorBB[color_of(pc)] ^= fromTo;
board[from] = NO_PIECE; board[from] = NO_PIECE;
board[to] = pc; board[to] = pc;
} }
inline void Position::do_move(Move m, StateInfo& newSt) { inline void Position::do_move(Move m, StateInfo& newSt) { do_move(m, newSt, gives_check(m)); }
do_move(m, newSt, gives_check(m));
}
inline StateInfo* Position::state() const { inline StateInfo* Position::state() const { return st; }
return st; } // namespace Stockfish
}
} // namespace Stockfish #endif // #ifndef POSITION_H_INCLUDED
#endif // #ifndef POSITION_H_INCLUDED

File diff suppressed because it is too large Load diff

View file

@ -38,20 +38,20 @@ namespace Search {
// its own array of Stack objects, indexed by the current ply. // its own array of Stack objects, indexed by the current ply.
struct Stack { struct Stack {
Move* pv; Move* pv;
PieceToHistory* continuationHistory; PieceToHistory* continuationHistory;
int ply; int ply;
Move currentMove; Move currentMove;
Move excludedMove; Move excludedMove;
Move killers[2]; Move killers[2];
Value staticEval; Value staticEval;
int statScore; int statScore;
int moveCount; int moveCount;
bool inCheck; bool inCheck;
bool ttPv; bool ttPv;
bool ttHit; bool ttHit;
int doubleExtensions; int doubleExtensions;
int cutoffCnt; int cutoffCnt;
}; };
@ -61,24 +61,24 @@ struct Stack {
struct RootMove { struct RootMove {
explicit RootMove(Move m) : pv(1, m) {} explicit RootMove(Move m) :
bool extract_ponder_from_tt(Position& pos); pv(1, m) {}
bool operator==(const Move& m) const { return pv[0] == m; } bool extract_ponder_from_tt(Position& pos);
bool operator<(const RootMove& m) const { // Sort in descending order bool operator==(const Move& m) const { return pv[0] == m; }
return m.score != score ? m.score < score bool operator<(const RootMove& m) const { // Sort in descending order
: m.previousScore < previousScore; return m.score != score ? m.score < score : m.previousScore < previousScore;
} }
Value score = -VALUE_INFINITE; Value score = -VALUE_INFINITE;
Value previousScore = -VALUE_INFINITE; Value previousScore = -VALUE_INFINITE;
Value averageScore = -VALUE_INFINITE; Value averageScore = -VALUE_INFINITE;
Value uciScore = -VALUE_INFINITE; Value uciScore = -VALUE_INFINITE;
bool scoreLowerbound = false; bool scoreLowerbound = false;
bool scoreUpperbound = false; bool scoreUpperbound = false;
int selDepth = 0; int selDepth = 0;
int tbRank = 0; int tbRank = 0;
Value tbScore; Value tbScore;
std::vector<Move> pv; std::vector<Move> pv;
}; };
using RootMoves = std::vector<RootMove>; using RootMoves = std::vector<RootMove>;
@ -89,20 +89,18 @@ using RootMoves = std::vector<RootMove>;
struct LimitsType { struct LimitsType {
LimitsType() { // Init explicitly due to broken value-initialization of non POD in MSVC LimitsType() { // Init explicitly due to broken value-initialization of non POD in MSVC
time[WHITE] = time[BLACK] = inc[WHITE] = inc[BLACK] = npmsec = movetime = TimePoint(0); time[WHITE] = time[BLACK] = inc[WHITE] = inc[BLACK] = npmsec = movetime = TimePoint(0);
movestogo = depth = mate = perft = infinite = 0; movestogo = depth = mate = perft = infinite = 0;
nodes = 0; nodes = 0;
} }
bool use_time_management() const { bool use_time_management() const { return time[WHITE] || time[BLACK]; }
return time[WHITE] || time[BLACK];
}
std::vector<Move> searchmoves; std::vector<Move> searchmoves;
TimePoint time[COLOR_NB], inc[COLOR_NB], npmsec, movetime, startTime; TimePoint time[COLOR_NB], inc[COLOR_NB], npmsec, movetime, startTime;
int movestogo, depth, mate, perft, infinite; int movestogo, depth, mate, perft, infinite;
int64_t nodes; int64_t nodes;
}; };
extern LimitsType Limits; extern LimitsType Limits;
@ -110,8 +108,8 @@ extern LimitsType Limits;
void init(); void init();
void clear(); void clear();
} // namespace Search } // namespace Search
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef SEARCH_H_INCLUDED #endif // #ifndef SEARCH_H_INCLUDED

File diff suppressed because it is too large Load diff

View file

@ -30,30 +30,30 @@ class Position;
namespace Stockfish::Tablebases { namespace Stockfish::Tablebases {
enum WDLScore { enum WDLScore {
WDLLoss = -2, // Loss WDLLoss = -2, // Loss
WDLBlessedLoss = -1, // Loss, but draw under 50-move rule WDLBlessedLoss = -1, // Loss, but draw under 50-move rule
WDLDraw = 0, // Draw WDLDraw = 0, // Draw
WDLCursedWin = 1, // Win, but draw under 50-move rule WDLCursedWin = 1, // Win, but draw under 50-move rule
WDLWin = 2, // Win WDLWin = 2, // Win
}; };
// Possible states after a probing operation // Possible states after a probing operation
enum ProbeState { enum ProbeState {
FAIL = 0, // Probe failed (missing file table) FAIL = 0, // Probe failed (missing file table)
OK = 1, // Probe successful OK = 1, // Probe successful
CHANGE_STM = -1, // DTZ should check the other side CHANGE_STM = -1, // DTZ should check the other side
ZEROING_BEST_MOVE = 2 // Best move zeroes DTZ (capture or pawn move) ZEROING_BEST_MOVE = 2 // Best move zeroes DTZ (capture or pawn move)
}; };
extern int MaxCardinality; extern int MaxCardinality;
void init(const std::string& paths); void init(const std::string& paths);
WDLScore probe_wdl(Position& pos, ProbeState* result); WDLScore probe_wdl(Position& pos, ProbeState* result);
int probe_dtz(Position& pos, ProbeState* result); int probe_dtz(Position& pos, ProbeState* result);
bool root_probe(Position& pos, Search::RootMoves& rootMoves); bool root_probe(Position& pos, Search::RootMoves& rootMoves);
bool root_probe_wdl(Position& pos, Search::RootMoves& rootMoves); bool root_probe_wdl(Position& pos, Search::RootMoves& rootMoves);
void rank_root_moves(Position& pos, Search::RootMoves& rootMoves); void rank_root_moves(Position& pos, Search::RootMoves& rootMoves);
} // namespace Stockfish::Tablebases } // namespace Stockfish::Tablebases
#endif #endif

View file

@ -37,15 +37,17 @@
namespace Stockfish { namespace Stockfish {
ThreadPool Threads; // Global object ThreadPool Threads; // Global object
// Thread constructor launches the thread and waits until it goes to sleep // Thread constructor launches the thread and waits until it goes to sleep
// in idle_loop(). Note that 'searching' and 'exit' should be already set. // in idle_loop(). Note that 'searching' and 'exit' should be already set.
Thread::Thread(size_t n) : idx(n), stdThread(&Thread::idle_loop, this) { Thread::Thread(size_t n) :
idx(n),
stdThread(&Thread::idle_loop, this) {
wait_for_search_finished(); wait_for_search_finished();
} }
@ -54,11 +56,11 @@ Thread::Thread(size_t n) : idx(n), stdThread(&Thread::idle_loop, this) {
Thread::~Thread() { Thread::~Thread() {
assert(!searching); assert(!searching);
exit = true; exit = true;
start_searching(); start_searching();
stdThread.join(); stdThread.join();
} }
@ -66,25 +68,25 @@ Thread::~Thread() {
void Thread::clear() { void Thread::clear() {
counterMoves.fill(MOVE_NONE); counterMoves.fill(MOVE_NONE);
mainHistory.fill(0); mainHistory.fill(0);
captureHistory.fill(0); captureHistory.fill(0);
for (bool inCheck : { false, true }) for (bool inCheck : {false, true})
for (StatsType c : { NoCaptures, Captures }) for (StatsType c : {NoCaptures, Captures})
for (auto& to : continuationHistory[inCheck][c]) for (auto& to : continuationHistory[inCheck][c])
for (auto& h : to) for (auto& h : to)
h->fill(-71); h->fill(-71);
} }
// Thread::start_searching() wakes up the thread that will start the search // Thread::start_searching() wakes up the thread that will start the search
void Thread::start_searching() { void Thread::start_searching() {
mutex.lock(); mutex.lock();
searching = true; searching = true;
mutex.unlock(); // Unlock before notifying saves a few CPU-cycles mutex.unlock(); // Unlock before notifying saves a few CPU-cycles
cv.notify_one(); // Wake up the thread in idle_loop() cv.notify_one(); // Wake up the thread in idle_loop()
} }
@ -93,8 +95,8 @@ void Thread::start_searching() {
void Thread::wait_for_search_finished() { void Thread::wait_for_search_finished() {
std::unique_lock<std::mutex> lk(mutex); std::unique_lock<std::mutex> lk(mutex);
cv.wait(lk, [&]{ return !searching; }); cv.wait(lk, [&] { return !searching; });
} }
@ -103,28 +105,28 @@ void Thread::wait_for_search_finished() {
void Thread::idle_loop() { void Thread::idle_loop() {
// If OS already scheduled us on a different group than 0 then don't overwrite // If OS already scheduled us on a different group than 0 then don't overwrite
// the choice, eventually we are one of many one-threaded processes running on // the choice, eventually we are one of many one-threaded processes running on
// some Windows NUMA hardware, for instance in fishtest. To make it simple, // some Windows NUMA hardware, for instance in fishtest. To make it simple,
// just check if running threads are below a threshold, in this case, all this // just check if running threads are below a threshold, in this case, all this
// NUMA machinery is not needed. // NUMA machinery is not needed.
if (Options["Threads"] > 8) if (Options["Threads"] > 8)
WinProcGroup::bindThisThread(idx); WinProcGroup::bindThisThread(idx);
while (true) while (true)
{ {
std::unique_lock<std::mutex> lk(mutex); std::unique_lock<std::mutex> lk(mutex);
searching = false; searching = false;
cv.notify_one(); // Wake up anyone waiting for search finished cv.notify_one(); // Wake up anyone waiting for search finished
cv.wait(lk, [&]{ return searching; }); cv.wait(lk, [&] { return searching; });
if (exit) if (exit)
return; return;
lk.unlock(); lk.unlock();
search(); search();
} }
} }
// ThreadPool::set() creates/destroys threads to match the requested number. // ThreadPool::set() creates/destroys threads to match the requested number.
@ -133,28 +135,28 @@ void Thread::idle_loop() {
void ThreadPool::set(size_t requested) { void ThreadPool::set(size_t requested) {
if (threads.size() > 0) // destroy any existing thread(s) if (threads.size() > 0) // destroy any existing thread(s)
{ {
main()->wait_for_search_finished(); main()->wait_for_search_finished();
while (threads.size() > 0) while (threads.size() > 0)
delete threads.back(), threads.pop_back(); delete threads.back(), threads.pop_back();
} }
if (requested > 0) // create new thread(s) if (requested > 0) // create new thread(s)
{ {
threads.push_back(new MainThread(0)); threads.push_back(new MainThread(0));
while (threads.size() < requested) while (threads.size() < requested)
threads.push_back(new Thread(threads.size())); threads.push_back(new Thread(threads.size()));
clear(); clear();
// Reallocate the hash with the new threadpool size // Reallocate the hash with the new threadpool size
TT.resize(size_t(Options["Hash"])); TT.resize(size_t(Options["Hash"]));
// Init thread number dependent search params. // Init thread number dependent search params.
Search::init(); Search::init();
} }
} }
@ -162,77 +164,79 @@ void ThreadPool::set(size_t requested) {
void ThreadPool::clear() { void ThreadPool::clear() {
for (Thread* th : threads) for (Thread* th : threads)
th->clear(); th->clear();
main()->callsCnt = 0; main()->callsCnt = 0;
main()->bestPreviousScore = VALUE_INFINITE; main()->bestPreviousScore = VALUE_INFINITE;
main()->bestPreviousAverageScore = VALUE_INFINITE; main()->bestPreviousAverageScore = VALUE_INFINITE;
main()->previousTimeReduction = 1.0; main()->previousTimeReduction = 1.0;
} }
// ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and // ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and
// returns immediately. Main thread will wake up other threads and start the search. // returns immediately. Main thread will wake up other threads and start the search.
void ThreadPool::start_thinking(Position& pos, StateListPtr& states, void ThreadPool::start_thinking(Position& pos,
const Search::LimitsType& limits, bool ponderMode) { StateListPtr& states,
const Search::LimitsType& limits,
bool ponderMode) {
main()->wait_for_search_finished(); main()->wait_for_search_finished();
main()->stopOnPonderhit = stop = false; main()->stopOnPonderhit = stop = false;
increaseDepth = true; increaseDepth = true;
main()->ponder = ponderMode; main()->ponder = ponderMode;
Search::Limits = limits; Search::Limits = limits;
Search::RootMoves rootMoves; Search::RootMoves rootMoves;
for (const auto& m : MoveList<LEGAL>(pos)) for (const auto& m : MoveList<LEGAL>(pos))
if ( limits.searchmoves.empty() if (limits.searchmoves.empty()
|| std::count(limits.searchmoves.begin(), limits.searchmoves.end(), m)) || std::count(limits.searchmoves.begin(), limits.searchmoves.end(), m))
rootMoves.emplace_back(m); rootMoves.emplace_back(m);
if (!rootMoves.empty()) if (!rootMoves.empty())
Tablebases::rank_root_moves(pos, rootMoves); Tablebases::rank_root_moves(pos, rootMoves);
// After ownership transfer 'states' becomes empty, so if we stop the search // After ownership transfer 'states' becomes empty, so if we stop the search
// and call 'go' again without setting a new position states.get() == nullptr. // and call 'go' again without setting a new position states.get() == nullptr.
assert(states.get() || setupStates.get()); assert(states.get() || setupStates.get());
if (states.get()) if (states.get())
setupStates = std::move(states); // Ownership transfer, states is now empty setupStates = std::move(states); // Ownership transfer, states is now empty
// We use Position::set() to set root position across threads. But there are // We use Position::set() to set root position across threads. But there are
// some StateInfo fields (previous, pliesFromNull, capturedPiece) that cannot // some StateInfo fields (previous, pliesFromNull, capturedPiece) that cannot
// be deduced from a fen string, so set() clears them and they are set from // be deduced from a fen string, so set() clears them and they are set from
// setupStates->back() later. The rootState is per thread, earlier states are shared // setupStates->back() later. The rootState is per thread, earlier states are shared
// since they are read-only. // since they are read-only.
for (Thread* th : threads) for (Thread* th : threads)
{ {
th->nodes = th->tbHits = th->nmpMinPly = th->bestMoveChanges = 0; th->nodes = th->tbHits = th->nmpMinPly = th->bestMoveChanges = 0;
th->rootDepth = th->completedDepth = 0; th->rootDepth = th->completedDepth = 0;
th->rootMoves = rootMoves; th->rootMoves = rootMoves;
th->rootPos.set(pos.fen(), pos.is_chess960(), &th->rootState, th); th->rootPos.set(pos.fen(), pos.is_chess960(), &th->rootState, th);
th->rootState = setupStates->back(); th->rootState = setupStates->back();
th->rootSimpleEval = Eval::simple_eval(pos, pos.side_to_move()); th->rootSimpleEval = Eval::simple_eval(pos, pos.side_to_move());
} }
main()->start_searching(); main()->start_searching();
} }
Thread* ThreadPool::get_best_thread() const { Thread* ThreadPool::get_best_thread() const {
Thread* bestThread = threads.front(); Thread* bestThread = threads.front();
std::map<Move, int64_t> votes; std::map<Move, int64_t> votes;
Value minScore = VALUE_NONE; Value minScore = VALUE_NONE;
// Find the minimum score of all threads // Find the minimum score of all threads
for (Thread* th: threads) for (Thread* th : threads)
minScore = std::min(minScore, th->rootMoves[0].score); minScore = std::min(minScore, th->rootMoves[0].score);
// Vote according to score and depth, and select the best thread // Vote according to score and depth, and select the best thread
auto thread_value = [minScore](Thread* th) { auto thread_value = [minScore](Thread* th) {
return (th->rootMoves[0].score - minScore + 14) * int(th->completedDepth); return (th->rootMoves[0].score - minScore + 14) * int(th->completedDepth);
}; };
for (Thread* th : threads) for (Thread* th : threads)
votes[th->rootMoves[0].pv[0]] += thread_value(th); votes[th->rootMoves[0].pv[0]] += thread_value(th);
@ -244,12 +248,13 @@ Thread* ThreadPool::get_best_thread() const {
if (th->rootMoves[0].score > bestThread->rootMoves[0].score) if (th->rootMoves[0].score > bestThread->rootMoves[0].score)
bestThread = th; bestThread = th;
} }
else if ( th->rootMoves[0].score >= VALUE_TB_WIN_IN_MAX_PLY else if (th->rootMoves[0].score >= VALUE_TB_WIN_IN_MAX_PLY
|| ( th->rootMoves[0].score > VALUE_TB_LOSS_IN_MAX_PLY || (th->rootMoves[0].score > VALUE_TB_LOSS_IN_MAX_PLY
&& ( votes[th->rootMoves[0].pv[0]] > votes[bestThread->rootMoves[0].pv[0]] && (votes[th->rootMoves[0].pv[0]] > votes[bestThread->rootMoves[0].pv[0]]
|| ( votes[th->rootMoves[0].pv[0]] == votes[bestThread->rootMoves[0].pv[0]] || (votes[th->rootMoves[0].pv[0]] == votes[bestThread->rootMoves[0].pv[0]]
&& thread_value(th) * int(th->rootMoves[0].pv.size() > 2) && thread_value(th) * int(th->rootMoves[0].pv.size() > 2)
> thread_value(bestThread) * int(bestThread->rootMoves[0].pv.size() > 2))))) > thread_value(bestThread)
* int(bestThread->rootMoves[0].pv.size() > 2)))))
bestThread = th; bestThread = th;
return bestThread; return bestThread;
@ -275,4 +280,4 @@ void ThreadPool::wait_for_search_finished() const {
th->wait_for_search_finished(); th->wait_for_search_finished();
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -41,56 +41,56 @@ namespace Stockfish {
class Thread { class Thread {
std::mutex mutex; std::mutex mutex;
std::condition_variable cv; std::condition_variable cv;
size_t idx; size_t idx;
bool exit = false, searching = true; // Set before starting std::thread bool exit = false, searching = true; // Set before starting std::thread
NativeThread stdThread; NativeThread stdThread;
public: public:
explicit Thread(size_t); explicit Thread(size_t);
virtual ~Thread(); virtual ~Thread();
virtual void search(); virtual void search();
void clear(); void clear();
void idle_loop(); void idle_loop();
void start_searching(); void start_searching();
void wait_for_search_finished(); void wait_for_search_finished();
size_t id() const { return idx; } size_t id() const { return idx; }
size_t pvIdx, pvLast; size_t pvIdx, pvLast;
std::atomic<uint64_t> nodes, tbHits, bestMoveChanges; std::atomic<uint64_t> nodes, tbHits, bestMoveChanges;
int selDepth, nmpMinPly; int selDepth, nmpMinPly;
Value bestValue, optimism[COLOR_NB]; Value bestValue, optimism[COLOR_NB];
Position rootPos; Position rootPos;
StateInfo rootState; StateInfo rootState;
Search::RootMoves rootMoves; Search::RootMoves rootMoves;
Depth rootDepth, completedDepth; Depth rootDepth, completedDepth;
Value rootDelta; Value rootDelta;
Value rootSimpleEval; Value rootSimpleEval;
CounterMoveHistory counterMoves; CounterMoveHistory counterMoves;
ButterflyHistory mainHistory; ButterflyHistory mainHistory;
CapturePieceToHistory captureHistory; CapturePieceToHistory captureHistory;
ContinuationHistory continuationHistory[2][2]; ContinuationHistory continuationHistory[2][2];
}; };
// MainThread is a derived class specific for main thread // MainThread is a derived class specific for main thread
struct MainThread : public Thread { struct MainThread: public Thread {
using Thread::Thread; using Thread::Thread;
void search() override; void search() override;
void check_time(); void check_time();
double previousTimeReduction; double previousTimeReduction;
Value bestPreviousScore; Value bestPreviousScore;
Value bestPreviousAverageScore; Value bestPreviousAverageScore;
Value iterValue[4]; Value iterValue[4];
int callsCnt; int callsCnt;
bool stopOnPonderhit; bool stopOnPonderhit;
std::atomic_bool ponder; std::atomic_bool ponder;
}; };
@ -100,41 +100,41 @@ struct MainThread : public Thread {
struct ThreadPool { struct ThreadPool {
void start_thinking(Position&, StateListPtr&, const Search::LimitsType&, bool = false); void start_thinking(Position&, StateListPtr&, const Search::LimitsType&, bool = false);
void clear(); void clear();
void set(size_t); void set(size_t);
MainThread* main() const { return static_cast<MainThread*>(threads.front()); } MainThread* main() const { return static_cast<MainThread*>(threads.front()); }
uint64_t nodes_searched() const { return accumulate(&Thread::nodes); } uint64_t nodes_searched() const { return accumulate(&Thread::nodes); }
uint64_t tb_hits() const { return accumulate(&Thread::tbHits); } uint64_t tb_hits() const { return accumulate(&Thread::tbHits); }
Thread* get_best_thread() const; Thread* get_best_thread() const;
void start_searching(); void start_searching();
void wait_for_search_finished() const; void wait_for_search_finished() const;
std::atomic_bool stop, increaseDepth; std::atomic_bool stop, increaseDepth;
auto cbegin() const noexcept { return threads.cbegin(); } auto cbegin() const noexcept { return threads.cbegin(); }
auto begin() noexcept { return threads.begin(); } auto begin() noexcept { return threads.begin(); }
auto end() noexcept { return threads.end(); } auto end() noexcept { return threads.end(); }
auto cend() const noexcept { return threads.cend(); } auto cend() const noexcept { return threads.cend(); }
auto size() const noexcept { return threads.size(); } auto size() const noexcept { return threads.size(); }
auto empty() const noexcept { return threads.empty(); } auto empty() const noexcept { return threads.empty(); }
private: private:
StateListPtr setupStates; StateListPtr setupStates;
std::vector<Thread*> threads; std::vector<Thread*> threads;
uint64_t accumulate(std::atomic<uint64_t> Thread::* member) const { uint64_t accumulate(std::atomic<uint64_t> Thread::*member) const {
uint64_t sum = 0; uint64_t sum = 0;
for (Thread* th : threads) for (Thread* th : threads)
sum += (th->*member).load(std::memory_order_relaxed); sum += (th->*member).load(std::memory_order_relaxed);
return sum; return sum;
} }
}; };
extern ThreadPool Threads; extern ThreadPool Threads;
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef THREAD_H_INCLUDED #endif // #ifndef THREAD_H_INCLUDED

View file

@ -29,46 +29,45 @@
#if defined(__APPLE__) || defined(__MINGW32__) || defined(__MINGW64__) || defined(USE_PTHREADS) #if defined(__APPLE__) || defined(__MINGW32__) || defined(__MINGW64__) || defined(USE_PTHREADS)
#include <pthread.h> #include <pthread.h>
namespace Stockfish { namespace Stockfish {
static const size_t TH_STACK_SIZE = 8 * 1024 * 1024; static const size_t TH_STACK_SIZE = 8 * 1024 * 1024;
template <class T, class P = std::pair<T*, void(T::*)()>> template<class T, class P = std::pair<T*, void (T::*)()>>
void* start_routine(void* ptr) void* start_routine(void* ptr) {
{ P* p = reinterpret_cast<P*>(ptr);
P* p = reinterpret_cast<P*>(ptr); (p->first->*(p->second))(); // Call member function pointer
(p->first->*(p->second))(); // Call member function pointer delete p;
delete p; return nullptr;
return nullptr;
} }
class NativeThread { class NativeThread {
pthread_t thread; pthread_t thread;
public: public:
template<class T, class P = std::pair<T*, void(T::*)()>> template<class T, class P = std::pair<T*, void (T::*)()>>
explicit NativeThread(void(T::*fun)(), T* obj) { explicit NativeThread(void (T::*fun)(), T* obj) {
pthread_attr_t attr_storage, *attr = &attr_storage; pthread_attr_t attr_storage, *attr = &attr_storage;
pthread_attr_init(attr); pthread_attr_init(attr);
pthread_attr_setstacksize(attr, TH_STACK_SIZE); pthread_attr_setstacksize(attr, TH_STACK_SIZE);
pthread_create(&thread, attr, start_routine<T>, new P(obj, fun)); pthread_create(&thread, attr, start_routine<T>, new P(obj, fun));
} }
void join() { pthread_join(thread, nullptr); } void join() { pthread_join(thread, nullptr); }
}; };
} // namespace Stockfish } // namespace Stockfish
#else // Default case: use STL classes #else // Default case: use STL classes
namespace Stockfish { namespace Stockfish {
using NativeThread = std::thread; using NativeThread = std::thread;
} // namespace Stockfish } // namespace Stockfish
#endif #endif
#endif // #ifndef THREAD_WIN32_OSX_H_INCLUDED #endif // #ifndef THREAD_WIN32_OSX_H_INCLUDED

View file

@ -26,7 +26,7 @@
namespace Stockfish { namespace Stockfish {
TimeManagement Time; // Our global time management object TimeManagement Time; // Our global time management object
// TimeManagement::init() is called at the beginning of the search and calculates // TimeManagement::init() is called at the beginning of the search and calculates
@ -36,74 +36,74 @@ TimeManagement Time; // Our global time management object
void TimeManagement::init(Search::LimitsType& limits, Color us, int ply) { void TimeManagement::init(Search::LimitsType& limits, Color us, int ply) {
// If we have no time, no need to initialize TM, except for the start time, // If we have no time, no need to initialize TM, except for the start time,
// which is used by movetime. // which is used by movetime.
startTime = limits.startTime; startTime = limits.startTime;
if (limits.time[us] == 0) if (limits.time[us] == 0)
return; return;
TimePoint moveOverhead = TimePoint(Options["Move Overhead"]); TimePoint moveOverhead = TimePoint(Options["Move Overhead"]);
TimePoint slowMover = TimePoint(Options["Slow Mover"]); TimePoint slowMover = TimePoint(Options["Slow Mover"]);
TimePoint npmsec = TimePoint(Options["nodestime"]); TimePoint npmsec = TimePoint(Options["nodestime"]);
// optScale is a percentage of available time to use for the current move. // optScale is a percentage of available time to use for the current move.
// maxScale is a multiplier applied to optimumTime. // maxScale is a multiplier applied to optimumTime.
double optScale, maxScale; double optScale, maxScale;
// If we have to play in 'nodes as time' mode, then convert from time // If we have to play in 'nodes as time' mode, then convert from time
// to nodes, and use resulting values in time management formulas. // to nodes, and use resulting values in time management formulas.
// WARNING: to avoid time losses, the given npmsec (nodes per millisecond) // WARNING: to avoid time losses, the given npmsec (nodes per millisecond)
// must be much lower than the real engine speed. // must be much lower than the real engine speed.
if (npmsec) if (npmsec)
{ {
if (!availableNodes) // Only once at game start if (!availableNodes) // Only once at game start
availableNodes = npmsec * limits.time[us]; // Time is in msec availableNodes = npmsec * limits.time[us]; // Time is in msec
// Convert from milliseconds to nodes // Convert from milliseconds to nodes
limits.time[us] = TimePoint(availableNodes); limits.time[us] = TimePoint(availableNodes);
limits.inc[us] *= npmsec; limits.inc[us] *= npmsec;
limits.npmsec = npmsec; limits.npmsec = npmsec;
} }
// Maximum move horizon of 50 moves // Maximum move horizon of 50 moves
int mtg = limits.movestogo ? std::min(limits.movestogo, 50) : 50; int mtg = limits.movestogo ? std::min(limits.movestogo, 50) : 50;
// Make sure timeLeft is > 0 since we may use it as a divisor // Make sure timeLeft is > 0 since we may use it as a divisor
TimePoint timeLeft = std::max(TimePoint(1), TimePoint timeLeft = std::max(TimePoint(1), limits.time[us] + limits.inc[us] * (mtg - 1)
limits.time[us] + limits.inc[us] * (mtg - 1) - moveOverhead * (2 + mtg)); - moveOverhead * (2 + mtg));
// Use extra time with larger increments // Use extra time with larger increments
double optExtra = std::clamp(1.0 + 12.0 * limits.inc[us] / limits.time[us], 1.0, 1.12); double optExtra = std::clamp(1.0 + 12.0 * limits.inc[us] / limits.time[us], 1.0, 1.12);
// A user may scale time usage by setting UCI option "Slow Mover" // A user may scale time usage by setting UCI option "Slow Mover"
// Default is 100 and changing this value will probably lose elo. // Default is 100 and changing this value will probably lose elo.
timeLeft = slowMover * timeLeft / 100; timeLeft = slowMover * timeLeft / 100;
// x basetime (+ z increment) // x basetime (+ z increment)
// If there is a healthy increment, timeLeft can exceed actual available // If there is a healthy increment, timeLeft can exceed actual available
// game time for the current move, so also cap to 20% of available game time. // game time for the current move, so also cap to 20% of available game time.
if (limits.movestogo == 0) if (limits.movestogo == 0)
{ {
optScale = std::min(0.0120 + std::pow(ply + 3.0, 0.45) * 0.0039, optScale = std::min(0.0120 + std::pow(ply + 3.0, 0.45) * 0.0039,
0.2 * limits.time[us] / double(timeLeft)) 0.2 * limits.time[us] / double(timeLeft))
* optExtra; * optExtra;
maxScale = std::min(7.0, 4.0 + ply / 12.0); maxScale = std::min(7.0, 4.0 + ply / 12.0);
} }
// x moves in y seconds (+ z increment) // x moves in y seconds (+ z increment)
else else
{ {
optScale = std::min((0.88 + ply / 116.4) / mtg, optScale = std::min((0.88 + ply / 116.4) / mtg, 0.88 * limits.time[us] / double(timeLeft));
0.88 * limits.time[us] / double(timeLeft)); maxScale = std::min(6.3, 1.5 + 0.11 * mtg);
maxScale = std::min(6.3, 1.5 + 0.11 * mtg); }
}
// Never use more than 80% of the available time for this move // Never use more than 80% of the available time for this move
optimumTime = TimePoint(optScale * timeLeft); optimumTime = TimePoint(optScale * timeLeft);
maximumTime = TimePoint(std::min(0.8 * limits.time[us] - moveOverhead, maxScale * optimumTime)) - 10; maximumTime =
TimePoint(std::min(0.8 * limits.time[us] - moveOverhead, maxScale * optimumTime)) - 10;
if (Options["Ponder"]) if (Options["Ponder"])
optimumTime += optimumTime / 4; optimumTime += optimumTime / 4;
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -32,23 +32,24 @@ namespace Stockfish {
// the maximum available time, the game move number, and other parameters. // the maximum available time, the game move number, and other parameters.
class TimeManagement { class TimeManagement {
public: public:
void init(Search::LimitsType& limits, Color us, int ply); void init(Search::LimitsType& limits, Color us, int ply);
TimePoint optimum() const { return optimumTime; } TimePoint optimum() const { return optimumTime; }
TimePoint maximum() const { return maximumTime; } TimePoint maximum() const { return maximumTime; }
TimePoint elapsed() const { return Search::Limits.npmsec ? TimePoint elapsed() const {
TimePoint(Threads.nodes_searched()) : now() - startTime; } return Search::Limits.npmsec ? TimePoint(Threads.nodes_searched()) : now() - startTime;
}
int64_t availableNodes; // When in 'nodes as time' mode int64_t availableNodes; // When in 'nodes as time' mode
private: private:
TimePoint startTime; TimePoint startTime;
TimePoint optimumTime; TimePoint optimumTime;
TimePoint maximumTime; TimePoint maximumTime;
}; };
extern TimeManagement Time; extern TimeManagement Time;
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef TIMEMAN_H_INCLUDED #endif // #ifndef TIMEMAN_H_INCLUDED

View file

@ -31,31 +31,29 @@
namespace Stockfish { namespace Stockfish {
TranspositionTable TT; // Our global transposition table TranspositionTable TT; // Our global transposition table
// TTEntry::save() populates the TTEntry with a new node's data, possibly // TTEntry::save() populates the TTEntry with a new node's data, possibly
// overwriting an old position. The update is not atomic and can be racy. // overwriting an old position. The update is not atomic and can be racy.
void TTEntry::save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev) { void TTEntry::save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev) {
// Preserve any existing move for the same position // Preserve any existing move for the same position
if (m || uint16_t(k) != key16) if (m || uint16_t(k) != key16)
move16 = uint16_t(m); move16 = uint16_t(m);
// Overwrite less valuable entries (cheapest checks first) // Overwrite less valuable entries (cheapest checks first)
if ( b == BOUND_EXACT if (b == BOUND_EXACT || uint16_t(k) != key16 || d - DEPTH_OFFSET + 2 * pv > depth8 - 4)
|| uint16_t(k) != key16 {
|| d - DEPTH_OFFSET + 2 * pv > depth8 - 4) assert(d > DEPTH_OFFSET);
{ assert(d < 256 + DEPTH_OFFSET);
assert(d > DEPTH_OFFSET);
assert(d < 256 + DEPTH_OFFSET);
key16 = uint16_t(k); key16 = uint16_t(k);
depth8 = uint8_t(d - DEPTH_OFFSET); depth8 = uint8_t(d - DEPTH_OFFSET);
genBound8 = uint8_t(TT.generation8 | uint8_t(pv) << 2 | b); genBound8 = uint8_t(TT.generation8 | uint8_t(pv) << 2 | b);
value16 = int16_t(v); value16 = int16_t(v);
eval16 = int16_t(ev); eval16 = int16_t(ev);
} }
} }
@ -65,21 +63,20 @@ void TTEntry::save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev)
void TranspositionTable::resize(size_t mbSize) { void TranspositionTable::resize(size_t mbSize) {
Threads.main()->wait_for_search_finished(); Threads.main()->wait_for_search_finished();
aligned_large_pages_free(table); aligned_large_pages_free(table);
clusterCount = mbSize * 1024 * 1024 / sizeof(Cluster); clusterCount = mbSize * 1024 * 1024 / sizeof(Cluster);
table = static_cast<Cluster*>(aligned_large_pages_alloc(clusterCount * sizeof(Cluster))); table = static_cast<Cluster*>(aligned_large_pages_alloc(clusterCount * sizeof(Cluster)));
if (!table) if (!table)
{ {
std::cerr << "Failed to allocate " << mbSize std::cerr << "Failed to allocate " << mbSize << "MB for transposition table." << std::endl;
<< "MB for transposition table." << std::endl; exit(EXIT_FAILURE);
exit(EXIT_FAILURE); }
}
clear(); clear();
} }
@ -88,28 +85,27 @@ void TranspositionTable::resize(size_t mbSize) {
void TranspositionTable::clear() { void TranspositionTable::clear() {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (size_t idx = 0; idx < size_t(Options["Threads"]); ++idx) for (size_t idx = 0; idx < size_t(Options["Threads"]); ++idx)
{ {
threads.emplace_back([this, idx]() { threads.emplace_back([this, idx]() {
// Thread binding gives faster search on systems with a first-touch policy
if (Options["Threads"] > 8)
WinProcGroup::bindThisThread(idx);
// Thread binding gives faster search on systems with a first-touch policy // Each thread will zero its part of the hash table
if (Options["Threads"] > 8) const size_t stride = size_t(clusterCount / Options["Threads"]),
WinProcGroup::bindThisThread(idx); start = size_t(stride * idx),
len =
idx != size_t(Options["Threads"]) - 1 ? stride : clusterCount - start;
// Each thread will zero its part of the hash table std::memset(&table[start], 0, len * sizeof(Cluster));
const size_t stride = size_t(clusterCount / Options["Threads"]), });
start = size_t(stride * idx), }
len = idx != size_t(Options["Threads"]) - 1 ?
stride : clusterCount - start;
std::memset(&table[start], 0, len * sizeof(Cluster)); for (std::thread& th : threads)
}); th.join();
}
for (std::thread& th : threads)
th.join();
} }
@ -122,30 +118,33 @@ void TranspositionTable::clear() {
TTEntry* TranspositionTable::probe(const Key key, bool& found) const { TTEntry* TranspositionTable::probe(const Key key, bool& found) const {
TTEntry* const tte = first_entry(key); TTEntry* const tte = first_entry(key);
const uint16_t key16 = uint16_t(key); // Use the low 16 bits as key inside the cluster const uint16_t key16 = uint16_t(key); // Use the low 16 bits as key inside the cluster
for (int i = 0; i < ClusterSize; ++i) for (int i = 0; i < ClusterSize; ++i)
if (tte[i].key16 == key16 || !tte[i].depth8) if (tte[i].key16 == key16 || !tte[i].depth8)
{ {
tte[i].genBound8 = uint8_t(generation8 | (tte[i].genBound8 & (GENERATION_DELTA - 1))); // Refresh tte[i].genBound8 =
uint8_t(generation8 | (tte[i].genBound8 & (GENERATION_DELTA - 1))); // Refresh
return found = bool(tte[i].depth8), &tte[i]; return found = bool(tte[i].depth8), &tte[i];
} }
// Find an entry to be replaced according to the replacement strategy // Find an entry to be replaced according to the replacement strategy
TTEntry* replace = tte; TTEntry* replace = tte;
for (int i = 1; i < ClusterSize; ++i) for (int i = 1; i < ClusterSize; ++i)
// Due to our packed storage format for generation and its cyclic // Due to our packed storage format for generation and its cyclic
// nature we add GENERATION_CYCLE (256 is the modulus, plus what // nature we add GENERATION_CYCLE (256 is the modulus, plus what
// is needed to keep the unrelated lowest n bits from affecting // is needed to keep the unrelated lowest n bits from affecting
// the result) to calculate the entry age correctly even after // the result) to calculate the entry age correctly even after
// generation8 overflows into the next cycle. // generation8 overflows into the next cycle.
if ( replace->depth8 - ((GENERATION_CYCLE + generation8 - replace->genBound8) & GENERATION_MASK) if (replace->depth8
> tte[i].depth8 - ((GENERATION_CYCLE + generation8 - tte[i].genBound8) & GENERATION_MASK)) - ((GENERATION_CYCLE + generation8 - replace->genBound8) & GENERATION_MASK)
replace = &tte[i]; > tte[i].depth8
- ((GENERATION_CYCLE + generation8 - tte[i].genBound8) & GENERATION_MASK))
replace = &tte[i];
return found = false, replace; return found = false, replace;
} }
@ -154,12 +153,13 @@ TTEntry* TranspositionTable::probe(const Key key, bool& found) const {
int TranspositionTable::hashfull() const { int TranspositionTable::hashfull() const {
int cnt = 0; int cnt = 0;
for (int i = 0; i < 1000; ++i) for (int i = 0; i < 1000; ++i)
for (int j = 0; j < ClusterSize; ++j) for (int j = 0; j < ClusterSize; ++j)
cnt += table[i].entry[j].depth8 && (table[i].entry[j].genBound8 & GENERATION_MASK) == generation8; cnt += table[i].entry[j].depth8
&& (table[i].entry[j].genBound8 & GENERATION_MASK) == generation8;
return cnt / ClusterSize; return cnt / ClusterSize;
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -40,23 +40,23 @@ namespace Stockfish {
struct TTEntry { struct TTEntry {
Move move() const { return Move (move16); } Move move() const { return Move(move16); }
Value value() const { return Value(value16); } Value value() const { return Value(value16); }
Value eval() const { return Value(eval16); } Value eval() const { return Value(eval16); }
Depth depth() const { return Depth(depth8 + DEPTH_OFFSET); } Depth depth() const { return Depth(depth8 + DEPTH_OFFSET); }
bool is_pv() const { return bool (genBound8 & 0x4); } bool is_pv() const { return bool(genBound8 & 0x4); }
Bound bound() const { return Bound(genBound8 & 0x3); } Bound bound() const { return Bound(genBound8 & 0x3); }
void save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev); void save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev);
private: private:
friend class TranspositionTable; friend class TranspositionTable;
uint16_t key16; uint16_t key16;
uint8_t depth8; uint8_t depth8;
uint8_t genBound8; uint8_t genBound8;
uint16_t move16; uint16_t move16;
int16_t value16; int16_t value16;
int16_t eval16; int16_t eval16;
}; };
@ -68,43 +68,45 @@ private:
class TranspositionTable { class TranspositionTable {
static constexpr int ClusterSize = 3; static constexpr int ClusterSize = 3;
struct Cluster { struct Cluster {
TTEntry entry[ClusterSize]; TTEntry entry[ClusterSize];
char padding[2]; // Pad to 32 bytes char padding[2]; // Pad to 32 bytes
}; };
static_assert(sizeof(Cluster) == 32, "Unexpected Cluster size"); static_assert(sizeof(Cluster) == 32, "Unexpected Cluster size");
// Constants used to refresh the hash table periodically // Constants used to refresh the hash table periodically
static constexpr unsigned GENERATION_BITS = 3; // nb of bits reserved for other things static constexpr unsigned GENERATION_BITS = 3; // nb of bits reserved for other things
static constexpr int GENERATION_DELTA = (1 << GENERATION_BITS); // increment for generation field static constexpr int GENERATION_DELTA =
static constexpr int GENERATION_CYCLE = 255 + (1 << GENERATION_BITS); // cycle length (1 << GENERATION_BITS); // increment for generation field
static constexpr int GENERATION_MASK = (0xFF << GENERATION_BITS) & 0xFF; // mask to pull out generation number static constexpr int GENERATION_CYCLE = 255 + (1 << GENERATION_BITS); // cycle length
static constexpr int GENERATION_MASK =
(0xFF << GENERATION_BITS) & 0xFF; // mask to pull out generation number
public: public:
~TranspositionTable() { aligned_large_pages_free(table); } ~TranspositionTable() { aligned_large_pages_free(table); }
void new_search() { generation8 += GENERATION_DELTA; } // Lower bits are used for other things void new_search() { generation8 += GENERATION_DELTA; } // Lower bits are used for other things
TTEntry* probe(const Key key, bool& found) const; TTEntry* probe(const Key key, bool& found) const;
int hashfull() const; int hashfull() const;
void resize(size_t mbSize); void resize(size_t mbSize);
void clear(); void clear();
TTEntry* first_entry(const Key key) const { TTEntry* first_entry(const Key key) const {
return &table[mul_hi64(key, clusterCount)].entry[0]; return &table[mul_hi64(key, clusterCount)].entry[0];
} }
private: private:
friend struct TTEntry; friend struct TTEntry;
size_t clusterCount; size_t clusterCount;
Cluster* table; Cluster* table;
uint8_t generation8; // Size must be not bigger than TTEntry::genBound8 uint8_t generation8; // Size must be not bigger than TTEntry::genBound8
}; };
extern TranspositionTable TT; extern TranspositionTable TT;
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef TT_H_INCLUDED #endif // #ifndef TT_H_INCLUDED

View file

@ -34,75 +34,84 @@ using std::string;
namespace Stockfish { namespace Stockfish {
bool Tune::update_on_last; bool Tune::update_on_last;
const UCI::Option* LastOption = nullptr; const UCI::Option* LastOption = nullptr;
static std::map<std::string, int> TuneResults; static std::map<std::string, int> TuneResults;
string Tune::next(string& names, bool pop) { string Tune::next(string& names, bool pop) {
string name; string name;
do { do
string token = names.substr(0, names.find(',')); {
string token = names.substr(0, names.find(','));
if (pop) if (pop)
names.erase(0, token.size() + 1); names.erase(0, token.size() + 1);
std::stringstream ws(token); std::stringstream ws(token);
name += (ws >> token, token); // Remove trailing whitespace name += (ws >> token, token); // Remove trailing whitespace
} while ( std::count(name.begin(), name.end(), '(') } while (std::count(name.begin(), name.end(), '(') - std::count(name.begin(), name.end(), ')'));
- std::count(name.begin(), name.end(), ')'));
return name; return name;
} }
static void on_tune(const UCI::Option& o) { static void on_tune(const UCI::Option& o) {
if (!Tune::update_on_last || LastOption == &o) if (!Tune::update_on_last || LastOption == &o)
Tune::read_options(); Tune::read_options();
} }
static void make_option(const string& n, int v, const SetRange& r) { static void make_option(const string& n, int v, const SetRange& r) {
// Do not generate option when there is nothing to tune (ie. min = max) // Do not generate option when there is nothing to tune (ie. min = max)
if (r(v).first == r(v).second) if (r(v).first == r(v).second)
return; return;
if (TuneResults.count(n)) if (TuneResults.count(n))
v = TuneResults[n]; v = TuneResults[n];
Options[n] << UCI::Option(v, r(v).first, r(v).second, on_tune); Options[n] << UCI::Option(v, r(v).first, r(v).second, on_tune);
LastOption = &Options[n]; LastOption = &Options[n];
// Print formatted parameters, ready to be copy-pasted in Fishtest // Print formatted parameters, ready to be copy-pasted in Fishtest
std::cout << n << "," std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << ","
<< v << "," << (r(v).second - r(v).first) / 20.0 << ","
<< r(v).first << "," << r(v).second << "," << "0.0020" << std::endl;
<< (r(v).second - r(v).first) / 20.0 << ","
<< "0.0020"
<< std::endl;
} }
template<> void Tune::Entry<int>::init_option() { make_option(name, value, range); } template<>
void Tune::Entry<int>::init_option() {
template<> void Tune::Entry<int>::read_option() { make_option(name, value, range);
if (Options.count(name))
value = int(Options[name]);
} }
template<> void Tune::Entry<Value>::init_option() { make_option(name, value, range); } template<>
void Tune::Entry<int>::read_option() {
if (Options.count(name))
value = int(Options[name]);
}
template<> void Tune::Entry<Value>::read_option() { template<>
if (Options.count(name)) void Tune::Entry<Value>::init_option() {
value = Value(int(Options[name])); make_option(name, value, range);
}
template<>
void Tune::Entry<Value>::read_option() {
if (Options.count(name))
value = Value(int(Options[name]));
} }
// Instead of a variable here we have a PostUpdate function: just call it // Instead of a variable here we have a PostUpdate function: just call it
template<> void Tune::Entry<Tune::PostUpdate>::init_option() {} template<>
template<> void Tune::Entry<Tune::PostUpdate>::read_option() { value(); } void Tune::Entry<Tune::PostUpdate>::init_option() {}
template<>
void Tune::Entry<Tune::PostUpdate>::read_option() {
value();
}
} // namespace Stockfish } // namespace Stockfish
// Init options with tuning session results instead of default values. Useful to // Init options with tuning session results instead of default values. Useful to
@ -117,9 +126,7 @@ template<> void Tune::Entry<Tune::PostUpdate>::read_option() { value(); }
namespace Stockfish { namespace Stockfish {
void Tune::read_results() { void Tune::read_results() { /* ...insert your values here... */
/* ...insert your values here... */
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -22,28 +22,29 @@
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits> // IWYU pragma: keep #include <type_traits> // IWYU pragma: keep
#include <utility> #include <utility>
#include <vector> #include <vector>
namespace Stockfish { namespace Stockfish {
enum Value : int; enum Value : int;
using Range = std::pair<int, int>; // Option's min-max values using Range = std::pair<int, int>; // Option's min-max values
using RangeFun = Range (int); using RangeFun = Range(int);
// Default Range function, to calculate Option's min-max values // Default Range function, to calculate Option's min-max values
inline Range default_range(int v) { inline Range default_range(int v) { return v > 0 ? Range(0, 2 * v) : Range(2 * v, 0); }
return v > 0 ? Range(0, 2 * v) : Range(2 * v, 0);
}
struct SetRange { struct SetRange {
explicit SetRange(RangeFun f) : fun(f) {} explicit SetRange(RangeFun f) :
SetRange(int min, int max) : fun(nullptr), range(min, max) {} fun(f) {}
Range operator()(int v) const { return fun ? fun(v) : range; } SetRange(int min, int max) :
fun(nullptr),
range(min, max) {}
Range operator()(int v) const { return fun ? fun(v) : range; }
RangeFun* fun; RangeFun* fun;
Range range; Range range;
}; };
#define SetDefaultRange SetRange(default_range) #define SetDefaultRange SetRange(default_range)
@ -76,88 +77,102 @@ struct SetRange {
class Tune { class Tune {
using PostUpdate = void (); // Post-update function using PostUpdate = void(); // Post-update function
Tune() { read_results(); } Tune() { read_results(); }
Tune(const Tune&) = delete; Tune(const Tune&) = delete;
void operator=(const Tune&) = delete; void operator=(const Tune&) = delete;
void read_results(); void read_results();
static Tune& instance() { static Tune t; return t; } // Singleton static Tune& instance() {
static Tune t;
return t;
} // Singleton
// Use polymorphism to accommodate Entry of different types in the same vector // Use polymorphism to accommodate Entry of different types in the same vector
struct EntryBase { struct EntryBase {
virtual ~EntryBase() = default; virtual ~EntryBase() = default;
virtual void init_option() = 0; virtual void init_option() = 0;
virtual void read_option() = 0; virtual void read_option() = 0;
}; };
template<typename T> template<typename T>
struct Entry : public EntryBase { struct Entry: public EntryBase {
static_assert(!std::is_const_v<T>, "Parameter cannot be const!"); static_assert(!std::is_const_v<T>, "Parameter cannot be const!");
static_assert( std::is_same_v<T, int> static_assert(std::is_same_v<T, int> || std::is_same_v<T, Value>
|| std::is_same_v<T, Value> || std::is_same_v<T, PostUpdate>,
|| std::is_same_v<T, PostUpdate>, "Parameter type not supported!"); "Parameter type not supported!");
Entry(const std::string& n, T& v, const SetRange& r) : name(n), value(v), range(r) {} Entry(const std::string& n, T& v, const SetRange& r) :
void operator=(const Entry&) = delete; // Because 'value' is a reference name(n),
void init_option() override; value(v),
void read_option() override; range(r) {}
void operator=(const Entry&) = delete; // Because 'value' is a reference
void init_option() override;
void read_option() override;
std::string name; std::string name;
T& value; T& value;
SetRange range; SetRange range;
}; };
// Our facility to fill the container, each Entry corresponds to a parameter // Our facility to fill the container, each Entry corresponds to a parameter
// to tune. We use variadic templates to deal with an unspecified number of // to tune. We use variadic templates to deal with an unspecified number of
// entries, each one of a possible different type. // entries, each one of a possible different type.
static std::string next(std::string& names, bool pop = true); static std::string next(std::string& names, bool pop = true);
int add(const SetRange&, std::string&&) { return 0; } int add(const SetRange&, std::string&&) { return 0; }
template<typename T, typename... Args> template<typename T, typename... Args>
int add(const SetRange& range, std::string&& names, T& value, Args&&... args) { int add(const SetRange& range, std::string&& names, T& value, Args&&... args) {
list.push_back(std::unique_ptr<EntryBase>(new Entry<T>(next(names), value, range))); list.push_back(std::unique_ptr<EntryBase>(new Entry<T>(next(names), value, range)));
return add(range, std::move(names), args...); return add(range, std::move(names), args...);
} }
// Template specialization for arrays: recursively handle multi-dimensional arrays // Template specialization for arrays: recursively handle multi-dimensional arrays
template<typename T, size_t N, typename... Args> template<typename T, size_t N, typename... Args>
int add(const SetRange& range, std::string&& names, T (&value)[N], Args&&... args) { int add(const SetRange& range, std::string&& names, T (&value)[N], Args&&... args) {
for (size_t i = 0; i < N; i++) for (size_t i = 0; i < N; i++)
add(range, next(names, i == N - 1) + "[" + std::to_string(i) + "]", value[i]); add(range, next(names, i == N - 1) + "[" + std::to_string(i) + "]", value[i]);
return add(range, std::move(names), args...); return add(range, std::move(names), args...);
} }
// Template specialization for SetRange // Template specialization for SetRange
template<typename... Args> template<typename... Args>
int add(const SetRange&, std::string&& names, SetRange& value, Args&&... args) { int add(const SetRange&, std::string&& names, SetRange& value, Args&&... args) {
return add(value, (next(names), std::move(names)), args...); return add(value, (next(names), std::move(names)), args...);
} }
std::vector<std::unique_ptr<EntryBase>> list; std::vector<std::unique_ptr<EntryBase>> list;
public: public:
template<typename... Args> template<typename... Args>
static int add(const std::string& names, Args&&... args) { static int add(const std::string& names, Args&&... args) {
return instance().add(SetDefaultRange, names.substr(1, names.size() - 2), args...); // Remove trailing parenthesis return instance().add(SetDefaultRange, names.substr(1, names.size() - 2),
} args...); // Remove trailing parenthesis
static void init() { for (auto& e : instance().list) e->init_option(); read_options(); } // Deferred, due to UCI::Options access }
static void read_options() { for (auto& e : instance().list) e->read_option(); } static void init() {
static bool update_on_last; for (auto& e : instance().list)
e->init_option();
read_options();
} // Deferred, due to UCI::Options access
static void read_options() {
for (auto& e : instance().list)
e->read_option();
}
static bool update_on_last;
}; };
// Some macro magic :-) we define a dummy int variable that the compiler initializes calling Tune::add() // Some macro magic :-) we define a dummy int variable that the compiler initializes calling Tune::add()
#define STRINGIFY(x) #x #define STRINGIFY(x) #x
#define UNIQUE2(x, y) x ## y #define UNIQUE2(x, y) x##y
#define UNIQUE(x, y) UNIQUE2(x, y) // Two indirection levels to expand __LINE__ #define UNIQUE(x, y) UNIQUE2(x, y) // Two indirection levels to expand __LINE__
#define TUNE(...) int UNIQUE(p, __LINE__) = Tune::add(STRINGIFY((__VA_ARGS__)), __VA_ARGS__) #define TUNE(...) int UNIQUE(p, __LINE__) = Tune::add(STRINGIFY((__VA_ARGS__)), __VA_ARGS__)
#define UPDATE_ON_LAST() bool UNIQUE(p, __LINE__) = Tune::update_on_last = true #define UPDATE_ON_LAST() bool UNIQUE(p, __LINE__) = Tune::update_on_last = true
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef TUNE_H_INCLUDED #endif // #ifndef TUNE_H_INCLUDED

View file

@ -17,7 +17,7 @@
*/ */
#ifndef TYPES_H_INCLUDED #ifndef TYPES_H_INCLUDED
#define TYPES_H_INCLUDED #define TYPES_H_INCLUDED
// When compiling with provided Makefile (e.g. for Linux and OSX), configuration // When compiling with provided Makefile (e.g. for Linux and OSX), configuration
// is done automatically. To get started type 'make help'. // is done automatically. To get started type 'make help'.
@ -36,15 +36,15 @@
// -DUSE_PEXT | Add runtime support for use of pext asm-instruction. Works // -DUSE_PEXT | Add runtime support for use of pext asm-instruction. Works
// | only in 64-bit mode and requires hardware with pext support. // | only in 64-bit mode and requires hardware with pext support.
#include <cassert> #include <cassert>
#include <cstdint> #include <cstdint>
#if defined(_MSC_VER) #if defined(_MSC_VER)
// Disable some silly and noisy warnings from MSVC compiler // Disable some silly and noisy warnings from MSVC compiler
#pragma warning(disable: 4127) // Conditional expression is constant #pragma warning(disable: 4127) // Conditional expression is constant
#pragma warning(disable: 4146) // Unary minus operator applied to unsigned type #pragma warning(disable: 4146) // Unary minus operator applied to unsigned type
#pragma warning(disable: 4800) // Forcing value to bool 'true' or 'false' #pragma warning(disable: 4800) // Forcing value to bool 'true' or 'false'
#endif #endif
// Predefined macros hell: // Predefined macros hell:
// //
@ -55,53 +55,54 @@
// _WIN32 Building on Windows (any) // _WIN32 Building on Windows (any)
// _WIN64 Building on Windows 64 bit // _WIN64 Building on Windows 64 bit
#if defined(__GNUC__ ) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2)) && defined(_WIN32) && !defined(__clang__) #if defined(__GNUC__) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2)) \
#define ALIGNAS_ON_STACK_VARIABLES_BROKEN && defined(_WIN32) && !defined(__clang__)
#endif #define ALIGNAS_ON_STACK_VARIABLES_BROKEN
#endif
#define ASSERT_ALIGNED(ptr, alignment) assert(reinterpret_cast<uintptr_t>(ptr) % alignment == 0) #define ASSERT_ALIGNED(ptr, alignment) assert(reinterpret_cast<uintptr_t>(ptr) % alignment == 0)
#if defined(_WIN64) && defined(_MSC_VER) // No Makefile used #if defined(_WIN64) && defined(_MSC_VER) // No Makefile used
# include <intrin.h> // Microsoft header for _BitScanForward64() #include <intrin.h> // Microsoft header for _BitScanForward64()
# define IS_64BIT #define IS_64BIT
#endif #endif
#if defined(USE_POPCNT) && defined(_MSC_VER) #if defined(USE_POPCNT) && defined(_MSC_VER)
# include <nmmintrin.h> // Microsoft header for _mm_popcnt_u64() #include <nmmintrin.h> // Microsoft header for _mm_popcnt_u64()
#endif #endif
#if !defined(NO_PREFETCH) && defined(_MSC_VER) #if !defined(NO_PREFETCH) && defined(_MSC_VER)
# include <xmmintrin.h> // Microsoft header for _mm_prefetch() #include <xmmintrin.h> // Microsoft header for _mm_prefetch()
#endif #endif
#if defined(USE_PEXT) #if defined(USE_PEXT)
# include <immintrin.h> // Header for _pext_u64() intrinsic #include <immintrin.h> // Header for _pext_u64() intrinsic
# define pext(b, m) _pext_u64(b, m) #define pext(b, m) _pext_u64(b, m)
#else #else
# define pext(b, m) 0 #define pext(b, m) 0
#endif #endif
namespace Stockfish { namespace Stockfish {
#ifdef USE_POPCNT #ifdef USE_POPCNT
constexpr bool HasPopCnt = true; constexpr bool HasPopCnt = true;
#else #else
constexpr bool HasPopCnt = false; constexpr bool HasPopCnt = false;
#endif #endif
#ifdef USE_PEXT #ifdef USE_PEXT
constexpr bool HasPext = true; constexpr bool HasPext = true;
#else #else
constexpr bool HasPext = false; constexpr bool HasPext = false;
#endif #endif
#ifdef IS_64BIT #ifdef IS_64BIT
constexpr bool Is64Bit = true; constexpr bool Is64Bit = true;
#else #else
constexpr bool Is64Bit = false; constexpr bool Is64Bit = false;
#endif #endif
using Key = uint64_t; using Key = uint64_t;
using Bitboard = uint64_t; using Bitboard = uint64_t;
constexpr int MAX_MOVES = 256; constexpr int MAX_MOVES = 256;
@ -120,164 +121,187 @@ constexpr int MAX_PLY = 246;
// while MOVE_NONE and MOVE_NULL have the same origin and destination square. // while MOVE_NONE and MOVE_NULL have the same origin and destination square.
enum Move : int { enum Move : int {
MOVE_NONE, MOVE_NONE,
MOVE_NULL = 65 MOVE_NULL = 65
}; };
enum MoveType { enum MoveType {
NORMAL, NORMAL,
PROMOTION = 1 << 14, PROMOTION = 1 << 14,
EN_PASSANT = 2 << 14, EN_PASSANT = 2 << 14,
CASTLING = 3 << 14 CASTLING = 3 << 14
}; };
enum Color { enum Color {
WHITE, BLACK, COLOR_NB = 2 WHITE,
BLACK,
COLOR_NB = 2
}; };
enum CastlingRights { enum CastlingRights {
NO_CASTLING, NO_CASTLING,
WHITE_OO, WHITE_OO,
WHITE_OOO = WHITE_OO << 1, WHITE_OOO = WHITE_OO << 1,
BLACK_OO = WHITE_OO << 2, BLACK_OO = WHITE_OO << 2,
BLACK_OOO = WHITE_OO << 3, BLACK_OOO = WHITE_OO << 3,
KING_SIDE = WHITE_OO | BLACK_OO, KING_SIDE = WHITE_OO | BLACK_OO,
QUEEN_SIDE = WHITE_OOO | BLACK_OOO, QUEEN_SIDE = WHITE_OOO | BLACK_OOO,
WHITE_CASTLING = WHITE_OO | WHITE_OOO, WHITE_CASTLING = WHITE_OO | WHITE_OOO,
BLACK_CASTLING = BLACK_OO | BLACK_OOO, BLACK_CASTLING = BLACK_OO | BLACK_OOO,
ANY_CASTLING = WHITE_CASTLING | BLACK_CASTLING, ANY_CASTLING = WHITE_CASTLING | BLACK_CASTLING,
CASTLING_RIGHT_NB = 16 CASTLING_RIGHT_NB = 16
}; };
enum Bound { enum Bound {
BOUND_NONE, BOUND_NONE,
BOUND_UPPER, BOUND_UPPER,
BOUND_LOWER, BOUND_LOWER,
BOUND_EXACT = BOUND_UPPER | BOUND_LOWER BOUND_EXACT = BOUND_UPPER | BOUND_LOWER
}; };
enum Value : int { enum Value : int {
VALUE_ZERO = 0, VALUE_ZERO = 0,
VALUE_DRAW = 0, VALUE_DRAW = 0,
VALUE_MATE = 32000, VALUE_MATE = 32000,
VALUE_INFINITE = 32001, VALUE_INFINITE = 32001,
VALUE_NONE = 32002, VALUE_NONE = 32002,
VALUE_TB_WIN_IN_MAX_PLY = VALUE_MATE - 2 * MAX_PLY, VALUE_TB_WIN_IN_MAX_PLY = VALUE_MATE - 2 * MAX_PLY,
VALUE_TB_LOSS_IN_MAX_PLY = -VALUE_TB_WIN_IN_MAX_PLY, VALUE_TB_LOSS_IN_MAX_PLY = -VALUE_TB_WIN_IN_MAX_PLY,
VALUE_MATE_IN_MAX_PLY = VALUE_MATE - MAX_PLY, VALUE_MATE_IN_MAX_PLY = VALUE_MATE - MAX_PLY,
VALUE_MATED_IN_MAX_PLY = -VALUE_MATE_IN_MAX_PLY, VALUE_MATED_IN_MAX_PLY = -VALUE_MATE_IN_MAX_PLY,
// In the code, we make the assumption that these values // In the code, we make the assumption that these values
// are such that non_pawn_material() can be used to uniquely // are such that non_pawn_material() can be used to uniquely
// identify the material on the board. // identify the material on the board.
PawnValue = 208, PawnValue = 208,
KnightValue = 781, KnightValue = 781,
BishopValue = 825, BishopValue = 825,
RookValue = 1276, RookValue = 1276,
QueenValue = 2538, QueenValue = 2538,
}; };
// clang-format off
enum PieceType { enum PieceType {
NO_PIECE_TYPE, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING, NO_PIECE_TYPE, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING,
ALL_PIECES = 0, ALL_PIECES = 0,
PIECE_TYPE_NB = 8 PIECE_TYPE_NB = 8
}; };
enum Piece { enum Piece {
NO_PIECE, NO_PIECE,
W_PAWN = PAWN, W_KNIGHT, W_BISHOP, W_ROOK, W_QUEEN, W_KING, W_PAWN = PAWN, W_KNIGHT, W_BISHOP, W_ROOK, W_QUEEN, W_KING,
B_PAWN = PAWN + 8, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING, B_PAWN = PAWN + 8, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING,
PIECE_NB = 16 PIECE_NB = 16
}; };
// clang-format on
constexpr Value PieceValue[PIECE_NB] = { VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO, constexpr Value PieceValue[PIECE_NB] = {
VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO }; VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO,
VALUE_ZERO, PawnValue, KnightValue, BishopValue, RookValue, QueenValue, VALUE_ZERO, VALUE_ZERO};
using Depth = int; using Depth = int;
enum : int { enum : int {
DEPTH_QS_CHECKS = 0, DEPTH_QS_CHECKS = 0,
DEPTH_QS_NO_CHECKS = -1, DEPTH_QS_NO_CHECKS = -1,
DEPTH_QS_RECAPTURES = -5, DEPTH_QS_RECAPTURES = -5,
DEPTH_NONE = -6, DEPTH_NONE = -6,
DEPTH_OFFSET = -7 // value used only for TT entry occupancy check DEPTH_OFFSET = -7 // value used only for TT entry occupancy check
}; };
// clang-format off
enum Square : int { enum Square : int {
SQ_A1, SQ_B1, SQ_C1, SQ_D1, SQ_E1, SQ_F1, SQ_G1, SQ_H1, SQ_A1, SQ_B1, SQ_C1, SQ_D1, SQ_E1, SQ_F1, SQ_G1, SQ_H1,
SQ_A2, SQ_B2, SQ_C2, SQ_D2, SQ_E2, SQ_F2, SQ_G2, SQ_H2, SQ_A2, SQ_B2, SQ_C2, SQ_D2, SQ_E2, SQ_F2, SQ_G2, SQ_H2,
SQ_A3, SQ_B3, SQ_C3, SQ_D3, SQ_E3, SQ_F3, SQ_G3, SQ_H3, SQ_A3, SQ_B3, SQ_C3, SQ_D3, SQ_E3, SQ_F3, SQ_G3, SQ_H3,
SQ_A4, SQ_B4, SQ_C4, SQ_D4, SQ_E4, SQ_F4, SQ_G4, SQ_H4, SQ_A4, SQ_B4, SQ_C4, SQ_D4, SQ_E4, SQ_F4, SQ_G4, SQ_H4,
SQ_A5, SQ_B5, SQ_C5, SQ_D5, SQ_E5, SQ_F5, SQ_G5, SQ_H5, SQ_A5, SQ_B5, SQ_C5, SQ_D5, SQ_E5, SQ_F5, SQ_G5, SQ_H5,
SQ_A6, SQ_B6, SQ_C6, SQ_D6, SQ_E6, SQ_F6, SQ_G6, SQ_H6, SQ_A6, SQ_B6, SQ_C6, SQ_D6, SQ_E6, SQ_F6, SQ_G6, SQ_H6,
SQ_A7, SQ_B7, SQ_C7, SQ_D7, SQ_E7, SQ_F7, SQ_G7, SQ_H7, SQ_A7, SQ_B7, SQ_C7, SQ_D7, SQ_E7, SQ_F7, SQ_G7, SQ_H7,
SQ_A8, SQ_B8, SQ_C8, SQ_D8, SQ_E8, SQ_F8, SQ_G8, SQ_H8, SQ_A8, SQ_B8, SQ_C8, SQ_D8, SQ_E8, SQ_F8, SQ_G8, SQ_H8,
SQ_NONE, SQ_NONE,
SQUARE_ZERO = 0, SQUARE_ZERO = 0,
SQUARE_NB = 64 SQUARE_NB = 64
}; };
// clang-format on
enum Direction : int { enum Direction : int {
NORTH = 8, NORTH = 8,
EAST = 1, EAST = 1,
SOUTH = -NORTH, SOUTH = -NORTH,
WEST = -EAST, WEST = -EAST,
NORTH_EAST = NORTH + EAST, NORTH_EAST = NORTH + EAST,
SOUTH_EAST = SOUTH + EAST, SOUTH_EAST = SOUTH + EAST,
SOUTH_WEST = SOUTH + WEST, SOUTH_WEST = SOUTH + WEST,
NORTH_WEST = NORTH + WEST NORTH_WEST = NORTH + WEST
}; };
enum File : int { enum File : int {
FILE_A, FILE_B, FILE_C, FILE_D, FILE_E, FILE_F, FILE_G, FILE_H, FILE_NB FILE_A,
FILE_B,
FILE_C,
FILE_D,
FILE_E,
FILE_F,
FILE_G,
FILE_H,
FILE_NB
}; };
enum Rank : int { enum Rank : int {
RANK_1, RANK_2, RANK_3, RANK_4, RANK_5, RANK_6, RANK_7, RANK_8, RANK_NB RANK_1,
RANK_2,
RANK_3,
RANK_4,
RANK_5,
RANK_6,
RANK_7,
RANK_8,
RANK_NB
}; };
// Keep track of what a move changes on the board (used by NNUE) // Keep track of what a move changes on the board (used by NNUE)
struct DirtyPiece { struct DirtyPiece {
// Number of changed pieces // Number of changed pieces
int dirty_num; int dirty_num;
// Max 3 pieces can change in one move. A promotion with capture moves // Max 3 pieces can change in one move. A promotion with capture moves
// both the pawn and the captured piece to SQ_NONE and the piece promoted // both the pawn and the captured piece to SQ_NONE and the piece promoted
// to from SQ_NONE to the capture square. // to from SQ_NONE to the capture square.
Piece piece[3]; Piece piece[3];
// From and to squares, which may be SQ_NONE // From and to squares, which may be SQ_NONE
Square from[3]; Square from[3];
Square to[3]; Square to[3];
}; };
#define ENABLE_BASE_OPERATORS_ON(T) \ #define ENABLE_BASE_OPERATORS_ON(T) \
constexpr T operator+(T d1, int d2) { return T(int(d1) + d2); } \ constexpr T operator+(T d1, int d2) { return T(int(d1) + d2); } \
constexpr T operator-(T d1, int d2) { return T(int(d1) - d2); } \ constexpr T operator-(T d1, int d2) { return T(int(d1) - d2); } \
constexpr T operator-(T d) { return T(-int(d)); } \ constexpr T operator-(T d) { return T(-int(d)); } \
inline T& operator+=(T& d1, int d2) { return d1 = d1 + d2; } \ inline T& operator+=(T& d1, int d2) { return d1 = d1 + d2; } \
inline T& operator-=(T& d1, int d2) { return d1 = d1 - d2; } inline T& operator-=(T& d1, int d2) { return d1 = d1 - d2; }
#define ENABLE_INCR_OPERATORS_ON(T) \ #define ENABLE_INCR_OPERATORS_ON(T) \
inline T& operator++(T& d) { return d = T(int(d) + 1); } \ inline T& operator++(T& d) { return d = T(int(d) + 1); } \
inline T& operator--(T& d) { return d = T(int(d) - 1); } inline T& operator--(T& d) { return d = T(int(d) - 1); }
#define ENABLE_FULL_OPERATORS_ON(T) \ #define ENABLE_FULL_OPERATORS_ON(T) \
ENABLE_BASE_OPERATORS_ON(T) \ ENABLE_BASE_OPERATORS_ON(T) \
constexpr T operator*(int i, T d) { return T(i * int(d)); } \ constexpr T operator*(int i, T d) { return T(i * int(d)); } \
constexpr T operator*(T d, int i) { return T(int(d) * i); } \ constexpr T operator*(T d, int i) { return T(int(d) * i); } \
constexpr T operator/(T d, int i) { return T(int(d) / i); } \ constexpr T operator/(T d, int i) { return T(int(d) / i); } \
constexpr int operator/(T d1, T d2) { return int(d1) / int(d2); } \ constexpr int operator/(T d1, T d2) { return int(d1) / int(d2); } \
inline T& operator*=(T& d, int i) { return d = T(int(d) * i); } \ inline T& operator*=(T& d, int i) { return d = T(int(d) * i); } \
inline T& operator/=(T& d, int i) { return d = T(int(d) / i); } inline T& operator/=(T& d, int i) { return d = T(int(d) / i); }
ENABLE_FULL_OPERATORS_ON(Value) ENABLE_FULL_OPERATORS_ON(Value)
ENABLE_FULL_OPERATORS_ON(Direction) ENABLE_FULL_OPERATORS_ON(Direction)
@ -287,131 +311,97 @@ ENABLE_INCR_OPERATORS_ON(Square)
ENABLE_INCR_OPERATORS_ON(File) ENABLE_INCR_OPERATORS_ON(File)
ENABLE_INCR_OPERATORS_ON(Rank) ENABLE_INCR_OPERATORS_ON(Rank)
#undef ENABLE_FULL_OPERATORS_ON #undef ENABLE_FULL_OPERATORS_ON
#undef ENABLE_INCR_OPERATORS_ON #undef ENABLE_INCR_OPERATORS_ON
#undef ENABLE_BASE_OPERATORS_ON #undef ENABLE_BASE_OPERATORS_ON
// Additional operators to add a Direction to a Square // Additional operators to add a Direction to a Square
constexpr Square operator+(Square s, Direction d) { return Square(int(s) + int(d)); } constexpr Square operator+(Square s, Direction d) { return Square(int(s) + int(d)); }
constexpr Square operator-(Square s, Direction d) { return Square(int(s) - int(d)); } constexpr Square operator-(Square s, Direction d) { return Square(int(s) - int(d)); }
inline Square& operator+=(Square& s, Direction d) { return s = s + d; } inline Square& operator+=(Square& s, Direction d) { return s = s + d; }
inline Square& operator-=(Square& s, Direction d) { return s = s - d; } inline Square& operator-=(Square& s, Direction d) { return s = s - d; }
constexpr Color operator~(Color c) { constexpr Color operator~(Color c) {
return Color(c ^ BLACK); // Toggle color return Color(c ^ BLACK); // Toggle color
} }
constexpr Square flip_rank(Square s) { // Swap A1 <-> A8 constexpr Square flip_rank(Square s) { // Swap A1 <-> A8
return Square(s ^ SQ_A8); return Square(s ^ SQ_A8);
} }
constexpr Square flip_file(Square s) { // Swap A1 <-> H1 constexpr Square flip_file(Square s) { // Swap A1 <-> H1
return Square(s ^ SQ_H1); return Square(s ^ SQ_H1);
} }
constexpr Piece operator~(Piece pc) { constexpr Piece operator~(Piece pc) {
return Piece(pc ^ 8); // Swap color of piece B_KNIGHT <-> W_KNIGHT return Piece(pc ^ 8); // Swap color of piece B_KNIGHT <-> W_KNIGHT
} }
constexpr CastlingRights operator&(Color c, CastlingRights cr) { constexpr CastlingRights operator&(Color c, CastlingRights cr) {
return CastlingRights((c == WHITE ? WHITE_CASTLING : BLACK_CASTLING) & cr); return CastlingRights((c == WHITE ? WHITE_CASTLING : BLACK_CASTLING) & cr);
} }
constexpr Value mate_in(int ply) { constexpr Value mate_in(int ply) { return VALUE_MATE - ply; }
return VALUE_MATE - ply;
}
constexpr Value mated_in(int ply) { constexpr Value mated_in(int ply) { return -VALUE_MATE + ply; }
return -VALUE_MATE + ply;
}
constexpr Square make_square(File f, Rank r) { constexpr Square make_square(File f, Rank r) { return Square((r << 3) + f); }
return Square((r << 3) + f);
}
constexpr Piece make_piece(Color c, PieceType pt) { constexpr Piece make_piece(Color c, PieceType pt) { return Piece((c << 3) + pt); }
return Piece((c << 3) + pt);
}
constexpr PieceType type_of(Piece pc) { constexpr PieceType type_of(Piece pc) { return PieceType(pc & 7); }
return PieceType(pc & 7);
}
inline Color color_of(Piece pc) { inline Color color_of(Piece pc) {
assert(pc != NO_PIECE); assert(pc != NO_PIECE);
return Color(pc >> 3); return Color(pc >> 3);
} }
constexpr bool is_ok(Move m) { constexpr bool is_ok(Move m) { return m != MOVE_NONE && m != MOVE_NULL; }
return m != MOVE_NONE && m != MOVE_NULL;
}
constexpr bool is_ok(Square s) { constexpr bool is_ok(Square s) { return s >= SQ_A1 && s <= SQ_H8; }
return s >= SQ_A1 && s <= SQ_H8;
}
constexpr File file_of(Square s) { constexpr File file_of(Square s) { return File(s & 7); }
return File(s & 7);
}
constexpr Rank rank_of(Square s) { constexpr Rank rank_of(Square s) { return Rank(s >> 3); }
return Rank(s >> 3);
}
constexpr Square relative_square(Color c, Square s) { constexpr Square relative_square(Color c, Square s) { return Square(s ^ (c * 56)); }
return Square(s ^ (c * 56));
}
constexpr Rank relative_rank(Color c, Rank r) { constexpr Rank relative_rank(Color c, Rank r) { return Rank(r ^ (c * 7)); }
return Rank(r ^ (c * 7));
}
constexpr Rank relative_rank(Color c, Square s) { constexpr Rank relative_rank(Color c, Square s) { return relative_rank(c, rank_of(s)); }
return relative_rank(c, rank_of(s));
}
constexpr Direction pawn_push(Color c) { constexpr Direction pawn_push(Color c) { return c == WHITE ? NORTH : SOUTH; }
return c == WHITE ? NORTH : SOUTH;
}
constexpr Square from_sq(Move m) { constexpr Square from_sq(Move m) {
assert(is_ok(m)); assert(is_ok(m));
return Square((m >> 6) & 0x3F); return Square((m >> 6) & 0x3F);
} }
constexpr Square to_sq(Move m) { constexpr Square to_sq(Move m) {
assert(is_ok(m)); assert(is_ok(m));
return Square(m & 0x3F); return Square(m & 0x3F);
} }
constexpr int from_to(Move m) { constexpr int from_to(Move m) { return m & 0xFFF; }
return m & 0xFFF;
}
constexpr MoveType type_of(Move m) { constexpr MoveType type_of(Move m) { return MoveType(m & (3 << 14)); }
return MoveType(m & (3 << 14));
}
constexpr PieceType promotion_type(Move m) { constexpr PieceType promotion_type(Move m) { return PieceType(((m >> 12) & 3) + KNIGHT); }
return PieceType(((m >> 12) & 3) + KNIGHT);
}
constexpr Move make_move(Square from, Square to) { constexpr Move make_move(Square from, Square to) { return Move((from << 6) + to); }
return Move((from << 6) + to);
}
template<MoveType T> template<MoveType T>
constexpr Move make(Square from, Square to, PieceType pt = KNIGHT) { constexpr Move make(Square from, Square to, PieceType pt = KNIGHT) {
return Move(T + ((pt - KNIGHT) << 12) + (from << 6) + to); return Move(T + ((pt - KNIGHT) << 12) + (from << 6) + to);
} }
// Based on a congruential pseudo-random number generator // Based on a congruential pseudo-random number generator
constexpr Key make_key(uint64_t seed) { constexpr Key make_key(uint64_t seed) {
return seed * 6364136223846793005ULL + 1442695040888963407ULL; return seed * 6364136223846793005ULL + 1442695040888963407ULL;
} }
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef TYPES_H_INCLUDED #endif // #ifndef TYPES_H_INCLUDED
#include "tune.h" // Global visibility to tuning setup #include "tune.h" // Global visibility to tuning setup

View file

@ -45,18 +45,18 @@ namespace Stockfish {
namespace { namespace {
// FEN string for the initial position in standard chess // FEN string for the initial position in standard chess
const char* StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"; const char* StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
// position() is called when the engine receives the "position" UCI command. // position() is called when the engine receives the "position" UCI command.
// It sets up the position that is described in the given FEN string ("fen") or // It sets up the position that is described in the given FEN string ("fen") or
// the initial position ("startpos") and then makes the moves given in the following // the initial position ("startpos") and then makes the moves given in the following
// move list ("moves"). // move list ("moves").
void position(Position& pos, std::istringstream& is, StateListPtr& states) { void position(Position& pos, std::istringstream& is, StateListPtr& states) {
Move m; Move m;
std::string token, fen; std::string token, fen;
is >> token; is >> token;
@ -64,7 +64,7 @@ namespace {
if (token == "startpos") if (token == "startpos")
{ {
fen = StartFEN; fen = StartFEN;
is >> token; // Consume the "moves" token, if any is >> token; // Consume the "moves" token, if any
} }
else if (token == "fen") else if (token == "fen")
while (is >> token && token != "moves") while (is >> token && token != "moves")
@ -72,7 +72,7 @@ namespace {
else else
return; return;
states = StateListPtr(new std::deque<StateInfo>(1)); // Drop the old state and create a new one states = StateListPtr(new std::deque<StateInfo>(1)); // Drop the old state and create a new one
pos.set(fen, Options["UCI_Chess960"], &states->back(), Threads.main()); pos.set(fen, Options["UCI_Chess960"], &states->back(), Threads.main());
// Parse the move list, if any // Parse the move list, if any
@ -81,33 +81,33 @@ namespace {
states->emplace_back(); states->emplace_back();
pos.do_move(m, states->back()); pos.do_move(m, states->back());
} }
} }
// trace_eval() prints the evaluation of the current position, consistent with // trace_eval() prints the evaluation of the current position, consistent with
// the UCI options set so far. // the UCI options set so far.
void trace_eval(Position& pos) { void trace_eval(Position& pos) {
StateListPtr states(new std::deque<StateInfo>(1)); StateListPtr states(new std::deque<StateInfo>(1));
Position p; Position p;
p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main()); p.set(pos.fen(), Options["UCI_Chess960"], &states->back(), Threads.main());
Eval::NNUE::verify(); Eval::NNUE::verify();
sync_cout << "\n" << Eval::trace(p) << sync_endl; sync_cout << "\n" << Eval::trace(p) << sync_endl;
} }
// setoption() is called when the engine receives the "setoption" UCI command. // setoption() is called when the engine receives the "setoption" UCI command.
// The function updates the UCI option ("name") to the given value ("value"). // The function updates the UCI option ("name") to the given value ("value").
void setoption(std::istringstream& is) { void setoption(std::istringstream& is) {
Threads.main()->wait_for_search_finished(); Threads.main()->wait_for_search_finished();
std::string token, name, value; std::string token, name, value;
is >> token; // Consume the "name" token is >> token; // Consume the "name" token
// Read the option name (can contain spaces) // Read the option name (can contain spaces)
while (is >> token && token != "value") while (is >> token && token != "value")
@ -121,54 +121,67 @@ namespace {
Options[name] = value; Options[name] = value;
else else
sync_cout << "No such option: " << name << sync_endl; sync_cout << "No such option: " << name << sync_endl;
} }
// go() is called when the engine receives the "go" UCI command. The function // go() is called when the engine receives the "go" UCI command. The function
// sets the thinking time and other parameters from the input string, then starts // sets the thinking time and other parameters from the input string, then starts
// with a search. // with a search.
void go(Position& pos, std::istringstream& is, StateListPtr& states) { void go(Position& pos, std::istringstream& is, StateListPtr& states) {
Search::LimitsType limits; Search::LimitsType limits;
std::string token; std::string token;
bool ponderMode = false; bool ponderMode = false;
limits.startTime = now(); // The search starts as early as possible limits.startTime = now(); // The search starts as early as possible
while (is >> token) while (is >> token)
if (token == "searchmoves") // Needs to be the last command on the line if (token == "searchmoves") // Needs to be the last command on the line
while (is >> token) while (is >> token)
limits.searchmoves.push_back(UCI::to_move(pos, token)); limits.searchmoves.push_back(UCI::to_move(pos, token));
else if (token == "wtime") is >> limits.time[WHITE]; else if (token == "wtime")
else if (token == "btime") is >> limits.time[BLACK]; is >> limits.time[WHITE];
else if (token == "winc") is >> limits.inc[WHITE]; else if (token == "btime")
else if (token == "binc") is >> limits.inc[BLACK]; is >> limits.time[BLACK];
else if (token == "movestogo") is >> limits.movestogo; else if (token == "winc")
else if (token == "depth") is >> limits.depth; is >> limits.inc[WHITE];
else if (token == "nodes") is >> limits.nodes; else if (token == "binc")
else if (token == "movetime") is >> limits.movetime; is >> limits.inc[BLACK];
else if (token == "mate") is >> limits.mate; else if (token == "movestogo")
else if (token == "perft") is >> limits.perft; is >> limits.movestogo;
else if (token == "infinite") limits.infinite = 1; else if (token == "depth")
else if (token == "ponder") ponderMode = true; is >> limits.depth;
else if (token == "nodes")
is >> limits.nodes;
else if (token == "movetime")
is >> limits.movetime;
else if (token == "mate")
is >> limits.mate;
else if (token == "perft")
is >> limits.perft;
else if (token == "infinite")
limits.infinite = 1;
else if (token == "ponder")
ponderMode = true;
Threads.start_thinking(pos, states, limits, ponderMode); Threads.start_thinking(pos, states, limits, ponderMode);
} }
// bench() is called when the engine receives the "bench" command. // bench() is called when the engine receives the "bench" command.
// First, a list of UCI commands is set up according to the bench // First, a list of UCI commands is set up according to the bench
// parameters, then it is run one by one, printing a summary at the end. // parameters, then it is run one by one, printing a summary at the end.
void bench(Position& pos, std::istream& args, StateListPtr& states) { void bench(Position& pos, std::istream& args, StateListPtr& states) {
std::string token; std::string token;
uint64_t num, nodes = 0, cnt = 1; uint64_t num, nodes = 0, cnt = 1;
std::vector<std::string> list = setup_bench(pos, args); std::vector<std::string> list = setup_bench(pos, args);
num = count_if(list.begin(), list.end(), [](const std::string& s) { return s.find("go ") == 0 || s.find("eval") == 0; }); num = count_if(list.begin(), list.end(),
[](const std::string& s) { return s.find("go ") == 0 || s.find("eval") == 0; });
TimePoint elapsed = now(); TimePoint elapsed = now();
@ -179,58 +192,64 @@ namespace {
if (token == "go" || token == "eval") if (token == "go" || token == "eval")
{ {
std::cerr << "\nPosition: " << cnt++ << '/' << num << " (" << pos.fen() << ")" << std::endl; std::cerr << "\nPosition: " << cnt++ << '/' << num << " (" << pos.fen() << ")"
<< std::endl;
if (token == "go") if (token == "go")
{ {
go(pos, is, states); go(pos, is, states);
Threads.main()->wait_for_search_finished(); Threads.main()->wait_for_search_finished();
nodes += Threads.nodes_searched(); nodes += Threads.nodes_searched();
} }
else else
trace_eval(pos); trace_eval(pos);
} }
else if (token == "setoption") setoption(is); else if (token == "setoption")
else if (token == "position") position(pos, is, states); setoption(is);
else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take a while else if (token == "position")
position(pos, is, states);
else if (token == "ucinewgame")
{
Search::clear();
elapsed = now();
} // Search::clear() may take a while
} }
elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero' elapsed = now() - elapsed + 1; // Ensure positivity to avoid a 'divide by zero'
dbg_print(); dbg_print();
std::cerr << "\n===========================" std::cerr << "\n==========================="
<< "\nTotal time (ms) : " << elapsed << "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes
<< "\nNodes searched : " << nodes
<< "\nNodes/second : " << 1000 * nodes / elapsed << std::endl; << "\nNodes/second : " << 1000 * nodes / elapsed << std::endl;
} }
// The win rate model returns the probability of winning (in per mille units) given an // The win rate model returns the probability of winning (in per mille units) given an
// eval and a game ply. It fits the LTC fishtest statistics rather accurately. // eval and a game ply. It fits the LTC fishtest statistics rather accurately.
int win_rate_model(Value v, int ply) { int win_rate_model(Value v, int ply) {
// The model only captures up to 240 plies, so limit the input and then rescale // The model only captures up to 240 plies, so limit the input and then rescale
double m = std::min(240, ply) / 64.0; double m = std::min(240, ply) / 64.0;
// The coefficients of a third-order polynomial fit is based on the fishtest data // The coefficients of a third-order polynomial fit is based on the fishtest data
// for two parameters that need to transform eval to the argument of a logistic // for two parameters that need to transform eval to the argument of a logistic
// function. // function.
constexpr double as[] = { 0.38036525, -2.82015070, 23.17882135, 307.36768407}; constexpr double as[] = {0.38036525, -2.82015070, 23.17882135, 307.36768407};
constexpr double bs[] = { -2.29434733, 13.27689788, -14.26828904, 63.45318330 }; constexpr double bs[] = {-2.29434733, 13.27689788, -14.26828904, 63.45318330};
// Enforce that NormalizeToPawnValue corresponds to a 50% win rate at ply 64 // Enforce that NormalizeToPawnValue corresponds to a 50% win rate at ply 64
static_assert(UCI::NormalizeToPawnValue == int(as[0] + as[1] + as[2] + as[3])); static_assert(UCI::NormalizeToPawnValue == int(as[0] + as[1] + as[2] + as[3]));
double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3]; double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3];
double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3]; double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3];
// Transform the eval to centipawns with limited range // Transform the eval to centipawns with limited range
double x = std::clamp(double(v), -4000.0, 4000.0); double x = std::clamp(double(v), -4000.0, 4000.0);
// Return the win rate in per mille units, rounded to the nearest integer // Return the win rate in per mille units, rounded to the nearest integer
return int(0.5 + 1000 / (1 + std::exp((a - x) / b))); return int(0.5 + 1000 / (1 + std::exp((a - x) / b)));
} }
} // namespace } // namespace
// UCI::loop() waits for a command from the stdin, parses it, and then calls the appropriate // UCI::loop() waits for a command from the stdin, parses it, and then calls the appropriate
@ -241,81 +260,91 @@ namespace {
void UCI::loop(int argc, char* argv[]) { void UCI::loop(int argc, char* argv[]) {
Position pos; Position pos;
std::string token, cmd; std::string token, cmd;
StateListPtr states(new std::deque<StateInfo>(1)); StateListPtr states(new std::deque<StateInfo>(1));
pos.set(StartFEN, false, &states->back(), Threads.main()); pos.set(StartFEN, false, &states->back(), Threads.main());
for (int i = 1; i < argc; ++i) for (int i = 1; i < argc; ++i)
cmd += std::string(argv[i]) + " "; cmd += std::string(argv[i]) + " ";
do { do
if (argc == 1 && !getline(std::cin, cmd)) // Wait for an input or an end-of-file (EOF) indication {
cmd = "quit"; if (argc == 1
&& !getline(std::cin, cmd)) // Wait for an input or an end-of-file (EOF) indication
cmd = "quit";
std::istringstream is(cmd); std::istringstream is(cmd);
token.clear(); // Avoid a stale if getline() returns nothing or a blank line token.clear(); // Avoid a stale if getline() returns nothing or a blank line
is >> std::skipws >> token; is >> std::skipws >> token;
if ( token == "quit" if (token == "quit" || token == "stop")
|| token == "stop") Threads.stop = true;
Threads.stop = true;
// The GUI sends 'ponderhit' to tell that the user has played the expected move. // The GUI sends 'ponderhit' to tell that the user has played the expected move.
// So, 'ponderhit' is sent if pondering was done on the same move that the user // So, 'ponderhit' is sent if pondering was done on the same move that the user
// has played. The search should continue, but should also switch from pondering // has played. The search should continue, but should also switch from pondering
// to the normal search. // to the normal search.
else if (token == "ponderhit") else if (token == "ponderhit")
Threads.main()->ponder = false; // Switch to the normal search Threads.main()->ponder = false; // Switch to the normal search
else if (token == "uci") else if (token == "uci")
sync_cout << "id name " << engine_info(true) sync_cout << "id name " << engine_info(true) << "\n"
<< "\n" << Options << Options << "\nuciok" << sync_endl;
<< "\nuciok" << sync_endl;
else if (token == "setoption") setoption(is); else if (token == "setoption")
else if (token == "go") go(pos, is, states); setoption(is);
else if (token == "position") position(pos, is, states); else if (token == "go")
else if (token == "ucinewgame") Search::clear(); go(pos, is, states);
else if (token == "isready") sync_cout << "readyok" << sync_endl; else if (token == "position")
position(pos, is, states);
else if (token == "ucinewgame")
Search::clear();
else if (token == "isready")
sync_cout << "readyok" << sync_endl;
// Add custom non-UCI commands, mainly for debugging purposes. // Add custom non-UCI commands, mainly for debugging purposes.
// These commands must not be used during a search! // These commands must not be used during a search!
else if (token == "flip") pos.flip(); else if (token == "flip")
else if (token == "bench") bench(pos, is, states); pos.flip();
else if (token == "d") sync_cout << pos << sync_endl; else if (token == "bench")
else if (token == "eval") trace_eval(pos); bench(pos, is, states);
else if (token == "compiler") sync_cout << compiler_info() << sync_endl; else if (token == "d")
else if (token == "export_net") sync_cout << pos << sync_endl;
{ else if (token == "eval")
std::optional<std::string> filename; trace_eval(pos);
std::string f; else if (token == "compiler")
if (is >> std::skipws >> f) sync_cout << compiler_info() << sync_endl;
filename = f; else if (token == "export_net")
Eval::NNUE::save_eval(filename); {
} std::optional<std::string> filename;
else if (token == "--help" || token == "help" || token == "--license" || token == "license") std::string f;
sync_cout << "\nStockfish is a powerful chess engine for playing and analyzing." if (is >> std::skipws >> f)
"\nIt is released as free software licensed under the GNU GPLv3 License." filename = f;
"\nStockfish is normally used with a graphical user interface (GUI) and implements" Eval::NNUE::save_eval(filename);
"\nthe Universal Chess Interface (UCI) protocol to communicate with a GUI, an API, etc." }
"\nFor any further information, visit https://github.com/official-stockfish/Stockfish#readme" else if (token == "--help" || token == "help" || token == "--license" || token == "license")
"\nor read the corresponding README.md and Copying.txt files distributed along with this program.\n" << sync_endl; sync_cout
else if (!token.empty() && token[0] != '#') << "\nStockfish is a powerful chess engine for playing and analyzing."
sync_cout << "Unknown command: '" << cmd << "'. Type help for more information." << sync_endl; "\nIt is released as free software licensed under the GNU GPLv3 License."
"\nStockfish is normally used with a graphical user interface (GUI) and implements"
"\nthe Universal Chess Interface (UCI) protocol to communicate with a GUI, an API, etc."
"\nFor any further information, visit https://github.com/official-stockfish/Stockfish#readme"
"\nor read the corresponding README.md and Copying.txt files distributed along with this program.\n"
<< sync_endl;
else if (!token.empty() && token[0] != '#')
sync_cout << "Unknown command: '" << cmd << "'. Type help for more information."
<< sync_endl;
} while (token != "quit" && argc == 1); // The command-line arguments are one-shot } while (token != "quit" && argc == 1); // The command-line arguments are one-shot
} }
// Turns a Value to an integer centipawn number, // Turns a Value to an integer centipawn number,
// without treatment of mate and similar special scores. // without treatment of mate and similar special scores.
int UCI::to_cp(Value v) { int UCI::to_cp(Value v) { return 100 * v / UCI::NormalizeToPawnValue; }
return 100 * v / UCI::NormalizeToPawnValue;
}
// UCI::value() converts a Value to a string by adhering to the UCI protocol specification: // UCI::value() converts a Value to a string by adhering to the UCI protocol specification:
// //
@ -325,21 +354,21 @@ int UCI::to_cp(Value v) {
std::string UCI::value(Value v) { std::string UCI::value(Value v) {
assert(-VALUE_INFINITE < v && v < VALUE_INFINITE); assert(-VALUE_INFINITE < v && v < VALUE_INFINITE);
std::stringstream ss; std::stringstream ss;
if (abs(v) < VALUE_TB_WIN_IN_MAX_PLY) if (abs(v) < VALUE_TB_WIN_IN_MAX_PLY)
ss << "cp " << UCI::to_cp(v); ss << "cp " << UCI::to_cp(v);
else if (abs(v) < VALUE_MATE_IN_MAX_PLY) else if (abs(v) < VALUE_MATE_IN_MAX_PLY)
{ {
const int ply = VALUE_MATE_IN_MAX_PLY - 1 - std::abs(v); // recompute ss->ply const int ply = VALUE_MATE_IN_MAX_PLY - 1 - std::abs(v); // recompute ss->ply
ss << "cp " << (v > 0 ? 20000 - ply : -20000 + ply); ss << "cp " << (v > 0 ? 20000 - ply : -20000 + ply);
} }
else else
ss << "mate " << (v > 0 ? VALUE_MATE - v + 1 : -VALUE_MATE - v) / 2; ss << "mate " << (v > 0 ? VALUE_MATE - v + 1 : -VALUE_MATE - v) / 2;
return ss.str(); return ss.str();
} }
@ -348,21 +377,21 @@ std::string UCI::value(Value v) {
std::string UCI::wdl(Value v, int ply) { std::string UCI::wdl(Value v, int ply) {
std::stringstream ss; std::stringstream ss;
int wdl_w = win_rate_model( v, ply); int wdl_w = win_rate_model(v, ply);
int wdl_l = win_rate_model(-v, ply); int wdl_l = win_rate_model(-v, ply);
int wdl_d = 1000 - wdl_w - wdl_l; int wdl_d = 1000 - wdl_w - wdl_l;
ss << " wdl " << wdl_w << " " << wdl_d << " " << wdl_l; ss << " wdl " << wdl_w << " " << wdl_d << " " << wdl_l;
return ss.str(); return ss.str();
} }
// UCI::square() converts a Square to a string in algebraic notation (g1, a7, etc.) // UCI::square() converts a Square to a string in algebraic notation (g1, a7, etc.)
std::string UCI::square(Square s) { std::string UCI::square(Square s) {
return std::string{ char('a' + file_of(s)), char('1' + rank_of(s)) }; return std::string{char('a' + file_of(s)), char('1' + rank_of(s))};
} }
@ -373,24 +402,24 @@ std::string UCI::square(Square s) {
std::string UCI::move(Move m, bool chess960) { std::string UCI::move(Move m, bool chess960) {
if (m == MOVE_NONE) if (m == MOVE_NONE)
return "(none)"; return "(none)";
if (m == MOVE_NULL) if (m == MOVE_NULL)
return "0000"; return "0000";
Square from = from_sq(m); Square from = from_sq(m);
Square to = to_sq(m); Square to = to_sq(m);
if (type_of(m) == CASTLING && !chess960) if (type_of(m) == CASTLING && !chess960)
to = make_square(to > from ? FILE_G : FILE_C, rank_of(from)); to = make_square(to > from ? FILE_G : FILE_C, rank_of(from));
std::string move = UCI::square(from) + UCI::square(to); std::string move = UCI::square(from) + UCI::square(to);
if (type_of(m) == PROMOTION) if (type_of(m) == PROMOTION)
move += " pnbrqk"[promotion_type(m)]; move += " pnbrqk"[promotion_type(m)];
return move; return move;
} }
@ -399,14 +428,14 @@ std::string UCI::move(Move m, bool chess960) {
Move UCI::to_move(const Position& pos, std::string& str) { Move UCI::to_move(const Position& pos, std::string& str) {
if (str.length() == 5) if (str.length() == 5)
str[4] = char(tolower(str[4])); // The promotion piece character must be lowercased str[4] = char(tolower(str[4])); // The promotion piece character must be lowercased
for (const auto& m : MoveList<LEGAL>(pos)) for (const auto& m : MoveList<LEGAL>(pos))
if (str == UCI::move(m, pos.is_chess960())) if (str == UCI::move(m, pos.is_chess960()))
return m; return m;
return MOVE_NONE; return MOVE_NONE;
} }
} // namespace Stockfish } // namespace Stockfish

View file

@ -43,7 +43,7 @@ class Option;
// Define a custom comparator, because the UCI options should be case-insensitive // Define a custom comparator, because the UCI options should be case-insensitive
struct CaseInsensitiveLess { struct CaseInsensitiveLess {
bool operator() (const std::string&, const std::string&) const; bool operator()(const std::string&, const std::string&) const;
}; };
// The options container is defined as a std::map // The options container is defined as a std::map
@ -52,44 +52,44 @@ using OptionsMap = std::map<std::string, Option, CaseInsensitiveLess>;
// The Option class implements each option as specified by the UCI protocol // The Option class implements each option as specified by the UCI protocol
class Option { class Option {
using OnChange = void (*)(const Option&); using OnChange = void (*)(const Option&);
public: public:
Option(OnChange = nullptr); Option(OnChange = nullptr);
Option(bool v, OnChange = nullptr); Option(bool v, OnChange = nullptr);
Option(const char* v, OnChange = nullptr); Option(const char* v, OnChange = nullptr);
Option(double v, int minv, int maxv, OnChange = nullptr); Option(double v, int minv, int maxv, OnChange = nullptr);
Option(const char* v, const char* cur, OnChange = nullptr); Option(const char* v, const char* cur, OnChange = nullptr);
Option& operator=(const std::string&); Option& operator=(const std::string&);
void operator<<(const Option&); void operator<<(const Option&);
operator int() const; operator int() const;
operator std::string() const; operator std::string() const;
bool operator==(const char*) const; bool operator==(const char*) const;
private: private:
friend std::ostream& operator<<(std::ostream&, const OptionsMap&); friend std::ostream& operator<<(std::ostream&, const OptionsMap&);
std::string defaultValue, currentValue, type; std::string defaultValue, currentValue, type;
int min, max; int min, max;
size_t idx; size_t idx;
OnChange on_change; OnChange on_change;
}; };
void init(OptionsMap&); void init(OptionsMap&);
void loop(int argc, char* argv[]); void loop(int argc, char* argv[]);
int to_cp(Value v); int to_cp(Value v);
std::string value(Value v); std::string value(Value v);
std::string square(Square s); std::string square(Square s);
std::string move(Move m, bool chess960); std::string move(Move m, bool chess960);
std::string pv(const Position& pos, Depth depth); std::string pv(const Position& pos, Depth depth);
std::string wdl(Value v, int ply); std::string wdl(Value v, int ply);
Move to_move(const Position& pos, std::string& str); Move to_move(const Position& pos, std::string& str);
} // namespace UCI } // namespace UCI
extern UCI::OptionsMap Options; extern UCI::OptionsMap Options;
} // namespace Stockfish } // namespace Stockfish
#endif // #ifndef UCI_H_INCLUDED #endif // #ifndef UCI_H_INCLUDED

View file

@ -40,7 +40,7 @@ using std::string;
namespace Stockfish { namespace Stockfish {
UCI::OptionsMap Options; // Global object UCI::OptionsMap Options; // Global object
namespace UCI { namespace UCI {
@ -53,10 +53,10 @@ static void on_tb_path(const Option& o) { Tablebases::init(o); }
static void on_eval_file(const Option&) { Eval::NNUE::init(); } static void on_eval_file(const Option&) { Eval::NNUE::init(); }
// Our case insensitive less() function as required by UCI protocol // Our case insensitive less() function as required by UCI protocol
bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { bool CaseInsensitiveLess::operator()(const string& s1, const string& s2) const {
return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), s2.end(), return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), s2.end(),
[](char c1, char c2) { return tolower(c1) < tolower(c2); }); [](char c1, char c2) { return tolower(c1) < tolower(c2); });
} }
@ -64,28 +64,28 @@ bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const
void init(OptionsMap& o) { void init(OptionsMap& o) {
constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048; constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048;
o["Debug Log File"] << Option("", on_logger); o["Debug Log File"] << Option("", on_logger);
o["Threads"] << Option(1, 1, 1024, on_threads); o["Threads"] << Option(1, 1, 1024, on_threads);
o["Hash"] << Option(16, 1, MaxHashMB, on_hash_size); o["Hash"] << Option(16, 1, MaxHashMB, on_hash_size);
o["Clear Hash"] << Option(on_clear_hash); o["Clear Hash"] << Option(on_clear_hash);
o["Ponder"] << Option(false); o["Ponder"] << Option(false);
o["MultiPV"] << Option(1, 1, 500); o["MultiPV"] << Option(1, 1, 500);
o["Skill Level"] << Option(20, 0, 20); o["Skill Level"] << Option(20, 0, 20);
o["Move Overhead"] << Option(10, 0, 5000); o["Move Overhead"] << Option(10, 0, 5000);
o["Slow Mover"] << Option(100, 10, 1000); o["Slow Mover"] << Option(100, 10, 1000);
o["nodestime"] << Option(0, 0, 10000); o["nodestime"] << Option(0, 0, 10000);
o["UCI_Chess960"] << Option(false); o["UCI_Chess960"] << Option(false);
o["UCI_AnalyseMode"] << Option(false); o["UCI_AnalyseMode"] << Option(false);
o["UCI_LimitStrength"] << Option(false); o["UCI_LimitStrength"] << Option(false);
o["UCI_Elo"] << Option(1320, 1320, 3190); o["UCI_Elo"] << Option(1320, 1320, 3190);
o["UCI_ShowWDL"] << Option(false); o["UCI_ShowWDL"] << Option(false);
o["SyzygyPath"] << Option("<empty>", on_tb_path); o["SyzygyPath"] << Option("<empty>", on_tb_path);
o["SyzygyProbeDepth"] << Option(1, 1, 100); o["SyzygyProbeDepth"] << Option(1, 1, 100);
o["Syzygy50MoveRule"] << Option(true); o["Syzygy50MoveRule"] << Option(true);
o["SyzygyProbeLimit"] << Option(7, 0, 7); o["SyzygyProbeLimit"] << Option(7, 0, 7);
o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file); o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file);
} }
@ -94,59 +94,81 @@ void init(OptionsMap& o) {
std::ostream& operator<<(std::ostream& os, const OptionsMap& om) { std::ostream& operator<<(std::ostream& os, const OptionsMap& om) {
for (size_t idx = 0; idx < om.size(); ++idx) for (size_t idx = 0; idx < om.size(); ++idx)
for (const auto& it : om) for (const auto& it : om)
if (it.second.idx == idx) if (it.second.idx == idx)
{ {
const Option& o = it.second; const Option& o = it.second;
os << "\noption name " << it.first << " type " << o.type; os << "\noption name " << it.first << " type " << o.type;
if (o.type == "string" || o.type == "check" || o.type == "combo") if (o.type == "string" || o.type == "check" || o.type == "combo")
os << " default " << o.defaultValue; os << " default " << o.defaultValue;
if (o.type == "spin") if (o.type == "spin")
os << " default " << int(stof(o.defaultValue)) os << " default " << int(stof(o.defaultValue)) << " min " << o.min << " max "
<< " min " << o.min << o.max;
<< " max " << o.max;
break; break;
} }
return os; return os;
} }
// Option class constructors and conversion operators // Option class constructors and conversion operators
Option::Option(const char* v, OnChange f) : type("string"), min(0), max(0), on_change(f) Option::Option(const char* v, OnChange f) :
{ defaultValue = currentValue = v; } type("string"),
min(0),
max(0),
on_change(f) {
defaultValue = currentValue = v;
}
Option::Option(bool v, OnChange f) : type("check"), min(0), max(0), on_change(f) Option::Option(bool v, OnChange f) :
{ defaultValue = currentValue = (v ? "true" : "false"); } type("check"),
min(0),
max(0),
on_change(f) {
defaultValue = currentValue = (v ? "true" : "false");
}
Option::Option(OnChange f) : type("button"), min(0), max(0), on_change(f) Option::Option(OnChange f) :
{} type("button"),
min(0),
max(0),
on_change(f) {}
Option::Option(double v, int minv, int maxv, OnChange f) : type("spin"), min(minv), max(maxv), on_change(f) Option::Option(double v, int minv, int maxv, OnChange f) :
{ defaultValue = currentValue = std::to_string(v); } type("spin"),
min(minv),
max(maxv),
on_change(f) {
defaultValue = currentValue = std::to_string(v);
}
Option::Option(const char* v, const char* cur, OnChange f) : type("combo"), min(0), max(0), on_change(f) Option::Option(const char* v, const char* cur, OnChange f) :
{ defaultValue = v; currentValue = cur; } type("combo"),
min(0),
max(0),
on_change(f) {
defaultValue = v;
currentValue = cur;
}
Option::operator int() const { Option::operator int() const {
assert(type == "check" || type == "spin"); assert(type == "check" || type == "spin");
return (type == "spin" ? std::stoi(currentValue) : currentValue == "true"); return (type == "spin" ? std::stoi(currentValue) : currentValue == "true");
} }
Option::operator std::string() const { Option::operator std::string() const {
assert(type == "string"); assert(type == "string");
return currentValue; return currentValue;
} }
bool Option::operator==(const char* s) const { bool Option::operator==(const char* s) const {
assert(type == "combo"); assert(type == "combo");
return !CaseInsensitiveLess()(currentValue, s) return !CaseInsensitiveLess()(currentValue, s) && !CaseInsensitiveLess()(s, currentValue);
&& !CaseInsensitiveLess()(s, currentValue);
} }
@ -154,10 +176,10 @@ bool Option::operator==(const char* s) const {
void Option::operator<<(const Option& o) { void Option::operator<<(const Option& o) {
static size_t insert_order = 0; static size_t insert_order = 0;
*this = o; *this = o;
idx = insert_order++; idx = insert_order++;
} }
@ -167,33 +189,33 @@ void Option::operator<<(const Option& o) {
Option& Option::operator=(const string& v) { Option& Option::operator=(const string& v) {
assert(!type.empty()); assert(!type.empty());
if ( (type != "button" && type != "string" && v.empty()) if ((type != "button" && type != "string" && v.empty())
|| (type == "check" && v != "true" && v != "false") || (type == "check" && v != "true" && v != "false")
|| (type == "spin" && (stof(v) < min || stof(v) > max))) || (type == "spin" && (stof(v) < min || stof(v) > max)))
return *this; return *this;
if (type == "combo") if (type == "combo")
{ {
OptionsMap comboMap; // To have case insensitive compare OptionsMap comboMap; // To have case insensitive compare
string token; string token;
std::istringstream ss(defaultValue); std::istringstream ss(defaultValue);
while (ss >> token) while (ss >> token)
comboMap[token] << Option(); comboMap[token] << Option();
if (!comboMap.count(v) || v == "var") if (!comboMap.count(v) || v == "var")
return *this; return *this;
} }
if (type != "button") if (type != "button")
currentValue = v; currentValue = v;
if (on_change) if (on_change)
on_change(*this); on_change(*this);
return *this; return *this;
} }
} // namespace UCI } // namespace UCI
} // namespace Stockfish } // namespace Stockfish