diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index a660359a..37c2dd8c 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -53,7 +53,7 @@ def __init__( self._world_size = 1 self._sync_buf = False self._first_eval_run = False - + # not-norm parameters self.not_norm = [] # norm parameters diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 02fbdc4d..ac36878f 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -195,7 +195,7 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_ if self.residual_in_fp32: residual1 = residual1.to(torch.float32) - + if self.norm_fp32: hidden_states = hidden_states.to(self.dtype) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index c5ce710b..827e75f4 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -101,12 +101,12 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): adam_cfg = gpc.config.adam if gpc.config.model.get("norm_fp32"): - params=[ + params = [ {"params": model.not_norm, "weight_decay": adam_cfg.weight_decay, "name": "default"}, {"params": model.norm, "weight_decay": adam_cfg.weight_decay, "name": "norm"}, ] else: - params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] + params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}] naive_optimizer = torch.optim.AdamW( params=params, lr=adam_cfg.lr,