From afea457eda9c1267d3aa2f3b5da3bb52775fe411 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 6 Apr 2024 05:12:03 -0400 Subject: [PATCH] fix: support more SDXL LoRA names (#216) * apply pmid lora only once for multiple txt2img calls * add better support for SDXL LoRA * fix for some sdxl lora, like lcm-lora-xl --------- Co-authored-by: bssrdf Co-authored-by: leejet --- examples/cli/main.cpp | 21 ++++++++++----------- lora.hpp | 13 +++++++++---- model.cpp | 17 +++++++++++++---- stable-diffusion.cpp | 2 +- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index aad6ae0f..a3dc042c 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -686,27 +686,26 @@ int main(int argc, const char* argv[]) { // Resize input image ... if (params.height % 64 != 0 || params.width % 64 != 0) { int resized_height = params.height + (64 - params.height % 64); - int resized_width = params.width + (64 - params.width % 64); + int resized_width = params.width + (64 - params.width % 64); - uint8_t *resized_image_buffer = (uint8_t *)malloc(resized_height * resized_width * 3); + uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3); if (resized_image_buffer == NULL) { fprintf(stderr, "error: allocate memory for resize input image\n"); free(input_image_buffer); return 1; } - stbir_resize(input_image_buffer, params.width, params.height, 0, - resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, - 3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0, - STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, - STBIR_FILTER_BOX, STBIR_FILTER_BOX, - STBIR_COLORSPACE_SRGB, nullptr - ); + stbir_resize(input_image_buffer, params.width, params.height, 0, + resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, + 3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, + STBIR_FILTER_BOX, STBIR_FILTER_BOX, + STBIR_COLORSPACE_SRGB, nullptr); // Save resized result free(input_image_buffer); input_image_buffer = resized_image_buffer; - params.height = resized_height; - params.width = resized_width; + params.height = resized_height; + params.width = resized_width; } } diff --git a/lora.hpp b/lora.hpp index 15a4dcd6..4713c117 100644 --- a/lora.hpp +++ b/lora.hpp @@ -11,7 +11,7 @@ struct LoraModel : public GGMLModule { std::string file_path; ModelLoader model_loader; bool load_failed = false; - bool applied = false; + bool applied = false; LoraModel(ggml_backend_t backend, ggml_type wtype, @@ -91,10 +91,15 @@ struct LoraModel : public GGMLModule { k_tensor = k_tensor.substr(0, k_pos); replace_all_chars(k_tensor, '.', '_'); // LOG_DEBUG("k_tensor %s", k_tensor.c_str()); - if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL - k_tensor = "model_diffusion_model_output_blocks_2_1_conv"; + std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + if (lora_tensors.find(lora_up_name) == lora_tensors.end()) { + if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { + // fix for some sdxl lora, like lcm-lora-xl + k_tensor = "model_diffusion_model_output_blocks_2_1_conv"; + lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + } } - std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight"; std::string alpha_name = "lora." + k_tensor + ".alpha"; std::string scale_name = "lora." + k_tensor + ".scale"; diff --git a/model.cpp b/model.cpp index 3ed0171d..c8cc5e32 100644 --- a/model.cpp +++ b/model.cpp @@ -211,6 +211,8 @@ std::string convert_sdxl_lora_name(std::string tensor_name) { {"unet", "model_diffusion_model"}, {"te2", "cond_stage_model_1_transformer"}, {"te1", "cond_stage_model_transformer"}, + {"text_encoder_2", "cond_stage_model_1_transformer"}, + {"text_encoder", "cond_stage_model_transformer"}, }; for (auto& pair_i : sdxl_lora_name_lookup) { if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) { @@ -446,18 +448,25 @@ std::string convert_tensor_name(const std::string& name) { } else { new_name = name; } - } else if (contains(name, "lora_up") || contains(name, "lora_down") || contains(name, "lora.up") || contains(name, "lora.down")) { + } else if (contains(name, "lora_up") || contains(name, "lora_down") || + contains(name, "lora.up") || contains(name, "lora.down") || + contains(name, "lora_linear")) { size_t pos = new_name.find(".processor"); if (pos != std::string::npos) { new_name.replace(pos, strlen(".processor"), ""); } - pos = new_name.find_last_of('_'); + pos = new_name.rfind("lora"); if (pos != std::string::npos) { - std::string name_without_network_parts = new_name.substr(0, pos); - std::string network_part = new_name.substr(pos + 1); + std::string name_without_network_parts = new_name.substr(0, pos - 1); + std::string network_part = new_name.substr(pos); // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.'); + new_key = convert_sdxl_lora_name(new_key); replace_all_chars(new_key, '.', '_'); + size_t npos = network_part.rfind("_linear_layer"); + if (npos != std::string::npos) { + network_part.replace(npos, strlen("_linear_layer"), ""); + } if (starts_with(network_part, "lora.")) { network_part = "lora_" + network_part.substr(5); } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 5ee2d56b..b489b499 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1610,7 +1610,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->stacked_id && !sd_ctx->sd->pmid_lora->applied) { t0 = ggml_time_ms(); sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads); - t1 = ggml_time_ms(); + t1 = ggml_time_ms(); sd_ctx->sd->pmid_lora->applied = true; LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); if (sd_ctx->sd->free_params_immediately) {