diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index 4bff8e57..dfd067b7 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -63,6 +63,8 @@ struct SparseEntry { static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); +typedef uint16_t Sym; + struct LR { enum Side { Left, Right, Value }; @@ -72,7 +74,7 @@ struct LR { // If symbol has length 1, then the first byte // is the stored value. template - int get() { + Sym get() { if (S == Left) return ((lr[1] & 0xF) << 8) | lr[0]; if (S == Right) @@ -93,14 +95,14 @@ struct PairsData { int real_num_blocks; int max_sym_len; // Maximum length in bits of the Huffman symbols int min_sym_len; // Minimum length in bits of the Huffman symbols - uint16_t* lowestSym; // Value of the lowest symbol of length l is lowestSym[l] + Sym* lowestSym; // Value of the lowest symbol of length l is lowestSym[l] LR* btree; // btree[sym] stores the left and right symbols that expand sym uint16_t* blockLengths; // Number of stored positions (minus one) for each block int blockLengthsSize; // Size of blockLengths[] table SparseEntry* sparseIndex; // Partial indices into blockLengths[] size_t sparseIndexSize; // Size of SparseIndex[] table uint8_t* data; // Start of Huffman compressed data - std::vector base; // Smallest symbol of length l padded to 64 bits is at base[l - min_sym_len] + std::vector base64; // Smallest symbol of length l padded to 64 bits is at base64[l - min_sym_len] std::vector symlen; // Number of values (-1) represented by a given Huffman symbol: 1..256 Piece pieces[TBPIECES]; uint64_t factor[TBPIECES]; @@ -179,6 +181,8 @@ auto item(DTZPiece& e, int , int ) -> decltype(e)& { return e; } auto item(WDLPawn& e, int stm, int f) -> decltype(e.file[stm][f])& { return e.file[stm][f]; } auto item(DTZPawn& e, int , int f) -> decltype(e.file[f])& { return e.file[f]; } +int off_A1H8(Square sq) { return int(rank_of(sq)) - file_of(sq); } + int MapToEdges[SQUARE_NB]; int MapB1H1H7[SQUARE_NB]; int MapA1D1D4[SQUARE_NB]; @@ -478,8 +482,8 @@ void HashTable::insert(const std::vector& pieces) insert(keys[BLACK], &WDLTable.back()); } -// TB are compressed with canonical Huffman code. The the compressed data is divided into -// blocks of size d->blocksize, and each block stores a variable number of symbols. +// TB are compressed with canonical Huffman code. The compressed data is divided into +// blocks of size d->sizeofBlock, and each block stores a variable number of symbols. // Each symbol represents either a WDL or (remapped) DTZ value, or a pair of other symbols // (recursively). If you keep expanding the symbols in a block, you end up with up to 65536 // WDL or DTZ values. Each symbol represents up to 256 values and will correspond after @@ -544,24 +548,24 @@ int decompress_pairs(PairsData* d, uint64_t idx) // Read the first 64 bits in our block. We still don't know the symbol length but // we know is at the beginning of this 64 bits sequence. uint64_t buf64 = number(ptr); ptr += 2; - int sym, buf64Size = 64; + int buf64Size = 64; + Sym sym; for (;;) { int len = d->min_sym_len; - // Now get the symbol length. Given two symbols of length l1 and l2, where - // l1 < l2 then d->base[l1] > d->base[l2]. Moreover, any symbol of length l - // right-padded to 64 bits is >= d->base[l] so we can find the symbol length - // iterating through base[] starting from minimum length. - while (buf64 < d->base[len - d->min_sym_len]) + // Now get the symbol length. For any symbol s64 of length l right-padded + // to 64 bits holds d->base64[l-1] >= s64 >= d->base64[l] so we can find + // the symbol length iterating through base64[]. + while (buf64 < d->base64[len - d->min_sym_len]) ++len; // Symbols of same length are mapped to consecutive numbers, so we can compute // the offset of our symbol of length len, stored at the beginning of buf64. - sym = (buf64 - d->base[len - d->min_sym_len]) >> (64 - len); + sym = (buf64 - d->base64[len - d->min_sym_len]) >> (64 - len); // Now add the value of the lowest symbol of length len to get our symbol - sym += number(&d->lowestSym[len]); + sym += number(&d->lowestSym[len]); // If our offset is within the number of values represented by symbol sym // we are done... @@ -585,10 +589,10 @@ int decompress_pairs(PairsData* d, uint64_t idx) // at a symbol of length 1 (symlen[sym] + 1 == 1), which is the value we need. while (d->symlen[sym]) { - // Each btree[] entry expand in a left-handed and right-handed pair of + // Each btree[] entry expands in a left-handed and right-handed pair of // additional symbols. We keep expanding recursively picking the symbol // that contains our idxOffset. - int sl = d->btree[sym].get(); + Sym sl = d->btree[sym].get(); if (idxOffset < (int)d->symlen[sl] + 1) sym = sl; @@ -630,8 +634,8 @@ int map_score(DTZEntry* entry, File f, int value, WDLScore wdl) { int flags = entry->hasPawns ? entry->pawn.file[f].precomp->flags : entry->piece.precomp->flags; - uint8_t* map = entry->hasPawns ? entry->pawn.map - : entry->piece.map; + uint8_t* map = entry->hasPawns ? entry->pawn.map + : entry->piece.map; uint16_t* idx = entry->hasPawns ? entry->pawn.file[f].map_idx : entry->piece.map_idx; @@ -650,8 +654,6 @@ int map_score(DTZEntry* entry, File f, int value, WDLScore wdl) { return value; } -int off_A1H8(Square sq) { return int(rank_of(sq)) - file_of(sq); } - template uint64_t probe_table(const Position& pos, Entry* entry, WDLScore wdl = WDLDraw, int* success = nullptr) { @@ -926,15 +928,15 @@ void set_norms(T* p, int num, const uint8_t pawns[]) ++p->norm[i]; } -uint8_t set_symlen(PairsData* d, size_t s, std::vector& visited) +uint8_t set_symlen(PairsData* d, Sym s, std::vector& visited) { - visited[s] = true; // We can set now because tree is acyclic - int sr = d->btree[s].get(); + visited[s] = true; // We can set it now because tree is acyclic + Sym sr = d->btree[s].get(); if (sr == 0xFFF) return 0; else { - int sl = d->btree[s].get(); + Sym sl = d->btree[s].get(); if (!visited[sl]) d->symlen[sl] = set_symlen(d, sl, visited); @@ -953,7 +955,7 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data, uint64_t tb_size) if (d->flags & TBFlag::SingleValue) { d->real_num_blocks = d->span = d->blockLengthsSize = d->sparseIndexSize = 0; // Broken MSVC zero-init - d->min_sym_len = *data++; + d->min_sym_len = *data++; // Here we store the single value return data; } @@ -965,29 +967,41 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data, uint64_t tb_size) d->blockLengthsSize += d->real_num_blocks; d->max_sym_len = *data++; d->min_sym_len = *data++; - d->lowestSym = (uint16_t*)data; - d->base.resize(d->max_sym_len - d->min_sym_len + 1); + d->lowestSym = (Sym*)data; + d->base64.resize(d->max_sym_len - d->min_sym_len + 1); - for (int i = d->base.size() - 2; i >= 0; --i) - d->base[i] = (d->base[i + 1] + number(&d->lowestSym[i]) - - number(&d->lowestSym[i + 1])) / 2; + // The canonical code is ordered such that longer symbols (in terms of + // the number of bits of their Huffman code) have lower numeric value, + // so that d->lowestSym[i] >= d->lowestSym[i+1] (when read as LittleEndian). + // Starting from this we compute a base64[] table indexed by symbol length + // and containing 64 bit values so that d->base64[i] >= d->base64[i+1] + for (int i = d->base64.size() - 2; i >= 0; --i) { + d->base64[i] = (d->base64[i + 1] + number(&d->lowestSym[i]) + - number(&d->lowestSym[i + 1])) / 2; - for (size_t i = 0; i < d->base.size(); ++i) - d->base[i] <<= (64 - d->min_sym_len) - i; // Right-padding to 64 bits + assert(d->base64[i] * 2 >= d->base64[i+1]); + } + + // Now left-shift by an amount so that d->base64[i] gets shifted 1 bit more + // than d->base64[i+1] and given the above assert condition, we ensure that + // d->base64[i] >= d->base64[i+1]. Moreover for any symbol s64 of length i + // and right-padded to 64 bits holds d->base64[i-1] >= s64 >= d->base64[i]. + for (size_t i = 0; i < d->base64.size(); ++i) + d->base64[i] <<= 64 - i - d->min_sym_len; // Right-padding to 64 bits d->lowestSym -= d->min_sym_len; - data += d->base.size() * sizeof(*d->lowestSym); + data += d->base64.size() * sizeof(Sym); d->symlen.resize(number(data)); data += sizeof(uint16_t); d->btree = (LR*)data; std::vector visited(d->symlen.size()); - for (size_t sym = 0; sym < d->symlen.size(); ++sym) + for (Sym sym = 0; sym < d->symlen.size(); ++sym) if (!visited[sym]) d->symlen[sym] = set_symlen(d, sym, visited); - return data + 3 * d->symlen.size() + (d->symlen.size() & 1); + return data + d->symlen.size() * sizeof(LR) + (d->symlen.size() & 1); } template