Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stable-diffusion-webui-forge #69

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Neutral prompt is an a1111 webui extension that adds alternative composable diff

## Features

- Now compatible wih [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)!
- [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword
- saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons)
- semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247)
Expand Down
111 changes: 95 additions & 16 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
import textwrap

from modules_forge.forge_sampler import cond_from_a1111_to_patched_ldm


def combine_denoised_hijack(
x_out: torch.Tensor,
Expand Down Expand Up @@ -212,27 +214,104 @@ def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
return vector * (torch.abs(vector) >= top_k).to(vector.dtype)


sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
try:
from ldm_patched.modules import samplers
from modules_forge import forge_sampler
forge = True
except ImportError:
forge = False


if forge:
forge_sampler_hijacker = hijacker.ModuleHijacker.install_or_get(
module=forge_sampler,
hijacker_attribute='__forge_sampler_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=samplers,
hijacker_attribute='__samplers_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()
@forge_sampler_hijacker.hijack('forge_sample')
def forge_sample(self, denoiser_params, cond_scale, cond_composition, original_function):
if not global_state.is_enabled:
return original_function(self, denoiser_params, cond_scale, cond_composition)

return sampler
self.inner_model.inner_model.forge_objects.unet.model_options['cond_composition'] = cond_composition
ljleb marked this conversation as resolved.
Show resolved Hide resolved
self.inner_model.inner_model.forge_objects.unet.model_options['uncond'] = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)

return original_function(self, denoiser_params, cond_scale, cond_composition)


@samplers_hijacker.hijack('calc_cond_uncond_batch')
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options, original_function):
if not global_state.is_enabled:
return original_function(model, cond, uncond, x_in, timestep, model_options)

cond_composition = model_options['cond_composition']
uncond = model_options['uncond']
ljleb marked this conversation as resolved.
Show resolved Hide resolved

cond.extend(uncond)
remove_last_cond = (len(cond) % 2) == 1
if remove_last_cond:
cond.append(cond[-1])

for elem in cond:
elem['strength'] = 1.0

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
denoised_latents = [
denoised
for first_cond, second_cond in zip(cond[::2], cond[1::2])
for denoised in original_function(model, [first_cond], [second_cond], x_in, timestep, model_options)
]

if remove_last_cond:
denoised_latents = denoised_latents[:-1]

# B, C, H, W
denoised_uncond = denoised_latents[-1]

# N, B, C, H, W
denoised_in_conds = torch.stack(denoised_latents[:-1], dim=0)
denoised_in_conds = denoised_in_conds.transpose(0, 1).reshape(-1, *denoised_in_conds.shape[2:])

# N, 1, 1, 1, 1
denoised_cond = denoised_uncond.clone()
for batch_i in range(denoised_uncond.shape[0]):
prompt = global_state.prompt_exprs[batch_i]
args = CombineDenoiseArgs(denoised_in_conds, denoised_uncond[batch_i], cond_composition[batch_i])
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
denoised_cond[batch_i] += cond_delta + aux_cond_delta

return denoised_cond, denoised_uncond
else:
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
return sampler


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)

if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()

return sampler

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
)

return sampler


def warn_unsupported_sampler():
Expand Down
10 changes: 9 additions & 1 deletion lib_neutral_prompt/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import gradio as gr
import dataclasses

try:
# import a forge-specific module
from modules_forge import forge_sampler
del forge_sampler
forge = True
except ImportError:
forge = False


txt2img_prompt_textbox = None
img2img_prompt_textbox = None
Expand All @@ -29,7 +37,7 @@ class AccordionInterface:
def __post_init__(self):
self.is_rendered = False

self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0)
self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0, visible=not forge, interactive=not forge)
self.neutral_prompt = gr.Textbox(label='Neutral prompt', show_label=False, lines=3, placeholder='Neutral prompt (click on apply below to append this to the positive prompt textbox)')
self.neutral_cond_scale = gr.Slider(label='Prompt weight', minimum=-3, maximum=3, value=1)
self.aux_prompt_type = gr.Dropdown(label='Prompt type', choices=list(prompt_types.keys()), value=next(iter(prompt_types.keys())), tooltip=prompt_types_tooltip, elem_id=self.get_elem_id('formatter_prompt_type'))
Expand Down