From 12e1e8a76c51ce07036f2df49c64abccba8aae65 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Nov 2024 17:52:24 +0800 Subject: [PATCH] [async io]supoort async io (#6137) * support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix --- colossalai/booster/booster.py | 5 +- colossalai/booster/plugin/gemini_plugin.py | 12 +- .../booster/plugin/low_level_zero_plugin.py | 65 ++++++++- colossalai/booster/plugin/torch_ddp_plugin.py | 9 +- .../booster/plugin/torch_fsdp_plugin.py | 12 +- .../checkpoint_io/checkpoint_io_base.py | 20 ++- .../checkpoint_io/general_checkpoint_io.py | 2 + .../hybrid_parallel_checkpoint_io.py | 5 +- colossalai/checkpoint_io/moe_checkpoint.py | 9 +- colossalai/checkpoint_io/utils.py | 8 +- colossalai/testing/comparison.py | 6 +- colossalai/utils/safetensors.py | 102 ++++++++++++-- colossalai/zero/low_level/low_level_optim.py | 33 ++++- .../test_low_level_zero_checkpoint_io.py | 5 +- .../test_safetensors_async_io.py | 127 ++++++++++++++++++ 15 files changed, 374 insertions(+), 46 deletions(-) create mode 100644 tests/test_checkpoint_io/test_safetensors_async_io.py diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index ad4047ee2fc5..43a3b75317ba 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -359,6 +359,7 @@ def save_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ) -> None: """ Save optimizer to checkpoint. @@ -374,7 +375,9 @@ def save_optimizer( names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ - self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) + self.checkpoint_io.save_optimizer( + optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async + ) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None: """Save lr scheduler to checkpoint. diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 35c51da0105a..30c1257ef14c 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -94,7 +94,9 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = assert isinstance(model, GeminiDDP), "Please boost the model before loading!" super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save unsharded optimizer state dict to checkpoint. After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. @@ -178,7 +180,13 @@ def load_sharded_model( return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) def save_sharded_optimizer( - self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + self, + optimizer: GeminiOptimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + use_async: bool = False, ): """ Save sharded optimizer state dict to checkpoint folder. diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index d4eb1bbed75a..12ffe5fe5adb 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -24,6 +24,7 @@ get_shard_filename, load_param_groups_into_optimizer, load_shard_state_dict, + load_state_dict, load_states_into_optimizer, save_param_groups, save_state_dict, @@ -113,7 +114,9 @@ def _hook_context(self): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False + ): """Save optimizer to checkpoint but only on master process. Args: @@ -125,9 +128,34 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, # the `state_dict` in LowLevelZeroOptimizer has communication # if only the master rank collect state_dict and save, # the communication on each rank would not match - state_dict = optimizer.state_dict() + if use_async: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + state_dict = optimizer.state_dict(pinned_state_dicts) if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + from colossalai.utils.safetensors import save_nested + + f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") + save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]}) + self.async_writers.append(f_writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + use_async = checkpoint.endswith(".safetensors") + if use_async: + from colossalai.utils.safetensors import load_flat + + checkpoint = load_flat(checkpoint) + else: + checkpoint = load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) def save_sharded_optimizer( self, @@ -136,6 +164,7 @@ def save_sharded_optimizer( gather_dtensor: bool = False, prefix: str = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save sharded Zero-optimizer checkpoint under the given checkpointing path. @@ -161,10 +190,16 @@ def save_sharded_optimizer( # state_dict only provide only 'param_groups' state_dict = optimizer.optim.state_dict() # state shard would be handled by the low-level zero optimizer - sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) + if use_async: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts) # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) index_file.append_meta_data("param_groups", param_group_file) @@ -184,7 +219,18 @@ def save_sharded_optimizer( checkpoint_file_path = os.path.join(checkpoint, shard_file) if self.coordinator.is_master(): - save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + from colossalai.utils.safetensors import save_nested + + f_writer = AsyncFileWriter( + fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread" + ) + save_nested(f_writer, shard) + self.async_writers.append(f_writer) + else: + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) # Wrap up index file. index_file.append_meta_data("total_size", total_size) @@ -223,7 +269,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + from colossalai.utils.safetensors import load_flat + + state_dict = load_flat(shard_file) + else: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) # shard state dict for param_idx, state in state_dict.items(): for k, v in state.items(): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 09830a2f9873..07be5b0516f6 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -52,7 +52,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str) assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" super().load_unsharded_optimizer(optimizer, checkpoint) - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer to checkpoint but only on master process. """ @@ -113,13 +115,16 @@ def save_sharded_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save optimizer to sharded checkpoint but only on master process. """ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): - super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) + super().save_sharded_optimizer( + optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async + ) def load_sharded_optimizer( self, diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index d309370dd620..b80d6d4b6eb8 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -67,7 +67,9 @@ def save_unsharded_model( full_model_state = model.state_dict() utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer to checkpoint but only on master process. """ @@ -157,7 +159,13 @@ def load_sharded_model( model.unwrap().load_state_dict(fsdp_state_dict, strict=False) def save_sharded_optimizer( - self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int + self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + use_async: bool = False, ): """ Save optimizer to checkpoint but only on master process. diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 6e4681f0ec2e..9e431d3559ac 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -213,6 +213,7 @@ def save_optimizer( gather_dtensor=True, prefix: str = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. @@ -229,11 +230,12 @@ def save_optimizer( prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ - if shard: - self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + self.save_sharded_optimizer( + optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async + ) else: - self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async) # ======================================================== # Abstract methods for model loading/saving implementation @@ -326,7 +328,13 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): @abstractmethod def save_sharded_optimizer( - self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + use_async: bool = False, ): """ Save optimizer to sharded checkpoint. @@ -340,7 +348,9 @@ def save_sharded_optimizer( """ @abstractmethod - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer to unsharded checkpoint. diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 580be91ca0d8..a2d1dd158afa 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -98,6 +98,7 @@ def save_sharded_optimizer( gather_dtensor: bool, prefix: str, size_per_shard: int, + use_async: bool = False, ): """ Save sharded optimizer checkpoint under the given checkpointing path. @@ -155,6 +156,7 @@ def save_unsharded_optimizer( optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, + use_async: bool = False, ): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 49d4f35f9cc0..d66171c58ccd 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -416,6 +416,7 @@ def save_sharded_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save sharded optimizer checkpoint under the given checkpointing path. @@ -725,7 +726,9 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer state dict to a file with given path. diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 4cb0f300f65e..3b07856ca06c 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -369,6 +369,7 @@ def save_sharded_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save sharded optimizer checkpoint under the given checkpointing path. @@ -729,7 +730,13 @@ def save_unsharded_model( dist.barrier() # Copied from colossalai.moe - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool, + use_async: bool = False, + ): """ Save optimizer state dict to a file with given path. diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0b4a88e78c70..09cce3059e80 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -24,9 +24,11 @@ SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" STATES_NAME = "pytorch_optim.bin" +SAFE_STATE_NAME = "optimizer.safetensors" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" STATES_INDEX_NAME = "pytorch_optim.bin.index.json" +SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json" GROUP_FILE_NAME = "pytorch_optim_group.bin" # ====================================== @@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False): return weights_name, save_index_file -def get_optimizer_base_filenames(prefix: str = None): +def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False): """ generate base optimizer state filenames """ - states_name = STATES_NAME + states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME states_name = add_prefix(states_name, prefix) - save_index_file = STATES_INDEX_NAME + save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME save_index_file = add_prefix(save_index_file, prefix) param_group_file = GROUP_FILE_NAME diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 8f9cce246556..4cbb01163e5a 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,4 +1,4 @@ -from typing import Any, List, OrderedDict +from typing import Any, List, OrderedDict, Tuple import torch import torch.distributed as dist @@ -78,7 +78,9 @@ def check_state_dict_equal( v1 = v1.to(v2.dtype) assert_close_loose(v1, v2) else: - assert v1 == v2, f"{v1} not equals to {v2}" + if isinstance(v1, Tuple) and not isinstance(v2, Tuple): + v2 = tuple(v2) + assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}" def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 0359541147f0..ad7d3be77d72 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -1,10 +1,11 @@ # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 import json +import warnings from dataclasses import asdict, dataclass from typing import Dict, List, Optional, Tuple import torch -from safetensors.torch import _TYPES +from safetensors.torch import _TYPES, load_file, safe_open try: from tensornvme.async_file_io import AsyncFileWriter @@ -27,34 +28,93 @@ class PreparedData: offset: int -def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: +def flatten_dict(nested_dict, parent_key="", separator="^"): + """ + Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. + + nested_dict: The input nested dictionary. + parent_key: The parent key currently being processed. + separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary." + """ + items = [] + for k, v in nested_dict.items(): + new_key = f"{parent_key}{separator}{k}" if parent_key else str(k) + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, separator).items()) + else: + v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v + items.append((new_key, v)) + + return dict(items) + + +def unflatten_dict(flattened_dict, separator="^"): + """ + Restore a flattened dictionary back to a multi-level nested dictionary. + + flattened_dict: The flattened dictionary. + separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary. + """ + nested_dict = {} + for key, value in flattened_dict.items(): + keys = key.split(separator) + try: + keys[0] = int(keys[0]) + except ValueError: + warnings.warn(f"{key[0]} can't convert to integer") + d = nested_dict + for part in keys[:-1]: + if part not in d: + d[part] = {} + d = d[part] + assert isinstance(value, torch.Tensor) + d[keys[-1]] = value + + return nested_dict + + +def prepare( + data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None +) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: + if metadata is not None: + assert isinstance(metadata, dict) + for k, v in metadata.items(): + metadata[k] = json.dumps(v) + assert isinstance(k, str) + assert isinstance(metadata[k], str) tensors = [] tensor_keys = [] - metadata = {} + header = {} offset = 0 + + if metadata is not None: + header["__metadata__"] = metadata + for name, tensor in data.items(): n = tensor.numel() * tensor.element_size() tensor_info = TensorInfo( dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) ) offset += n - metadata[name] = asdict(tensor_info) + header[name] = asdict(tensor_info) tensors.append(tensor) tensor_keys.append(name) - metadata_buf = json.dumps(metadata).encode("utf-8") + header_buf = json.dumps(header).encode("utf-8") - extra = (8 - len(metadata_buf) % 8) % 8 - metadata_buf += b" " * extra + extra = (8 - len(header_buf) % 8) % 8 + header_buf += b" " * extra - n = len(metadata_buf) + n = len(header_buf) - return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys + return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys -def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: - prepared_data, tensors, _ = prepare(state_dict) +def save( + f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None +) -> None: + prepared_data, tensors, _ = prepare(state_dict, metadata) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset f_writer.write(n.to_bytes(8, byteorder="little")) @@ -64,6 +124,13 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) +def save_nested( + f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None +) -> None: + flatten_data = flatten_dict(state_dict) + save(f_writer, flatten_data, metadata) + + def move_and_save( f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], @@ -81,3 +148,16 @@ def move_and_save( f_writer.write_tensor(state_dict[name], state_dict_pinned[name]) else: f_writer.write_tensor(state_dict[name]) + + +def load_flat(checkpoint_path): + with safe_open(checkpoint_path, framework="pt") as f: + metadata = f.metadata() + state_dict_load = load_file(checkpoint_path) + state_dict = unflatten_dict(state_dict_load) + if metadata is None: + return state_dict + metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items())) + combined_state_dict = {"state": state_dict} + combined_state_dict.update(metadata) + return combined_state_dict diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 26fff75fbfdf..e3c301640867 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -770,7 +770,7 @@ def pack_group(group): return {"state": packed_state, "param_groups": param_groups} - def state_dict(self) -> Dict: + def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict: """Return a state_dict same with DDP Returns: @@ -779,15 +779,23 @@ def state_dict(self) -> Dict: zero_state = dict() device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): + if pinned_state_dicts and param not in pinned_state_dicts: + pinned_state_dicts[param] = {} zero_state[param] = copy.deepcopy(state) for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": + if pinned_state_dicts and k not in pinned_state_dicts[param]: + pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu") working_param = self.master_to_working_param[id(param)] pg = self.param_to_pg[working_param] gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) - param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu() - zero_state[param][k] = param_state + param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) + if pinned_state_dicts: + pinned_state_dicts[param][k].copy_(param_state) + zero_state[param][k] = pinned_state_dicts[param][k] + else: + zero_state[param][k] = param_state.cpu() states_dict = self._pack_state(zero_state) @@ -822,7 +830,9 @@ def load_state_dict(self, state_dict: Dict): self.optim.load_state_dict(zero_state_dict) - def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: + def state_dict_shard( + self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None + ) -> Iterator[Tuple[Dict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Only include the 'state' in state_dict. @@ -847,18 +857,27 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i for param_idx, states in local_states.items(): current_block_size = 0 current_block = copy.deepcopy(states) - + if pinned_state_dicts and param_idx not in pinned_state_dicts: + pinned_state_dicts[param_idx] = {} master_param = idx2master[param_idx] working_param = self.master_to_working_param[id(master_param)] pg = self.param_to_pg[working_param] for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": + if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: + pinned_state_dicts[param_idx][k] = torch.empty_like( + working_param, pin_memory=True, device="cpu" + ) state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) - state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu() + state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) + if pinned_state_dicts: + pinned_state_dicts[param_idx][k].copy_(state_tensor) + current_block[k] = pinned_state_dicts[param_idx][k] + else: + current_block[k] = state_tensor.cpu() current_block_size += state_tensor.numel() - current_block[k] = state_tensor if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: yield ret_block, ret_block_size diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 5e3cc2bdc6b3..05dfcce4f674 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -51,6 +51,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us model_ckpt_path = f"{model_ckpt_path}.pt" if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" + if not shard and use_async: + optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors" booster.save_model( model, model_ckpt_path, @@ -59,7 +61,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us ) # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io() dist.barrier() @@ -139,7 +141,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo assert torch.equal( working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) ) - new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py new file mode 100644 index 000000000000..31c69e961e30 --- /dev/null +++ b/tests/test_checkpoint_io/test_safetensors_async_io.py @@ -0,0 +1,127 @@ +import tempfile +from copy import deepcopy + +import torch + +from colossalai.utils.safetensors import load_flat, save_nested + +try: + from tensornvme.async_file_io import AsyncFileWriter +except ModuleNotFoundError: + raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") + +from colossalai.testing import check_state_dict_equal + + +def test_save_load(): + with tempfile.TemporaryDirectory() as tempdir: + optimizer_state_dict = { + 0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, + 1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, + 2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, + } + # group_dict = {"param_groups": [0, 1, 2]} + group_dict = { + "param_groups": [ + { + "lr": 0.001, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "bias_correction": True, + "params": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + ], + } + ] + } + metadata = deepcopy(group_dict) + optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" + f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") + + save_nested(f_writer, optimizer_state_dict, metadata) + f_writer.sync_before_step() + f_writer.synchronize() + f_writer.fp.close() + + load_state_dict = load_flat(optimizer_saved_path) + state_dict = load_state_dict["state"] + group = {"param_groups": load_state_dict["param_groups"]} + check_state_dict_equal(optimizer_state_dict, state_dict) + check_state_dict_equal(group_dict, group) + + model_state_dict = { + "module.weight0": torch.rand((1024, 1024)), + "module.weight1": torch.rand((1024, 1024)), + "module.weight2": torch.rand((1024, 1024)), + } + model_saved_path = f"{tempdir}/save_model.safetensors" + f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") + save_nested(f_writer, model_state_dict) + f_writer.sync_before_step() + f_writer.synchronize() + f_writer.fp.close() + + load_state_dict = load_flat(model_saved_path) + check_state_dict_equal(model_state_dict, load_state_dict)