Skip to content

Commit

Permalink
fix lint bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Sep 6, 2023
1 parent 89167f1 commit e1e683d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion internlm/core/naive_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self._world_size = 1
self._sync_buf = False
self._first_eval_run = False

# not-norm parameters
self.not_norm = []
# norm parameters
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_

if self.residual_in_fp32:
residual1 = residual1.to(torch.float32)

if self.norm_fp32:
hidden_states = hidden_states.to(self.dtype)

Expand Down
4 changes: 2 additions & 2 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):

adam_cfg = gpc.config.adam
if gpc.config.model.get("norm_fp32"):
params=[
params = [
{"params": model.not_norm, "weight_decay": adam_cfg.weight_decay, "name": "default"},
{"params": model.norm, "weight_decay": adam_cfg.weight_decay, "name": "norm"},
]
else:
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
naive_optimizer = torch.optim.AdamW(
params=params,
lr=adam_cfg.lr,
Expand Down

0 comments on commit e1e683d

Please sign in to comment.