Skip to content

Commit

Permalink
fix load_state_dict failed in dylora
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Apr 14, 2023
1 parent 06a9f51 commit 92332eb
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def forward(self, x):
def state_dict(self, destination=None, prefix="", keep_vars=False):
# state dictを通常のLoRAと同じにする:
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
sd = super().state_dict(destination, prefix, keep_vars)
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
if self.is_conv2d and not self.is_conv2d_3x3:
Expand All @@ -129,7 +129,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()

i = 0
while True:
while True:
key_a = f"{self.lora_name}.lora_A.{i}"
key_b = f"{self.lora_name}.lora_B.{i}"
if key_a in sd:
Expand All @@ -140,10 +140,8 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
i += 1
return sd

def load_state_dict(self, state_dict, strict=True):
# 通常のLoRAと同じstate dictを読み込めるようにする
state_dict = state_dict.copy()

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)

Expand All @@ -152,15 +150,19 @@ def load_state_dict(self, state_dict, strict=True):
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
else:
return

if self.is_conv2d and not self.is_conv2d_3x3:
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)

state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))})
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))})
state_dict.update(
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
)
state_dict.update(
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
)

super().load_state_dict(state_dict, strict=strict)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
Expand Down

0 comments on commit 92332eb

Please sign in to comment.