Skip to content

Commit

Permalink
update as official impl
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 23, 2023
1 parent 40525d4 commit 2bdcd9b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
67 changes: 55 additions & 12 deletions library/sdxl_original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,21 +998,25 @@ def __init__(

# FreeU
self.freeU = False
self.freeUSl = 0.5
self.freeURThres = 0.5
self.freeUBl = 0.5
self.freeUB1 = 1.0
self.freeUB2 = 1.0
self.freeUS1 = 1.0
self.freeUS2 = 1.0
self.freeURThres = 1

# implementation of FreeU
# FreeU: Free Lunch in Diffusion U-Net https://arxiv.org/abs/2309.11497

def set_free_u_enabled(self, enabled: bool, bl=0.5, sl=0.5, rthresh=0.5):
print(f"FreeU: {enabled}, bl={bl}, sl={sl}, rthresh={rthresh}")
def set_free_u_enabled(self, enabled: bool, b1=1.0, b2=1.0, s1=1.0, s2=1.0, rthresh=1):
print(f"FreeU: {enabled}, b1={b1}, b2={b2}, s1={s1}, s2={s2}, rthresh={rthresh}")
self.freeU = enabled
self.freeUSl = sl
self.freeUB1 = b1
self.freeUB2 = b2
self.freeUS1 = s1
self.freeUS2 = s2
self.freeURThres = rthresh
self.freeUBl = bl

def spectral_modulation(self, skip_feature, sl=0.5, rthresh=0.5):
def spectral_modulation(self, skip_feature, sl=1.0, rthresh=1):
"""
スキップ特徴を周波数領域で修正する関数
Expand All @@ -1024,6 +1028,9 @@ def spectral_modulation(self, skip_feature, sl=0.5, rthresh=0.5):

import torch.fft

r"""
# 論文に従った実装
org_dtype = skip_feature.dtype
if org_dtype == torch.bfloat16:
skip_feature = skip_feature.to(torch.float32)
Expand Down Expand Up @@ -1060,9 +1067,35 @@ def spectral_modulation(self, skip_feature, sl=0.5, rthresh=0.5):
modified_skip_feature = torch.fft.ifftn(F_prime, dim=(2, 3))
modified_skip_feature = modified_skip_feature.real # 実部のみを取得
"""

if org_dtype == torch.bfloat16:
modified_skip_feature = modified_skip_feature.to(org_dtype)
# 公式リポジトリの実装

org_dtype = skip_feature.dtype

x = skip_feature
threshold = rthresh
scale = sl

# FFT
x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))

B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W), device=x.device)

crow, ccol = H // 2, W // 2
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
x_freq = x_freq * mask

# IFFT
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real

modified_skip_feature = x_filtered

# if org_dtype == torch.bfloat16:
modified_skip_feature = modified_skip_feature.to(org_dtype)

return modified_skip_feature

Expand Down Expand Up @@ -1151,7 +1184,12 @@ def call_module(module, h, emb, context):
h = call_module(module, h, emb, context)

if self.freeU:
h_mod = self.spectral_modulation(h, self.freeUSl, self.freeURThres)
ch = h.shape[1]
s = self.freeUS1 if ch == 1280 else (self.freeUS2 if ch == 640 else 1.0)
if s == 1.0:
h_mod = h
else:
h_mod = self.spectral_modulation(h, s, self.freeURThres)
hs.append(h_mod)
else:
hs.append(h)
Expand All @@ -1161,7 +1199,12 @@ def call_module(module, h, emb, context):
for module in self.output_blocks:
if self.freeU:
ch = h.shape[1]
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUBl
if ch == 1280:
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB1
elif ch == 640:
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB2
# else:
# print(f"disable freeU: {ch}")

h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
Expand Down
2 changes: 1 addition & 1 deletion sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ def __getattr__(self, item):

# freeU
# unet.set_free_u_enabled(False, 1.0, 1.0, 0)
unet.set_free_u_enabled(True, 1.4, 1.0, 10)
unet.set_free_u_enabled(True, 1.1, 1.2, 0.9, 0.2)

# networkを組み込む
if args.network_module:
Expand Down

0 comments on commit 2bdcd9b

Please sign in to comment.