diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index ce9e917..0204557 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -30,7 +30,7 @@ def safe_save( class LoraInjectedLinear(nn.Module): - def __init__(self, in_features, out_features, bias=False, r=4): + def __init__(self, in_features, out_features, bias=False, r=4, dropout_p=0.1): super().__init__() if r > min(in_features, out_features): @@ -40,6 +40,7 @@ def __init__(self, in_features, out_features, bias=False, r=4): self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) + self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = 1.0 @@ -47,7 +48,10 @@ def __init__(self, in_features, out_features, bias=False, r=4): nn.init.zeros_(self.lora_up.weight) def forward(self, input): - return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale + return ( + self.linear(input) + + self.lora_up(self.dropout(self.lora_down(input))) * self.scale + ) class LoraInjectedConv2d(nn.Module): @@ -63,6 +67,7 @@ def __init__( bias: bool = True, padding_mode: str = "zeros", r: int = 4, + dropout_p : float = 0.1 ): super().__init__() if r > min(in_channels, out_channels): @@ -91,6 +96,7 @@ def __init__( groups=groups, bias=False, ) + self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Conv2d( in_channels=r, out_channels=out_channels, @@ -105,7 +111,10 @@ def __init__( nn.init.zeros_(self.lora_up.weight) def forward(self, input): - return self.conv(input) + self.lora_up(self.lora_down(input)) * self.scale + return ( + self.conv(input) + + self.lora_up(self.dropout(self.lora_down(input))) * self.scale + ) UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}