Skip to content

Commit

Permalink
format by black, add ja comment
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 25, 2023
1 parent 9795840 commit 0fb9ecf
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions networks/extract_lora_from_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import lora


#CLAMP_QUANTILE = 0.99
#MIN_DIFF = 1e-1
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1


def save_to_file(file_name, model, state_dict, dtype):
Expand All @@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name)


def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False):
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
Expand All @@ -39,9 +53,7 @@ def str_to_dtype(p):
return torch.bfloat16
return None

assert v2 != sdxl or (
not v2 and not sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if v_parameterization is None:
v_parameterization = v2

Expand Down Expand Up @@ -199,9 +211,7 @@ def str_to_dtype(p):

if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, v2, v_parameterization, sdxl, True, False, time.time(), title=title
)
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)

lora_network_save.save_weights(save_to, save_dtype, metadata)
Expand Down Expand Up @@ -242,7 +252,11 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
)
parser.add_argument(
"--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
parser.add_argument(
Expand All @@ -256,13 +270,14 @@ def setup_parser() -> argparse.ArgumentParser:
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99",
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01",
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
)
parser.add_argument(
"--no_metadata",
Expand Down

0 comments on commit 0fb9ecf

Please sign in to comment.