Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stable-diffusion-webui-forge #69

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Neutral prompt is an a1111 webui extension that adds alternative composable diff

## Features

- Now compatible wih [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)!
- [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword
- saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons)
- semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247)
Expand Down
178 changes: 158 additions & 20 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
import textwrap

from modules_forge.forge_sampler import cond_from_a1111_to_patched_ldm


def combine_denoised_hijack(
x_out: torch.Tensor,
Expand Down Expand Up @@ -49,8 +51,7 @@ def get_webui_denoised(
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_batch_x_out.extend(sliced_x_out)

sliced_batch_x_out += list(uncond)
sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0)
sliced_batch_x_out = torch.stack(sliced_batch_x_out + list(uncond), dim=0)
return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale)


Expand Down Expand Up @@ -207,32 +208,169 @@ def get_salience(vector: torch.Tensor) -> torch.Tensor:


def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
k = int(torch.numel(vector) * (1 - k_ratio))
k = int(vector.numel() * (1 - k_ratio))
top_k, _ = torch.kthvalue(torch.abs(torch.flatten(vector)), k)
return vector * (torch.abs(vector) >= top_k).to(vector.dtype)
return vector * (vector.abs() >= top_k).to(vector.dtype)


sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
try:
from ldm_patched.modules import samplers
from modules_forge import forge_sampler
forge = True
except ImportError:
forge = False


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()
if forge:
forge_sampler_hijacker = hijacker.ModuleHijacker.install_or_get(
module=forge_sampler,
hijacker_attribute='__forge_sampler_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=samplers,
hijacker_attribute='__samplers_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)


@forge_sampler_hijacker.hijack('forge_sample')
def forge_sample(self, denoiser_params, cond_scale, cond_composition, original_function):
if not global_state.is_enabled:
return original_function(self, denoiser_params, cond_scale, cond_composition)

self.inner_model.inner_model.forge_objects.unet.model_options['cond_composition'] = cond_composition
ljleb marked this conversation as resolved.
Show resolved Hide resolved
self.inner_model.inner_model.forge_objects.unet.model_options['uncond'] = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)

return original_function(self, denoiser_params, cond_scale, cond_composition)

return sampler

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
def sampling_function_hijack(model, x, timestep, uncond, cond, cond_scale, model_options, seed, original_function):
if not global_state.is_enabled or not global_state.prompt_exprs:
return original_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)

prompt = global_state.prompt_exprs[0]
original_strengths, new_strengths = prompt.accept(ForgeStrengthOverride(), cond, 0, False)

model_options['neutral_prompt_override'] = True
model_options['original_strengths'] = original_strengths
return original_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)


class ForgeStrengthOverride:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.LeafPrompt,
cond: List[dict],
index: int,
is_parent_aux: bool,
) -> tuple:
original_strength = cond[index].get('strength', 1.0)
new_strength = original_strength * float(not is_parent_aux and that.conciliation is None)
cond[index]['strength'] = new_strength
return [original_strength], [new_strength]

def visit_composite_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
cond: List[dict],
index: int,
is_parent_aux: bool,
) -> tuple:
original_strengths = []
new_strengths = []

for child in that.children:
child_original_strengths, child_new_strengths = child.accept(ForgeStrengthOverride(), cond, index, is_parent_aux or that.conciliation is not None)
original_strengths.extend(child_original_strengths)
new_strengths.extend(child_new_strengths)

index += child.accept(neutral_prompt_parser.FlatSizeVisitor())

return original_strengths, new_strengths


samplers_hijacker.hijack('sampling_function')(sampling_function_hijack)
forge_sampler_hijacker.hijack('sampling_function')(sampling_function_hijack)


@samplers_hijacker.hijack('calc_cond_uncond_batch')
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options, original_function):
if not global_state.is_enabled or not 'neutral_prompt_override' in model_options.keys():
return original_function(model, cond, uncond, x_in, timestep, model_options)

cond_composition = model_options['cond_composition']
original_strengths = model_options['original_strengths']
if uncond is None:
uncond = model_options['uncond']

for i in range(len(cond)):
cond[i]['strength'] = original_strengths[i]

cond = cond.copy()
cond.extend(uncond)
discard_last = (len(cond) % 2) == 1
if discard_last:
cond.append(cond[-1])

for i in range(len(cond)):
cond[i] = cond[i].copy()
cond[i]['strength'] = 1.0

denoised_latents = [
denoised
for first_cond, second_cond in zip(cond[::2], cond[1::2])
for denoised in original_function(model, [first_cond], [second_cond], x_in, timestep, model_options)
]

if discard_last:
denoised_latents = denoised_latents[:-1]

# B, C, H, W
denoised_uncond = denoised_latents[-1]

# N, B, C, H, W
denoised_in_conds = torch.stack(denoised_latents[:-1], dim=0)
denoised_in_conds = denoised_in_conds.transpose(0, 1).reshape(-1, *denoised_in_conds.shape[2:])

# N, 1, 1, 1, 1
denoised_cond = denoised_uncond.clone()
for batch_i in range(denoised_uncond.shape[0]):
prompt = global_state.prompt_exprs[batch_i]
args = CombineDenoiseArgs(denoised_in_conds, denoised_uncond[batch_i], cond_composition[batch_i])
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
denoised_cond[batch_i] += cond_delta + aux_cond_delta

# consume 'neutral_prompt_override' before returning, in case another extension calls the method
# outside of CFG sampling; for example: extensions-builtin/sd_forge_sag
del model_options['neutral_prompt_override']

return denoised_cond, denoised_uncond
else:
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
return sampler


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)

if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()

return sampler

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
)

return sampler


def warn_unsupported_sampler():
Expand Down
10 changes: 9 additions & 1 deletion lib_neutral_prompt/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import gradio as gr
import dataclasses

try:
# import a forge-specific module
from modules_forge import forge_sampler
del forge_sampler
forge = True
except ImportError:
forge = False


txt2img_prompt_textbox = None
img2img_prompt_textbox = None
Expand All @@ -29,7 +37,7 @@ class AccordionInterface:
def __post_init__(self):
self.is_rendered = False

self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0)
self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0, visible=not forge, interactive=not forge)
self.neutral_prompt = gr.Textbox(label='Neutral prompt', show_label=False, lines=3, placeholder='Neutral prompt (click on apply below to append this to the positive prompt textbox)')
self.neutral_cond_scale = gr.Slider(label='Prompt weight', minimum=-3, maximum=3, value=1)
self.aux_prompt_type = gr.Dropdown(label='Prompt type', choices=list(prompt_types.keys()), value=next(iter(prompt_types.keys())), tooltip=prompt_types_tooltip, elem_id=self.get_elem_id('formatter_prompt_type'))
Expand Down