diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index e5b082ce..7d60ff77 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -103,6 +103,7 @@ struct SDParams { bool clip_on_cpu = false; bool vae_on_cpu = false; bool canny_preprocess = false; + bool color = false; int upscale_repeats = 1; }; @@ -469,6 +470,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(0); } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; + } else if (arg == "--color") { + params.color = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -572,18 +575,47 @@ std::string get_image_params(SDParams params, int64_t seed) { return parameter_string; } +/* Enables Printing the log level tag in color using ANSI escape codes */ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { SDParams* params = (SDParams*)data; - if (!params->verbose && level <= SD_LOG_DEBUG) { + int tag_color; + const char* level_str; + FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout; + + if (!log || (!params->verbose && level <= SD_LOG_DEBUG)) { return; } - if (level <= SD_LOG_INFO) { - fputs(log, stdout); - fflush(stdout); + + switch (level) { + case SD_LOG_DEBUG: + tag_color = 37; + level_str = "DEBUG"; + break; + case SD_LOG_INFO: + tag_color = 34; + level_str = "INFO"; + break; + case SD_LOG_WARN: + tag_color = 35; + level_str = "WARN"; + break; + case SD_LOG_ERROR: + tag_color = 31; + level_str = "ERROR"; + break; + default: /* Potential future-proofing */ + tag_color = 33; + level_str = "?????"; + break; + } + + if (params->color == true) { + fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str); } else { - fputs(log, stderr); - fflush(stderr); + fprintf(out_stream, "[%-5s] ", level_str); } + fputs(log, out_stream); + fflush(out_stream); } int main(int argc, const char* argv[]) { diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 25a7cdca..70e1d17e 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -759,8 +759,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding( // virtual struct ggml_cgraph* get_ggml_cgraph() = 0; // }; +/* #define MAX_PARAMS_TENSOR_NUM 10240 #define MAX_GRAPH_SIZE 10240 +*/ +/* SDXL with LoRA requires more space */ +#define MAX_PARAMS_TENSOR_NUM 15360 +#define MAX_GRAPH_SIZE 15360 struct GGMLModule { protected: @@ -1308,4 +1313,4 @@ class MultiheadAttention : public GGMLBlock { } }; -#endif // __GGML_EXTEND__HPP__ \ No newline at end of file +#endif // __GGML_EXTEND__HPP__ diff --git a/lora.hpp b/lora.hpp index 1336e829..06b37bbc 100644 --- a/lora.hpp +++ b/lora.hpp @@ -75,7 +75,7 @@ struct LoraModel : public GGMLModule { return true; } - struct ggml_cgraph* build_graph(std::map model_tensors) { + struct ggml_cgraph* build_lora_graph(std::map model_tensors) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false); std::set applied_lora_tensors; @@ -90,7 +90,7 @@ struct LoraModel : public GGMLModule { k_tensor = k_tensor.substr(0, k_pos); replace_all_chars(k_tensor, '.', '_'); // LOG_DEBUG("k_tensor %s", k_tensor.c_str()); - if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL + if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL k_tensor = "model_diffusion_model_output_blocks_2_1_conv"; } std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; @@ -155,21 +155,37 @@ struct LoraModel : public GGMLModule { ggml_build_forward_expand(gf, final_weight); } + size_t total_lora_tensors_count = 0; + size_t applied_lora_tensors_count = 0; + for (auto& kv : lora_tensors) { + total_lora_tensors_count++; if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) { LOG_WARN("unused lora tensor %s", kv.first.c_str()); + } else { + applied_lora_tensors_count++; } } + /* Don't worry if this message shows up twice in the logs per LoRA, + * this function is called once to calculate the required buffer size + * and then again to actually generate a graph to be used */ + if (applied_lora_tensors_count != total_lora_tensors_count) { + LOG_WARN("Only (%lu / %lu) LoRA tensors have been applied", + applied_lora_tensors_count, total_lora_tensors_count); + } else { + LOG_DEBUG("(%lu / %lu) LoRA tensors applied successfully", + applied_lora_tensors_count, total_lora_tensors_count); + } return gf; } void apply(std::map model_tensors, int n_threads) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(model_tensors); + return build_lora_graph(model_tensors); }; GGMLModule::compute(get_graph, n_threads, true); } }; -#endif // __LORA_HPP__ \ No newline at end of file +#endif // __LORA_HPP__ diff --git a/model.cpp b/model.cpp index 78b1dc31..3ed0171d 100644 --- a/model.cpp +++ b/model.cpp @@ -204,6 +204,23 @@ std::string convert_vae_decoder_name(const std::string& name) { return name; } +/* If not a SDXL LoRA the unet" prefix will have already been replaced by this + * point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */ +std::string convert_sdxl_lora_name(std::string tensor_name) { + const std::pair sdxl_lora_name_lookup[] = { + {"unet", "model_diffusion_model"}, + {"te2", "cond_stage_model_1_transformer"}, + {"te1", "cond_stage_model_transformer"}, + }; + for (auto& pair_i : sdxl_lora_name_lookup) { + if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) { + tensor_name = std::regex_replace(tensor_name, std::regex(pair_i.first), pair_i.second); + break; + } + } + return tensor_name; +} + std::unordered_map> suffix_conversion_underline = { { "attentions", @@ -415,8 +432,12 @@ std::string convert_tensor_name(const std::string& name) { if (pos != std::string::npos) { std::string name_without_network_parts = name.substr(5, pos - 5); std::string network_part = name.substr(pos + 1); + // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '_'); + /* For dealing with the new SDXL LoRA tensor naming convention */ + new_key = convert_sdxl_lora_name(new_key); + if (new_key.empty()) { new_name = name; } else { @@ -1641,4 +1662,4 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa } bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type); return success; -} \ No newline at end of file +} diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4d622dde..9d100861 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -122,10 +122,16 @@ class StableDiffusionGGML { } ~StableDiffusionGGML() { + if (clip_backend != backend) { + ggml_backend_free(clip_backend); + } + if (control_net_backend != backend) { + ggml_backend_free(control_net_backend); + } + if (vae_backend != backend) { + ggml_backend_free(vae_backend); + } ggml_backend_free(backend); - ggml_backend_free(clip_backend); - ggml_backend_free(control_net_backend); - ggml_backend_free(vae_backend); } bool load_from_file(const std::string& model_path, @@ -521,9 +527,7 @@ class StableDiffusionGGML { int64_t t1 = ggml_time_ms(); - LOG_INFO("lora '%s' applied, taking %.2fs", - lora_name.c_str(), - (t1 - t0) * 1.0f / 1000); + LOG_INFO("lora '%s' applied, taking %.2fs", lora_name.c_str(), (t1 - t0) * 1.0f / 1000); } void apply_loras(const std::unordered_map& lora_state) { @@ -546,6 +550,8 @@ class StableDiffusionGGML { } } + LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size()); + for (auto& kv : lora_state_diff) { apply_lora(kv.first, kv.second); } @@ -2109,4 +2115,4 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, LOG_INFO("img2vid completed in %.2fs", (t3 - t0) * 1.0f / 1000); return result_images; -} \ No newline at end of file +} diff --git a/stable-diffusion.h b/stable-diffusion.h index 1b3cd14a..36960041 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -201,4 +201,4 @@ SD_API uint8_t* preprocess_canny(uint8_t* img, } #endif -#endif // __STABLE_DIFFUSION_H__ \ No newline at end of file +#endif // __STABLE_DIFFUSION_H__ diff --git a/util.cpp b/util.cpp index 96310cb6..0755cc32 100644 --- a/util.cpp +++ b/util.cpp @@ -366,18 +366,8 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo va_list args; va_start(args, format); - const char* level_str = "DEBUG"; - if (level == SD_LOG_INFO) { - level_str = "INFO "; - } else if (level == SD_LOG_WARN) { - level_str = "WARN "; - } else if (level == SD_LOG_ERROR) { - level_str = "ERROR"; - } - static char log_buffer[LOG_BUFFER_SIZE + 1]; - - int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line); + int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line); if (written >= 0 && written < LOG_BUFFER_SIZE) { vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args); @@ -572,4 +562,4 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) { } return result; -} \ No newline at end of file +}