mirror of
https://github.com/sockspls/badfish
synced 2025-04-30 00:33:09 +00:00
Exporting the currently loaded network file
This PR adds an ability to export any currently loaded network. The export_net command now takes an optional filename parameter. If the loaded net is not the embedded net the filename parameter is required. Two changes were required to support this: * the "architecture" string, which is really just a some kind of description in the net, is now saved into netDescription on load and correctly saved on export. * the AffineTransform scrambles weights for some architectures and sparsifies them, such that retrieving the index is hard. This is solved by having a temporary scrambled<->unscrambled index lookup table when loading the network, and the actual index is saved for each individual weight that makes it to canSaturate16. This increases the size of the canSaturate16 entries by 6 bytes. closes https://github.com/official-stockfish/Stockfish/pull/3456 No functional change
This commit is contained in:
parent
d777ea79ff
commit
58054fd0fa
10 changed files with 159 additions and 27 deletions
10
README.md
10
README.md
|
@ -156,8 +156,14 @@ For developers the following non-standard commands might be of interest, mainly
|
||||||
* #### eval
|
* #### eval
|
||||||
Return the evaluation of the current position.
|
Return the evaluation of the current position.
|
||||||
|
|
||||||
* #### export_net
|
* #### export_net [filename]
|
||||||
If the binary contains an embedded net, save it in a file (named according to the default value of EvalFile).
|
Exports the currently loaded network to a file.
|
||||||
|
If the currently loaded network is the embedded network and the filename
|
||||||
|
is not specified then the network is saved to the file matching the name
|
||||||
|
of the embedded network, as defined in evaluate.h.
|
||||||
|
If the currently loaded network is not the embedded network (some net set
|
||||||
|
through the UCI setoption) then the filename parameter is required and the
|
||||||
|
network is saved into that file.
|
||||||
|
|
||||||
* #### flip
|
* #### flip
|
||||||
Flips the side to move.
|
Flips the side to move.
|
||||||
|
|
|
@ -47,9 +47,7 @@
|
||||||
// Note that this does not work in Microsoft Visual Studio.
|
// Note that this does not work in Microsoft Visual Studio.
|
||||||
#if !defined(_MSC_VER) && !defined(NNUE_EMBEDDING_OFF)
|
#if !defined(_MSC_VER) && !defined(NNUE_EMBEDDING_OFF)
|
||||||
INCBIN(EmbeddedNNUE, EvalFileDefaultName);
|
INCBIN(EmbeddedNNUE, EvalFileDefaultName);
|
||||||
constexpr bool gHasEmbeddedNet = true;
|
|
||||||
#else
|
#else
|
||||||
constexpr bool gHasEmbeddedNet = false;
|
|
||||||
const unsigned char gEmbeddedNNUEData[1] = {0x0};
|
const unsigned char gEmbeddedNNUEData[1] = {0x0};
|
||||||
const unsigned char *const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1];
|
const unsigned char *const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1];
|
||||||
const unsigned int gEmbeddedNNUESize = 1;
|
const unsigned int gEmbeddedNNUESize = 1;
|
||||||
|
@ -116,12 +114,23 @@ namespace Eval {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void NNUE::export_net() {
|
void NNUE::export_net(const std::optional<std::string>& filename) {
|
||||||
if constexpr (gHasEmbeddedNet) {
|
std::string actualFilename;
|
||||||
ofstream stream(EvalFileDefaultName, std::ios_base::binary);
|
if (filename.has_value()) {
|
||||||
stream.write(reinterpret_cast<const char*>(gEmbeddedNNUEData), gEmbeddedNNUESize);
|
actualFilename = filename.value();
|
||||||
} else {
|
} else {
|
||||||
sync_cout << "No embedded network file." << sync_endl;
|
if (eval_file_loaded != EvalFileDefaultName) {
|
||||||
|
sync_cout << "Failed to export a net. A non-embedded net can only be saved if the filename is specified." << sync_endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
actualFilename = EvalFileDefaultName;
|
||||||
|
}
|
||||||
|
|
||||||
|
ofstream stream(actualFilename, std::ios_base::binary);
|
||||||
|
if (save_eval(stream)) {
|
||||||
|
sync_cout << "Network saved successfully to " << actualFilename << "." << sync_endl;
|
||||||
|
} else {
|
||||||
|
sync_cout << "Failed to export a net." << sync_endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#define EVALUATE_H_INCLUDED
|
#define EVALUATE_H_INCLUDED
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "types.h"
|
#include "types.h"
|
||||||
|
|
||||||
|
@ -44,8 +45,9 @@ namespace Eval {
|
||||||
|
|
||||||
Value evaluate(const Position& pos);
|
Value evaluate(const Position& pos);
|
||||||
bool load_eval(std::string name, std::istream& stream);
|
bool load_eval(std::string name, std::istream& stream);
|
||||||
|
bool save_eval(std::ostream& stream);
|
||||||
void init();
|
void init();
|
||||||
void export_net();
|
void export_net(const std::optional<std::string>& filename);
|
||||||
void verify();
|
void verify();
|
||||||
|
|
||||||
} // namespace NNUE
|
} // namespace NNUE
|
||||||
|
|
|
@ -39,6 +39,7 @@ namespace Stockfish::Eval::NNUE {
|
||||||
|
|
||||||
// Evaluation function file name
|
// Evaluation function file name
|
||||||
std::string fileName;
|
std::string fileName;
|
||||||
|
std::string netDescription;
|
||||||
|
|
||||||
namespace Detail {
|
namespace Detail {
|
||||||
|
|
||||||
|
@ -68,6 +69,14 @@ namespace Stockfish::Eval::NNUE {
|
||||||
return reference.read_parameters(stream);
|
return reference.read_parameters(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write evaluation function parameters
|
||||||
|
template <typename T>
|
||||||
|
bool write_parameters(std::ostream& stream, const T& reference) {
|
||||||
|
|
||||||
|
write_little_endian<std::uint32_t>(stream, T::get_hash_value());
|
||||||
|
return reference.write_parameters(stream);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Detail
|
} // namespace Detail
|
||||||
|
|
||||||
// Initialize the evaluation function parameters
|
// Initialize the evaluation function parameters
|
||||||
|
@ -78,7 +87,7 @@ namespace Stockfish::Eval::NNUE {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read network header
|
// Read network header
|
||||||
bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* architecture)
|
bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc)
|
||||||
{
|
{
|
||||||
std::uint32_t version, size;
|
std::uint32_t version, size;
|
||||||
|
|
||||||
|
@ -86,8 +95,18 @@ namespace Stockfish::Eval::NNUE {
|
||||||
*hashValue = read_little_endian<std::uint32_t>(stream);
|
*hashValue = read_little_endian<std::uint32_t>(stream);
|
||||||
size = read_little_endian<std::uint32_t>(stream);
|
size = read_little_endian<std::uint32_t>(stream);
|
||||||
if (!stream || version != Version) return false;
|
if (!stream || version != Version) return false;
|
||||||
architecture->resize(size);
|
desc->resize(size);
|
||||||
stream.read(&(*architecture)[0], size);
|
stream.read(&(*desc)[0], size);
|
||||||
|
return !stream.fail();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write network header
|
||||||
|
bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc)
|
||||||
|
{
|
||||||
|
write_little_endian<std::uint32_t>(stream, Version);
|
||||||
|
write_little_endian<std::uint32_t>(stream, hashValue);
|
||||||
|
write_little_endian<std::uint32_t>(stream, desc.size());
|
||||||
|
stream.write(&desc[0], desc.size());
|
||||||
return !stream.fail();
|
return !stream.fail();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,14 +114,22 @@ namespace Stockfish::Eval::NNUE {
|
||||||
bool read_parameters(std::istream& stream) {
|
bool read_parameters(std::istream& stream) {
|
||||||
|
|
||||||
std::uint32_t hashValue;
|
std::uint32_t hashValue;
|
||||||
std::string architecture;
|
if (!read_header(stream, &hashValue, &netDescription)) return false;
|
||||||
if (!read_header(stream, &hashValue, &architecture)) return false;
|
|
||||||
if (hashValue != HashValue) return false;
|
if (hashValue != HashValue) return false;
|
||||||
if (!Detail::read_parameters(stream, *featureTransformer)) return false;
|
if (!Detail::read_parameters(stream, *featureTransformer)) return false;
|
||||||
if (!Detail::read_parameters(stream, *network)) return false;
|
if (!Detail::read_parameters(stream, *network)) return false;
|
||||||
return stream && stream.peek() == std::ios::traits_type::eof();
|
return stream && stream.peek() == std::ios::traits_type::eof();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write network parameters
|
||||||
|
bool write_parameters(std::ostream& stream) {
|
||||||
|
|
||||||
|
if (!write_header(stream, HashValue, netDescription)) return false;
|
||||||
|
if (!Detail::write_parameters(stream, *featureTransformer)) return false;
|
||||||
|
if (!Detail::write_parameters(stream, *network)) return false;
|
||||||
|
return (bool)stream;
|
||||||
|
}
|
||||||
|
|
||||||
// Evaluation function. Perform differential calculation.
|
// Evaluation function. Perform differential calculation.
|
||||||
Value evaluate(const Position& pos) {
|
Value evaluate(const Position& pos) {
|
||||||
|
|
||||||
|
@ -141,4 +168,13 @@ namespace Stockfish::Eval::NNUE {
|
||||||
return read_parameters(stream);
|
return read_parameters(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save eval, to a file stream or a memory stream
|
||||||
|
bool save_eval(std::ostream& stream) {
|
||||||
|
|
||||||
|
if (fileName.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return write_parameters(stream);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Stockfish::Eval::NNUE
|
} // namespace Stockfish::Eval::NNUE
|
||||||
|
|
|
@ -69,15 +69,19 @@ namespace Stockfish::Eval::NNUE::Layers {
|
||||||
if (!previousLayer.read_parameters(stream)) return false;
|
if (!previousLayer.read_parameters(stream)) return false;
|
||||||
for (std::size_t i = 0; i < OutputDimensions; ++i)
|
for (std::size_t i = 0; i < OutputDimensions; ++i)
|
||||||
biases[i] = read_little_endian<BiasType>(stream);
|
biases[i] = read_little_endian<BiasType>(stream);
|
||||||
for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
|
|
||||||
#if !defined (USE_SSSE3)
|
#if !defined (USE_SSSE3)
|
||||||
|
for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
|
||||||
weights[i] = read_little_endian<WeightType>(stream);
|
weights[i] = read_little_endian<WeightType>(stream);
|
||||||
#else
|
#else
|
||||||
weights[
|
std::unique_ptr<uint32_t[]> indexMap = std::make_unique<uint32_t[]>(OutputDimensions * PaddedInputDimensions);
|
||||||
|
for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) {
|
||||||
|
const uint32_t scrambledIdx =
|
||||||
(i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
|
(i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
|
||||||
i / PaddedInputDimensions * 4 +
|
i / PaddedInputDimensions * 4 +
|
||||||
i % 4
|
i % 4;
|
||||||
] = read_little_endian<WeightType>(stream);
|
weights[scrambledIdx] = read_little_endian<WeightType>(stream);
|
||||||
|
indexMap[scrambledIdx] = i;
|
||||||
|
}
|
||||||
|
|
||||||
// Determine if eights of weight and input products can be summed using 16bits
|
// Determine if eights of weight and input products can be summed using 16bits
|
||||||
// without saturation. We assume worst case combinations of 0 and 127 for all inputs.
|
// without saturation. We assume worst case combinations of 0 and 127 for all inputs.
|
||||||
|
@ -109,7 +113,8 @@ namespace Stockfish::Eval::NNUE::Layers {
|
||||||
|
|
||||||
IndexType idx = maxK / 2 * OutputDimensions * 4 + maxK % 2;
|
IndexType idx = maxK / 2 * OutputDimensions * 4 + maxK % 2;
|
||||||
sum[sign == -1] -= w[idx];
|
sum[sign == -1] -= w[idx];
|
||||||
canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx]);
|
const uint32_t scrambledIdx = idx + i * OutputDimensions + j * 4 + x * 2;
|
||||||
|
canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx], indexMap[scrambledIdx]);
|
||||||
w[idx] = 0;
|
w[idx] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -125,6 +130,34 @@ namespace Stockfish::Eval::NNUE::Layers {
|
||||||
return !stream.fail();
|
return !stream.fail();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write network parameters
|
||||||
|
bool write_parameters(std::ostream& stream) const {
|
||||||
|
if (!previousLayer.write_parameters(stream)) return false;
|
||||||
|
for (std::size_t i = 0; i < OutputDimensions; ++i)
|
||||||
|
write_little_endian<BiasType>(stream, biases[i]);
|
||||||
|
#if !defined (USE_SSSE3)
|
||||||
|
for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
|
||||||
|
write_little_endian<WeightType>(stream, weights[i]);
|
||||||
|
#else
|
||||||
|
std::unique_ptr<WeightType[]> unscrambledWeights = std::make_unique<WeightType[]>(OutputDimensions * PaddedInputDimensions);
|
||||||
|
for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) {
|
||||||
|
unscrambledWeights[i] =
|
||||||
|
weights[
|
||||||
|
(i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
|
||||||
|
i / PaddedInputDimensions * 4 +
|
||||||
|
i % 4
|
||||||
|
];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < canSaturate16.count; ++i)
|
||||||
|
unscrambledWeights[canSaturate16.ids[i].wIdx] = canSaturate16.ids[i].w;
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
|
||||||
|
write_little_endian<WeightType>(stream, unscrambledWeights[i]);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return !stream.fail();
|
||||||
|
}
|
||||||
|
|
||||||
// Forward propagation
|
// Forward propagation
|
||||||
const OutputType* propagate(
|
const OutputType* propagate(
|
||||||
const TransformedFeatureType* transformedFeatures, char* buffer) const {
|
const TransformedFeatureType* transformedFeatures, char* buffer) const {
|
||||||
|
@ -444,12 +477,14 @@ namespace Stockfish::Eval::NNUE::Layers {
|
||||||
struct CanSaturate {
|
struct CanSaturate {
|
||||||
int count;
|
int count;
|
||||||
struct Entry {
|
struct Entry {
|
||||||
|
uint32_t wIdx;
|
||||||
uint16_t out;
|
uint16_t out;
|
||||||
uint16_t in;
|
uint16_t in;
|
||||||
int8_t w;
|
int8_t w;
|
||||||
} ids[PaddedInputDimensions * OutputDimensions * 3 / 4];
|
} ids[PaddedInputDimensions * OutputDimensions * 3 / 4];
|
||||||
|
|
||||||
void add(int i, int j, int8_t w) {
|
void add(int i, int j, int8_t w, uint32_t wIdx) {
|
||||||
|
ids[count].wIdx = wIdx;
|
||||||
ids[count].out = i;
|
ids[count].out = i;
|
||||||
ids[count].in = j;
|
ids[count].in = j;
|
||||||
ids[count].w = w;
|
ids[count].w = w;
|
||||||
|
|
|
@ -59,6 +59,11 @@ namespace Stockfish::Eval::NNUE::Layers {
|
||||||
return previousLayer.read_parameters(stream);
|
return previousLayer.read_parameters(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write network parameters
|
||||||
|
bool write_parameters(std::ostream& stream) const {
|
||||||
|
return previousLayer.write_parameters(stream);
|
||||||
|
}
|
||||||
|
|
||||||
// Forward propagation
|
// Forward propagation
|
||||||
const OutputType* propagate(
|
const OutputType* propagate(
|
||||||
const TransformedFeatureType* transformedFeatures, char* buffer) const {
|
const TransformedFeatureType* transformedFeatures, char* buffer) const {
|
||||||
|
|
|
@ -53,6 +53,11 @@ class InputSlice {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read network parameters
|
||||||
|
bool write_parameters(std::ostream& /*stream*/) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Forward propagation
|
// Forward propagation
|
||||||
const OutputType* propagate(
|
const OutputType* propagate(
|
||||||
const TransformedFeatureType* transformedFeatures,
|
const TransformedFeatureType* transformedFeatures,
|
||||||
|
|
|
@ -99,6 +99,24 @@ namespace Stockfish::Eval::NNUE {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename IntType>
|
||||||
|
inline void write_little_endian(std::ostream& stream, IntType value) {
|
||||||
|
|
||||||
|
std::uint8_t u[sizeof(IntType)];
|
||||||
|
typename std::make_unsigned<IntType>::type v = value;
|
||||||
|
|
||||||
|
std::size_t i = 0;
|
||||||
|
// if constexpr to silence the warning about shift by 8
|
||||||
|
if constexpr (sizeof(IntType) > 1) {
|
||||||
|
for (; i + 1 < sizeof(IntType); ++i) {
|
||||||
|
u[i] = v;
|
||||||
|
v >>= 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
u[i] = v;
|
||||||
|
|
||||||
|
stream.write(reinterpret_cast<char*>(u), sizeof(IntType));
|
||||||
|
}
|
||||||
} // namespace Stockfish::Eval::NNUE
|
} // namespace Stockfish::Eval::NNUE
|
||||||
|
|
||||||
#endif // #ifndef NNUE_COMMON_H_INCLUDED
|
#endif // #ifndef NNUE_COMMON_H_INCLUDED
|
||||||
|
|
|
@ -118,6 +118,15 @@ namespace Stockfish::Eval::NNUE {
|
||||||
return !stream.fail();
|
return !stream.fail();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write network parameters
|
||||||
|
bool write_parameters(std::ostream& stream) const {
|
||||||
|
for (std::size_t i = 0; i < HalfDimensions; ++i)
|
||||||
|
write_little_endian<BiasType>(stream, biases[i]);
|
||||||
|
for (std::size_t i = 0; i < HalfDimensions * InputDimensions; ++i)
|
||||||
|
write_little_endian<WeightType>(stream, weights[i]);
|
||||||
|
return !stream.fail();
|
||||||
|
}
|
||||||
|
|
||||||
// Convert input features
|
// Convert input features
|
||||||
void transform(const Position& pos, OutputType* output) const {
|
void transform(const Position& pos, OutputType* output) const {
|
||||||
update_accumulator(pos, WHITE);
|
update_accumulator(pos, WHITE);
|
||||||
|
|
|
@ -277,7 +277,14 @@ void UCI::loop(int argc, char* argv[]) {
|
||||||
else if (token == "d") sync_cout << pos << sync_endl;
|
else if (token == "d") sync_cout << pos << sync_endl;
|
||||||
else if (token == "eval") trace_eval(pos);
|
else if (token == "eval") trace_eval(pos);
|
||||||
else if (token == "compiler") sync_cout << compiler_info() << sync_endl;
|
else if (token == "compiler") sync_cout << compiler_info() << sync_endl;
|
||||||
else if (token == "export_net") Eval::NNUE::export_net();
|
else if (token == "export_net") {
|
||||||
|
std::optional<std::string> filename;
|
||||||
|
std::string f;
|
||||||
|
if (is >> skipws >> f) {
|
||||||
|
filename = f;
|
||||||
|
}
|
||||||
|
Eval::NNUE::export_net(filename);
|
||||||
|
}
|
||||||
else if (!token.empty() && token[0] != '#')
|
else if (!token.empty() && token[0] != '#')
|
||||||
sync_cout << "Unknown command: " << cmd << sync_endl;
|
sync_cout << "Unknown command: " << cmd << sync_endl;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue