diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 3bec0a6b9..581931645 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -998,21 +998,25 @@ def __init__( # FreeU self.freeU = False - self.freeUSl = 0.5 - self.freeURThres = 0.5 - self.freeUBl = 0.5 + self.freeUB1 = 1.0 + self.freeUB2 = 1.0 + self.freeUS1 = 1.0 + self.freeUS2 = 1.0 + self.freeURThres = 1 # implementation of FreeU # FreeU: Free Lunch in Diffusion U-Net https://arxiv.org/abs/2309.11497 - def set_free_u_enabled(self, enabled: bool, bl=0.5, sl=0.5, rthresh=0.5): - print(f"FreeU: {enabled}, bl={bl}, sl={sl}, rthresh={rthresh}") + def set_free_u_enabled(self, enabled: bool, b1=1.0, b2=1.0, s1=1.0, s2=1.0, rthresh=1): + print(f"FreeU: {enabled}, b1={b1}, b2={b2}, s1={s1}, s2={s2}, rthresh={rthresh}") self.freeU = enabled - self.freeUSl = sl + self.freeUB1 = b1 + self.freeUB2 = b2 + self.freeUS1 = s1 + self.freeUS2 = s2 self.freeURThres = rthresh - self.freeUBl = bl - def spectral_modulation(self, skip_feature, sl=0.5, rthresh=0.5): + def spectral_modulation(self, skip_feature, sl=1.0, rthresh=1): """ スキップ特徴を周波数領域で修正する関数 @@ -1024,6 +1028,9 @@ def spectral_modulation(self, skip_feature, sl=0.5, rthresh=0.5): import torch.fft + r""" + # 論文に従った実装 + org_dtype = skip_feature.dtype if org_dtype == torch.bfloat16: skip_feature = skip_feature.to(torch.float32) @@ -1060,9 +1067,35 @@ def spectral_modulation(self, skip_feature, sl=0.5, rthresh=0.5): modified_skip_feature = torch.fft.ifftn(F_prime, dim=(2, 3)) modified_skip_feature = modified_skip_feature.real # 実部のみを取得 + """ - if org_dtype == torch.bfloat16: - modified_skip_feature = modified_skip_feature.to(org_dtype) + # 公式リポジトリの実装 + + org_dtype = skip_feature.dtype + + x = skip_feature + threshold = rthresh + scale = sl + + # FFT + x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) + x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real + + modified_skip_feature = x_filtered + + # if org_dtype == torch.bfloat16: + modified_skip_feature = modified_skip_feature.to(org_dtype) return modified_skip_feature @@ -1151,7 +1184,12 @@ def call_module(module, h, emb, context): h = call_module(module, h, emb, context) if self.freeU: - h_mod = self.spectral_modulation(h, self.freeUSl, self.freeURThres) + ch = h.shape[1] + s = self.freeUS1 if ch == 1280 else (self.freeUS2 if ch == 640 else 1.0) + if s == 1.0: + h_mod = h + else: + h_mod = self.spectral_modulation(h, s, self.freeURThres) hs.append(h_mod) else: hs.append(h) @@ -1161,7 +1199,12 @@ def call_module(module, h, emb, context): for module in self.output_blocks: if self.freeU: ch = h.shape[1] - h[:, : ch // 2] = h[:, : ch // 2] * self.freeUBl + if ch == 1280: + h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB1 + elif ch == 640: + h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB2 + # else: + # print(f"disable freeU: {ch}") h = torch.cat([h, hs.pop()], dim=1) h = call_module(module, h, emb, context) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 2e465e223..f5a09aa5a 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1523,7 +1523,7 @@ def __getattr__(self, item): # freeU # unet.set_free_u_enabled(False, 1.0, 1.0, 0) - unet.set_free_u_enabled(True, 1.4, 1.0, 10) + unet.set_free_u_enabled(True, 1.1, 1.2, 0.9, 0.2) # networkを組み込む if args.network_module: