Skip to content

Commit

Permalink
fix: detect model format base on file content
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet committed Dec 3, 2023
1 parent 8a87b27 commit f99bcd1
Showing 1 changed file with 77 additions and 5 deletions.
82 changes: 77 additions & 5 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "ggml/ggml-backend.h"
#include "ggml/ggml.h"

#define ST_HEADER_SIZE_LEN 8

uint64_t read_u64(uint8_t* buffer) {
// little endian
uint64_t value = 0;
Expand Down Expand Up @@ -533,17 +535,89 @@ std::map<char, int> unicode_to_byte() {
return byte_decoder;
}

bool is_zip_file(const std::string& file_path) {
struct zip_t* zip = zip_open(file_path.c_str(), 0, 'r');
if (zip == NULL) {
return false;
}
zip_close(zip);
return true;
}

bool is_gguf_file(const std::string& file_path) {
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
return false;
}

char magic[4];

file.read(magic, sizeof(magic));
if (!file) {
return false;
}
for (uint32_t i = 0; i < sizeof(magic); i++) {
if (magic[i] != GGUF_MAGIC[i]) {
return false;
}
}

return true;
}

bool is_safetensors_file(const std::string& file_path) {
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
return false;
}

// get file size
file.seekg(0, file.end);
size_t file_size_ = file.tellg();
file.seekg(0, file.beg);

// read header size
if (file_size_ <= ST_HEADER_SIZE_LEN) {
return false;
}

uint8_t header_size_buf[ST_HEADER_SIZE_LEN];
file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN);
if (!file) {
return false;
}

size_t header_size_ = read_u64(header_size_buf);
if (header_size_ >= file_size_) {
return false;
}

// read header
std::vector<char> header_buf;
header_buf.resize(header_size_ + 1);
header_buf[header_size_] = '\0';
file.read(header_buf.data(), header_size_);
if (!file) {
return false;
}
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
if (header_.is_discarded()) {
return false;
}
return true;
}

bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
if (is_directory(file_path)) {
LOG_INFO("load %s using diffusers format", file_path.c_str());
return init_from_diffusers_file(file_path, prefix);
} else if (ends_with(file_path, ".gguf")) {
} else if (is_gguf_file(file_path)) {
LOG_INFO("load %s using gguf format", file_path.c_str());
return init_from_gguf_file(file_path, prefix);
} else if (ends_with(file_path, ".safetensors")) {
} else if (is_safetensors_file(file_path)) {
LOG_INFO("load %s using safetensors format", file_path.c_str());
return init_from_safetensors_file(file_path, prefix);
} else if (ends_with(file_path, ".ckpt")) {
} else if (is_zip_file(file_path)) {
LOG_INFO("load %s using checkpoint format", file_path.c_str());
return init_from_ckpt_file(file_path, prefix);
} else {
Expand Down Expand Up @@ -593,8 +667,6 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s

/*================================================= SafeTensorsModelLoader ==================================================*/

#define ST_HEADER_SIZE_LEN 8

ggml_type str_to_ggml_type(const std::string& dtype) {
ggml_type ttype = GGML_TYPE_COUNT;
if (dtype == "F16") {
Expand Down

0 comments on commit f99bcd1

Please sign in to comment.