Skip to content

Commit

Permalink
fix: avoid double free and fix sdxl lora naming conversion
Browse files Browse the repository at this point in the history
* Fixed a double free issue when running multiple backends on the CPU, eg: CLIP
and the primary backend, as this would result in the *_backend pointers both
pointing to the same thing resulting in a segfault when calling the
StableDiffusionGGML destructor.

* Improve logging to allow for a color switch on the command line interface.
Changed the base log_printf function to not bake the log level directly into
the log buffer as that information is already passed the logging function via
the level parameter and it's easier to add in there than strip it out.

* Added a fix for certain SDXL LoRAs that don't seem to follow the expected
naming convention, converts over the tensor name during the LoRA model
loading. Added some logging of useful LoRA loading information. Had to
increase the base size of the GGML graph as the existing size results in an
insufficient graph memory error when using SDXL LoRAs.

* small fixes

---------

Co-authored-by: leejet <[email protected]>
  • Loading branch information
grauho and leejet authored Mar 20, 2024
1 parent a469688 commit 48bcce4
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 32 deletions.
44 changes: 38 additions & 6 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct SDParams {
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
};

Expand Down Expand Up @@ -469,6 +470,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(0);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--color") {
params.color = true;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
Expand Down Expand Up @@ -572,18 +575,47 @@ std::string get_image_params(SDParams params, int64_t seed) {
return parameter_string;
}

/* Enables Printing the log level tag in color using ANSI escape codes */
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
SDParams* params = (SDParams*)data;
if (!params->verbose && level <= SD_LOG_DEBUG) {
int tag_color;
const char* level_str;
FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;

if (!log || (!params->verbose && level <= SD_LOG_DEBUG)) {
return;
}
if (level <= SD_LOG_INFO) {
fputs(log, stdout);
fflush(stdout);

switch (level) {
case SD_LOG_DEBUG:
tag_color = 37;
level_str = "DEBUG";
break;
case SD_LOG_INFO:
tag_color = 34;
level_str = "INFO";
break;
case SD_LOG_WARN:
tag_color = 35;
level_str = "WARN";
break;
case SD_LOG_ERROR:
tag_color = 31;
level_str = "ERROR";
break;
default: /* Potential future-proofing */
tag_color = 33;
level_str = "?????";
break;
}

if (params->color == true) {
fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
} else {
fputs(log, stderr);
fflush(stderr);
fprintf(out_stream, "[%-5s] ", level_str);
}
fputs(log, out_stream);
fflush(out_stream);
}

int main(int argc, const char* argv[]) {
Expand Down
7 changes: 6 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
// virtual struct ggml_cgraph* get_ggml_cgraph() = 0;
// };

/*
#define MAX_PARAMS_TENSOR_NUM 10240
#define MAX_GRAPH_SIZE 10240
*/
/* SDXL with LoRA requires more space */
#define MAX_PARAMS_TENSOR_NUM 15360
#define MAX_GRAPH_SIZE 15360

struct GGMLModule {
protected:
Expand Down Expand Up @@ -1308,4 +1313,4 @@ class MultiheadAttention : public GGMLBlock {
}
};

#endif // __GGML_EXTEND__HPP__
#endif // __GGML_EXTEND__HPP__
24 changes: 20 additions & 4 deletions lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct LoraModel : public GGMLModule {
return true;
}

struct ggml_cgraph* build_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);

std::set<std::string> applied_lora_tensors;
Expand All @@ -90,7 +90,7 @@ struct LoraModel : public GGMLModule {
k_tensor = k_tensor.substr(0, k_pos);
replace_all_chars(k_tensor, '.', '_');
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
}
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
Expand Down Expand Up @@ -155,21 +155,37 @@ struct LoraModel : public GGMLModule {
ggml_build_forward_expand(gf, final_weight);
}

size_t total_lora_tensors_count = 0;
size_t applied_lora_tensors_count = 0;

for (auto& kv : lora_tensors) {
total_lora_tensors_count++;
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
LOG_WARN("unused lora tensor %s", kv.first.c_str());
} else {
applied_lora_tensors_count++;
}
}
/* Don't worry if this message shows up twice in the logs per LoRA,
* this function is called once to calculate the required buffer size
* and then again to actually generate a graph to be used */
if (applied_lora_tensors_count != total_lora_tensors_count) {
LOG_WARN("Only (%lu / %lu) LoRA tensors have been applied",
applied_lora_tensors_count, total_lora_tensors_count);
} else {
LOG_DEBUG("(%lu / %lu) LoRA tensors applied successfully",
applied_lora_tensors_count, total_lora_tensors_count);
}

return gf;
}

void apply(std::map<std::string, struct ggml_tensor*> model_tensors, int n_threads) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(model_tensors);
return build_lora_graph(model_tensors);
};
GGMLModule::compute(get_graph, n_threads, true);
}
};

#endif // __LORA_HPP__
#endif // __LORA_HPP__
23 changes: 22 additions & 1 deletion model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,23 @@ std::string convert_vae_decoder_name(const std::string& name) {
return name;
}

/* If not a SDXL LoRA the unet" prefix will have already been replaced by this
* point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */
std::string convert_sdxl_lora_name(std::string tensor_name) {
const std::pair<std::string, std::string> sdxl_lora_name_lookup[] = {
{"unet", "model_diffusion_model"},
{"te2", "cond_stage_model_1_transformer"},
{"te1", "cond_stage_model_transformer"},
};
for (auto& pair_i : sdxl_lora_name_lookup) {
if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) {
tensor_name = std::regex_replace(tensor_name, std::regex(pair_i.first), pair_i.second);
break;
}
}
return tensor_name;
}

std::unordered_map<std::string, std::unordered_map<std::string, std::string>> suffix_conversion_underline = {
{
"attentions",
Expand Down Expand Up @@ -415,8 +432,12 @@ std::string convert_tensor_name(const std::string& name) {
if (pos != std::string::npos) {
std::string name_without_network_parts = name.substr(5, pos - 5);
std::string network_part = name.substr(pos + 1);

// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '_');
/* For dealing with the new SDXL LoRA tensor naming convention */
new_key = convert_sdxl_lora_name(new_key);

if (new_key.empty()) {
new_name = name;
} else {
Expand Down Expand Up @@ -1641,4 +1662,4 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
}
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
return success;
}
}
20 changes: 13 additions & 7 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,16 @@ class StableDiffusionGGML {
}

~StableDiffusionGGML() {
if (clip_backend != backend) {
ggml_backend_free(clip_backend);
}
if (control_net_backend != backend) {
ggml_backend_free(control_net_backend);
}
if (vae_backend != backend) {
ggml_backend_free(vae_backend);
}
ggml_backend_free(backend);
ggml_backend_free(clip_backend);
ggml_backend_free(control_net_backend);
ggml_backend_free(vae_backend);
}

bool load_from_file(const std::string& model_path,
Expand Down Expand Up @@ -521,9 +527,7 @@ class StableDiffusionGGML {

int64_t t1 = ggml_time_ms();

LOG_INFO("lora '%s' applied, taking %.2fs",
lora_name.c_str(),
(t1 - t0) * 1.0f / 1000);
LOG_INFO("lora '%s' applied, taking %.2fs", lora_name.c_str(), (t1 - t0) * 1.0f / 1000);
}

void apply_loras(const std::unordered_map<std::string, float>& lora_state) {
Expand All @@ -546,6 +550,8 @@ class StableDiffusionGGML {
}
}

LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());

for (auto& kv : lora_state_diff) {
apply_lora(kv.first, kv.second);
}
Expand Down Expand Up @@ -2109,4 +2115,4 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
LOG_INFO("img2vid completed in %.2fs", (t3 - t0) * 1.0f / 1000);

return result_images;
}
}
2 changes: 1 addition & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,4 @@ SD_API uint8_t* preprocess_canny(uint8_t* img,
}
#endif

#endif // __STABLE_DIFFUSION_H__
#endif // __STABLE_DIFFUSION_H__
14 changes: 2 additions & 12 deletions util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,8 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo
va_list args;
va_start(args, format);

const char* level_str = "DEBUG";
if (level == SD_LOG_INFO) {
level_str = "INFO ";
} else if (level == SD_LOG_WARN) {
level_str = "WARN ";
} else if (level == SD_LOG_ERROR) {
level_str = "ERROR";
}

static char log_buffer[LOG_BUFFER_SIZE + 1];

int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line);
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line);

if (written >= 0 && written < LOG_BUFFER_SIZE) {
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
Expand Down Expand Up @@ -572,4 +562,4 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
}

return result;
}
}

0 comments on commit 48bcce4

Please sign in to comment.