Skip to content

Commit

Permalink
Merge pull request #830 from kohya-ss/dev2
Browse files Browse the repository at this point in the history
add extension checking for resize_lora.py
  • Loading branch information
kohya-ss authored Sep 24, 2023
2 parents 4c6f312 + f2491ee commit 54500b8
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions networks/resize_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
block_down_name = key.rsplit('.lora_down', 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue
Expand Down Expand Up @@ -283,7 +283,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn


def resize(args):
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")


def str_to_dtype(p):
if p == 'float':
return torch.float
Expand Down

0 comments on commit 54500b8

Please sign in to comment.