diff --git a/scripts/lora_compvis.py b/scripts/lora_compvis.py index 7d58ac5..f8f940d 100644 --- a/scripts/lora_compvis.py +++ b/scripts/lora_compvis.py @@ -369,7 +369,9 @@ def merge_weights(weight, up_weight, down_weight): for t in ["q", "k", "v", "out"]: del state_dict[f"{lora_info.module_name}_{t}_proj.lora_down.weight"] del state_dict[f"{lora_info.module_name}_{t}_proj.lora_up.weight"] - del state_dict[f"{lora_info.module_name}_{t}_proj.alpha"] + alpha_key = f"{lora_info.module_name}_{t}_proj.alpha" + if alpha_key in state_dict: + del state_dict[alpha_key] else: # corresponding weight not exists: version mismatch pass