From 004dfbef270b19c9516950d2bc1d918efdbd1427 Mon Sep 17 00:00:00 2001 From: Steward Garcia <57494570+FSSRepo@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:46:48 -0500 Subject: [PATCH] feat: implement ESRGAN upscaler + Metal Backend (#104) * add esrgan upscaler * add sd_tiling * support metal backend * add clip_skip --------- Co-authored-by: leejet --- CMakeLists.txt | 11 + README.md | 31 +- examples/cli/main.cpp | 41 +- ggml | 2 +- model.cpp | 14 +- model.h | 2 +- stable-diffusion.cpp | 847 ++++++++++++++++++++++++++++++++++++++++-- stable-diffusion.h | 6 +- 8 files changed, 915 insertions(+), 39 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b119ee6e..95d59d61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,9 @@ endif() #option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_CUBLAS "sd: cuda backend" OFF) +option(SD_METAL "sd: metal backend" OFF) option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) +option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) @@ -33,6 +35,15 @@ if(SD_CUBLAS) message("Use CUBLAS as backend stable-diffusion") set(GGML_CUBLAS ON) add_definitions(-DSD_USE_CUBLAS) + if(SD_FAST_SOFTMAX) + set(GGML_CUDA_FAST_SOFTMAX ON) + endif() +endif() + +if(SD_METAL) + message("Use Metal as backend stable-diffusion") + set(GGML_METAL ON) + add_definitions(-DSD_USE_METAL) endif() if(SD_FLASH_ATTN) diff --git a/README.md b/README.md index a0765dab..cc6938bf 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Accelerated memory-efficient CPU inference - Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB. - AVX, AVX2 and AVX512 support for x86 architectures -- Full CUDA backend for GPU acceleration. +- Full CUDA and Metal backend for GPU acceleration. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models - No need to convert to `.ggml` or `.gguf` anymore! - Flash Attention for memory usage optimization (only cpu for now) @@ -27,6 +27,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora) - Latent Consistency Models support (LCM/LCM-LoRA) - Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd) +- Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN) +- VAE tiling processing for reduce memory usage - Sampling method - `Euler A` - `Euler` @@ -51,7 +53,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - The current implementation of ggml_conv_2d is slow and has high memory usage - Implement Winograd Convolution 2D for 3x3 kernel filtering - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler +- [ ] Implement Textual Inversion (embeddings) +- [ ] Implement Inpainting support - [ ] k-quants support ## Usage @@ -112,6 +115,15 @@ cmake .. -DSD_CUBLAS=ON cmake --build . --config Release ``` +##### Using Metal + +Using Metal makes the computation run on the GPU. Currently, there are some issues with Metal when performing operations on very large matrices, making it highly inefficient at the moment. Performance improvements are expected in the near future. + +``` +cmake .. -DSD_METAL=ON +cmake --build . --config Release +``` + ### Using Flash Attention Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing. @@ -124,7 +136,7 @@ cmake --build . --config Release ### Run ``` -usage: sd [arguments] +usage: ./bin/sd [arguments] arguments: -h, --help show this help message and exit @@ -134,6 +146,7 @@ arguments: -m, --model [MODEL] path to model --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) + --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) If not specified, the default is the type of the weight file. --lora-model-dir [DIR] lora model directory @@ -153,6 +166,8 @@ arguments: -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -b, --batch-count COUNT number of images to generate. --schedule {discrete, karras} Denoiser sigma schedule (default: discrete) + --clip-skip N number of layers to skip of clip model (default: 0) + --vae-tiling process vae in tiles to reduce memory usage -v, --verbose print extra info ``` @@ -240,6 +255,16 @@ curl -L -O https://huggingface.co/madebyollin/taesd/blob/main/diffusion_pytorch_ sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --taesd ../models/diffusion_pytorch_model.safetensors ``` +## Using ESRGAN to upscale results + +You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon. + +- Specify the model path using the `--upscale-model PATH` parameter. example: + +```bash +sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --upscale-model ../models/RealESRGAN_x4plus_anime_6B.pth +``` + ### Docker #### Building using Docker diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 68824dd9..6264d6e2 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -59,6 +59,7 @@ struct SDParams { std::string model_path; std::string vae_path; std::string taesd_path; + std::string esrgan_path; ggml_type wtype = GGML_TYPE_COUNT; std::string lora_model_dir; std::string output_path = "output.png"; @@ -67,6 +68,7 @@ struct SDParams { std::string prompt; std::string negative_prompt; float cfg_scale = 7.0f; + int clip_skip = -1; // <= 0 represents unspecified int width = 512; int height = 512; int batch_count = 1; @@ -78,6 +80,7 @@ struct SDParams { RNGType rng_type = CUDA_RNG; int64_t seed = 42; bool verbose = false; + bool vae_tiling = false; }; void print_params(SDParams params) { @@ -88,11 +91,13 @@ void print_params(SDParams params) { printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified"); printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); + printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); printf(" sample_method: %s\n", sample_method_str[params.sample_method]); @@ -102,6 +107,7 @@ void print_params(SDParams params) { printf(" rng: %s\n", rng_type_to_str[params.rng_type]); printf(" seed: %ld\n", params.seed); printf(" batch_count: %d\n", params.batch_count); + printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); } void print_usage(int argc, const char* argv[]) { @@ -115,6 +121,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -m, --model [MODEL] path to model\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" If not specified, the default is the type of the weight file.\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); @@ -134,6 +141,9 @@ void print_usage(int argc, const char* argv[]) { printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate.\n"); printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n"); + printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); + printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); + printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" -v, --verbose print extra info\n"); } @@ -185,6 +195,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.taesd_path = argv[i]; + } else if (arg == "--upscale-model") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.esrgan_path = argv[i]; } else if (arg == "--type") { if (++i >= argc) { invalid_arg = true; @@ -270,6 +286,14 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.sample_steps = std::stoi(argv[i]); + } else if (arg == "--clip-skip") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_skip = std::stoi(argv[i]); + } else if (arg == "--vae-tiling") { + params.vae_tiling = true; } else if (arg == "-b" || arg == "--batch-count") { if (++i >= argc) { invalid_arg = true; @@ -458,9 +482,9 @@ int main(int argc, const char* argv[]) { } } - StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, true, params.lora_model_dir, params.rng_type); + StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); - if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule)) { + if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) { return 1; } @@ -488,6 +512,19 @@ int main(int argc, const char* argv[]) { params.seed); } + if (params.esrgan_path.size() > 0) { + // TODO: support more ESRGAN models, making it easier to set up ESRGAN models. + /* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible + See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py + + To avoid this, the upscaler needs to be separated from the stable diffusion pipeline. + However, a considerable amount of work would be required for this. It might be better + to opt for a complete project refactoring that facilitates the easier assignment of parameters. + */ + params.width *= 4; + params.height *= 4; + } + if (results.size() == 0 || results.size() != params.batch_count) { LOG_ERROR("generate failed"); return 1; diff --git a/ggml b/ggml index 70474c68..a0c2ec77 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 70474c6890c015b53dc10a2300ae35246cc73589 +Subproject commit a0c2ec77a5ef8e630aff65bc535d13b9805cb929 diff --git a/model.cpp b/model.cpp index 41d3347b..f8f0752c 100644 --- a/model.cpp +++ b/model.cpp @@ -14,6 +14,10 @@ #include "ggml/ggml-backend.h" #include "ggml/ggml.h" +#ifdef SD_USE_METAL +#include "ggml-metal.h" +#endif + #define ST_HEADER_SIZE_LEN 8 uint64_t read_u64(uint8_t* buffer) { @@ -1197,7 +1201,7 @@ std::string ModelLoader::load_merges() { return merges_utf8_str; } -bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { +bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) { bool success = true; for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) { std::string file_path = file_paths_[file_index]; @@ -1285,11 +1289,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { continue; } - ggml_backend_t backend = ggml_get_backend(dst_tensor); - size_t nbytes_to_read = tensor_storage.nbytes_to_read(); - if (backend == NULL || ggml_backend_is_cpu(backend)) { + if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend) +#ifdef SD_USE_METAL + || ggml_backend_is_metal(backend) +#endif + ) { // for the CPU and Metal backend, we can copy directly into the tensor if (tensor_storage.type == dst_tensor->type) { GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); diff --git a/model.h b/model.h index 4df7f8bf..7966e9aa 100644 --- a/model.h +++ b/model.h @@ -116,7 +116,7 @@ class ModelLoader { SDVersion get_sd_version(); ggml_type get_sd_wtype(); std::string load_merges(); - bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb); + bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend); int64_t cal_mem_size(ggml_backend_t backend); ~ModelLoader() = default; }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index d0c499f3..70cd79a7 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -23,6 +23,10 @@ #include "ggml-cuda.h" #endif +#ifdef SD_USE_METAL +#include "ggml-metal.h" +#endif + #include "model.h" #include "rng.h" #include "rng_philox.h" @@ -75,6 +79,13 @@ std::string sd_get_system_info() { return ss.str(); } +static void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) { + (void)level; + (void)user_data; + fputs(text, stderr); + fflush(stderr); +} + void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr rng) { uint32_t n = (uint32_t)ggml_nelements(tensor); std::vector random_numbers = rng->randn(n); @@ -335,6 +346,54 @@ void sd_image_to_tensor(const uint8_t* image_data, } } +void ggml_split_tensor_2d(struct ggml_tensor* input, + struct ggml_tensor* output, + int x, + int y) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + for (int k = 0; k < channels; k++) { + float value = ggml_tensor_get_f32(input, ix + x, iy + y, k); + ggml_tensor_set_f32(output, value, ix, iy, k); + } + } + } +} + +void ggml_merge_tensor_2d(struct ggml_tensor* input, + struct ggml_tensor* output, + int x, + int y, + int overlap) { + int64_t width = input->ne[0]; + int64_t height = input->ne[1]; + int64_t channels = input->ne[2]; + GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + for (int k = 0; k < channels; k++) { + float new_value = ggml_tensor_get_f32(input, ix, iy, k); + if (overlap > 0) { // blend colors in overlapped area + float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k); + if (x > 0 && ix < overlap) { // in overlapped horizontal + ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (ix / (1.0f * overlap)), x + ix, y + iy, k); + continue; + } + if (y > 0 && iy < overlap) { // in overlapped vertical + ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (iy / (1.0f * overlap)), x + ix, y + iy, k); + continue; + } + } + ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k); + } + } + } +} + float ggml_tensor_mean(struct ggml_tensor* src) { float mean = 0.0f; int64_t nelements = ggml_nelements(src); @@ -393,6 +452,71 @@ void ggml_tensor_scale_output(struct ggml_tensor* src) { } } +typedef std::function on_tile_process; + +// Tiling +void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { + int input_width = input->ne[0]; + int input_height = input->ne[1]; + int output_width = output->ne[0]; + int output_height = output->ne[1]; + GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 + + int tile_overlap = (int32_t)(tile_size * tile_overlap_factor); + int non_tile_overlap = tile_size - tile_overlap; + + struct ggml_init_params params = {}; + params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk + params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk + params.mem_size += 3 * ggml_tensor_overhead(); + params.mem_buffer = NULL; + params.no_alloc = false; + + LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); + + // draft context + struct ggml_context* tiles_ctx = ggml_init(params); + if (!tiles_ctx) { + LOG_ERROR("ggml_init() failed"); + return; + } + + // tiling + ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1); + ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1); + on_processing(input_tile, NULL, true); + int num_tiles = (input_width * input_height) / (non_tile_overlap * non_tile_overlap); + LOG_INFO("processing %i tiles", num_tiles); + pretty_progress(1, num_tiles, 0.0f); + int tile_count = 1; + bool last_y = false, last_x = false; + float last_time = 0.0f; + for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) { + if (y + tile_size >= input_height) { + y = input_height - tile_size; + last_y = true; + } + for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) { + if (x + tile_size >= input_width) { + x = input_width - tile_size; + last_x = true; + } + int64_t t1 = ggml_time_ms(); + ggml_split_tensor_2d(input, input_tile, x, y); + on_processing(input_tile, output_tile, false); + ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale); + int64_t t2 = ggml_time_ms(); + last_time = (t2 - t1) / 1000.0f; + pretty_progress(tile_count, num_tiles, last_time); + tile_count++; + } + last_x = false; + } + if (tile_count < num_tiles) { + pretty_progress(num_tiles, num_tiles, last_time); + } +} + struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, struct ggml_tensor* a) { return ggml_group_norm(ctx, a, 32); @@ -481,6 +605,15 @@ std::pair, std::string> extract_and_remov return std::make_pair(filename2multiplier, text); } +void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + #ifdef SD_USE_CUBLAS + ggml_backend_tensor_get_async(backend, tensor, data, offset, size); + ggml_backend_synchronize(backend); + #else + ggml_backend_tensor_get(tensor, data, offset, size); + #endif +} + /*================================================== CLIPTokenizer ===================================================*/ const std::string UNK_TOKEN = "<|endoftext|>"; @@ -1042,6 +1175,7 @@ struct CLIPTextModel { int32_t intermediate_size = 3072; // 4096 for SD 2.x int32_t n_head = 12; // num_attention_heads, 16 for SD 2.x int32_t num_hidden_layers = 12; // 24 for SD 2.x + int32_t clip_skip = 1; // embeddings struct ggml_tensor* position_ids; @@ -1200,8 +1334,9 @@ struct CLIPTextModel { ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size] // transformer + int layer_idx = num_hidden_layers - clip_skip; for (int i = 0; i < num_hidden_layers; i++) { - if (version == VERSION_2_x && i == num_hidden_layers - 1) { // layer: "penultimate" + if (i == layer_idx + 1) { break; } x = resblocks[i].forward(ctx0, x); // [N, n_token, hidden_size] @@ -1273,12 +1408,18 @@ struct CLIPTextModel { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_output->data, 0, ggml_nbytes(work_output)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_output->data, 0, ggml_nbytes(work_output)); return work_output; } @@ -1682,7 +1823,7 @@ struct SpatialTransformer { { x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn1_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) q = ggml_scale_inplace(ctx, q, attn_scale); #endif q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] @@ -1699,7 +1840,7 @@ struct SpatialTransformer { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, h * w] v = ggml_reshape_3d(ctx, v, h * w, d_head, n_head * n); // [N * n_head, d_head, h * w] -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] #else struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, h * w] @@ -1730,7 +1871,7 @@ struct SpatialTransformer { x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] context = ggml_reshape_2d(ctx, context, context->ne[0], context->ne[1] * context->ne[2]); // [N * max_position, hidden_size] struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn2_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) q = ggml_scale_inplace(ctx, q, attn_scale); #endif q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] @@ -1746,7 +1887,7 @@ struct SpatialTransformer { v = ggml_reshape_4d(ctx, v, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, max_position] v = ggml_reshape_3d(ctx, v, max_position, d_head, n_head * n); // [N * n_head, d_head, max_position] -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] #else struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] @@ -2544,13 +2685,19 @@ struct UNetModel { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_latent->data, 0, ggml_nbytes(work_latent)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_latent->data, 0, ggml_nbytes(work_latent)); } void end() { @@ -3349,13 +3496,19 @@ struct AutoEncoderKL { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); } void end() { @@ -3366,7 +3519,7 @@ struct AutoEncoderKL { }; /* - + =================================== TinyAutoEncoder =================================== References: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_tiny.py https://github.com/madebyollin/taesd/blob/main/taesd.py @@ -3939,7 +4092,7 @@ struct TinyAutoEncoder { return true; }; - bool success = model_loader.load_tensors(on_new_tensor_cb); + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); bool some_tensor_not_init = false; @@ -4026,13 +4179,542 @@ struct TinyAutoEncoder { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); + } + + void end() { + ggml_allocr_free(compute_alloc); + ggml_backend_buffer_free(compute_buffer); + compute_alloc = NULL; + } +}; + +/* + =================================== ESRGAN =================================== + References: + https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py + https://github.com/XPixelGroup/BasicSR/blob/v1.4.2/basicsr/archs/rrdbnet_arch.py + +*/ + +struct ResidualDenseBlock { + int num_features; + int num_grow_ch; + ggml_tensor* conv1_w; // [num_grow_ch, num_features, 3, 3] + ggml_tensor* conv1_b; // [num_grow_ch] + + ggml_tensor* conv2_w; // [num_grow_ch, num_features + num_grow_ch, 3, 3] + ggml_tensor* conv2_b; // [num_grow_ch] + + ggml_tensor* conv3_w; // [num_grow_ch, num_features + 2 * num_grow_ch, 3, 3] + ggml_tensor* conv3_b; // [num_grow_ch] + + ggml_tensor* conv4_w; // [num_grow_ch, num_features + 3 * num_grow_ch, 3, 3] + ggml_tensor* conv4_b; // [num_grow_ch] + + ggml_tensor* conv5_w; // [num_features, num_features + 4 * num_grow_ch, 3, 3] + ggml_tensor* conv5_b; // [num_features] + + ResidualDenseBlock() {} + + ResidualDenseBlock(int num_feat, int n_grow_ch) { + num_features = num_feat; + num_grow_ch = n_grow_ch; + } + + size_t calculate_mem_size() { + size_t mem_size = num_features * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv1_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv1_b + + mem_size += (num_features + num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv2_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv2_b + + mem_size += (num_features + 2 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv3_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv3_w + + mem_size += (num_features + 3 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv4_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv4_w + + mem_size += (num_features + 4 * num_grow_ch) * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv5_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv5_w + + return mem_size; + } + + int get_num_tensors() { + int num_tensors = 10; + return num_tensors; + } + + void init_params(ggml_context* ctx) { + conv1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_grow_ch); + conv1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + num_grow_ch, num_grow_ch); + conv2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 2 * num_grow_ch, num_grow_ch); + conv3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv4_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 3 * num_grow_ch, num_grow_ch); + conv4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv5_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 4 * num_grow_ch, num_features); + conv5_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + } + + void map_by_name(std::map& tensors, std::string prefix) { + tensors[prefix + "conv1.weight"] = conv1_w; + tensors[prefix + "conv1.bias"] = conv1_b; + + tensors[prefix + "conv2.weight"] = conv2_w; + tensors[prefix + "conv2.bias"] = conv2_b; + + tensors[prefix + "conv3.weight"] = conv3_w; + tensors[prefix + "conv3.bias"] = conv3_b; + + tensors[prefix + "conv4.weight"] = conv4_w; + tensors[prefix + "conv4.bias"] = conv4_b; + + tensors[prefix + "conv5.weight"] = conv5_w; + tensors[prefix + "conv5.bias"] = conv5_b; + } + + ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { + // x1 = self.lrelu(self.conv1(x)) + ggml_tensor* x1 = ggml_nn_conv_2d(ctx, x, conv1_w, conv1_b, 1, 1, 1, 1); + x1 = ggml_leaky_relu(ctx, x1, 0.2f, true); + + // x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + ggml_tensor* x_cat = ggml_concat(ctx, x, x1); + ggml_tensor* x2 = ggml_nn_conv_2d(ctx, x_cat, conv2_w, conv2_b, 1, 1, 1, 1); + x2 = ggml_leaky_relu(ctx, x2, 0.2f, true); + + // x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x_cat = ggml_concat(ctx, x_cat, x2); + ggml_tensor* x3 = ggml_nn_conv_2d(ctx, x_cat, conv3_w, conv3_b, 1, 1, 1, 1); + x3 = ggml_leaky_relu(ctx, x3, 0.2f, true); + + // x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x_cat = ggml_concat(ctx, x_cat, x3); + ggml_tensor* x4 = ggml_nn_conv_2d(ctx, x_cat, conv4_w, conv4_b, 1, 1, 1, 1); + x4 = ggml_leaky_relu(ctx, x4, 0.2f, true); + + // self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + x_cat = ggml_concat(ctx, x_cat, x4); + ggml_tensor* x5 = ggml_nn_conv_2d(ctx, x_cat, conv5_w, conv5_b, 1, 1, 1, 1); + + // return x5 * 0.2 + x + x5 = ggml_add(ctx, ggml_scale(ctx, x5, out_scale), x); + return x5; + } +}; + +struct EsrganBlock { + ResidualDenseBlock rd_blocks[3]; + int num_residual_blocks = 3; + + EsrganBlock() {} + + EsrganBlock(int num_feat, int num_grow_ch) { + for (int i = 0; i < num_residual_blocks; i++) { + rd_blocks[i] = ResidualDenseBlock(num_feat, num_grow_ch); + } + } + + int get_num_tensors() { + int num_tensors = 0; + for (int i = 0; i < num_residual_blocks; i++) { + num_tensors += rd_blocks[i].get_num_tensors(); + } + return num_tensors; + } + + size_t calculate_mem_size() { + size_t mem_size = 0; + for (int i = 0; i < num_residual_blocks; i++) { + mem_size += rd_blocks[i].calculate_mem_size(); + } + return mem_size; + } + + void init_params(ggml_context* ctx) { + for (int i = 0; i < num_residual_blocks; i++) { + rd_blocks[i].init_params(ctx); + } + } + + void map_by_name(std::map& tensors, std::string prefix) { + for (int i = 0; i < num_residual_blocks; i++) { + rd_blocks[i].map_by_name(tensors, prefix + "rdb" + std::to_string(i + 1) + "."); + } + } + + ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x) { + ggml_tensor* out = x; + for (int i = 0; i < num_residual_blocks; i++) { + // out = self.rdb...(x) + out = rd_blocks[i].forward(ctx, out_scale, out); + } + // return out * 0.2 + x + out = ggml_add(ctx, ggml_scale(ctx, out, out_scale), x); + return out; + } +}; + +struct ESRGAN { + int scale = 4; // default RealESRGAN_x4plus_anime_6B + int num_blocks = 6; // default RealESRGAN_x4plus_anime_6B + int in_channels = 3; + int out_channels = 3; + int num_features = 64; // default RealESRGAN_x4plus_anime_6B + int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B + int tile_size = 128; // avoid cuda OOM for 4gb VRAM + + ggml_tensor* conv_first_w; // [num_features, in_channels, 3, 3] + ggml_tensor* conv_first_b; // [num_features] + + EsrganBlock body_blocks[6]; + ggml_tensor* conv_body_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_body_b; // [num_features] + + // upsample + ggml_tensor* conv_up1_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_up1_b; // [num_features] + ggml_tensor* conv_up2_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_up2_b; // [num_features] + + ggml_tensor* conv_hr_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_hr_b; // [num_features] + ggml_tensor* conv_last_w; // [out_channels, num_features, 3, 3] + ggml_tensor* conv_last_b; // [out_channels] + + ggml_context* ctx; + bool decode_only = false; + ggml_backend_buffer_t params_buffer; + ggml_backend_buffer_t compute_buffer; // for compute + struct ggml_allocr* compute_alloc = NULL; + + int memory_buffer_size = 0; + ggml_type wtype; + ggml_backend_t backend = NULL; + + ESRGAN() { + for (int i = 0; i < num_blocks; i++) { + body_blocks[i] = EsrganBlock(num_features, num_grow_ch); + } + } + + size_t calculate_mem_size() { + size_t mem_size = num_features * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_first_b + + for (int i = 0; i < num_blocks; i++) { + mem_size += body_blocks[i].calculate_mem_size(); + } + + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_body_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_body_w + + // upsample + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up1_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up1_b + + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up2_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up2_b + + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_hr_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_hr_b + + mem_size += out_channels * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_last_w + mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_last_b + return mem_size; + } + + int get_num_tensors() { + int num_tensors = 12; + for (int i = 0; i < num_blocks; i++) { + num_tensors += body_blocks[i].get_num_tensors(); + } + return num_tensors; + } + + bool init(ggml_backend_t backend_) { + this->backend = backend_; + memory_buffer_size = calculate_mem_size(); + memory_buffer_size += 1024; // overhead + int num_tensors = get_num_tensors(); + + LOG_DEBUG("ESRGAN params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); + + struct ggml_init_params params; + params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); + params.mem_buffer = NULL; + params.no_alloc = true; + + params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); + + ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + return true; + } + + void alloc_params() { + ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); + conv_first_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, num_features); + conv_first_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_body_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_body_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_up1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_up1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_up2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_up2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_hr_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_hr_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_last_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, out_channels); + conv_last_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + for (int i = 0; i < num_blocks; i++) { + body_blocks[i].init_params(ctx); + } + + // alloc all tensors linked to this context + for (struct ggml_tensor* t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->data == NULL) { + ggml_allocr_alloc(alloc, t); + } + } + ggml_allocr_free(alloc); + } + + bool load_from_file(const std::string& file_path, ggml_backend_t backend) { + LOG_INFO("loading esrgan from '%s'", file_path.c_str()); + + if (!init(backend)) { + return false; + } + + std::map esrgan_tensors; + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + // prepare memory for the weights + { + alloc_params(); + map_by_name(esrgan_tensors); + } + + std::set tensor_names_in_file; + + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + tensor_names_in_file.insert(name); + + struct ggml_tensor* real; + if (esrgan_tensors.find(name) != esrgan_tensors.end()) { + real = esrgan_tensors[name]; + } else { + LOG_ERROR("unknown tensor '%s' in model file", name.data()); + return true; + } + + if ( + real->ne[0] != tensor_storage.ne[0] || + real->ne[1] != tensor_storage.ne[1] || + real->ne[2] != tensor_storage.ne[2] || + real->ne[3] != tensor_storage.ne[3]) { + LOG_ERROR( + "tensor '%s' has wrong shape in model file: " + "got [%d, %d, %d, %d], expected [%d, %d, %d, %d]", + name.c_str(), + (int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3], + (int)real->ne[0], (int)real->ne[1], (int)real->ne[2], (int)real->ne[3]); + return false; + } + + *dst_tensor = real; + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); + + bool some_tensor_not_init = false; + + for (auto pair : esrgan_tensors) { + if (tensor_names_in_file.find(pair.first) == tensor_names_in_file.end()) { + LOG_ERROR("tensor '%s' not in model file", pair.first.c_str()); + some_tensor_not_init = true; + } + } + + if (some_tensor_not_init) { + return false; + } + + LOG_INFO("esrgan model loaded"); + return success; + } + + void map_by_name(std::map& tensors) { + tensors["conv_first.weight"] = conv_first_w; + tensors["conv_first.bias"] = conv_first_b; + + for (int i = 0; i < num_blocks; i++) { + body_blocks[i].map_by_name(tensors, "body." + std::to_string(i) + "."); + } + + tensors["conv_body.weight"] = conv_body_w; + tensors["conv_body.bias"] = conv_body_b; + + tensors["conv_up1.weight"] = conv_up1_w; + tensors["conv_up1.bias"] = conv_up1_b; + tensors["conv_up2.weight"] = conv_up2_w; + tensors["conv_up2.bias"] = conv_up2_b; + tensors["conv_hr.weight"] = conv_hr_w; + tensors["conv_hr.bias"] = conv_hr_b; + + tensors["conv_last.weight"] = conv_last_w; + tensors["conv_last.bias"] = conv_last_b; + } + + ggml_tensor* forward(ggml_context* ctx0, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { + // feat = self.conv_first(feat) + auto h = ggml_nn_conv_2d(ctx0, x, conv_first_w, conv_first_b, 1, 1, 1, 1); + + auto body_h = h; + // self.body(feat) + for (int i = 0; i < num_blocks; i++) { + body_h = body_blocks[i].forward(ctx0, out_scale, body_h); + } + + // body_feat = self.conv_body(self.body(feat)) + body_h = ggml_nn_conv_2d(ctx0, body_h, conv_body_w, conv_body_b, 1, 1, 1, 1); + + // feat = feat + body_feat + h = ggml_add(ctx0, h, body_h); + + // upsample + // feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + h = ggml_upscale(ctx0, h, 2); + h = ggml_nn_conv_2d(ctx0, h, conv_up1_w, conv_up1_b, 1, 1, 1, 1); + h = ggml_leaky_relu(ctx0, h, 0.2f, true); + + // feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + h = ggml_upscale(ctx0, h, 2); + h = ggml_nn_conv_2d(ctx0, h, conv_up2_w, conv_up2_b, 1, 1, 1, 1); + h = ggml_leaky_relu(ctx0, h, 0.2f, true); + + // out = self.conv_last(self.lrelu(self.conv_hr(feat))) + h = ggml_nn_conv_2d(ctx0, h, conv_hr_w, conv_hr_b, 1, 1, 1, 1); + h = ggml_leaky_relu(ctx0, h, 0.2f, true); + + h = ggml_nn_conv_2d(ctx0, h, conv_last_w, conv_last_b, 1, 1, 1, 1); + return h; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + struct ggml_context* ctx0 = ggml_init(params); + + struct ggml_cgraph* gf = ggml_new_graph(ctx0); + + struct ggml_tensor* x_ = NULL; + struct ggml_tensor* os = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(compute_alloc, os); + if (!ggml_allocr_is_measure(compute_alloc)) { + float scale = 0.2f; + ggml_backend_tensor_set(os, &scale, 0, sizeof(scale)); + } + + // it's performing a compute, check if backend isn't cpu + if (!ggml_backend_is_cpu(backend)) { + // pass input tensors to gpu memory + x_ = ggml_dup_tensor(ctx0, x); + ggml_allocr_alloc(compute_alloc, x_); + + // pass data to device backend + if (!ggml_allocr_is_measure(compute_alloc)) { + ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x)); + } + } else { + x_ = x; + } + + struct ggml_tensor* out = forward(ctx0, os, x); + + ggml_build_forward_expand(gf, out); + ggml_free(ctx0); + + return gf; + } + + void begin(struct ggml_tensor* x) { + // calculate the amount of memory required + // alignment required by the backend + compute_alloc = ggml_allocr_new_measure_from_backend(backend); + + struct ggml_cgraph* gf = build_graph(x); + + // compute the required memory + size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); + + // recreate the allocator with the required memory + ggml_allocr_free(compute_alloc); + + LOG_DEBUG("ESRGAN compute buffer size: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); + + compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); + compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); + } + + void compute(struct ggml_tensor* work_result, const int n_threads, struct ggml_tensor* x) { + ggml_allocr_reset(compute_alloc); + + struct ggml_cgraph* gf = build_graph(x); + ggml_allocr_alloc_graph(compute_alloc, gf); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + + ggml_backend_graph_compute(backend, gf); + +#ifdef GGML_PERF + ggml_graph_print(gf); +#endif + ggml_tensor* out = gf->nodes[gf->n_nodes - 1]; + ggml_backend_tensor_get_and_sync(backend, out, work_result->data, 0, ggml_nbytes(out)); } void end() { @@ -4105,7 +4787,7 @@ struct LoraModel { return true; }; - model_loader.load_tensors(on_new_tensor_cb); + model_loader.load_tensors(on_new_tensor_cb, backend); LOG_DEBUG("finished loaded lora"); ggml_allocr_free(alloc); @@ -4241,6 +4923,13 @@ struct LoraModel { if (ggml_backend_is_cpu(backend)) { ggml_backend_cpu_set_n_threads(backend, n_threads); } + +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); ggml_allocr_free(compute_alloc); ggml_backend_buffer_free(buffer_compute_lora); @@ -4392,6 +5081,7 @@ class StableDiffusionGGML { UNetModel diffusion_model; AutoEncoderKL first_stage_model; bool use_tiny_autoencoder = false; + bool vae_tiling = false; std::map tensors; @@ -4407,6 +5097,10 @@ class StableDiffusionGGML { TinyAutoEncoder tae_first_stage; std::string taesd_path; + ESRGAN esrgan_upscaler; + std::string esrgan_path; + bool upscale_output = false; + StableDiffusionGGML() = default; StableDiffusionGGML(int n_threads, @@ -4439,18 +5133,25 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, - Schedule schedule) { + Schedule schedule, + int clip_skip) { #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); - backend = ggml_backend_cuda_init(); + backend = ggml_backend_cuda_init(0); +#endif +#ifdef SD_USE_METAL + LOG_DEBUG("Using Metal backend"); + ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); + backend = ggml_backend_metal_init(); #endif + if (!backend) { LOG_DEBUG("Using CPU backend"); backend = ggml_backend_cpu_init(); } #ifdef SD_USE_FLASH_ATTENTION -#ifdef SD_USE_CUBLAS - LOG_WARN("Flash Attention not supported with CUDA"); +#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) + LOG_WARN("Flash Attention not supported with GPU Backend"); #else LOG_INFO("Flash Attention enabled"); #endif @@ -4475,8 +5176,15 @@ class StableDiffusionGGML { LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); return false; } - cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); - diffusion_model = UNetModel(version); + if (clip_skip <= 0) { + clip_skip = 1; + if (version == VERSION_2_x) { + clip_skip = 2; + } + } + cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); + cond_stage_model.text_model.clip_skip = clip_skip; + diffusion_model = UNetModel(version); LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { model_data_type = model_loader.get_sd_wtype(); @@ -4593,7 +5301,7 @@ class StableDiffusionGGML { // print_ggml_tensor(alphas_cumprod_tensor); - bool success = model_loader.load_tensors(on_new_tensor_cb); + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); if (!success) { LOG_ERROR("load tensors from file failed"); ggml_free(ctx); @@ -4681,6 +5389,11 @@ class StableDiffusionGGML { } LOG_DEBUG("finished loaded file"); ggml_free(ctx); + if (upscale_output) { + if (!esrgan_upscaler.load_from_file(esrgan_path, backend)) { + return false; + } + } if (use_tiny_autoencoder) { return tae_first_stage.load_from_file(taesd_path, backend); } @@ -5341,15 +6054,39 @@ class StableDiffusionGGML { } else { ggml_tensor_scale_input(x); } - first_stage_model.begin(x, decode); - first_stage_model.compute(result, n_threads, x, decode); + if (vae_tiling && decode) { // TODO: support tiling vae encode + // split latent in 32x32 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + if (init) { + first_stage_model.begin(in, decode); + } else { + first_stage_model.compute(out, n_threads, in, decode); + } + }; + sd_tiling(x, result, 8, 32, 0.5f, on_tiling); + } else { + first_stage_model.begin(x, decode); + first_stage_model.compute(result, n_threads, x, decode); + } first_stage_model.end(); if (decode) { ggml_tensor_scale_output(result); } } else { - tae_first_stage.begin(x, decode); - tae_first_stage.compute(result, n_threads, x, decode); + if (vae_tiling && decode) { // TODO: support tiling vae encode + // split latent in 64x64 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + if (init) { + tae_first_stage.begin(in, decode); + } else { + tae_first_stage.compute(out, n_threads, in, decode); + } + }; + sd_tiling(x, result, 8, 64, 0.5f, on_tiling); + } else { + tae_first_stage.begin(x, decode); + tae_first_stage.compute(result, n_threads, x, decode); + } tae_first_stage.end(); } int64_t t1 = ggml_time_ms(); @@ -5360,6 +6097,41 @@ class StableDiffusionGGML { return result; } + uint8_t* upscale(ggml_tensor* image) { + int output_width = image->ne[0] * esrgan_upscaler.scale; + int output_height = image->ne[1] * esrgan_upscaler.scale; + LOG_INFO("upscaling from (%i x %i) to (%i x %i)", image->ne[0], image->ne[1], output_width, output_height); + struct ggml_init_params params; + params.mem_size = output_width * output_height * 3 * sizeof(float); // upscaled + params.mem_size += 1 * ggml_tensor_overhead(); + params.mem_buffer = NULL; + params.no_alloc = false; + // draft context + struct ggml_context* upscale_ctx = ggml_init(params); + if (!upscale_ctx) { + LOG_ERROR("ggml_init() failed"); + return NULL; + } + LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); + ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, image->ne[2], 1); + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + if (init) { + esrgan_upscaler.begin(in); + } else { + esrgan_upscaler.compute(out, n_threads, in); + } + }; + int64_t t0 = ggml_time_ms(); + sd_tiling(image, upscaled, esrgan_upscaler.scale, esrgan_upscaler.tile_size, 0.25f, on_tiling); + esrgan_upscaler.end(); + ggml_tensor_clamp(upscaled, 0.f, 1.f); + uint8_t* upscaled_data = sd_tensor_to_image(upscaled); + ggml_free(upscale_ctx); + int64_t t3 = ggml_time_ms(); + LOG_INFO("image upscaled, taking %.2fs", (t3 - t0) / 1000.0f); + return upscaled_data; + } + ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) { return compute_first_stage(work_ctx, x, false); } @@ -5374,7 +6146,9 @@ class StableDiffusionGGML { StableDiffusion::StableDiffusion(int n_threads, bool vae_decode_only, std::string taesd_path, + std::string esrgan_path, bool free_params_immediately, + bool vae_tiling, std::string lora_model_dir, RNGType rng_type) { sd = std::make_shared(n_threads, @@ -5384,13 +6158,17 @@ StableDiffusion::StableDiffusion(int n_threads, rng_type); sd->use_tiny_autoencoder = taesd_path.size() > 0; sd->taesd_path = taesd_path; + sd->upscale_output = esrgan_path.size() > 0; + sd->esrgan_path = esrgan_path; + sd->vae_tiling = vae_tiling; } bool StableDiffusion::load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, - Schedule s) { - return sd->load_from_file(model_path, vae_path, wtype, s); + Schedule s, + int clip_skip) { + return sd->load_from_file(model_path, vae_path, wtype, s, clip_skip); } std::vector StableDiffusion::txt2img(std::string prompt, @@ -5487,11 +6265,12 @@ std::vector StableDiffusion::txt2img(std::string prompt, LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000); LOG_INFO("decoding %zu latents", final_latents.size()); + std::vector decoded_images; // collect decoded images for (size_t i = 0; i < final_latents.size(); i++) { t1 = ggml_time_ms(); struct ggml_tensor* img = sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */); if (img != NULL) { - results.push_back(sd_tensor_to_image(img)); + decoded_images.push_back(img); } int64_t t2 = ggml_time_ms(); LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); @@ -5502,6 +6281,16 @@ std::vector StableDiffusion::txt2img(std::string prompt, if (sd->free_params_immediately && !sd->use_tiny_autoencoder) { sd->first_stage_model.destroy(); } + if (sd->upscale_output) { + LOG_INFO("upscaling %" PRId64 " images", decoded_images.size()); + } + for (size_t i = 0; i < decoded_images.size(); i++) { + if (sd->upscale_output) { + results.push_back(sd->upscale(decoded_images[i])); + } else { + results.push_back(sd_tensor_to_image(decoded_images[i])); + } + } ggml_free(work_ctx); LOG_INFO( "txt2img completed in %.2fs", @@ -5608,7 +6397,11 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, struct ggml_tensor* img = sd->decode_first_stage(work_ctx, x_0); if (img != NULL) { - result.push_back(sd_tensor_to_image(img)); + if (sd->upscale_output) { + result.push_back(sd->upscale(img)); + } else { + result.push_back(sd_tensor_to_image(img)); + } } int64_t t4 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index c94f6c7c..3ae012f9 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -4,6 +4,7 @@ #include #include #include +#include "ggml/ggml.h" #include "ggml/ggml.h" @@ -41,14 +42,17 @@ class StableDiffusion { StableDiffusion(int n_threads = -1, bool vae_decode_only = false, std::string taesd_path = "", + std::string esrgan_path = "", bool free_params_immediately = false, + bool vae_tiling = false, std::string lora_model_dir = "", RNGType rng_type = STD_DEFAULT_RNG); bool load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, - Schedule d = DEFAULT); + Schedule d = DEFAULT, + int clip_skip = -1); std::vector txt2img( std::string prompt,