diff --git a/model.cpp b/model.cpp index 912e2e43..71b3c1bb 100644 --- a/model.cpp +++ b/model.cpp @@ -14,6 +14,8 @@ #include "ggml/ggml-backend.h" #include "ggml/ggml.h" +#define ST_HEADER_SIZE_LEN 8 + uint64_t read_u64(uint8_t* buffer) { // little endian uint64_t value = 0; @@ -533,17 +535,89 @@ std::map unicode_to_byte() { return byte_decoder; } +bool is_zip_file(const std::string& file_path) { + struct zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == NULL) { + return false; + } + zip_close(zip); + return true; +} + +bool is_gguf_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + char magic[4]; + + file.read(magic, sizeof(magic)); + if (!file) { + return false; + } + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + return false; + } + } + + return true; +} + +bool is_safetensors_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + // get file size + file.seekg(0, file.end); + size_t file_size_ = file.tellg(); + file.seekg(0, file.beg); + + // read header size + if (file_size_ <= ST_HEADER_SIZE_LEN) { + return false; + } + + uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; + file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); + if (!file) { + return false; + } + + size_t header_size_ = read_u64(header_size_buf); + if (header_size_ >= file_size_) { + return false; + } + + // read header + std::vector header_buf; + header_buf.resize(header_size_ + 1); + header_buf[header_size_] = '\0'; + file.read(header_buf.data(), header_size_); + if (!file) { + return false; + } + nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); + if (header_.is_discarded()) { + return false; + } + return true; +} + bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { if (is_directory(file_path)) { LOG_INFO("load %s using diffusers format", file_path.c_str()); return init_from_diffusers_file(file_path, prefix); - } else if (ends_with(file_path, ".gguf")) { + } else if (is_gguf_file(file_path)) { LOG_INFO("load %s using gguf format", file_path.c_str()); return init_from_gguf_file(file_path, prefix); - } else if (ends_with(file_path, ".safetensors")) { + } else if (is_safetensors_file(file_path)) { LOG_INFO("load %s using safetensors format", file_path.c_str()); return init_from_safetensors_file(file_path, prefix); - } else if (ends_with(file_path, ".ckpt")) { + } else if (is_zip_file(file_path)) { LOG_INFO("load %s using checkpoint format", file_path.c_str()); return init_from_ckpt_file(file_path, prefix); } else { @@ -593,8 +667,6 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s /*================================================= SafeTensorsModelLoader ==================================================*/ -#define ST_HEADER_SIZE_LEN 8 - ggml_type str_to_ggml_type(const std::string& dtype) { ggml_type ttype = GGML_TYPE_COUNT; if (dtype == "F16") {