diff --git a/flux_train.py b/flux_train.py index d2a9b3f32..56b057085 100644 --- a/flux_train.py +++ b/flux_train.py @@ -310,7 +310,7 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") else: - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize, model=flux) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset diff --git a/library/adam_mini.py b/library/adam_mini.py new file mode 100644 index 000000000..7cb63804b --- /dev/null +++ b/library/adam_mini.py @@ -0,0 +1,259 @@ +import math +from typing import Iterable, Tuple, Union, Optional + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.distributed._tensor import Replicate + + +class Adam_mini(torch.optim.Optimizer): + def __init__( + self, + named_parameters: Iterable[Tuple[str, nn.Parameter]], + lr: Union[float, torch.Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + *, + model_sharding: bool = True, + dim: int = 2048, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + ): + ''' + named_parameters: model.named_parameters() + + lr: learning rate + + betas: same betas as Adam + + eps: same eps as Adam + + weight_decay: weight_decay coefficient + + model_sharding: set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1,2,3 in Deepspeed. Set to False if otherwise. + + dim: dimension for hidden feature. Could be unspecified if you are training non-transformer models. + + n_heads: number of attention heads. Could be unspecified if you are training non-transformer models. + + n_kv_heads: number of head for Key and Value. Or equivalently, number of query groups in Group query Attention. Also known as "n_query_groups". If not specified, it will be equal to n_head. Could be unspecified if you are training non-transformer models. + ''' + self.dim = dim + self.n_heads = n_heads + if n_kv_heads is not None: + assert n_heads % n_kv_heads == 0, f"{n_heads} {n_kv_heads}" + self.n_kv_heads = n_kv_heads + else: + self.n_kv_heads = n_heads + + self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + self.model_sharding = model_sharding + if self.model_sharding: + print("=====>>> Adam-mini is using model_sharding") + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not self.dim == int(self.dim): + raise ValueError("Invalid dim value: {}".format(self.dim)) + if not self.n_heads == int(self.n_heads): + raise ValueError("Invalid n_heads value: {}".format(self.n_heads)) + if not self.n_kv_heads == int(self.n_kv_heads): + raise ValueError("Invalid n_kv_heads value: {}".format(self.n_kv_heads)) + + optim_groups = [] + count_embd = count_output = count_wq = count_wk = 0 + for param_name, param in named_parameters: + if not param.requires_grad: + continue + print('Adam-mini found the param block with name:', param_name) + state = {} + state["name"] = param_name + state["params"] = param + if "norm" in param_name or "ln_f" in param_name: + state["weight_decay"] = 0.0 + else: + state["weight_decay"] = weight_decay + if "embed" in param_name or "wte" in param_name or "embd" in param_name or "layer.weight" in param_name: + count_embd += 1 + if "lm_head.weight" in param_name or "output.weight" in param_name or "final_layer.linear.weight" in param_name or "final_layer.adaLN_modulation.1.weight" in param_name: + count_output += 1 + if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name or "attn.qkv.weight" in param_name: + count_wq += 1 + assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" + state["head_numel"] = self.dim * self.dim // self.n_heads + if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name or "attn.proj.weight" in param_name: + count_wk += 1 + assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" + state["head_numel"] = self.dim * self.dim // self.n_heads + optim_groups.append(state) + + print( + f'Adam-mini found {count_embd} embedding layers, {count_output} output layers, {count_wq} Querys, {count_wk} Keys.') + + if count_embd == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No embedding layer found. If you are training Transformers, please check the name of your embedding layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the name of your embedding layer'). ") + if count_output == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No output layer found. If you are training Transformers (without weight-tying), please check the name of your output layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the name of your output layer'). Please ignore this warning if you are using weight-tying.") + if count_wq == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No Query found. If you are training Transformers, please check the name of your Query in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.wqk_names.add('the name of your Query'). ") + + if count_wk == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No Key found. If you are training Transformers, please check the name of your Key in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.wqk_names.add('the name of your Key').") + + if count_output + count_embd + count_wq + count_wk == 0: + print( + "=====>>> Warning by Adam-mini: you are using default PyTorch partition for Adam-mini. It can cause training instability on large-scale Transformers.") + + # embd_blocks, including embd and output layers. Use normal adamW updates for these blocks + self.embd_names = {"embed", "embd", "wte", "lm_head.weight", "output.weight"} + # Query and Keys, will assign lrs by heads + self.wqk_names = {"k_proj.weight", "q_proj.weight", "wq.weight", "wk.weight"} + + defaults = dict(lr=lr, beta1=betas[0], beta2=betas[1], eps=eps) + super().__init__(optim_groups, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + beta1 = group["beta1"] + beta2 = group["beta2"] + lr = group["lr"] + name = group["name"] + eps = group["eps"] + + for p in group["params"]: + + state = self.state[p] + if any(embd_name in name for embd_name in self.embd_names): # this is for embedding and output layer + if p.grad is None: + continue + if len(state) == 0: + state["m"] = torch.zeros_like(p, dtype=torch.float32) + state["step"] = 0 + state["v"] = torch.zeros_like(p, dtype=torch.float32) + + grad = p.grad.to(torch.float32) + state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = lr / bias_correction_1 + p.addcdiv_(state["m"], h, value=-stepsize) + elif any(wqk_name in name for wqk_name in self.wqk_names): # this is for query and key + if p.grad is None: + continue + head_numel = group["head_numel"] + if len(state) == 0: + m = torch.zeros_like(p, dtype=torch.float32) + state["m"] = m.view(-1, head_numel) + state["head"] = state["m"].size(0) + state["step"] = 0 + # NOTE: We must use `zeros_like` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # state["vmean"] = torch.zeros(state["head"]) + state["vmean"] = torch.zeros_like(state["m"][0:state["head"], 0:1]) + + grad = p.grad.to(torch.float32) + head = state["head"] + grad = grad.view(head, head_numel) + tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True) + + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = ((1 / bias_correction_1) / h).view(head, 1) + update = (state["m"] * stepsize).view(p.size()) + update.mul_(lr) + p.add_(-update) + else: # other blocks + if len(state) == 0: + block_numel = torch.tensor(p.numel(), dtype=torch.float32, device=p.device) + reduced = False + if (self.world_size > 1) and (self.model_sharding is True): + tensor_list = [torch.zeros_like(block_numel) for _ in range(self.world_size)] + + dist.all_gather(tensor_list, block_numel) + s = 0 + block_numel = 0 + for d in tensor_list: + if (d > 0): + s = s + 1 + block_numel = block_numel + d + if (s >= 2): + reduced = True + + state["m"] = torch.zeros_like(p, dtype=torch.float32) + state["step"] = 0 + state["reduced"] = reduced + # NOTE: We must use `new_zeros` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # state["vmean"] = torch.zeros(1, device=p.device) + # state["vmean"] = p.new_zeros(1) + state["vmean"] = torch.zeros_like(torch.sum(p * p)) + state["block_numel"] = block_numel.item() + if p.grad is None: + tmp_lr = torch.zeros_like(torch.sum(p * p)) + else: + grad = p.grad.to(torch.float32) + tmp_lr = torch.sum(grad * grad) + + if (state["reduced"]): + if "device_mesh" in dir(tmp_lr): + # when tmp_lr is a DTensor in TorchTitan + lr_local = tmp_lr.to_local() + dist.all_reduce(lr_local, op=dist.ReduceOp.SUM) + tmp_lr.redistribute(placements=[Replicate()]) + else: + # when tmp_lr is a standard tensor + # print(f"...... dist all reduce.......") + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) + + if (p.grad is None): + continue + tmp_lr = tmp_lr / state["block_numel"] + + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["step"] += 1 + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = (1 / bias_correction_1) / h + update = state["m"] * (stepsize.to(state["m"].device)) + update.mul_(lr) + p.add_(-update) + return loss \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index f4ac8740a..a2ac73d63 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3250,6 +3250,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="num_last_block_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -4272,7 +4278,7 @@ def task(): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): +def get_optimizer(args, trainable_params, model=None): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type @@ -4545,6 +4551,26 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "AdamMini".lower(): + logger.info(f"use AdamMini optimizer | {optimizer_kwargs}") + try: + import library.adam_mini as adam_mini + optimizer_class = adam_mini.Adam_mini + except ImportError: + raise ImportError("No adam-mini / adam-mini がインストールされていないようです") + + named_params = [(name, param) for name, param in model.named_parameters() if param.requires_grad] + + optimizer_kwargs["dim"] = 722 + optimizer_kwargs["n_heads"] = 19 + + optimizer = optimizer_class(named_params, lr=lr, **optimizer_kwargs) + optimizer.embd_names.add("layer.weight") + optimizer.embd_names.add("final_layer.linear.weight") + optimizer.embd_names.add("final_layer.adaLN_modulation.1.weight") + optimizer.wqk_names.add("attn") + optimizer.wqk_names.add('mlp') + if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) @@ -5762,6 +5788,21 @@ def sample_image_inference( pass +def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): + + filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] + print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) + + print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + # endregion diff --git a/requirements.txt b/requirements.txt index 4ee19b3ee..d86b8da83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 +adam-mini==1.0.1 tensorboard safetensors==0.4.2 # gradio==3.16.2 diff --git a/sd3_train.py b/sd3_train.py index 3b6c8a118..ce9500b0b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -368,12 +368,19 @@ def train(args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + training_models = [] params_to_optimize = [] # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) diff --git a/train_network.py b/train_network.py index 086b314a5..303e2d3a4 100644 --- a/train_network.py +++ b/train_network.py @@ -73,7 +73,7 @@ def generate_step_logs( lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): - if lr_descriptions is not None: + if lr_descriptions is not None and i < len(lr_descriptions): lr_desc = lr_descriptions[i] else: idx = i - (0 if args.network_train_unet_only else -1) @@ -471,7 +471,7 @@ def train(self, args): # v = len(v) # accelerator.print(f"trainable_params: {k} = {v}") - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params, network) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset