diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 1a72412b19e..b6941d8c425 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -201,7 +201,7 @@ def cleanup(self): super().cleanup() class ControlLoraOps: - class Linear(torch.nn.Module): + class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -220,7 +220,7 @@ def forward(self, input): else: return torch.nn.functional.linear(input, weight, bias) - class Conv2d(torch.nn.Module): + class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( self, in_channels, diff --git a/comfy/ops.py b/comfy/ops.py index cfdec355c20..eb6507682d1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -31,13 +31,13 @@ def cast_bias_weight(s, input): weight = s.weight_function(weight) return weight, bias +class CastWeightBiasOp: + comfy_cast_weights = False + weight_function = None + bias_function = None class disable_weight_init: - class Linear(torch.nn.Linear): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None @@ -51,11 +51,7 @@ def forward(self, *args, **kwargs): else: return super().forward(*args, **kwargs) - class Conv2d(torch.nn.Conv2d): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -69,11 +65,7 @@ def forward(self, *args, **kwargs): else: return super().forward(*args, **kwargs) - class Conv3d(torch.nn.Conv3d): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None @@ -87,11 +79,7 @@ def forward(self, *args, **kwargs): else: return super().forward(*args, **kwargs) - class GroupNorm(torch.nn.GroupNorm): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -106,11 +94,7 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) - class LayerNorm(torch.nn.LayerNorm): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -128,11 +112,7 @@ def forward(self, *args, **kwargs): else: return super().forward(*args, **kwargs) - class ConvTranspose2d(torch.nn.ConvTranspose2d): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None