Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Aug 17, 2024
1 parent 843e7e6 commit 02af7a2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
8 changes: 4 additions & 4 deletions library/adam_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def __init__(
state["weight_decay"] = 0.0
else:
state["weight_decay"] = weight_decay
if "embed" in param_name or "wte" in param_name or "embd" in param_name:
if "embed" in param_name or "wte" in param_name or "embd" in param_name or "layer.weight" in param_name:
count_embd += 1
if "lm_head.weight" in param_name or "output.weight" in param_name:
if "lm_head.weight" in param_name or "output.weight" in param_name or "final_layer.linear.weight" in param_name or "final_layer.adaLN_modulation.1.weight" in param_name:
count_output += 1
if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name:
if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name or "attn.qkv.weight" in param_name:
count_wq += 1
assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}"
state["head_numel"] = self.dim * self.dim // self.n_heads
if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name:
if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name or "attn.proj.weight" in param_name:
count_wk += 1
assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}"
state["head_numel"] = self.dim * self.dim // self.n_heads
Expand Down
11 changes: 9 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4559,8 +4559,15 @@ def get_optimizer(args, trainable_params, model=None):
except ImportError:
raise ImportError("No adam-mini / adam-mini がインストールされていないようです")

optimizer = optimizer_class(model.named_parameters(), lr=lr, **optimizer_kwargs)
optimizer.embd_names.add("embed")
named_params = [(name, param) for name, param in model.named_parameters() if param.requires_grad]

optimizer_kwargs["dim"] = 722
optimizer_kwargs["n_heads"] = 19

optimizer = optimizer_class(named_params, lr=lr, **optimizer_kwargs)
optimizer.embd_names.add("layer.weight")
optimizer.embd_names.add("final_layer.linear.weight")
optimizer.embd_names.add("final_layer.adaLN_modulation.1.weight")
optimizer.wqk_names.add("attn")
optimizer.wqk_names.add('mlp')

Expand Down

0 comments on commit 02af7a2

Please sign in to comment.