diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 416197586a1..f859a50d4c3 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -287,13 +287,13 @@ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): for k in sd: weight = sd[k] try: - comfy.utils.set_attr(self.control_model, k, weight) + comfy.utils.set_attr_param(self.control_model, k, weight) except: pass for k in self.control_weights: if k not in {"lora_controlnet"}: - comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f29781f319a..604e347799c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -176,10 +176,9 @@ def model_state_dict(self, filter_prefix=None): def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: - old = comfy.utils.get_attr(self.model, k) + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old - comfy.utils.set_attr(self.model, k, self.object_patches[k]) if patch_weights: model_sd = self.model_state_dict() @@ -203,7 +202,7 @@ def patch_model(self, device_to=None, patch_weights=True): if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: - comfy.utils.set_attr(self.model, key, out_weight) + comfy.utils.set_attr_param(self.model, key, out_weight) del temp_weight if device_to is not None: @@ -342,7 +341,7 @@ def unpatch_model(self, device_to=None): comfy.utils.copy_to_param(self.model, k, self.backup[k]) else: for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) + comfy.utils.set_attr_param(self.model, k, self.backup[k]) self.backup = {} diff --git a/comfy/utils.py b/comfy/utils.py index 41f730c8ecf..5deb14cd2de 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -294,8 +294,11 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) - del prev + setattr(obj, attrs[-1], value) + return prev + +def set_attr_param(obj, attr, value): + return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it