Skip to content

Commit

Permalink
remove fp32 process in optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Sep 6, 2023
1 parent 7f687bf commit 30572b0
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ def __init__(
param_bcast_sync_handler: ParamBcastSyncHandler = None,
):
# DynamicGradScaler related args
if gpc.config.model.dtype is torch.float32:
initial_scale = 1
else:
initial_scale = grad_scal_cfg.fp16.initial_scale
initial_scale = grad_scal_cfg.fp16.initial_scale
min_scale = grad_scal_cfg.fp16.min_scale
growth_interval = grad_scal_cfg.fp16.growth_interval
growth_factor = grad_scal_cfg.growth_factor
Expand Down Expand Up @@ -571,8 +568,7 @@ def _step(self, closure=None, norms=None):
found_inf = True

loss_scale = float(self.loss_scale.item()) # backup
if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf)
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
if gpc.is_rank_for_log():
Expand Down Expand Up @@ -621,11 +617,10 @@ def _step(self, closure=None, norms=None):
global_norm_groups[group_name] = norm**0.5

# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
)
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
)

# update the parameters
timer("step").start()
Expand Down

0 comments on commit 30572b0

Please sign in to comment.