From f2491ee0ac4727b49110764dd2862d8aa316b048 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 24 Sep 2023 12:10:56 +0900 Subject: [PATCH] change block name doesn't contain '.' at end --- networks/resize_lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 39d4c9072..03fc545e7 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -219,7 +219,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn for key, value in tqdm(lora_sd.items()): weight_name = None if 'lora_down' in key: - block_down_name = key.rsplit('lora_down', 1)[0] + block_down_name = key.rsplit('.lora_down', 1)[0] weight_name = key.rsplit(".", 1)[-1] lora_down_weight = value else: @@ -227,8 +227,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn # find corresponding lora_up and alpha block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + 'lora_up.' + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + 'alpha', None) + lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + '.alpha', None) weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) @@ -263,9 +263,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn verbose_str+=f"\n" new_alpha = param_dict['new_alpha'] - o_lora_sd[block_down_name + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) block_down_name = None block_up_name = None