Skip to content

Commit

Permalink
Fix control loras breaking.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Mar 14, 2024
1 parent db8b59e commit 448d926
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 32 deletions.
4 changes: 2 additions & 2 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def cleanup(self):
super().cleanup()

class ControlLoraOps:
class Linear(torch.nn.Module):
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -220,7 +220,7 @@ def forward(self, input):
else:
return torch.nn.functional.linear(input, weight, bias)

class Conv2d(torch.nn.Module):
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
self,
in_channels,
Expand Down
40 changes: 10 additions & 30 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def cast_bias_weight(s, input):
weight = s.weight_function(weight)
return weight, bias

class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None

class disable_weight_init:
class Linear(torch.nn.Linear):
comfy_cast_weights = False
weight_function = None
bias_function = None

class Linear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
return None

Expand All @@ -51,11 +51,7 @@ def forward(self, *args, **kwargs):
else:
return super().forward(*args, **kwargs)

class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
weight_function = None
bias_function = None

class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
def reset_parameters(self):
return None

Expand All @@ -69,11 +65,7 @@ def forward(self, *args, **kwargs):
else:
return super().forward(*args, **kwargs)

class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
weight_function = None
bias_function = None

class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
def reset_parameters(self):
return None

Expand All @@ -87,11 +79,7 @@ def forward(self, *args, **kwargs):
else:
return super().forward(*args, **kwargs)

class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
weight_function = None
bias_function = None

class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
def reset_parameters(self):
return None

Expand All @@ -106,11 +94,7 @@ def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)


class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
weight_function = None
bias_function = None

class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None

Expand All @@ -128,11 +112,7 @@ def forward(self, *args, **kwargs):
else:
return super().forward(*args, **kwargs)

class ConvTranspose2d(torch.nn.ConvTranspose2d):
comfy_cast_weights = False
weight_function = None
bias_function = None

class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None

Expand Down

0 comments on commit 448d926

Please sign in to comment.