From 02af7a2a12abaa47b0cd71294b9dd722c0212496 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sun, 18 Aug 2024 07:37:13 +0800 Subject: [PATCH] update --- library/adam_mini.py | 8 ++++---- library/train_util.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/library/adam_mini.py b/library/adam_mini.py index d521bbfba..7cb63804b 100644 --- a/library/adam_mini.py +++ b/library/adam_mini.py @@ -81,15 +81,15 @@ def __init__( state["weight_decay"] = 0.0 else: state["weight_decay"] = weight_decay - if "embed" in param_name or "wte" in param_name or "embd" in param_name: + if "embed" in param_name or "wte" in param_name or "embd" in param_name or "layer.weight" in param_name: count_embd += 1 - if "lm_head.weight" in param_name or "output.weight" in param_name: + if "lm_head.weight" in param_name or "output.weight" in param_name or "final_layer.linear.weight" in param_name or "final_layer.adaLN_modulation.1.weight" in param_name: count_output += 1 - if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name: + if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name or "attn.qkv.weight" in param_name: count_wq += 1 assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" state["head_numel"] = self.dim * self.dim // self.n_heads - if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name: + if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name or "attn.proj.weight" in param_name: count_wk += 1 assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" state["head_numel"] = self.dim * self.dim // self.n_heads diff --git a/library/train_util.py b/library/train_util.py index 8b8bba281..a2ac73d63 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4559,8 +4559,15 @@ def get_optimizer(args, trainable_params, model=None): except ImportError: raise ImportError("No adam-mini / adam-mini がインストールされていないようです") - optimizer = optimizer_class(model.named_parameters(), lr=lr, **optimizer_kwargs) - optimizer.embd_names.add("embed") + named_params = [(name, param) for name, param in model.named_parameters() if param.requires_grad] + + optimizer_kwargs["dim"] = 722 + optimizer_kwargs["n_heads"] = 19 + + optimizer = optimizer_class(named_params, lr=lr, **optimizer_kwargs) + optimizer.embd_names.add("layer.weight") + optimizer.embd_names.add("final_layer.linear.weight") + optimizer.embd_names.add("final_layer.adaLN_modulation.1.weight") optimizer.wqk_names.add("attn") optimizer.wqk_names.add('mlp')