-
-
Notifications
You must be signed in to change notification settings - Fork 34
/
flux_extract_lora.py
221 lines (190 loc) · 7.46 KB
/
flux_extract_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# extract approximating LoRA by svd from two FLUX models
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo!
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
from .library import flux_utils, sai_model_spec
from .library.utils import MemoryEfficientSafeOpen
from .library.utils import setup_logging
from .networks import lora_flux
setup_logging()
import logging
logger = logging.getLogger(__name__)
from comfy.utils import ProgressBar
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
def save_to_file(file_name, state_dict, metadata, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
save_file(state_dict, file_name, metadata=metadata)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
device=None,
store_device='cpu',
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
mem_eff_safe_open=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
calc_dtype = torch.float
save_dtype = str_to_dtype(save_precision)
# open models
lora_weights = {}
if not mem_eff_safe_open:
# use original safetensors.safe_open
open_fn = lambda fn: safe_open(fn, framework="pt")
else:
logger.info("Using memory efficient safe_open")
open_fn = lambda fn: MemoryEfficientSafeOpen(fn)
with open_fn(model_org) as fo:
# filter keys
keys = []
for key in fo.keys():
if not ("single_block" in key or "double_block" in key):
continue
if ".bias" in key:
continue
if "norm" in key:
continue
keys.append(key)
comfy_pbar = ProgressBar(len(keys))
with open_fn(model_tuned) as ft:
for key in tqdm(keys):
# get tensors and calculate difference
value_o = fo.get_tensor(key)
value_t = ft.get_tensor(key)
mat = value_t.to(calc_dtype) - value_o.to(calc_dtype)
del value_o, value_t
# extract LoRA weights
if device:
mat = mat.to(device)
out_dim, in_dim = mat.size()[0:2]
rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
U = U.to(store_device, dtype=save_dtype).contiguous()
Vh = Vh.to(store_device, dtype=save_dtype).contiguous()
print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}")
comfy_pbar.update(1)
lora_weights[key] = (U, Vh)
del mat, U, S, Vh
# make state dict for LoRA
lora_sd = {}
for key, (up_weight, down_weight) in lora_weights.items():
lora_name = key.replace(".weight", "").replace(".", "_")
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank
# minimum metadata
net_kwargs = {}
metadata = {
"ss_v2": str(False),
"ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1,
"ss_network_module": "networks.lora_flux",
"ss_network_dim": str(dim),
"ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, metadata, save_dtype)
logger.info(f"LoRA weights saved to {save_to}")
return save_to
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
)
parser.add_argument(
"--model_org",
type=str,
default=None,
required=True,
help="Original model: safetensors file / 元モデル、safetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
required=True,
help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
)
parser.add_argument(
"--mem_eff_safe_open",
action="store_true",
help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough."
" / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: safetensors file / 保存先のファイル名、safetensors",
)
parser.add_argument(
"--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)"
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
# parser.add_argument(
# "--min_diff",
# type=float,
# default=0.01,
# help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
# + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
# )
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(**vars(args))