From 3f93f59ec49d9055c536e95dac934d7dc7d5cd20 Mon Sep 17 00:00:00 2001 From: William Bradford Clark Date: Tue, 19 Mar 2024 19:45:27 -0400 Subject: [PATCH] Support stable-diffusion-webui-forge Hijacks the forge_sample function and calls it separately for each individual text conditioning. The outputs are then recombined into a format that the extension already knows how to work with. This doesn't use forge's patcher or take advantage of all of forge's performance optimizations brought over from ComfyUI via ldm_patched, but it does guarantee backwards compatibility for base A1111 users, and allows Forge users to use the extension without maintaining a separate install of A1111. For a discussion of some of the challenges of manipulating denoised latents more directly using built-in functionality of Forge, see: https://github.com/hako-mikan/sd-webui-regional-prompter/issues/299 --- README.md | 1 + lib_neutral_prompt/cfg_denoiser_hijack.py | 115 +++++++++++++++++++--- 2 files changed, 100 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 4a2cb4e..0c0383b 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ Neutral prompt is an a1111 webui extension that adds alternative composable diff ## Features +- Now compatible wih [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)! - [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword - saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons) - semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247) diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index a8605ab..d670442 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -212,27 +212,110 @@ def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor: return vector * (torch.abs(vector) >= top_k).to(vector.dtype) -sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get( - module=sd_samplers, - hijacker_attribute='__neutral_prompt_hijacker', - on_uninstall=script_callbacks.on_script_unloaded, -) +try: + from modules_forge import forge_sampler + forge = True +except ImportError: + forge = False -@sd_samplers_hijacker.hijack('create_sampler') -def create_sampler_hijack(name: str, model, original_function): - sampler = original_function(name, model) - if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'): - if global_state.is_enabled: - warn_unsupported_sampler() +if forge: + from ldm_patched.modules.samplers import sampling_function - return sampler + forge_sampler_hijacker = hijacker.ModuleHijacker.install_or_get( + module=forge_sampler, + hijacker_attribute='__forge_sample_hijacker', + on_uninstall=script_callbacks.on_script_unloaded, + ) + + @forge_sampler_hijacker.hijack('forge_sample') + def forge_sample(self, denoiser_params, cond_scale, cond_composition, original_function): + if not global_state.is_enabled: + return original_function(self, denoiser_params, cond_scale, cond_composition) + + model = self.inner_model.inner_model.forge_objects.unet.model + control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list + extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition + x = denoiser_params.x + timestep = denoiser_params.sigma + model_options = self.inner_model.inner_model.forge_objects.unet.model_options + seed = self.p.seeds[0] + + uncond = forge_sampler.cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) + conds = forge_sampler.cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition) + conds = conds + uncond + + denoised = [] + + for current_cond in conds: + cond = [current_cond] + cond[0]['strength'] = 1.0 + + if extra_concat_condition is not None: + image_cond_in = extra_concat_condition + else: + image_cond_in = denoiser_params.image_cond + + if isinstance(image_cond_in, torch.Tensor): + if image_cond_in.shape[0] == x.shape[0] \ + and image_cond_in.shape[2] == x.shape[2] \ + and image_cond_in.shape[3] == x.shape[3]: + cond[0]['model_conds']['c_concat'] = CONDRegular(image_cond_in) + + if control is not None: + cond[0]['control'] = control + + for modifier in model_options.get('conditioning_modifiers', []): + model, x, timestep, _, cond, cond_scale, model_options, seed = modifier(model, x, timestep, None, cond, cond_scale, model_options, seed) + + model_options["disable_cfg1_optimization"] = True + + result = sampling_function(model, x, timestep, None, cond, 1.0, model_options, seed) + denoised.append(result) - sampler.model_wrap_cfg.combine_denoised = functools.partial( - combine_denoised_hijack, - original_function=sampler.model_wrap_cfg.combine_denoised + denoised_uncond = denoised[-1] + denoised_conds = torch.stack(denoised[:-1], dim=0) + denoised_cond = torch.mean(denoised_conds, dim=0) + forge_denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale + + cond_indices = cond_composition[0] + prompt = global_state.prompt_exprs[0] + + for batch_i in range(denoised_uncond.shape[0]): + args = CombineDenoiseArgs(denoised_conds.unbind(dim=1)[batch_i], denoised_uncond[batch_i], cond_indices) + cond_delta = prompt.accept(CondDeltaVisitor(), args, 0) + aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0) + cfg_cond = forge_denoised[batch_i] + aux_cond_delta * cond_scale + forge_denoised[batch_i] = cfg_rescale(cfg_cond, denoised_uncond[batch_i] + cond_delta + aux_cond_delta) + + return forge_denoised + + +else: + sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get( + module=sd_samplers, + hijacker_attribute='__neutral_prompt_hijacker', + on_uninstall=script_callbacks.on_script_unloaded, ) - return sampler + + + @sd_samplers_hijacker.hijack('create_sampler') + def create_sampler_hijack(name: str, model, original_function): + sampler = original_function(name, model) + + + if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'): + if global_state.is_enabled: + warn_unsupported_sampler() + + return sampler + + sampler.model_wrap_cfg.combine_denoised = functools.partial( + combine_denoised_hijack, + original_function=sampler.model_wrap_cfg.combine_denoised + ) + + return sampler def warn_unsupported_sampler():