Skip to content

Commit

Permalink
Update resize_lora.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Symbiomatrix authored Aug 15, 2023
1 parent cf80210 commit 9d678a6
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions networks/resize_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,16 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
block_down_name = key.rsplit('lora_down', 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue

# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
lora_up_weight = lora_sd.get(block_up_name + 'lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + 'alpha', None)

weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)

Expand Down Expand Up @@ -263,9 +263,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
verbose_str+=f"\n"

new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
o_lora_sd[block_down_name + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)

block_down_name = None
block_up_name = None
Expand Down

0 comments on commit 9d678a6

Please sign in to comment.