diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 2c6237cd9a1a..72c3ec46ae75 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -464,23 +464,23 @@ def __init__( if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( - optimizer, - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale, - clip_grad_norm, - verbose, - reduce_bucket_size, - communication_dtype, - overlap_communication, - partition_grad, - cpu_offload, - dp_process_group, - forced_dtype, + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + forced_dtype=forced_dtype, ) def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index dffa4ce164ef..088b67c8c533 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -262,6 +262,7 @@ def __init__( communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, cpu_offload: bool = False, + master_weights: bool = True, verbose: bool = False, ) -> None: super().__init__() @@ -272,18 +273,19 @@ def __init__( self.precision = precision self.zero_optim_kwargs = dict( initial_scale=initial_scale, + min_scale=min_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, - min_scale=min_scale, max_scale=max_scale, clip_grad_norm=max_norm, reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, communication_dtype=communication_dtype, overlap_communication=overlap_communication, - cpu_offload=cpu_offload, partition_grad=(stage == 2), + cpu_offload=cpu_offload, + master_weights=master_weights, ) self.verbose = verbose diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index d9be7af17d15..e6974a6760ce 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -75,6 +75,7 @@ def __init__( cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm forced_dtype: Optional[torch.dtype] = None, + master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]["params"][0].dtype @@ -106,6 +107,9 @@ def __init__( # gradient clipping self._clip_grad_norm = clip_grad_norm + # master weights copy + self._master_weights = master_weights + if forced_dtype: for group in self.optim.param_groups: group_params = group["params"] @@ -135,7 +139,6 @@ def __init__( self._working_param_groups[group_id] = group_params master_param_current_rank = self._create_master_param_current_rank(group_params) - self._master_param_groups_of_current_rank[group_id] = master_param_current_rank # need to replace the params in the `params` field in the optimizer @@ -200,11 +203,18 @@ def _create_master_param_current_rank(self, param_list): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + # reset working params' ptr when no master weights + if self._master_weights == False: + param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // self._world_size) - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + # use fp32 when master_weights is True + if self._master_weights is True: + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + else: + splited_param_current_rank = splited_params[self._local_rank] params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) @@ -402,9 +412,7 @@ def step(self, closure=None): # and should not be updated real_working_params = dict() real_master_params = dict() - grad_index = 0 if self._partition_grads else self._local_rank - for group_id in range(self.num_param_groups): master_params = self._master_param_groups_of_current_rank[group_id] real_working_params[group_id] = [] @@ -417,7 +425,12 @@ def step(self, closure=None): grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: real_working_params[group_id].append(working_param) - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) + # no need to copy fp32 grad if master_weights is False + grad = ( + grads[grad_index].to(splited_param.dtype).to(splited_param.device) + if self._master_weights + else grads[grad_index] + ) splited_param.grad = grad grad_partition_groups.append(grad) real_master_params[group_id].append(splited_param) @@ -445,17 +458,16 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - dtype = real_working_params[0][0].dtype + # dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: