Skip to content

Commit

Permalink
Merge pull request #403 from Kosinkadink/hook_update_fix
Browse files Browse the repository at this point in the history
Fix Lora Hooks for latest ComfyUI versions
  • Loading branch information
Kosinkadink authored Jun 17, 2024
2 parents e2313c4 + 89e449d commit 8f9d582
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,19 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
# TODO: make this work with timestep scheduling
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
for key in patches:
model_sd = self.model.state_dict()
model_sd = self.model.state_dict()
for k in patches:
offset = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]

if key in model_sd:
p.add(key)
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
current_patches.append((strength_patch, patches[key], strength_model))
current_patches.append((strength_patch, patches[k], strength_model, offset))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
Expand All @@ -160,13 +167,20 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
# TODO: make this work with timestep scheduling
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
for key in patches:
model_sd = self.model.state_dict()
model_sd = self.model.state_dict()
for k in patches:
offset = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]

if key in model_sd:
p.add(key)
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
# take difference between desired weight and existing weight to get diff
current_patches.append((strength_patch, (patches[key]-comfy.utils.get_attr(self.model, key),), strength_model))
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
Expand Down Expand Up @@ -433,7 +447,7 @@ def __init__(self, m: ModelPatcher):
if hasattr(m, "object_patches_backup"):
self.object_patches_backup = m.object_patches_backup
# lora hook stuff
self.hooked_patches = {} # binds LoraHook to specific keys
self.hooked_patches: dict[HookRef] = {} # binds LoraHook to specific keys
self.patches_backup = {}
self.hooked_backup: dict[str, tuple[Tensor, torch.device]] = {}

Expand Down Expand Up @@ -485,16 +499,23 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
'''
Based on add_patches, but for hooked weights.
'''
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {})
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
for key in patches:
model_sd = self.model.state_dict()
model_sd = self.model.state_dict()
for k in patches:
offset = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]

if key in model_sd:
p.add(key)
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
current_patches.append((strength_patch, patches[key], strength_model))
current_patches.append((strength_patch, patches[k], strength_model, offset))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook] = current_hooked_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
self.patches_uuid = uuid.uuid4()
return list(p)
Expand All @@ -503,17 +524,24 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat
'''
Based on add_hooked_patches, but intended for using a model's weights as lora hook.
'''
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {})
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
for key in patches:
model_sd = self.model.state_dict()
model_sd = self.model.state_dict()
for k in patches:
offset = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]

if key in model_sd:
p.add(key)
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
# take difference between desired weight and existing weight to get diff
current_patches.append((strength_patch, (patches[key]-comfy.utils.get_attr(self.model, key),), strength_model))
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook] = current_hooked_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
self.patches_uuid = uuid.uuid4()
return list(p)
Expand All @@ -526,7 +554,7 @@ def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup):
combined_patches = {}
if lora_hooks is not None:
for hook in lora_hooks.hooks:
hook_patches: dict = self.hooked_patches.get(hook, {})
hook_patches: dict = self.hooked_patches.get(hook.hook_ref, {})
for key in hook_patches.keys():
current_patches: list[tuple] = combined_patches.get(key, [])
current_patches.extend(hook_patches[key])
Expand Down

0 comments on commit 8f9d582

Please sign in to comment.