Skip to content

Commit

Permalink
adds dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Jan 14, 2023
1 parent 204d6ed commit 583b1e7
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def safe_save(


class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
def __init__(self, in_features, out_features, bias=False, r=4, dropout_p=0.1):
super().__init__()

if r > min(in_features, out_features):
Expand All @@ -40,14 +40,18 @@ def __init__(self, in_features, out_features, bias=False, r=4):

self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.dropout = nn.Dropout(dropout_p)
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = 1.0

nn.init.normal_(self.lora_down.weight, std=1 / r)
nn.init.zeros_(self.lora_up.weight)

def forward(self, input):
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
return (
self.linear(input)
+ self.lora_up(self.dropout(self.lora_down(input))) * self.scale
)


class LoraInjectedConv2d(nn.Module):
Expand All @@ -63,6 +67,7 @@ def __init__(
bias: bool = True,
padding_mode: str = "zeros",
r: int = 4,
dropout_p : float = 0.1
):
super().__init__()
if r > min(in_channels, out_channels):
Expand Down Expand Up @@ -91,6 +96,7 @@ def __init__(
groups=groups,
bias=False,
)
self.dropout = nn.Dropout(dropout_p)
self.lora_up = nn.Conv2d(
in_channels=r,
out_channels=out_channels,
Expand All @@ -105,7 +111,10 @@ def __init__(
nn.init.zeros_(self.lora_up.weight)

def forward(self, input):
return self.conv(input) + self.lora_up(self.lora_down(input)) * self.scale
return (
self.conv(input)
+ self.lora_up(self.dropout(self.lora_down(input))) * self.scale
)


UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
Expand Down

0 comments on commit 583b1e7

Please sign in to comment.