Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stable-diffusion-webui-forge #69

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
53 changes: 24 additions & 29 deletions lib_neutral_prompt/cfg_denoiser_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import sys
import textwrap

from modules_forge.forge_sampler import cond_from_a1111_to_patched_ldm


def combine_denoised_hijack(
x_out: torch.Tensor,
Expand Down Expand Up @@ -223,13 +225,12 @@ def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
if forge:
forge_sampler_hijacker = hijacker.ModuleHijacker.install_or_get(
module=forge_sampler,
hijacker_attribute='__forge_sample_hijacker',
hijacker_attribute='__forge_sampler_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)

ldm_patched_modules_hijacker = hijacker.ModuleHijacker.install_or_get(
samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=samplers,
hijacker_attribute='__ldm_patched_modules_hijacker',
hijacker_attribute='__samplers_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)

Expand All @@ -240,58 +241,53 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition, original_f
return original_function(self, denoiser_params, cond_scale, cond_composition)

self.inner_model.inner_model.forge_objects.unet.model_options['cond_composition'] = cond_composition
ljleb marked this conversation as resolved.
Show resolved Hide resolved
self.inner_model.inner_model.forge_objects.unet.model_options['uncond'] = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)

return original_function(self, denoiser_params, cond_scale, cond_composition)

@ldm_patched_modules_hijacker.hijack('samplers')

@samplers_hijacker.hijack('calc_cond_uncond_batch')
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options, original_function):
if not global_state.is_enabled:
return original_function(model, cond, uncond, x_in, timestep, model_options)

cond_composition = model_options['cond_composition']
uncond = model_options['uncond']
ljleb marked this conversation as resolved.
Show resolved Hide resolved

cond.extend(uncond)
discard_last = (len(cond) % 2) == 1
if discard_last:
cond.append(cond[0])
remove_last_cond = (len(cond) % 2) == 1
if remove_last_cond:
cond.append(cond[-1])

weights = []
for elem in cond:
weights.append((elem['strength'] if 'strength' in elem else 1.0))
elem['strength'] = 1.0

denoised_latents = []
cond_it = iter(cond)
for first_cond, second_cond in zip(cond_it, cond_it):
first_latent, second_latent = original_function(model, first_cond, second_cond, x_in, timestep, model_options)
denoised_latents.extend([first_latent, second_latent])
denoised_latents = [
denoised
for first_cond, second_cond in zip(cond[::2], cond[1::2])
for denoised in original_function(model, [first_cond], [second_cond], x_in, timestep, model_options)
]

if discard_last:
if remove_last_cond:
denoised_latents = denoised_latents[:-1]
weights = weights[:-1]

# B, C, H, W
denoised_uncond = denoised_latents[-1]

# N, B, C, H, W
denoised_conds = torch.stack(denoised[:-1], dim=0)
denoised_in_conds = torch.stack(denoised_latents[:-1], dim=0)
denoised_in_conds = denoised_in_conds.transpose(0, 1).reshape(-1, *denoised_in_conds.shape[2:])

# N, 1, 1, 1, 1
weights = torch.tensor(weights, device=denoised_uncond.device).view(-1, 1, 1, 1, 1)

# B, C, H, W
denoised_cond = (denoised_conds * weights).sum(dim=0)
denoised_cond = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale

denoised_cond = denoised_uncond.clone()
for batch_i in range(denoised_uncond.shape[0]):
prompt = global_state.prompt_exprs[batch_i]
args = CombineDenoiseArgs(denoised_conds.unbind(dim=1)[batch_i], denoised_uncond[batch_i], cond_composition[batch_i])
args = CombineDenoiseArgs(denoised_in_conds, denoised_uncond[batch_i], cond_composition[batch_i])
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
denoised_cond[batch_i] += aux_cond_delta * cond_scale
denoised_cond[batch_i] += cond_delta + aux_cond_delta

return denoised_cond, denoised_uncond


else:
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
Expand All @@ -304,7 +300,6 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options, o
def create_sampler_hijack(name: str, model, original_function):
sampler = original_function(name, model)


if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
if global_state.is_enabled:
warn_unsupported_sampler()
Expand Down