diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 16bb2e9b80b1..a3039f97ba30 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -128,22 +128,20 @@ def save_unsharded_optimizer( # the `state_dict` in LowLevelZeroOptimizer has communication # if only the master rank collect state_dict and save, # the communication on each rank would not match - if use_async: + if use_async and self.coordinator.is_master(): if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts[id(optimizer)] = {} pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] else: pinned_state_dicts = None - state_dict = optimizer.state_dict(pinned_state_dicts) + state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True) if self.coordinator.is_master(): if use_async: from tensornvme.async_file_io import AsyncFileWriter from colossalai.utils.safetensors import save_nested - f_writer = AsyncFileWriter( - fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread" - ) + f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread") save_nested(f_writer, state_dict) self.async_writers.append(f_writer) else: @@ -192,13 +190,15 @@ def save_sharded_optimizer( # state_dict only provide only 'param_groups' state_dict = optimizer.optim.state_dict() # state shard would be handled by the low-level zero optimizer - if use_async: + if use_async and self.coordinator.is_master(): if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts[id(optimizer)] = {} pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] else: pinned_state_dicts = None - sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts) + sharded_state = optimizer.state_dict_shard( + max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True + ) # Preparing file paths and index file. states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) @@ -227,7 +227,7 @@ def save_sharded_optimizer( from colossalai.utils.safetensors import save_nested f_writer = AsyncFileWriter( - fp=open(checkpoint_file_path, "wb", buffering=0), + checkpoint_file_path, n_entries=self.N_WRITE_ENTRIES, backend="pthread", ) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 9e431d3559ac..b96c0c7b8557 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -72,7 +72,6 @@ def __init__(self): def _sync_io(self): for writer in self.async_writers: writer.synchronize() - writer.fp.close() self.async_writers.clear() def _sync_d2h(self): diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index ddfe5502f718..2545806775a4 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -56,7 +56,7 @@ def save_unsharded_model( if use_async: from tensornvme.async_file_io import AsyncFileWriter - writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread") + writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.async_writers.append(writer) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 581575058cd9..da3199e12a40 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -690,9 +690,7 @@ def save_unsharded_model( from colossalai.utils.safetensors import move_and_save - writer = AsyncFileWriter( - open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread" - ) + writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.async_writers.append(writer) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index eb8bb2dcf35b..77b9faa0bcbf 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -311,7 +311,7 @@ def async_save_state_dict_shards( index_file.append_weight_map(key, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file) - writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread") + writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread") writers.append(writer) if pinned_state_dict is not None: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 74b1817a0cba..3c67299bbe79 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -776,7 +776,9 @@ def pack_group(group): return {"state": packed_state, "param_groups": param_groups} - def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict: + def state_dict( + self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, only_on_master: bool = False + ) -> Dict: """Return a state_dict same with DDP Returns: @@ -785,23 +787,29 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens zero_state = dict() device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): + working_param = self.master_to_working_param[id(param)] + pg = self.param_to_pg[working_param] + if not only_on_master or get_nd_rank(pg) == 0: + zero_state[param] = copy.deepcopy(state) + else: + zero_state[param] = {} + if pinned_state_dicts is not None and param not in pinned_state_dicts: pinned_state_dicts[param] = {} - zero_state[param] = copy.deepcopy(state) + for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - working_param = self.master_to_working_param[id(param)] - pg = self.param_to_pg[working_param] gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) - if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: - pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu") - if pinned_state_dicts is not None: - pinned_state_dicts[param][k].copy_(param_state) - zero_state[param][k] = pinned_state_dicts[param][k] - else: - zero_state[param][k] = param_state.cpu() + if not only_on_master or get_nd_rank(pg) == 0: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: + pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu") + if pinned_state_dicts is not None: + pinned_state_dicts[param][k].copy_(param_state) + zero_state[param][k] = pinned_state_dicts[param][k] + else: + zero_state[param][k] = param_state.cpu() states_dict = self._pack_state(zero_state) @@ -837,7 +845,10 @@ def load_state_dict(self, state_dict: Dict): self.optim.load_state_dict(zero_state_dict) def state_dict_shard( - self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None + self, + max_shard_size: int = 1024, + pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + only_on_master: bool = False, ) -> Iterator[Tuple[Dict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Only include the 'state' in state_dict. @@ -862,25 +873,31 @@ def state_dict_shard( cnt += 1 for param_idx, states in local_states.items(): current_block_size = 0 - current_block = copy.deepcopy(states) if pinned_state_dicts is not None and param_idx not in pinned_state_dicts: pinned_state_dicts[param_idx] = {} master_param = idx2master[param_idx] working_param = self.master_to_working_param[id(master_param)] pg = self.param_to_pg[working_param] + if not only_on_master or get_nd_rank(pg) == 0: + current_block = copy.deepcopy(states) + else: + current_block = {} for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) - if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]: - pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu") - if pinned_state_dicts is not None: - pinned_state_dicts[param_idx][k].copy_(state_tensor) - current_block[k] = pinned_state_dicts[param_idx][k] - else: - current_block[k] = state_tensor.cpu() + if not only_on_master or get_nd_rank(pg) == 0: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]: + pinned_state_dicts[param_idx][k] = torch.empty_like( + state_tensor, pin_memory=True, device="cpu" + ) + if pinned_state_dicts is not None: + pinned_state_dicts[param_idx][k].copy_(state_tensor) + current_block[k] = pinned_state_dicts[param_idx][k] + else: + current_block[k] = state_tensor.cpu() current_block_size += calculate_tensor_size(state_tensor) if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py index 521ec10bd09b..882b5e2c585f 100644 --- a/tests/test_checkpoint_io/test_safetensors_async_io.py +++ b/tests/test_checkpoint_io/test_safetensors_async_io.py @@ -10,6 +10,7 @@ except ModuleNotFoundError: raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") + from colossalai.testing import check_state_dict_equal from colossalai.utils import get_current_device @@ -110,20 +111,20 @@ def test_save_load(): } optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" - f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") + f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread") save_nested(f_writer, optimizer_state_dict) f_writer.sync_before_step() f_writer.synchronize() - f_writer.fp.close() + del f_writer load_state_dict = load_flat(optimizer_saved_path) check_state_dict_equal(load_state_dict, optimizer_state_dict) optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors" - f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread") + f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread") save_nested(f_writer, optimizer_state_dict["state"]) f_writer.sync_before_step() f_writer.synchronize() - f_writer.fp.close() + del f_writer load_state_dict_shard = load_flat(optimizer_shard_saved_path) check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"]) @@ -133,21 +134,21 @@ def test_save_load(): "module.weight2": torch.rand((1024, 1024)), } model_saved_path = f"{tempdir}/save_model.safetensors" - f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") + f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread") save(f_writer, model_state_dict) f_writer.sync_before_step() f_writer.synchronize() - f_writer.fp.close() + del f_writer load_state_dict = load_file(model_saved_path) check_state_dict_equal(model_state_dict, load_state_dict) model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()} model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()} model_saved_path = f"{tempdir}/save_model_cuda.safetensors" - f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") + f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread") move_and_save(f_writer, model_state_dict_cuda, model_state_pinned) f_writer.sync_before_step() f_writer.synchronize() - f_writer.fp.close() + del f_writer load_state_dict = load_file(model_saved_path) check_state_dict_equal(model_state_dict, load_state_dict)