Skip to content

Commit

Permalink
feat: add SDXL support (leejet#117)
Browse files Browse the repository at this point in the history
* add SDXL support

* fix the issue with generating large images
  • Loading branch information
leejet authored Dec 28, 2023
1 parent 004dfbe commit 78ad76f
Show file tree
Hide file tree
Showing 5 changed files with 669 additions and 347 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "ggml"]
path = ggml
url = https://github.com/FSSRepo/ggml.git
url = https://github.com/leejet/ggml.git
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in

- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
- Super lightweight and without external dependencies
- SD1.x and SD2.x support
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) support
- SD1.x, SD2.x and SDXL support
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
- 16-bit, 32-bit float support
- 4-bit, 5-bit and 8-bit integer quantization support
- Accelerated memory-efficient CPU inference
Expand Down Expand Up @@ -302,3 +302,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
- [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)
- [generative-models](https://github.com/Stability-AI/generative-models/)
2 changes: 1 addition & 1 deletion ggml
48 changes: 32 additions & 16 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ const char* unused_tensors[] = {
"cond_stage_model.transformer.text_model.embeddings.position_ids",
"cond_stage_model.model.logit_scale",
"cond_stage_model.model.text_projection",
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
"conditioner.embedders.0.model.logit_scale",
"conditioner.embedders.0.model.text_projection",
"conditioner.embedders.1.model.logit_scale",
"model.diffusion_model.time_embedding.cond_proj.weight",
"unet.time_embedding.cond_proj.weight",
"model_ema.decay",
Expand All @@ -100,11 +101,11 @@ bool is_unused_tensor(std::string name) {
}

std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
{"cond_stage_model.model.ln_final.bias", "cond_stage_model.transformer.text_model.final_layer_norm.bias"},
{"cond_stage_model.model.ln_final.weight", "cond_stage_model.transformer.text_model.final_layer_norm.weight"},
{"cond_stage_model.model.positional_embedding", "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight"},
{"cond_stage_model.model.token_embedding.weight", "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"},

{"model.ln_final.bias", "transformer.text_model.final_layer_norm.bias"},
{"model.ln_final.weight", "transformer.text_model.final_layer_norm.weight"},
{"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"},
{"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"},
{"model.text_projection", "transformer.text_model.text_projection"},
};

std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
Expand Down Expand Up @@ -133,11 +134,21 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {

std::string convert_open_clip_to_hf_clip(const std::string& name) {
std::string new_name = name;
std::string prefix;
if (starts_with(new_name, "conditioner.embedders.0.")) {
new_name = "cond_stage_model." + new_name.substr(strlen("conditioner.embedders.0."));
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("conditioner.embedders.0."));
} else if (starts_with(new_name, "conditioner.embedders.1.")) {
prefix = "cond_stage_model.1.";
new_name = new_name.substr(strlen("conditioner.embedders.0."));
} else if (starts_with(new_name, "cond_stage_model.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("cond_stage_model."));
} else {
return new_name;
}
std::string open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers.";
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";

if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[new_name];
Expand All @@ -156,7 +167,7 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
}
}

return new_name;
return prefix + new_name;
}

std::string convert_vae_decoder_name(const std::string& name) {
Expand Down Expand Up @@ -358,7 +369,7 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq)

std::string convert_tensor_name(const std::string& name) {
std::string new_name;
if (starts_with(name, "cond_stage_model.model") || starts_with(name, "conditioner.embedders.0.model")) {
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.")) {
new_name = convert_open_clip_to_hf_clip(name);
} else if (starts_with(name, "first_stage_model.decoder")) {
new_name = convert_vae_decoder_name(name);
Expand Down Expand Up @@ -419,7 +430,7 @@ void preprocess_tensor(TensorStorage tensor_storage,

tensor_storage.name = new_name;

if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") &&
if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
ends_with(new_name, "attn.in_proj_weight")) {
size_t prefix_size = new_name.find("attn.in_proj_weight");
std::string prefix = new_name.substr(0, prefix_size);
Expand All @@ -431,7 +442,7 @@ void preprocess_tensor(TensorStorage tensor_storage,

processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());

} else if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") &&
} else if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
ends_with(new_name, "attn.in_proj_bias")) {
size_t prefix_size = new_name.find("attn.in_proj_bias");
std::string prefix = new_name.substr(0, prefix_size);
Expand Down Expand Up @@ -1163,15 +1174,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
}

SDVersion ModelLoader::get_sd_version() {
// return VERSION_1_x;
TensorStorage token_embedding_weight;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
return VERSION_XL;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight") {
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight" ||
tensor_storage.name == "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight") {
token_embedding_weight = tensor_storage;
break;
// break;
}
}
if (token_embedding_weight.ne[0] == 768) {
Expand Down Expand Up @@ -1275,7 +1291,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
}

for (auto& tensor_storage : processed_tensor_storages) {
// LOG_DEBUG("%s", name.c_str());
// LOG_DEBUG("%s", tensor_storage.name.c_str());

ggml_tensor* dst_tensor = NULL;

Expand Down
Loading

0 comments on commit 78ad76f

Please sign in to comment.