diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index 8cdff9ab..f4426f26 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -151,11 +151,6 @@ struct DTZEntry_pawn { uint8_t *map; }; -struct TBHashEntry { - uint64_t key; - struct TBEntry *ptr; -}; - struct DTZTableEntry { uint64_t key1; uint64_t key2; @@ -370,8 +365,6 @@ const Value WDL_to_value[] = { const int DTZ_ENTRIES = 64; const int TBMAX_PIECE = 254; const int TBMAX_PAWN = 256; -const int TBHASHBITS = 10; -const int HSHMAX = 5; const std::string PieceChar = " PNBRQK"; @@ -381,16 +374,59 @@ Mutex TB_mutex; std::string TBPaths; TBEntry_piece TB_piece[TBMAX_PIECE]; TBEntry_pawn TB_pawn[TBMAX_PAWN]; -TBHashEntry TB_hash[1 << TBHASHBITS][HSHMAX]; DTZTableEntry DTZ_table[DTZ_ENTRIES]; int Binomial[5][64]; int Pawnidx[5][24]; int Pfactor[5][4]; +class HashTable { + + struct Entry { + Key key; + struct TBEntry* ptr; + }; + + static const int TBHASHBITS = 10; + static const int HSHMAX = 5; + + Entry table[1 << TBHASHBITS][HSHMAX]; + +public: + TBEntry* operator[](Key key) { + Entry* entry = table[key >> (64 - TBHASHBITS)]; + + for (int i = 0; i < HSHMAX; i++, entry++) + if (entry->key == key) + return entry->ptr; + + return nullptr; + } + + void insert(TBEntry* ptr, Key key) { + Entry* entry = table[key >> (64 - TBHASHBITS)]; + + for (int i = 0; i < HSHMAX; i++, entry++) + if (!entry->ptr) { + entry->key = key; + entry->ptr = ptr; + return; + } + + std::cerr << "HSHMAX too low!" << std::endl; + exit(1); + } + + void clear() { std::memset(table, 0, sizeof(table)); } +}; + +HashTable TBHash; + + class TBFile : public std::ifstream { std::string fname; + public: // Open the file with the given name found among the TBPaths. TBPaths stores // the paths to directories where the .rtbw and .rtbz files can be found. @@ -505,21 +541,6 @@ Key get_key(uint8_t* pcs, bool mirror) return key; } -void add_to_hash(TBEntry* ptr, Key key) -{ - TBHashEntry* entry = TB_hash[key >> (64 - TBHASHBITS)]; - - for (int i = 0; i < HSHMAX && entry->ptr; i++, entry++) {} - - if (!entry->ptr) { - entry->key = key; - entry->ptr = ptr; - } else { - std::cerr << "HSHMAX too low!" << std::endl; - exit(1); - } -} - void free_wdl_entry(TBEntry_piece* entry) { TBFile::unmap(entry->data, entry->mapping); @@ -635,10 +656,10 @@ void init_tb(const std::vector& pieces) entry->symmetric = (key1 == key2); entry->has_pawns = hasPawns; - add_to_hash(entry, key1); + TBHash.insert(entry, key1); if (key2 != key1) - add_to_hash(entry, key2); + TBHash.insert(entry, key2); } uint64_t encode_piece(TBEntry_piece* ptr, uint8_t* norm, int* pos, int* factor) @@ -1407,27 +1428,16 @@ uint8_t decompress_pairs(PairsData *d, uint64_t idx) void load_dtz_table(const std::string& str, uint64_t key1, uint64_t key2) { - int i; - TBEntry *ptr, *ptr3; - TBHashEntry *ptr2; - DTZ_table[0].key1 = key1; DTZ_table[0].key2 = key2; DTZ_table[0].entry = NULL; - // find corresponding WDL entry - ptr2 = TB_hash[key1 >> (64 - TBHASHBITS)]; + TBEntry* ptr = TBHash[key1]; - for (i = 0; i < HSHMAX; i++) - if (ptr2[i].key == key1) - break; - - if (i == HSHMAX) + if (!ptr) return; - ptr = ptr2[i].ptr; - - ptr3 = (TBEntry *)malloc(ptr->has_pawns + TBEntry* ptr3 = (TBEntry*)malloc(ptr->has_pawns ? sizeof(DTZEntry_pawn) : sizeof(DTZEntry_piece)); @@ -1476,37 +1486,24 @@ std::string prt_str(Position& pos, bool mirror) return s; } -// probe_wdl_table and probe_dtz_table require similar adaptations. int probe_wdl_table(Position& pos, int *success) { - TBEntry *ptr; - TBHashEntry *ptr2; uint64_t idx; - uint64_t key; - int i; - uint8_t res; + int i, res; int p[TBPIECES]; - // Obtain the position's material signature key. - key = pos.material_key(); + Key key = pos.material_key(); - // Test for KvK. - if (key == (Zobrist::psq[WHITE][KING][0] ^ Zobrist::psq[BLACK][KING][0])) - return 0; + if (pos.count(WHITE) + pos.count(BLACK) == 2) + return 0; // KvK - ptr2 = TB_hash[key >> (64 - TBHASHBITS)]; + TBEntry* ptr = TBHash[key]; - for (i = 0; i < HSHMAX; i++) - if (ptr2[i].key == key) - break; - - if (i == HSHMAX) { + if (!ptr) { *success = 0; return 0; } - ptr = ptr2[i].ptr; - if (!ptr->ready) { TB_mutex.lock(); @@ -1514,7 +1511,7 @@ int probe_wdl_table(Position& pos, int *success) std::string s = prt_str(pos, ptr->key != key); if (!init_table_wdl(ptr, s)) { - ptr2[i].key = 0ULL; + // Was ptr2->key = 0ULL; Just leave !ptr->ready condition *success = 0; TB_mutex.unlock(); return 0; @@ -1598,7 +1595,6 @@ int probe_wdl_table(Position& pos, int *success) int probe_dtz_table(Position& pos, int wdl, int *success) { - TBEntry *ptr; uint64_t idx; int i, res; int p[TBPIECES]; @@ -1618,18 +1614,12 @@ int probe_dtz_table(Position& pos, int wdl, int *success) DTZ_table[0] = table_entry; } else { - TBHashEntry *ptr2 = TB_hash[key >> (64 - TBHASHBITS)]; - - for (i = 0; i < HSHMAX; i++) - if (ptr2[i].key == key) - break; - - if (i == HSHMAX) { + TBEntry* ptr = TBHash[key]; + if (!ptr) { *success = 0; return 0; } - ptr = ptr2[i].ptr; bool mirror = (ptr->key != key); std::string s = prt_str(pos, mirror); @@ -1643,7 +1633,7 @@ int probe_dtz_table(Position& pos, int wdl, int *success) } } - ptr = DTZ_table[0].entry; + TBEntry* ptr = DTZ_table[0].entry; if (!ptr) { *success = 0; @@ -1825,7 +1815,7 @@ void Tablebases::init(const std::string& paths) DTZ_table[i].entry = nullptr; } - std::memset(TB_hash, 0, sizeof(TB_hash)); + TBHash.clear(); TBnum_piece = TBnum_pawn = 0; MaxCardinality = 0;