Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EMA #893

Closed
wants to merge 16 commits into from
44 changes: 44 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
from library.train_util import EMAModel


def train(args):
Expand Down Expand Up @@ -245,6 +246,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

if args.enable_ema:
#ema_dtype = weight_dtype if (args.full_bf16 or args.full_fp16) else torch.float32
ema = EMAModel(trainable_params, decay=args.ema_decay, beta=args.ema_exp_beta, max_train_steps=args.max_train_steps)
ema.to(accelerator.device, dtype=weight_dtype)
ema = accelerator.prepare(ema)
else:
ema = None

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand Down Expand Up @@ -375,6 +384,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.enable_ema:
with torch.no_grad(), accelerator.autocast():
ema.step(trainable_params)

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down Expand Up @@ -429,6 +441,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
if args.enable_ema and not args.ema_save_only_ema_weights and ((epoch + 1) % args.save_every_n_epochs == 0):
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
True,
Expand All @@ -444,13 +459,34 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.unwrap_model(unet),
vae,
)
if args.enable_ema and ((epoch + 1) % args.save_every_n_epochs == 0):
args.output_name = temp_name if temp_name else args.output_name
with ema.ema_parameters(trainable_params):
print("Saving EMA:")
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

is_main_process = accelerator.is_main_process
if is_main_process:
unet = accelerator.unwrap_model(unet)
text_encoder = accelerator.unwrap_model(text_encoder)
if args.enable_ema:
ema = accelerator.unwrap_model(ema)

accelerator.end_training()

Expand All @@ -461,6 +497,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
if args.enable_ema and not args.ema_save_only_ema_weights:
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae)
args.output_name = temp_name
if args.enable_ema:
print("Saving EMA:")
ema.copy_to(trainable_params)
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
Expand Down
22 changes: 22 additions & 0 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def save_sd_model_on_epoch_end_or_stepwise(
vae,
logit_scale,
ckpt_info,
ema = None,
params_to_replace = None,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
Expand Down Expand Up @@ -306,6 +308,10 @@ def diffusers_saver(out_dir):
save_dtype=save_dtype,
)

if args.enable_ema and not args.ema_save_only_ema_weights and ema:
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"

train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
Expand All @@ -318,6 +324,22 @@ def diffusers_saver(out_dir):
sd_saver,
diffusers_saver,
)
args.output_name = temp_name if temp_name else args.output_name
if args.enable_ema and ema:
with ema.ema_parameters(params_to_replace):
print("Saving EMA:")
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
save_stable_diffusion_format,
use_safetensors,
epoch,
num_train_epochs,
global_step,
sd_saver,
diffusers_saver,
)


def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
Expand Down
144 changes: 144 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Sequence,
Tuple,
Union,
Iterable,
)
from accelerate import Accelerator, InitProcessGroupKwargs
import gc
Expand Down Expand Up @@ -68,6 +69,7 @@
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel
import contextlib

# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -2292,6 +2294,135 @@ def load_text_encoder_outputs_from_disk(npz_path):

# endregion


# based mostly on https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py
class EMAModel:
"""
Maintains (exponential) moving average of a set of parameters.
"""
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay: float, beta=0, max_train_steps=10000):
parameters = self.get_params_list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.optimization_step = 0
self.collected_params = None
if beta < 0:
raise ValueError('ema_exp_beta should be > 0')
self.beta = beta
self.max_train_steps = max_train_steps
print(f"len(self.shadow_params): {len(self.shadow_params)}")

def get_params_list(self, parameters: Iterable[torch.nn.Parameter]):
parameters = list(parameters)
if isinstance(parameters[0], dict):
params_list = []
for m in parameters:
params_list.extend(list(m["params"]))
return params_list
else:
return parameters

def get_decay(self, optimization_step: int) -> float:
"""
Get current decay for the exponential moving average.
"""
if self.beta == 0:
return min(self.decay, (1 + optimization_step) / (10 + optimization_step))
else:
# exponential schedule. scales to max_train_steps
x = optimization_step / self.max_train_steps
return min(self.decay, self.decay * (1 - np.exp(-x * self.beta)))

def step(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Update currently maintained parameters.

Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
"""
parameters = self.get_params_list(parameters)
one_minus_decay = 1.0 - self.get_decay(self.optimization_step)
self.optimization_step += 1
#print(f" {one_minus_decay}")
#with torch.no_grad():
for s_param, param in zip(self.shadow_params, parameters, strict=True):
tmp = (s_param - param)
#print(torch.sum(tmp))
# tmp will be a new tensor so we can do in-place
tmp.mul_(one_minus_decay)
s_param.sub_(tmp)

def copy_to(self, parameters: Iterable[torch.nn.Parameter] = None) -> None:
"""
Copy current averaged parameters into given collection of parameters.
"""
parameters = self.get_params_list(parameters)
for s_param, param in zip(self.shadow_params, parameters, strict=True):
# print(f"diff: {torch.sum(s_param) - torch.sum(param)}")
param.data.copy_(s_param.data)

def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
"""
self.shadow_params = [
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
for p in self.shadow_params
]
return

def store(self, parameters: Iterable[torch.nn.Parameter] = None) -> None:
"""
Save the current parameters for restoring later.
"""
parameters = self.get_params_list(parameters)
self.collected_params = [
param.clone()
for param in parameters
]

def restore(self, parameters: Iterable[torch.nn.Parameter] = None) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
"""
if self.collected_params is None:
raise RuntimeError(
"This ExponentialMovingAverage has no `store()`ed weights "
"to `restore()`"
)
parameters = self.get_params_list(parameters)
for c_param, param in zip(self.collected_params, parameters, strict=True):
param.data.copy_(c_param.data)

@contextlib.contextmanager
def ema_parameters(self, parameters: Iterable[torch.nn.Parameter] = None):
r"""
Context manager for validation/inference with averaged parameters.

Equivalent to:
ema.store()
ema.copy_to()
try:
...
finally:
ema.restore()
"""
parameters = self.get_params_list(parameters)
self.store(parameters)
self.copy_to(parameters)
try:
yield
finally:
self.restore(parameters)


# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
Expand Down Expand Up @@ -2932,6 +3063,19 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
# default=None,
# help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
# )
parser.add_argument(
"--enable_ema", action="store_true", help="Enable EMA (Exponential Moving Average) of model parameters / モデルパラメータのEMA(指数移動平均)を有効にする "
)
parser.add_argument(
"--ema_decay", type=float, default=0.999, help="Max EMA decay. Typical values: 0.999 - 0.9999 / 最大EMA減衰。標準的な値: 0.999 - 0.9999 "
)
parser.add_argument(
"--ema_exp_beta", type=float, default=15, help="Choose EMA decay schedule. By default: (1+x)/(10+x). If beta is set: use exponential schedule scaled to max_train_steps. If beta>0, recommended values are around 10-15 "
+ "/ EMAの減衰スケジュールを設定する。デフォルト:(1+x)/(10+x)。beta が設定されている場合: max_train_steps にスケーリングされた指数スケジュールを使用する。beta>0 の場合、推奨値は 10-15 程度。 "
)
parser.add_argument(
"--ema_save_only_ema_weights", action="store_true", help="By default both EMA and non-EMA weights are saved. If enabled, saves only EMA / デフォルトでは、EMAウェイトと非EMAウェイトの両方が保存される。有効にすると、EMAのみが保存される "
)
parser.add_argument(
"--multires_noise_discount",
type=float,
Expand Down
40 changes: 40 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
apply_debiased_estimation,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
from library.train_util import EMAModel


UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
Expand Down Expand Up @@ -394,6 +395,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)

if args.enable_ema:
#ema_dtype = weight_dtype if (args.full_bf16 or args.full_fp16) else torch.float
ema = EMAModel(params_to_optimize, decay=args.ema_decay, beta=args.ema_exp_beta, max_train_steps=args.max_train_steps)
ema.to(accelerator.device, dtype=weight_dtype)
ema = accelerator.prepare(ema)
else:
ema = None
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
Expand Down Expand Up @@ -590,6 +598,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.enable_ema:
with torch.no_grad(), accelerator.autocast():
ema.step(params_to_optimize)

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down Expand Up @@ -630,6 +641,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
logit_scale,
ckpt_info,
ema=ema,
params_to_replace=params_to_optimize,
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
Expand Down Expand Up @@ -676,6 +689,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
logit_scale,
ckpt_info,
ema=ema,
params_to_replace=params_to_optimize,
)

sdxl_train_util.sample_images(
Expand All @@ -695,6 +710,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet = accelerator.unwrap_model(unet)
text_encoder1 = accelerator.unwrap_model(text_encoder1)
text_encoder2 = accelerator.unwrap_model(text_encoder2)
if args.enable_ema:
ema = accelerator.unwrap_model(ema)

accelerator.end_training()

Expand All @@ -705,6 +722,29 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
if args.enable_ema and not args.ema_save_only_ema_weights:
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"
sdxl_train_util.save_sd_model_on_train_end(
args,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
global_step,
text_encoder1,
text_encoder2,
unet,
vae,
logit_scale,
ckpt_info,
)
args.output_name = temp_name
if args.enable_ema:
print("Saving EMA:")
ema.copy_to(params_to_optimize)

sdxl_train_util.save_sd_model_on_train_end(
args,
src_path,
Expand Down
Loading