Skip to content

Commit

Permalink
add norm_fp32 in config
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Sep 6, 2023
1 parent 6ad317e commit 0296425
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 36 deletions.
1 change: 1 addition & 0 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
norm_fp32=False, # whether to use fp32 for norm layer when dtype is torch.float16 or torch.bfloat16
)
"""
zero1 parallel:
Expand Down
18 changes: 10 additions & 8 deletions internlm/core/naive_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,14 @@ def __init__(
parallel_mode: ParallelMode = ParallelMode.DATA,
sync_buffer: bool = True,
dtype=torch.float16,
norm_fp32: bool = False,
):
super().__init__()
self.model = model.to(dtype)
self._output_to_fp32 = output_to_fp32
self._sync_buf = sync_buffer
self.dtype = dtype

# not-norm parameters
self.not_norm = []
# norm parameters
self.norm = []
self.norm_fp32 = norm_fp32

if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
self._process_group = gpc.get_group(parallel_mode)
Expand All @@ -57,9 +54,14 @@ def __init__(
self._sync_buf = False
self._first_eval_run = False

if self.dtype in [torch.float16, torch.bfloat16]:
# set the norm weight dtype to fp32
self.set_norm_fp32(self.model)
if self.norm_fp32:
# not-norm parameters
self.not_norm = []
# norm parameters
self.norm = []
if self.dtype in [torch.float16, torch.bfloat16]:
# set the norm weight dtype to fp32
self.set_norm_fp32(self.model)

def set_norm_fp32(self, module):
if len(list(module.children())) == 0:
Expand Down
8 changes: 7 additions & 1 deletion internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,13 @@ def args_sanity_check():
"torch.float32",
"torch.tf32",
]

if "norm_fp32" not in model:
logger.warning("norm_fp32 is not set, use False by defalut!")
model._add_item("norm_fp32", False)
else:
if gpc.config.model.dtype is torch.float32:
gpc.config.model.norm_fp32 = False

if "checkpoint" in model:
if model.checkpoint is True:
model.checkpoint = 1
Expand Down
40 changes: 15 additions & 25 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
self.layer_idx = layer_idx
self.use_flash_attn = use_flash_attn
self.dtype = dtype
self.norm_fp32 = gpc.config.model.get("norm_fp32")

head_dim = hidden_size // num_attention_heads
self.mixer = MHA(
Expand Down Expand Up @@ -187,36 +189,17 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_
"inference_params": inference_params,
}

# def _dropout_and_norm_attn(_hidden_states):
# _dropped = self.dropout1(_hidden_states)
# _residual = _dropped
# _hidden_states = self.norm1(_residual.float())
# return _residual, _hidden_states

# if self.dropout_selective_checkpoint:
# residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states)
# else:
# residual, hidden_states = _dropout_and_norm_attn(hidden_states)

dropped1 = self.dropout1(hidden_states)
residual1 = dropped1
hidden_states = self.norm1(residual1.float())

if self.residual_in_fp32:
residual1 = residual1.to(torch.float32)

if self.norm_fp32:
hidden_states = hidden_states.to(self.dtype)

hidden_states = self.mixer(hidden_states.half(), **mixer_kwargs)

# def _dropout_and_norm_ffn(_residual, _hidden_states):
# _dropped = self.dropout2(_hidden_states)
# _residual = (_dropped + _residual) if _residual is not None else _dropped
# _hidden_states = self.norm2(_residual.float())
# return _residual, _hidden_states

# if self.dropout_selective_checkpoint:
# residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states)
# else:
# residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
hidden_states = self.mixer(hidden_states, **mixer_kwargs)

dropped2 = self.dropout2(hidden_states)
residual2 = (dropped2 + residual1) if residual1 is not None else dropped2
Expand All @@ -225,7 +208,10 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_
if self.residual_in_fp32:
residual2 = residual2.to(torch.float32)

hidden_states = self.mlp(hidden_states.half())
if self.norm_fp32:
hidden_states = hidden_states.to(self.dtype)

hidden_states = self.mlp(hidden_states)

return hidden_states + residual2

Expand Down Expand Up @@ -356,6 +342,8 @@ def __init__(
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output
self.dtype = dtype
self.norm_fp32 = gpc.config.model.get("norm_fp32")

def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
Expand Down Expand Up @@ -392,7 +380,9 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "head"):
hidden_states = self.head(hidden_states.half())
if self.norm_fp32:
hidden_states = hidden_states.to(self.dtype)
hidden_states = self.head(hidden_states)

if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
Expand Down
10 changes: 8 additions & 2 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def initialize_model():
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
norm_fp32=False,
)
for _m in model
]
Expand All @@ -69,6 +70,7 @@ def initialize_model():
output_to_fp32=is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
norm_fp32=False,
)

# This sync is very important, cause the model weights kept in optimizer are copied
Expand Down Expand Up @@ -98,11 +100,15 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
param_bcast_sync_handler = None

adam_cfg = gpc.config.adam
naive_optimizer = torch.optim.AdamW(
if gpc.config.model.get("norm_fp32"):
params=[
{"params": model.not_norm, "weight_decay": adam_cfg.weight_decay, "name": "default"},
{"params": model.norm, "weight_decay": adam_cfg.weight_decay, "name": "norm"},
],
]
else:
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
naive_optimizer = torch.optim.AdamW(
params=params,
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,
Expand Down

0 comments on commit 0296425

Please sign in to comment.