Skip to content

Commit

Permalink
fix: support more SDXL LoRA names (leejet#216)
Browse files Browse the repository at this point in the history
* apply pmid lora only once for multiple txt2img calls

* add better support for SDXL LoRA

* fix for some sdxl lora, like lcm-lora-xl

---------

Co-authored-by: bssrdf <[email protected]>
Co-authored-by: leejet <[email protected]>
  • Loading branch information
3 people authored Apr 6, 2024
1 parent 646e776 commit afea457
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 20 deletions.
21 changes: 10 additions & 11 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,27 +686,26 @@ int main(int argc, const char* argv[]) {
// Resize input image ...
if (params.height % 64 != 0 || params.width % 64 != 0) {
int resized_height = params.height + (64 - params.height % 64);
int resized_width = params.width + (64 - params.width % 64);
int resized_width = params.width + (64 - params.width % 64);

uint8_t *resized_image_buffer = (uint8_t *)malloc(resized_height * resized_width * 3);
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
if (resized_image_buffer == NULL) {
fprintf(stderr, "error: allocate memory for resize input image\n");
free(input_image_buffer);
return 1;
}
stbir_resize(input_image_buffer, params.width, params.height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr
);
stbir_resize(input_image_buffer, params.width, params.height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr);

// Save resized result
free(input_image_buffer);
input_image_buffer = resized_image_buffer;
params.height = resized_height;
params.width = resized_width;
params.height = resized_height;
params.width = resized_width;
}
}

Expand Down
13 changes: 9 additions & 4 deletions lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct LoraModel : public GGMLModule {
std::string file_path;
ModelLoader model_loader;
bool load_failed = false;
bool applied = false;
bool applied = false;

LoraModel(ggml_backend_t backend,
ggml_type wtype,
Expand Down Expand Up @@ -91,10 +91,15 @@ struct LoraModel : public GGMLModule {
k_tensor = k_tensor.substr(0, k_pos);
replace_all_chars(k_tensor, '.', '_');
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") {
// fix for some sdxl lora, like lcm-lora-xl
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
lora_up_name = "lora." + k_tensor + ".lora_up.weight";
}
}
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";

std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight";
std::string alpha_name = "lora." + k_tensor + ".alpha";
std::string scale_name = "lora." + k_tensor + ".scale";
Expand Down
17 changes: 13 additions & 4 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ std::string convert_sdxl_lora_name(std::string tensor_name) {
{"unet", "model_diffusion_model"},
{"te2", "cond_stage_model_1_transformer"},
{"te1", "cond_stage_model_transformer"},
{"text_encoder_2", "cond_stage_model_1_transformer"},
{"text_encoder", "cond_stage_model_transformer"},
};
for (auto& pair_i : sdxl_lora_name_lookup) {
if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) {
Expand Down Expand Up @@ -446,18 +448,25 @@ std::string convert_tensor_name(const std::string& name) {
} else {
new_name = name;
}
} else if (contains(name, "lora_up") || contains(name, "lora_down") || contains(name, "lora.up") || contains(name, "lora.down")) {
} else if (contains(name, "lora_up") || contains(name, "lora_down") ||
contains(name, "lora.up") || contains(name, "lora.down") ||
contains(name, "lora_linear")) {
size_t pos = new_name.find(".processor");
if (pos != std::string::npos) {
new_name.replace(pos, strlen(".processor"), "");
}
pos = new_name.find_last_of('_');
pos = new_name.rfind("lora");
if (pos != std::string::npos) {
std::string name_without_network_parts = new_name.substr(0, pos);
std::string network_part = new_name.substr(pos + 1);
std::string name_without_network_parts = new_name.substr(0, pos - 1);
std::string network_part = new_name.substr(pos);
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
new_key = convert_sdxl_lora_name(new_key);
replace_all_chars(new_key, '.', '_');
size_t npos = network_part.rfind("_linear_layer");
if (npos != std::string::npos) {
network_part.replace(npos, strlen("_linear_layer"), "");
}
if (starts_with(network_part, "lora.")) {
network_part = "lora_" + network_part.substr(5);
}
Expand Down
2 changes: 1 addition & 1 deletion stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1610,7 +1610,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->stacked_id && !sd_ctx->sd->pmid_lora->applied) {
t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
t1 = ggml_time_ms();
t1 = ggml_time_ms();
sd_ctx->sd->pmid_lora->applied = true;
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
Expand Down

0 comments on commit afea457

Please sign in to comment.