From 7f3c7c533d50af983e2a79b740a5afbcf845206a Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 26 Jun 2024 19:25:06 +0800 Subject: [PATCH 01/12] Update sd3_train.py --- sd3_train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 0721b2ae4..1291e9e9d 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,9 +60,6 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" - assert ( - not args.train_text_encoder or not args.cache_text_encoder_outputs - ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] From 8b1653548f8f219e5be2cde96f65a8813cf9ea1f Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 29 Jun 2024 15:32:54 +0800 Subject: [PATCH 02/12] add freeze block lr --- library/train_util.py | 20 ++++++++++++++++++++ sd3_train.py | 4 ++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 96d32e3bc..d2ab95a05 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3061,6 +3061,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_layers_to_freeze", + type=int, + default=None, + help="num_last_layers_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5598,6 +5604,20 @@ def sample_image_inference( pass +def freeze_blocks_lr(model, num_last_layers_to_freeze, base_lr, block_name="x_block"): + bottom_layers = list(model.children())[-num_last_layers_to_freeze:] + + params_to_optimize = [] + + for layer in reversed(bottom_layers): + for name, param in layer.named_parameters(): + if block_name in name: + params_to_optimize.append({"params": [param], "lr": 0.0}) + else: + params_to_optimize.append({"params": [param], "lr": base_lr}) + + return params_to_optimize + # endregion diff --git a/sd3_train.py b/sd3_train.py index 9a7de2393..8bff476a6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,7 +60,6 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" - # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -283,7 +282,8 @@ def train(args): # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize = train_util.freeze_blocks_lr(mmdit, args.num_last_layers_to_freeze,args.args.learning_rate) + # params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) From ad9f9c163482a08d10b11e47e46f3b9771ff0a7e Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 29 Jun 2024 15:55:51 +0800 Subject: [PATCH 03/12] Update train_util.py --- library/train_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d2ab95a05..b448b1382 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5605,14 +5605,14 @@ def sample_image_inference( def freeze_blocks_lr(model, num_last_layers_to_freeze, base_lr, block_name="x_block"): - bottom_layers = list(model.children())[-num_last_layers_to_freeze:] - params_to_optimize = [] + frozen_params_count = 0 - for layer in reversed(bottom_layers): - for name, param in layer.named_parameters(): - if block_name in name: + for module in reversed(model.children()): + for name, param in module.named_parameters(): + if block_name in name and frozen_params_count < num_last_layers_to_freeze: params_to_optimize.append({"params": [param], "lr": 0.0}) + frozen_params_count += 1 else: params_to_optimize.append({"params": [param], "lr": base_lr}) From 110205fb4f2d004b8879eea4e6af1fe3555c2bfb Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 29 Jun 2024 17:34:04 +0800 Subject: [PATCH 04/12] update --- library/train_util.py | 27 ++++++++++++++------------- sd3_train.py | 6 ++++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b448b1382..3e127f74f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3062,10 +3062,10 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) parser.add_argument( - "--num_last_layers_to_freeze", + "--num_last_block_to_freeze", type=int, default=None, - help="num_last_layers_to_freeze", + help="num_last_block_to_freeze", ) @@ -5604,19 +5604,20 @@ def sample_image_inference( pass -def freeze_blocks_lr(model, num_last_layers_to_freeze, base_lr, block_name="x_block"): - params_to_optimize = [] - frozen_params_count = 0 +def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): - for module in reversed(model.children()): - for name, param in module.named_parameters(): - if block_name in name and frozen_params_count < num_last_layers_to_freeze: - params_to_optimize.append({"params": [param], "lr": 0.0}) - frozen_params_count += 1 - else: - params_to_optimize.append({"params": [param], "lr": base_lr}) + filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] + print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) + + print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) - return params_to_optimize + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False # endregion diff --git a/sd3_train.py b/sd3_train.py index 8bff476a6..3f9f9689e 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -277,13 +277,15 @@ def train(args): if not train_mmdit: mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + if args.num_last_block_to_freeze: + train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + training_models = [] params_to_optimize = [] # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize = train_util.freeze_blocks_lr(mmdit, args.num_last_layers_to_freeze,args.args.learning_rate) - # params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) From 1449a3bc04b6cd4eccb20cbf8a04062607fb4037 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 12:18:56 +0800 Subject: [PATCH 05/12] Update train_network.py --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 7ba073855..f396f6427 100644 --- a/train_network.py +++ b/train_network.py @@ -72,7 +72,7 @@ def generate_step_logs( lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): - if lr_descriptions is not None: + if lr_descriptions is not None and i < len(lr_descriptions): lr_desc = lr_descriptions[i] else: idx = i - (0 if args.network_train_unet_only else -1) @@ -364,7 +364,7 @@ def train(self, args): # v = len(v) # accelerator.print(f"trainable_params: {k} = {v}") - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params, network) # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 From c4478b8dc154cfeac426123be4395806d072d587 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Thu, 18 Jul 2024 00:11:40 +0800 Subject: [PATCH 06/12] Update train_util.py --- library/train_util.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 760be33eb..b1c5cd881 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3990,7 +3990,7 @@ def task(): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): +def get_optimizer(args, trainable_params, model=None): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type @@ -4263,6 +4263,15 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "AdamMini".lower(): + logger.info(f"use AdamMini optimizer | {optimizer_kwargs}") + try: + import pytorch_optimizer + optimizer_class = pytorch_optimizer.AdamMini + except ImportError: + raise ImportError("No adam-mini / adam-mini がインストールされていないようです") + optimizer = optimizer_class(model, lr=lr, **optimizer_kwargs) + if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) From bdc630ab60c77be205ca22401166da2d4f286390 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Tue, 13 Aug 2024 23:53:55 +0800 Subject: [PATCH 07/12] Update train_util.py --- library/train_util.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b1c5cd881..72c0e50bc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4266,11 +4266,19 @@ def get_optimizer(args, trainable_params, model=None): elif optimizer_type == "AdamMini".lower(): logger.info(f"use AdamMini optimizer | {optimizer_kwargs}") try: - import pytorch_optimizer - optimizer_class = pytorch_optimizer.AdamMini + import adam_mini + optimizer_class = adam_mini.Adam_mini except ImportError: raise ImportError("No adam-mini / adam-mini がインストールされていないようです") - optimizer = optimizer_class(model, lr=lr, **optimizer_kwargs) + + # trainable_params → named_parameters + named_params = [(f"{model}.{name}", param) for name, param in model.named_parameters() if param in trainable_params] + + optimizer = optimizer_class(named_params, lr=lr, **optimizer_kwargs) + optimizer.embd_names.add("to_out") + optimizer.wqk_names.add("to_q") + optimizer.wqk_names.add('to_k') + optimizer.wqk_names.add('to_v') if optimizer is None: # 任意のoptimizerを使う From c07697c1ffc0769553263c414bacd53ebfdd2509 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Tue, 13 Aug 2024 23:54:47 +0800 Subject: [PATCH 08/12] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index e99775b8a..0ce2f7471 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.43.0 prodigyopt==1.0 lion-pytorch==0.0.6 +adam-mini==1.0.1 tensorboard safetensors==0.4.2 # gradio==3.16.2 From 3dc8ec38043231c6eefd6fed9b7745c1d7944327 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 14 Aug 2024 00:54:07 +0800 Subject: [PATCH 09/12] update bnb for 124 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0ce2f7471..8edc0429f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ ftfy==6.1.1 opencv-python==4.7.0.68 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 adam-mini==1.0.1 From ee9ef2040ad0d54adcceb89f86ebe0c7032ebb00 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 17 Aug 2024 17:52:06 +0800 Subject: [PATCH 10/12] update adam-mini --- library/adam_mini.py | 259 ++++++++++++++++++++++++++++++++++++++++++ library/train_util.py | 16 +-- 2 files changed, 265 insertions(+), 10 deletions(-) create mode 100644 library/adam_mini.py diff --git a/library/adam_mini.py b/library/adam_mini.py new file mode 100644 index 000000000..d521bbfba --- /dev/null +++ b/library/adam_mini.py @@ -0,0 +1,259 @@ +import math +from typing import Iterable, Tuple, Union, Optional + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.distributed._tensor import Replicate + + +class Adam_mini(torch.optim.Optimizer): + def __init__( + self, + named_parameters: Iterable[Tuple[str, nn.Parameter]], + lr: Union[float, torch.Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + *, + model_sharding: bool = True, + dim: int = 2048, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + ): + ''' + named_parameters: model.named_parameters() + + lr: learning rate + + betas: same betas as Adam + + eps: same eps as Adam + + weight_decay: weight_decay coefficient + + model_sharding: set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1,2,3 in Deepspeed. Set to False if otherwise. + + dim: dimension for hidden feature. Could be unspecified if you are training non-transformer models. + + n_heads: number of attention heads. Could be unspecified if you are training non-transformer models. + + n_kv_heads: number of head for Key and Value. Or equivalently, number of query groups in Group query Attention. Also known as "n_query_groups". If not specified, it will be equal to n_head. Could be unspecified if you are training non-transformer models. + ''' + self.dim = dim + self.n_heads = n_heads + if n_kv_heads is not None: + assert n_heads % n_kv_heads == 0, f"{n_heads} {n_kv_heads}" + self.n_kv_heads = n_kv_heads + else: + self.n_kv_heads = n_heads + + self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + self.model_sharding = model_sharding + if self.model_sharding: + print("=====>>> Adam-mini is using model_sharding") + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not self.dim == int(self.dim): + raise ValueError("Invalid dim value: {}".format(self.dim)) + if not self.n_heads == int(self.n_heads): + raise ValueError("Invalid n_heads value: {}".format(self.n_heads)) + if not self.n_kv_heads == int(self.n_kv_heads): + raise ValueError("Invalid n_kv_heads value: {}".format(self.n_kv_heads)) + + optim_groups = [] + count_embd = count_output = count_wq = count_wk = 0 + for param_name, param in named_parameters: + if not param.requires_grad: + continue + print('Adam-mini found the param block with name:', param_name) + state = {} + state["name"] = param_name + state["params"] = param + if "norm" in param_name or "ln_f" in param_name: + state["weight_decay"] = 0.0 + else: + state["weight_decay"] = weight_decay + if "embed" in param_name or "wte" in param_name or "embd" in param_name: + count_embd += 1 + if "lm_head.weight" in param_name or "output.weight" in param_name: + count_output += 1 + if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name: + count_wq += 1 + assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" + state["head_numel"] = self.dim * self.dim // self.n_heads + if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name: + count_wk += 1 + assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" + state["head_numel"] = self.dim * self.dim // self.n_heads + optim_groups.append(state) + + print( + f'Adam-mini found {count_embd} embedding layers, {count_output} output layers, {count_wq} Querys, {count_wk} Keys.') + + if count_embd == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No embedding layer found. If you are training Transformers, please check the name of your embedding layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the name of your embedding layer'). ") + if count_output == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No output layer found. If you are training Transformers (without weight-tying), please check the name of your output layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the name of your output layer'). Please ignore this warning if you are using weight-tying.") + if count_wq == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No Query found. If you are training Transformers, please check the name of your Query in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.wqk_names.add('the name of your Query'). ") + + if count_wk == 0: + # warning + print( + "=====>>> Warning by Adam-mini: No Key found. If you are training Transformers, please check the name of your Key in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.wqk_names.add('the name of your Key').") + + if count_output + count_embd + count_wq + count_wk == 0: + print( + "=====>>> Warning by Adam-mini: you are using default PyTorch partition for Adam-mini. It can cause training instability on large-scale Transformers.") + + # embd_blocks, including embd and output layers. Use normal adamW updates for these blocks + self.embd_names = {"embed", "embd", "wte", "lm_head.weight", "output.weight"} + # Query and Keys, will assign lrs by heads + self.wqk_names = {"k_proj.weight", "q_proj.weight", "wq.weight", "wk.weight"} + + defaults = dict(lr=lr, beta1=betas[0], beta2=betas[1], eps=eps) + super().__init__(optim_groups, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + beta1 = group["beta1"] + beta2 = group["beta2"] + lr = group["lr"] + name = group["name"] + eps = group["eps"] + + for p in group["params"]: + + state = self.state[p] + if any(embd_name in name for embd_name in self.embd_names): # this is for embedding and output layer + if p.grad is None: + continue + if len(state) == 0: + state["m"] = torch.zeros_like(p, dtype=torch.float32) + state["step"] = 0 + state["v"] = torch.zeros_like(p, dtype=torch.float32) + + grad = p.grad.to(torch.float32) + state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = lr / bias_correction_1 + p.addcdiv_(state["m"], h, value=-stepsize) + elif any(wqk_name in name for wqk_name in self.wqk_names): # this is for query and key + if p.grad is None: + continue + head_numel = group["head_numel"] + if len(state) == 0: + m = torch.zeros_like(p, dtype=torch.float32) + state["m"] = m.view(-1, head_numel) + state["head"] = state["m"].size(0) + state["step"] = 0 + # NOTE: We must use `zeros_like` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # state["vmean"] = torch.zeros(state["head"]) + state["vmean"] = torch.zeros_like(state["m"][0:state["head"], 0:1]) + + grad = p.grad.to(torch.float32) + head = state["head"] + grad = grad.view(head, head_numel) + tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True) + + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = ((1 / bias_correction_1) / h).view(head, 1) + update = (state["m"] * stepsize).view(p.size()) + update.mul_(lr) + p.add_(-update) + else: # other blocks + if len(state) == 0: + block_numel = torch.tensor(p.numel(), dtype=torch.float32, device=p.device) + reduced = False + if (self.world_size > 1) and (self.model_sharding is True): + tensor_list = [torch.zeros_like(block_numel) for _ in range(self.world_size)] + + dist.all_gather(tensor_list, block_numel) + s = 0 + block_numel = 0 + for d in tensor_list: + if (d > 0): + s = s + 1 + block_numel = block_numel + d + if (s >= 2): + reduced = True + + state["m"] = torch.zeros_like(p, dtype=torch.float32) + state["step"] = 0 + state["reduced"] = reduced + # NOTE: We must use `new_zeros` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # state["vmean"] = torch.zeros(1, device=p.device) + # state["vmean"] = p.new_zeros(1) + state["vmean"] = torch.zeros_like(torch.sum(p * p)) + state["block_numel"] = block_numel.item() + if p.grad is None: + tmp_lr = torch.zeros_like(torch.sum(p * p)) + else: + grad = p.grad.to(torch.float32) + tmp_lr = torch.sum(grad * grad) + + if (state["reduced"]): + if "device_mesh" in dir(tmp_lr): + # when tmp_lr is a DTensor in TorchTitan + lr_local = tmp_lr.to_local() + dist.all_reduce(lr_local, op=dist.ReduceOp.SUM) + tmp_lr.redistribute(placements=[Replicate()]) + else: + # when tmp_lr is a standard tensor + # print(f"...... dist all reduce.......") + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) + + if (p.grad is None): + continue + tmp_lr = tmp_lr / state["block_numel"] + + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["step"] += 1 + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = (1 / bias_correction_1) / h + update = state["m"] * (stepsize.to(state["m"].device)) + update.mul_(lr) + p.add_(-update) + return loss \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index 72c0e50bc..950d4d00d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4266,19 +4266,15 @@ def get_optimizer(args, trainable_params, model=None): elif optimizer_type == "AdamMini".lower(): logger.info(f"use AdamMini optimizer | {optimizer_kwargs}") try: - import adam_mini + import library.adam_mini as adam_mini optimizer_class = adam_mini.Adam_mini except ImportError: raise ImportError("No adam-mini / adam-mini がインストールされていないようです") - - # trainable_params → named_parameters - named_params = [(f"{model}.{name}", param) for name, param in model.named_parameters() if param in trainable_params] - - optimizer = optimizer_class(named_params, lr=lr, **optimizer_kwargs) - optimizer.embd_names.add("to_out") - optimizer.wqk_names.add("to_q") - optimizer.wqk_names.add('to_k') - optimizer.wqk_names.add('to_v') + + optimizer = optimizer_class(model.named_parameters(), lr=lr, **optimizer_kwargs) + optimizer.embd_names.add("embed") + optimizer.wqk_names.add("attn") + optimizer.wqk_names.add('mlp') if optimizer is None: # 任意のoptimizerを使う From 843e7e6012447f87380658b1e73fd77e7aa1e49d Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sun, 18 Aug 2024 07:19:47 +0800 Subject: [PATCH 11/12] Update flux_train.py --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index d2a9b3f32..56b057085 100644 --- a/flux_train.py +++ b/flux_train.py @@ -310,7 +310,7 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") else: - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize, model=flux) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset From 02af7a2a12abaa47b0cd71294b9dd722c0212496 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sun, 18 Aug 2024 07:37:13 +0800 Subject: [PATCH 12/12] update --- library/adam_mini.py | 8 ++++---- library/train_util.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/library/adam_mini.py b/library/adam_mini.py index d521bbfba..7cb63804b 100644 --- a/library/adam_mini.py +++ b/library/adam_mini.py @@ -81,15 +81,15 @@ def __init__( state["weight_decay"] = 0.0 else: state["weight_decay"] = weight_decay - if "embed" in param_name or "wte" in param_name or "embd" in param_name: + if "embed" in param_name or "wte" in param_name or "embd" in param_name or "layer.weight" in param_name: count_embd += 1 - if "lm_head.weight" in param_name or "output.weight" in param_name: + if "lm_head.weight" in param_name or "output.weight" in param_name or "final_layer.linear.weight" in param_name or "final_layer.adaLN_modulation.1.weight" in param_name: count_output += 1 - if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name: + if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name or "attn.qkv.weight" in param_name: count_wq += 1 assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" state["head_numel"] = self.dim * self.dim // self.n_heads - if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name: + if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name or "attn.proj.weight" in param_name: count_wk += 1 assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" state["head_numel"] = self.dim * self.dim // self.n_heads diff --git a/library/train_util.py b/library/train_util.py index 8b8bba281..a2ac73d63 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4559,8 +4559,15 @@ def get_optimizer(args, trainable_params, model=None): except ImportError: raise ImportError("No adam-mini / adam-mini がインストールされていないようです") - optimizer = optimizer_class(model.named_parameters(), lr=lr, **optimizer_kwargs) - optimizer.embd_names.add("embed") + named_params = [(name, param) for name, param in model.named_parameters() if param.requires_grad] + + optimizer_kwargs["dim"] = 722 + optimizer_kwargs["n_heads"] = 19 + + optimizer = optimizer_class(named_params, lr=lr, **optimizer_kwargs) + optimizer.embd_names.add("layer.weight") + optimizer.embd_names.add("final_layer.linear.weight") + optimizer.embd_names.add("final_layer.adaLN_modulation.1.weight") optimizer.wqk_names.add("attn") optimizer.wqk_names.add('mlp')