Skip to content

Commit

Permalink
Support concat LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
laksjdjf authored Sep 28, 2023
1 parent 1e395ed commit 14aa292
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 10 deletions.
39 changes: 34 additions & 5 deletions networks/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
module.weight = torch.nn.Parameter(weight)


def merge_lora_models(models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}

Expand Down Expand Up @@ -158,26 +158,43 @@ def merge_lora_models(models, ratios, merge_dtype):
for key in lora_sd.keys():
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None

lora_module_name = key[: key.rfind(".lora_")]

base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]

scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。

if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale

# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]

print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
Expand Down Expand Up @@ -256,7 +273,7 @@ def str_to_dtype(p):
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
)
else:
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)

print(f"calculating hashes and creating metadata...")

Expand Down Expand Up @@ -317,7 +334,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
)

parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)

return parser


Expand Down
40 changes: 35 additions & 5 deletions networks/sdxl_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
module.weight = torch.nn.Parameter(weight)


def merge_lora_models(models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}

Expand Down Expand Up @@ -161,26 +161,44 @@ def merge_lora_models(models, ratios, merge_dtype):
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue

if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None

lora_module_name = key[: key.rfind(".lora_")]

base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]

scale = math.sqrt(alpha / base_alpha) * ratio

scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。

if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale

# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]

print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
Expand Down Expand Up @@ -252,7 +270,7 @@ def str_to_dtype(p):
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)

print(f"calculating hashes and creating metadata...")

Expand Down Expand Up @@ -307,6 +325,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)

return parser

Expand Down

0 comments on commit 14aa292

Please sign in to comment.