Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stable-diffusion-webui-forge #69

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Neutral prompt is an a1111 webui extension that adds alternative composable diff

## Features

- Now compatible wih [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge)!
- [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword
- saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons)
- semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247)
Expand Down
126 changes: 110 additions & 16 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,121 @@ def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
return vector * (torch.abs(vector) >= top_k).to(vector.dtype)


sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
try:
from modules_forge import forge_sampler
forge = True
except ImportError:
forge = False


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()
if forge:
from ldm_patched.modules.samplers import sampling_function

return sampler
forge_sampler_hijacker = hijacker.ModuleHijacker.install_or_get(
module=forge_sampler,
hijacker_attribute='__forge_sample_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)

@forge_sampler_hijacker.hijack('forge_sample')
def forge_sample(self, denoiser_params, cond_scale, cond_composition, original_function):
if not global_state.is_enabled:
return original_function(self, denoiser_params, cond_scale, cond_composition)

model = self.inner_model.inner_model.forge_objects.unet.model
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
x = denoiser_params.x
timestep = denoiser_params.sigma
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
seed = self.p.seeds[0]

uncond = forge_sampler.cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
conds = forge_sampler.cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition)
conds += uncond

denoised = []

for current_cond in conds:
cond = [current_cond]
cond[0]['strength'] = 1.0

if extra_concat_condition is not None:
image_cond_in = extra_concat_condition
else:
image_cond_in = denoiser_params.image_cond

if isinstance(image_cond_in, torch.Tensor):
if image_cond_in.shape[0] == x.shape[0] \
and image_cond_in.shape[2] == x.shape[2] \
and image_cond_in.shape[3] == x.shape[3]:
cond[0]['model_conds']['c_concat'] = CONDRegular(image_cond_in)

if control is not None:
cond[0]['control'] = control

for modifier in model_options.get('conditioning_modifiers', []):
model, x, timestep, _, cond, cond_scale, model_options, seed = modifier(model, x, timestep, None, cond, cond_scale, model_options, seed)

model_options["disable_cfg1_optimization"] = True

result = sampling_function(model, x, timestep, None, cond, 1.0, model_options, seed)
denoised.append(result)

cond_indices = cond_composition[0]
prompt = global_state.prompt_exprs[0]

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
# B, C, H, W
denoised_uncond = denoised[-1]

# N, B, C, H, W
denoised_conds = torch.stack(denoised[:-1], dim=0)
ljleb marked this conversation as resolved.
Show resolved Hide resolved

# N, 1, 1, 1, 1
weights = torch.tensor([ weight for (_, weight) in cond_indices ], device=denoised_uncond.device)
weights /= weights.abs().sum()
wbclark marked this conversation as resolved.
Show resolved Hide resolved
weights = weights.view(-1, 1, 1, 1, 1)

# B, C, H, W
assert denoised_conds.shape[0] == weights.shape[0]
denoised_cond = (denoised_conds * weights).sum(dim=0)
forge_denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale

for batch_i in range(denoised_uncond.shape[0]):
args = CombineDenoiseArgs(denoised_conds.unbind(dim=1)[batch_i], denoised_uncond[batch_i], cond_indices)
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
cfg_cond = forge_denoised[batch_i] + aux_cond_delta * cond_scale
forge_denoised[batch_i] = cfg_rescale(cfg_cond, denoised_uncond[batch_i] + cond_delta + aux_cond_delta)
ljleb marked this conversation as resolved.
Show resolved Hide resolved

return forge_denoised


else:
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
return sampler


@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)


if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()

return sampler

sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
)

return sampler


def warn_unsupported_sampler():
Expand Down