From 0a7e2d28418f994114bdbe55b87b3317a2adadff Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 16 Jun 2024 21:10:14 -0500 Subject: [PATCH 1/2] Fixed lora hooks for newest ComfyUI versions --- animatediff/model_injection.py | 68 ++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 353adf8..a7fa270 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -141,12 +141,19 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s # TODO: make this work with timestep scheduling current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) p = set() - for key in patches: - model_sd = self.model.state_dict() + model_sd = self.model.state_dict() + for k in patches: + offset = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if key in model_sd: - p.add(key) + p.add(k) current_patches: list[tuple] = current_hooked_patches.get(key, []) - current_patches.append((strength_patch, patches[key], strength_model)) + current_patches.append((strength_patch, patches[k], strength_model, offset)) current_hooked_patches[key] = current_patches self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches # since should care about these patches too to determine if same model, reroll patches_uuid @@ -160,13 +167,20 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng # TODO: make this work with timestep scheduling current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) p = set() - for key in patches: - model_sd = self.model.state_dict() + model_sd = self.model.state_dict() + for k in patches: + offset = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if key in model_sd: - p.add(key) + p.add(k) current_patches: list[tuple] = current_hooked_patches.get(key, []) # take difference between desired weight and existing weight to get diff - current_patches.append((strength_patch, (patches[key]-comfy.utils.get_attr(self.model, key),), strength_model)) + current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset)) current_hooked_patches[key] = current_patches self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches # since should care about these patches too to determine if same model, reroll patches_uuid @@ -485,16 +499,23 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s ''' Based on add_patches, but for hooked weights. ''' - current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {}) + current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) p = set() - for key in patches: - model_sd = self.model.state_dict() + model_sd = self.model.state_dict() + for k in patches: + offset = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if key in model_sd: - p.add(key) + p.add(k) current_patches: list[tuple] = current_hooked_patches.get(key, []) - current_patches.append((strength_patch, patches[key], strength_model)) + current_patches.append((strength_patch, patches[k], strength_model, offset)) current_hooked_patches[key] = current_patches - self.hooked_patches[lora_hook] = current_hooked_patches + self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches # since should care about these patches too to determine if same model, reroll patches_uuid self.patches_uuid = uuid.uuid4() return list(p) @@ -503,17 +524,24 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat ''' Based on add_hooked_patches, but intended for using a model's weights as lora hook. ''' - current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {}) + current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) p = set() - for key in patches: - model_sd = self.model.state_dict() + model_sd = self.model.state_dict() + for k in patches: + offset = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if key in model_sd: - p.add(key) + p.add(k) current_patches: list[tuple] = current_hooked_patches.get(key, []) # take difference between desired weight and existing weight to get diff - current_patches.append((strength_patch, (patches[key]-comfy.utils.get_attr(self.model, key),), strength_model)) + current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset)) current_hooked_patches[key] = current_patches - self.hooked_patches[lora_hook] = current_hooked_patches + self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches # since should care about these patches too to determine if same model, reroll patches_uuid self.patches_uuid = uuid.uuid4() return list(p) From 89e449d757f97aa10fbf10d6e712793c81708678 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 16 Jun 2024 21:35:58 -0500 Subject: [PATCH 2/2] Fix small error in updated code for ModelPatcherCLIPHooks --- animatediff/model_injection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index a7fa270..e5a666c 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -447,7 +447,7 @@ def __init__(self, m: ModelPatcher): if hasattr(m, "object_patches_backup"): self.object_patches_backup = m.object_patches_backup # lora hook stuff - self.hooked_patches = {} # binds LoraHook to specific keys + self.hooked_patches: dict[HookRef] = {} # binds LoraHook to specific keys self.patches_backup = {} self.hooked_backup: dict[str, tuple[Tensor, torch.device]] = {} @@ -554,7 +554,7 @@ def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup): combined_patches = {} if lora_hooks is not None: for hook in lora_hooks.hooks: - hook_patches: dict = self.hooked_patches.get(hook, {}) + hook_patches: dict = self.hooked_patches.get(hook.hook_ref, {}) for key in hook_patches.keys(): current_patches: list[tuple] = combined_patches.get(key, []) current_patches.extend(hook_patches[key])