Skip to content

Commit

Permalink
utils.set_attr can now be used to set any attribute.
Browse files Browse the repository at this point in the history
The old set_attr has been renamed to set_attr_param.
  • Loading branch information
comfyanonymous committed Mar 2, 2024
1 parent dce3555 commit 1abf837
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
for k in sd:
weight = sd[k]
try:
comfy.utils.set_attr(self.control_model, k, weight)
comfy.utils.set_attr_param(self.control_model, k, weight)
except:
pass

for k in self.control_weights:
if k not in {"lora_controlnet"}:
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))

def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
Expand Down
7 changes: 3 additions & 4 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,9 @@ def model_state_dict(self, filter_prefix=None):

def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = comfy.utils.get_attr(self.model, k)
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
comfy.utils.set_attr(self.model, k, self.object_patches[k])

if patch_weights:
model_sd = self.model_state_dict()
Expand All @@ -203,7 +202,7 @@ def patch_model(self, device_to=None, patch_weights=True):
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr(self.model, key, out_weight)
comfy.utils.set_attr_param(self.model, key, out_weight)
del temp_weight

if device_to is not None:
Expand Down Expand Up @@ -342,7 +341,7 @@ def unpatch_model(self, device_to=None):
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
comfy.utils.set_attr_param(self.model, k, self.backup[k])

self.backup = {}

Expand Down
7 changes: 5 additions & 2 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,11 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
setattr(obj, attrs[-1], value)
return prev

def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))

def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
Expand Down

0 comments on commit 1abf837

Please sign in to comment.