diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index 27da143b..8ac96c64 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -55,7 +55,7 @@ inline Square operator^(Square s, int i) { return Square(int(s) ^ i); } // Each table has a set of flags: all of them refer to DTZ tables, the last one to WDL tables enum TBFlag { STM = 1, Mapped = 2, WinPlies = 4, LossPlies = 8, SingleValue = 128 }; -// Little endian numbers of index in blockLengths[] and offet within the block +// Little endian numbers of one index in blockLengths[] and the offset within the block struct SparseEntry { char block[4]; char offset[2]; @@ -63,6 +63,29 @@ struct SparseEntry { static_assert(sizeof(SparseEntry) == 6, "SparseEntry must be 6 bytes"); +struct LR { + + enum Side { Left, Right, Value }; + + uint8_t lr[3]; // The first 12 bits is the left-hand symbol, + // the second 12 bits is the right-hand symbol. + // If symbol has length 1, then the first byte + // is the stored value. + template + int get() { + if (S == Left) + return ((lr[1] & 0xF) << 8) | lr[0]; + if (S == Right) + return (lr[2] << 4) | (lr[1] >> 4); + if (S == Value) + return lr[0]; + + assert(0); + } +}; + +static_assert(sizeof(LR) == 3, "LR tree entry must be 3 bytes"); + struct PairsData { int flags; size_t sizeofBlock; // Block size in bytes @@ -71,7 +94,7 @@ struct PairsData { int max_sym_len; // Maximum length in bits of the Huffman symbols int min_sym_len; // Minimum length in bits of the Huffman symbols uint16_t* lowestSym; // Value of the lowest symbol of length l is lowestSym[l] - uint8_t* sympat; + LR* btree; // btree[sym] stores the left and right symbols that expand sym uint16_t* blockLengths; // Number of stored positions (minus one) for each block int blockLengthsSize; // Size of blockLengths[] table SparseEntry* sparseIndex; // Partial indices into blockLengths[] @@ -507,8 +530,8 @@ int decompress_pairs(PairsData* d, uint64_t idx) // Sum to idxOffset to find the offset corresponding to our idx idxOffset += diff; - // Move to previous or next block, until we reach the correct block that contains idx, - // that is when 0 <= idxOffset <= d->sizetable[block] + // Move to previous/next block, until we reach the correct block that contains idx, + // that is when 0 <= idxOffset <= d->blockLengths[block] while (idxOffset < 0) idxOffset += d->blockLengths[--block] + 1; @@ -556,21 +579,26 @@ int decompress_pairs(PairsData* d, uint64_t idx) } } - uint8_t *sympat = d->sympat; + // Ok, now we have our symbol that stores d->symlen[sym] values, the score we are + // looking for is among those values. We binary-search for it expanding the symbol + // in a pair of left and right child symbols and continue recursively until we are + // at a symbol of length 1 (symlen[sym] + 1 == 1), which is the value we need. + while (d->symlen[sym]) { - while (d->symlen[sym] != 0) { - uint8_t* w = sympat + (3 * sym); - int s1 = ((w[1] & 0xf) << 8) | w[0]; + // Each btree[] entry expand in a left-handed and right-handed pair of + // additional symbols. We keep expanding recursively picking the symbol + // that contains our idxOffset. + int sl = d->btree[sym].get(); - if (idxOffset < (int)d->symlen[s1] + 1) - sym = s1; + if (idxOffset < (int)d->symlen[sl] + 1) + sym = sl; else { - idxOffset -= (int)d->symlen[s1] + 1; - sym = (w[2] << 4) | (w[1] >> 4); + idxOffset -= d->symlen[sl] + 1; + sym = d->btree[sym].get(); } } - return sympat[3 * sym]; + return d->btree[sym].get(); } template @@ -900,23 +928,20 @@ void set_norms(T* p, int num, const uint8_t pawns[]) void calc_symlen(PairsData* d, size_t s, std::vector& tmp) { - int s1, s2; + int sr = d->btree[s].get(); - uint8_t* w = d->sympat + 3 * s; - s2 = (w[2] << 4) | (w[1] >> 4); - - if (s2 == 0xFFF) + if (sr == 0xFFF) d->symlen[s] = 0; else { - s1 = ((w[1] & 0xF) << 8) | w[0]; + int sl = d->btree[s].get(); - if (!tmp[s1]) - calc_symlen(d, s1, tmp); + if (!tmp[sl]) + calc_symlen(d, sl, tmp); - if (!tmp[s2]) - calc_symlen(d, s2, tmp); + if (!tmp[sr]) + calc_symlen(d, sr, tmp); - d->symlen[s] = d->symlen[s1] + d->symlen[s2] + 1; + d->symlen[s] = d->symlen[sl] + d->symlen[sr] + 1; } tmp[s] = 1; @@ -955,7 +980,7 @@ uint8_t* set_sizes(PairsData* d, uint8_t* data, uint64_t tb_size) data += d->base.size() * sizeof(*d->lowestSym); d->symlen.resize(number(data)); data += sizeof(uint16_t); - d->sympat = data; + d->btree = (LR*)data; std::vector tmp(d->symlen.size());