Skip to content

Commit

Permalink
Update prompt_injection.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DataCTE authored Jun 15, 2024
1 parent d6699be commit c67bf72
Showing 1 changed file with 174 additions and 46 deletions.
220 changes: 174 additions & 46 deletions prompt_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,123 @@
import comfy.model_patcher
import comfy.samplers
import torch
import torch.nn.functional as F

def build_patch(patchedBlocks, weight=1.0, sigma_start=0.0, sigma_end=1.0):
def prompt_injection_patch(n, context_attn1: torch.Tensor, value_attn1, extra_options):
(block, block_index) = extra_options.get('block', (None,None))
sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9

batch_prompt = n.shape[0] // len(extra_options["cond_or_uncond"])

if sigma <= sigma_start and sigma >= sigma_end:
if (block and f'{block}:{block_index}' in patchedBlocks and patchedBlocks[f'{block}:{block_index}']):
c = context_attn1[0]
if c.dim() == 2:
c = c.unsqueeze(0)
cond = torch.stack(
(
c,
patchedBlocks[f'{block}:{block_index}'][0][0].to(context_attn1.device)
)
).to(dtype=context_attn1.dtype)

return n, cond, cond * weight
if context_attn1.dim() == 3:
c = context_attn1[0].unsqueeze(0)
else:
c = context_attn1[0][0].unsqueeze(0)
b = patchedBlocks[f'{block}:{block_index}'][0][0].repeat(c.shape[0], 1, 1).to(context_attn1.device)
out = torch.stack((c, b)).to(dtype=context_attn1.dtype) * weight
out = out.repeat(1, batch_prompt, 1, 1) * weight

return n, out, out

return n, context_attn1, value_attn1
return prompt_injection_patch

def build_svd_patch(patchedBlocks, weight=1.0, sigma_start=0.0, sigma_end=1.0):
def prompt_injection_patch(n, context_attn1: torch.Tensor, value_attn1, extra_options):
(block, block_index) = extra_options.get('block', (None, None))
sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9

if sigma_start <= sigma <= sigma_end:
if block and f'{block}:{block_index}' in patchedBlocks and patchedBlocks[f'{block}:{block_index}']:
if context_attn1.dim() == 3:
c = context_attn1[0].unsqueeze(0)
else:
c = context_attn1[0][0].unsqueeze(0)
b = patchedBlocks[f'{block}:{block_index}'][0][0].repeat(c.shape[0], 1, 1).to(context_attn1.device)

# Interpolate to match the sizes
if c.size() != b.size():
b = F.interpolate(b.unsqueeze(0), size=c.size()[1:], mode='nearest').squeeze(0)

out = torch.cat((c, b), dim=-1).to(dtype=context_attn1.dtype) * weight
return n, out # Ensure exactly two values are returned for SVD
return n, context_attn1, value_attn1 # Ensure exactly three values are returned

return prompt_injection_patch

class SVDPromptInjection:
@classmethod
def INPUT_TYPES(s):
return {
"required": {"model": ("MODEL",)},
"optional": {
"all": ("CONDITIONING",),
"time_embed": ("CONDITIONING",),
"label_emb": ("CONDITIONING",),
"input_blocks_0": ("CONDITIONING",),
"input_blocks_1": ("CONDITIONING",),
"input_blocks_2": ("CONDITIONING",),
"input_blocks_3": ("CONDITIONING",),
"input_blocks_4": ("CONDITIONING",),
"input_blocks_5": ("CONDITIONING",),
"input_blocks_6": ("CONDITIONING",),
"input_blocks_7": ("CONDITIONING",),
"input_blocks_8": ("CONDITIONING",),
"middle_block_0": ("CONDITIONING",),
"middle_block_1": ("CONDITIONING",),
"middle_block_2": ("CONDITIONING",),
"output_blocks_0": ("CONDITIONING",),
"output_blocks_1": ("CONDITIONING",),
"output_blocks_2": ("CONDITIONING",),
"output_blocks_3": ("CONDITIONING",),
"output_blocks_4": ("CONDITIONING",),
"output_blocks_5": ("CONDITIONING",),
"output_blocks_6": ("CONDITIONING",),
"output_blocks_7": ("CONDITIONING",),
"output_blocks_8": ("CONDITIONING",),
"weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"

def patch(self, model: comfy.model_patcher.ModelPatcher, all=None, time_embed=None, label_emb=None, input_blocks_0=None, input_blocks_1=None, input_blocks_2=None, input_blocks_3=None, input_blocks_4=None, input_blocks_5=None, input_blocks_6=None, input_blocks_7=None, input_blocks_8=None, middle_block_0=None, middle_block_1=None, middle_block_2=None, output_blocks_0=None, output_blocks_1=None, output_blocks_2=None, output_blocks_3=None, output_blocks_4=None, output_blocks_5=None, output_blocks_6=None, output_blocks_7=None, output_blocks_8=None, weight=1.0, start_at=0.0, end_at=1.0):
if not any((all, time_embed, label_emb, input_blocks_0, input_blocks_1, input_blocks_2, input_blocks_3, input_blocks_4, input_blocks_5, input_blocks_6, input_blocks_7, input_blocks_8, middle_block_0, middle_block_1, middle_block_2, output_blocks_0, output_blocks_1, output_blocks_2, output_blocks_3, output_blocks_4, output_blocks_5, output_blocks_6, output_blocks_7, output_blocks_8)):
return (model,)

m = model.clone()
sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_at)
sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_at)

patchedBlocks = {}
blocks = {
'time_embed': [0],
'label_emb': [0],
'input_blocks': list(range(9)),
'middle_block': list(range(3)),
'output_blocks': list(range(9))
}

for block in blocks:
for index in blocks[block]:
block_name = f"{block}_{index}"
value = locals().get(block_name, None)
if value is None:
value = all
if value is not None:
patchedBlocks[f"{block}:{index}"] = value

m.set_model_attn2_patch(build_svd_patch(patchedBlocks, weight=weight, sigma_start=sigma_start, sigma_end=sigma_end))

return (m,)

class PromptInjection:
@classmethod
def INPUT_TYPES(s):
Expand All @@ -37,17 +131,18 @@ def INPUT_TYPES(s):
"model": ("MODEL",),
},
"optional": {
"in4": ("CONDITIONING",),
"in5": ("CONDITIONING",),
"in7": ("CONDITIONING",),
"in8": ("CONDITIONING",),
"mid0": ("CONDITIONING",),
"out0": ("CONDITIONING",),
"out1": ("CONDITIONING",),
"out2": ("CONDITIONING",),
"out3": ("CONDITIONING",),
"out4": ("CONDITIONING",),
"out5": ("CONDITIONING",),
"all": ("CONDITIONING",),
"input_4": ("CONDITIONING",),
"input_5": ("CONDITIONING",),
"input_7": ("CONDITIONING",),
"input_8": ("CONDITIONING",),
"middle_0": ("CONDITIONING",),
"output_0": ("CONDITIONING",),
"output_1": ("CONDITIONING",),
"output_2": ("CONDITIONING",),
"output_3": ("CONDITIONING",),
"output_4": ("CONDITIONING",),
"output_5": ("CONDITIONING",),
"weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
Expand All @@ -59,26 +154,57 @@ def INPUT_TYPES(s):

CATEGORY = "advanced/model"

def patch(self, model: comfy.model_patcher.ModelPatcher, in4=None, in5=None, in7=None, in8=None, mid0=None, out0=None, out1=None, out2=None, out3=None, out4=None, out5=None, weight=1.0, start_at=0.0, end_at=1.0):
def patch(self, model: comfy.model_patcher.ModelPatcher, all=None, input_4=None, input_5=None, input_7=None, input_8=None, middle_0=None, output_0=None, output_1=None, output_2=None, output_3=None, output_4=None, output_5=None, weight=1.0, start_at=0.0, end_at=1.0):
if not any((all, input_4, input_5, input_7, input_8, middle_0, output_0, output_1, output_2, output_3, output_4, output_5)):
return (model,)

m = model.clone()
sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_at)
sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_at)

if any((in4, in5, in7, in8, mid0, out0, out1, out2, out3, out4, out5)):
patchedBlocks = {
'input:4': in4,
'input:5': in5,
'input:7': in7,
'input:8': in8,
'middle:0': mid0,
'output:0': out0,
'output:1': out1,
'output:2': out2,
'output:3': out3,
'output:4': out4,
'output:5': out5,
patchedBlocks = {}
blocks = {'input': [4, 5, 7, 8], 'middle': [0], 'output': [0, 1, 2, 3, 4, 5]}

for block in blocks:
for index in blocks[block]:
value = locals()[f"{block}_{index}"] if locals()[f"{block}_{index}"] is not None else all
if value is not None:
patchedBlocks[f"{block}:{index}"] = value

m.set_model_attn2_patch(build_patch(patchedBlocks, weight=weight, sigma_start=sigma_start, sigma_end=sigma_end))

return (m,)

class SimplePromptInjection:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
},
"optional": {
"block": (["input:4", "input:5", "input:7", "input:8", "middle:0", "output:0", "output:1", "output:2", "output:3", "output:4", "output:5"],),
"conditioning": ("CONDITIONING",),
"weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}
m.set_model_attn2_patch(build_patch(patchedBlocks, weight=weight, sigma_start=sigma_start, sigma_end=sigma_end))
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model: comfy.model_patcher.ModelPatcher, block, conditioning=None, weight=1.0, start_at=0.0, end_at=1.0):
if conditioning is None:
return (model,)

m = model.clone()
sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_at)
sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_at)

m.set_model_attn2_patch(build_patch({f"{block}": conditioning}, weight=weight, sigma_start=sigma_start, sigma_end=sigma_end))

return (m,)

Expand All @@ -90,8 +216,7 @@ def INPUT_TYPES(s):
"model": ("MODEL",),
},
"optional": {
"block": (("output", "middle", "input"),),
"index": ("INT", {"default": 0, "min": 0, "max": 8, "step": 1}),
"block": (["input:4", "input:5", "input:7", "input:8", "middle:0", "output:0", "output:1", "output:2", "output:3", "output:4", "output:5"],),
"conditioning": ("CONDITIONING",),
"weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
Expand All @@ -104,15 +229,15 @@ def INPUT_TYPES(s):

CATEGORY = "advanced/model"

def patch(self, model: comfy.model_patcher.ModelPatcher, block, index, conditioning=None, weight=1.0, start_at=0.0, end_at=1.0):
if not conditioning:
def patch(self, model: comfy.model_patcher.ModelPatcher, block, conditioning=None, weight=1.0, start_at=0.0, end_at=1.0):
if conditioning is None:
return (model,)

m = model.clone()
sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_at)
sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_at)

m.set_model_attn2_patch(build_patch({f"{block}:{index}": conditioning}, weight=weight, sigma_start=sigma_start, sigma_end=sigma_end))
m.set_model_attn2_patch(build_patch({f"{block}": conditioning}, weight=weight, sigma_start=sigma_start, sigma_end=sigma_end))

return (m,)

Expand Down Expand Up @@ -156,14 +281,17 @@ def patch(self, model: comfy.model_patcher.ModelPatcher, locations: str, conditi

return (m,)


NODE_CLASS_MAPPINGS = {
"PromptInjection": PromptInjection,
"SimplePromptInjection": SimplePromptInjection,
"AdvancedPromptInjection": AdvancedPromptInjection
"AdvancedPromptInjection": AdvancedPromptInjection,
"SVDPromptInjection": SVDPromptInjection
}

NODE_DISPLAY_NAME_MAPPINGS = {
"PromptInjection": "Inject Prompt in Attention",
"SimplePromptInjection": "Inject Prompt in Attention (simple)",
"AdvancedPromptInjection": "Inject Prompt in Attention (advanced)"
"PromptInjection": "Attn2 Prompt Injection",
"SimplePromptInjection": "Attn2 Prompt Injection (simple)",
"AdvancedPromptInjection": "Attn2 Prompt Injection (advanced)",
"SVDPromptInjection": "Attn2 SVD Prompt Injection"
}

0 comments on commit c67bf72

Please sign in to comment.