diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2545806775a4..a0c7dd610753 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -13,7 +13,7 @@ from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( - async_save_state_dict_shards, + async_move_save_state_dict_shards, create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, @@ -189,7 +189,7 @@ def save_sharded_model( if use_async: pinned_state_dict = self.pinned_state_dicts.get(id(model), None) - total_size, new_pinned_state_dict, writers = async_save_state_dict_shards( + total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint_path, index_file=index_file, diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index da3199e12a40..75724fc78ea7 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -22,6 +22,7 @@ to_unpadded_tensor, ) from colossalai.utils import get_current_device, get_non_persistent_buffers_set +from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat, save from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -88,7 +89,11 @@ def __init__( @staticmethod def _model_sharder( - model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024 + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. @@ -102,6 +107,13 @@ def _model_sharder( if is_padded_tensor(param): param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(param_) + param_ = pinned_state_dicts[prefix + name] + else: + param_ = param_.cpu() block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -111,6 +123,13 @@ def _model_sharder( for name, buf in model.named_buffers(): if buf is not None and name not in non_persist_buffers_set: buffer = buf if keep_vars else buf.detach() + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] + else: + buffer = buffer.cpu() block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -122,6 +141,13 @@ def _model_sharder( is not torch.nn.Module.get_extra_state ): extra_state = model.get_extra_state() + if pinned_state_dicts is not None: + if extra_state_key not in pinned_state_dicts: + pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] + else: + extra_state = extra_state.cpu() block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size @@ -136,6 +162,7 @@ def _optimizer_sharder( dp_group: ProcessGroup, tp_group: ProcessGroup, size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, ): # An internel method that breaks state_dict of optimizer into shards within limited size. @@ -153,6 +180,9 @@ def _optimizer_sharder( working_param = param param_id = param_info["param2id"][id(working_param)] + if pinned_state_dicts is not None: + if param_id not in pinned_state_dicts: + pinned_state_dicts[param_id] = {} original_shape = param_info["param2shape"][id(working_param)] state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( state, @@ -162,6 +192,7 @@ def _optimizer_sharder( tp_group=tp_group, use_zero=use_zero, inplace=False, + pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None, ) block, block_size = state_dict_sharder.append_optim_state(param_id, state_) @@ -216,15 +247,32 @@ def save_sharded_model( # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + control_saving = self.tp_rank == 0 + if control_saving and use_async: + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = HybridParallelCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + n_write_entries=self.N_WRITE_ENTRIES, + state_preprocess=False, + ) + self.async_writers.extend(writers) else: total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -259,24 +307,26 @@ def save_sharded_model( save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) if use_async: - total_size, returned_state_dict, writers = async_save_state_dict_shards( + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + n_write_entries=self.N_WRITE_ENTRIES, + state_preprocess=False, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=weights_name, is_master=control_saving, + use_safetensors=use_safetensors, use_pp_format=True, - n_write_entries=191, ) - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - use_pp_format=True, - ) if control_saving: assert ( @@ -448,26 +498,48 @@ def save_sharded_optimizer( # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + control_saving = self.dp_rank == 0 and self.tp_rank == 0 + + if use_async and control_saving: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + # pinned_state_dicts = None state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder( optimizer, use_zero=self.use_zero, dp_group=self.global_dp_group, tp_group=self.tp_group, size_per_shard=size_per_shard, + pinned_state_dicts=pinned_state_dicts, ) - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.dp_rank == 0 and self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + n_write_entries=self.N_WRITE_ENTRIES, + use_pp_format=True, + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) if control_saving: # Store param groups. @@ -499,17 +571,31 @@ def save_sharded_optimizer( # Manage filenames of sharded weights and index file for each pipeline stage. states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - use_pp_format=True, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + n_write_entries=self.N_WRITE_ENTRIES, + use_pp_format=True, + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) if control_saving: assert ( @@ -622,7 +708,10 @@ def _get_param_id_from_optimizer_param( continue file_path = os.path.join(ckpt_root_path, filename) - state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + if file_path.endswith(".safetensors"): + state_dict = load_flat(file_path) + else: + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) @@ -672,7 +761,16 @@ def save_unsharded_model( # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + from tensornvme.async_file_io import AsyncFileWriter + + writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[id(model)][name] + self.async_writers.append(writer) + save(f_writer=writer, state_dict=state_dict) else: save_state_dict(state_dict, checkpoint, use_safetensors) else: @@ -688,13 +786,14 @@ def save_unsharded_model( if use_async: from tensornvme.async_file_io import AsyncFileWriter - from colossalai.utils.safetensors import move_and_save - writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) + for name, param in complete_state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] self.async_writers.append(writer) - move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + save(writer, state_dict=complete_state_dict) else: save_state_dict(complete_state_dict, checkpoint, use_safetensors) @@ -759,6 +858,7 @@ def save_unsharded_optimizer( # gather complete state from tp shards & dp shards param_id = optimizer.param_info["param2id"][id(working_param)] original_shape = optimizer.param_info["param2shape"][id(working_param)] + local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( state, working_param, @@ -778,7 +878,20 @@ def save_unsharded_optimizer( ] state_dict = {"param_groups": param_groups, "state": local_states} if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[k] + self.async_writers.append(writer) + save(f_writer=writer, state_dict=flatten_state_dict, metadata=metadata) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. states_list = [None for _ in range(self.pp_size)] @@ -794,7 +907,20 @@ def save_unsharded_optimizer( state_dict = {"param_groups": param_groups, "state": dict()} for _states in states_list: state_dict["state"].update(_states) - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[k] + self.async_writers.append(writer) + save(f_writer=writer, state_dict=flatten_state_dict, metadata=metadata) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): """ @@ -820,7 +946,10 @@ def _get_param_id_from_optimizer_param( assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" # Complete optimizer state_dict loaded from checkpoint, need to be processed later. - state_dict = load_state_dict(checkpoint) + if checkpoint.endswith(".safetensors"): + state_dict = load_flat(checkpoint) + else: + state_dict = load_state_dict(checkpoint) # Load param_groups. updated_groups = [] @@ -874,6 +1003,7 @@ def gather_from_sharded_optimizer_state( use_zero: bool, inplace: bool, device: torch.device = torch.device("cpu"), + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> OrderedDict: """ With given parameter and its optimizer states, gather the complete optimizer state for saving. @@ -917,7 +1047,13 @@ def gather_from_sharded_optimizer_state( v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) v = to_unpadded_tensor(v) - state_[k] = v.detach().clone().to(device) + if pinned_state_dicts is not None: + if k not in pinned_state_dicts: + pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu") + pinned_state_dicts[k].copy_(v) + state_[k] = pinned_state_dicts[k] + else: + state_[k] = v.detach().clone().to(device) return state_ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 77b9faa0bcbf..abd3efba1b84 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -19,7 +19,7 @@ to_global, to_global_for_customized_distributed_tensor, ) -from colossalai.utils.safetensors import move_and_save +from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save, save SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -267,6 +267,65 @@ def save_state_dict_shards( def async_save_state_dict_shards( + sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + n_write_entries: int, + use_pp_format: bool = False, + state_preprocess: bool = False, +) -> Tuple[int, list]: + """ + Save sharded state dict only on master rank, this method can be used by both model and optimizer states. + Args: + sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. + checkpoint (str): The path of checkpoint directory as string. + index_file (CheckpointIndexFile): The index file object to be updated. + base_filename (str): Decides the prefix of filenames of shards. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + Returns: + int: the total size of shards + """ + from tensornvme.async_file_io import AsyncFileWriter + + total_size = 0 + shard_filenames = [] + writers = [] + for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair + # Just loop over the sharder and gather to other ranks if not master + if not is_master: + del shard + continue + shard_file = get_shard_filename(base_filename, idx) + total_size = total_size + current_size + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint, shard_file) + + if state_preprocess: + state_dict, _ = _flatten_optim_state_dict(state_dict=shard) + else: + state_dict = shard + writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread") + writers.append(writer) + + # Only save on master rank. + save(f_writer=writer, state_dict=state_dict) + shard_filenames.append(shard_file) + del shard + + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + + return total_size, writers + + +def async_move_save_state_dict_shards( sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 86d7924fb828..81d184f7681a 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -38,12 +38,13 @@ ] -@parameterize("shard", [True, False]) +@parameterize("shard", [False, True]) @parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) +@parameterize("use_async", [False, True]) @clear_cache_before_run() -def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) ) @@ -85,8 +86,16 @@ def _preprocess_data(data): with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.save_optimizer( + optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = model_fn().cuda()