From 56b4ea963ee745b73be1ad53c602854e3ff7ed16 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 26 Oct 2024 22:01:10 +0900 Subject: [PATCH] Fix LoRA metadata hash calculation bug in svd_merge_lora.py, sdxl_merge_lora.py, and resize_lora.py closes #1722 --- README.md | 8 ++++++++ networks/resize_lora.py | 15 ++++++++------- networks/sdxl_merge_lora.py | 15 ++++++++------- networks/svd_merge_lora.py | 15 ++++++++------- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 0be2f9a70..7f8508dc0 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Oct 26, 2024 / 2024-10-26: + +- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue. +- It will be included in the next release. + +- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。 +- 以上は次回リリースに含まれます。 + ### Sep 13, 2024 / 2024-09-13: - `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). diff --git a/networks/resize_lora.py b/networks/resize_lora.py index d697baa4c..7df7ef0cc 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, state_dict, metadata): if model_util.is_safetensors(file_name): save_file(state_dict, file_name, metadata) else: @@ -349,12 +344,18 @@ def str_to_dtype(p): metadata["ss_network_dim"] = "Dynamic" metadata["ss_network_alpha"] = "Dynamic" + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 62f5a87d4..b147eb446 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -35,12 +35,7 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, model, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(model, file_name, metadata=metadata) else: @@ -430,6 +425,12 @@ def str_to_dtype(p): else: state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle) + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) @@ -445,7 +446,7 @@ def str_to_dtype(p): metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index b4b9e3bfd..c520e7f89 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -216,12 +216,7 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, state_dict, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(state_dict, file_name, metadata=metadata) else: @@ -430,6 +425,12 @@ def str_to_dtype(p): args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) @@ -451,7 +452,7 @@ def str_to_dtype(p): metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: