From 0fb9ecf1f39301d5362b7a5d143eabc949604128 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 25 Nov 2023 21:05:55 +0900 Subject: [PATCH] format by black, add ja comment --- networks/extract_lora_from_models.py | 39 +++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 6c45d9770..6357df55d 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -13,8 +13,8 @@ import lora -#CLAMP_QUANTILE = 0.99 -#MIN_DIFF = 1e-1 +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): @@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False): +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + v2=None, + sdxl=None, + conv_dim=None, + v_parameterization=None, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, +): def str_to_dtype(p): if p == "float": return torch.float @@ -39,9 +53,7 @@ def str_to_dtype(p): return torch.bfloat16 return None - assert v2 != sdxl or ( - not v2 and not sdxl - ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" + assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" if v_parameterization is None: v_parameterization = v2 @@ -199,9 +211,7 @@ def str_to_dtype(p): if not no_metadata: title = os.path.splitext(os.path.basename(save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - None, v2, v_parameterization, sdxl, True, False, time.time(), title=title - ) + sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) metadata.update(sai_metadata) lora_network_save.save_weights(save_to, save_dtype, metadata) @@ -242,7 +252,11 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", ) parser.add_argument( - "--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") parser.add_argument( @@ -256,13 +270,14 @@ def setup_parser() -> argparse.ArgumentParser: "--clamp_quantile", type=float, default=0.99, - help="Quantile clamping value, float, (0-1). Default = 0.99", + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", ) parser.add_argument( "--min_diff", type=float, default=0.01, - help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01", + help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", ) parser.add_argument( "--no_metadata",