Skip to content

Commit

Permalink
Merge branch 'sd3' into sd3_5_support
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 26, 2024
2 parents 014064f + 8549669 commit 150579d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 21 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,14 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。


### Oct 26, 2024 / 2024-10-26:

- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
- It will be included in the next release.

- `svd_merge_lora.py``sdxl_merge_lora.py``resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
- 以上は次回リリースに含まれます。

### Sep 13, 2024 / 2024-09-13:

- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
Expand Down
15 changes: 8 additions & 7 deletions networks/resize_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype):
return sd, metadata


def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)

def save_to_file(file_name, state_dict, metadata):
if model_util.is_safetensors(file_name):
save_file(state_dict, file_name, metadata)
else:
Expand Down Expand Up @@ -349,12 +344,18 @@ def str_to_dtype(p):
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"

# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)

model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash

logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)


def setup_parser() -> argparse.ArgumentParser:
Expand Down
15 changes: 8 additions & 7 deletions networks/sdxl_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@ def load_state_dict(file_name, dtype):
return sd, metadata


def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)

def save_to_file(file_name, model, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata=metadata)
else:
Expand Down Expand Up @@ -430,6 +425,12 @@ def str_to_dtype(p):
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)

# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)

logger.info(f"calculating hashes and creating metadata...")

model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
Expand All @@ -445,7 +446,7 @@ def str_to_dtype(p):
metadata.update(sai_metadata)

logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)


def setup_parser() -> argparse.ArgumentParser:
Expand Down
15 changes: 8 additions & 7 deletions networks/svd_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,7 @@ def load_state_dict(file_name, dtype):
return sd, metadata


def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)

def save_to_file(file_name, state_dict, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name, metadata=metadata)
else:
Expand Down Expand Up @@ -430,6 +425,12 @@ def str_to_dtype(p):
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
)

# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)

logger.info(f"calculating hashes and creating metadata...")

model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
Expand All @@ -451,7 +452,7 @@ def str_to_dtype(p):
metadata.update(sai_metadata)

logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)


def setup_parser() -> argparse.ArgumentParser:
Expand Down

0 comments on commit 150579d

Please sign in to comment.