Skip to content

Commit

Permalink
[async io]supoort async io (hpcaitech#6137)
Browse files Browse the repository at this point in the history
* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
flybird11111 authored Nov 18, 2024
1 parent 9664ed7 commit 12e1e8a
Show file tree
Hide file tree
Showing 15 changed files with 374 additions and 46 deletions.
5 changes: 4 additions & 1 deletion colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def save_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
) -> None:
"""
Save optimizer to checkpoint.
Expand All @@ -374,7 +375,9 @@ def save_optimizer(
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
self.checkpoint_io.save_optimizer(
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.
Expand Down
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool =
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
Expand Down Expand Up @@ -178,7 +180,13 @@ def load_sharded_model(
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: GeminiOptimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer state dict to checkpoint folder.
Expand Down
65 changes: 58 additions & 7 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
Expand Down Expand Up @@ -113,7 +114,9 @@ def _hook_context(self):


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
):
"""Save optimizer to checkpoint but only on master process.
Args:
Expand All @@ -125,9 +128,34 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat

checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)

def save_sharded_optimizer(
self,
Expand All @@ -136,6 +164,7 @@ def save_sharded_optimizer(
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
Expand All @@ -161,10 +190,16 @@ def save_sharded_optimizer(
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)

Expand All @@ -184,7 +219,18 @@ def save_sharded_optimizer(

checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
Expand Down Expand Up @@ -223,7 +269,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat

state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
Expand Down
9 changes: 7 additions & 2 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str)
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
Expand Down Expand Up @@ -113,13 +115,16 @@ def save_sharded_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)

def load_sharded_optimizer(
self,
Expand Down
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def save_unsharded_model(
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
Expand Down Expand Up @@ -157,7 +159,13 @@ def load_sharded_model(
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)

def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to checkpoint but only on master process.
Expand Down
20 changes: 15 additions & 5 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def save_optimizer(
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
Expand All @@ -229,11 +230,12 @@ def save_optimizer(
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""

if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
self.save_sharded_optimizer(
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
else:
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)

# ========================================================
# Abstract methods for model loading/saving implementation
Expand Down Expand Up @@ -326,7 +328,13 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):

@abstractmethod
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint.
Expand All @@ -340,7 +348,9 @@ def save_sharded_optimizer(
"""

@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to unsharded checkpoint.
Expand Down
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def save_sharded_optimizer(
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -155,6 +156,7 @@ def save_unsharded_optimizer(
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
Expand Down
5 changes: 4 additions & 1 deletion colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def save_sharded_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -725,7 +726,9 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer state dict to a file with given path.
Expand Down
9 changes: 8 additions & 1 deletion colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def save_sharded_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -729,7 +730,13 @@ def save_unsharded_model(
dist.barrier()

# Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool,
use_async: bool = False,
):
"""
Save optimizer state dict to a file with given path.
Expand Down
8 changes: 5 additions & 3 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_STATE_NAME = "optimizer.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"

# ======================================
Expand Down Expand Up @@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
return weights_name, save_index_file


def get_optimizer_base_filenames(prefix: str = None):
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
states_name = add_prefix(states_name, prefix)

save_index_file = STATES_INDEX_NAME
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)

param_group_file = GROUP_FILE_NAME
Expand Down
6 changes: 4 additions & 2 deletions colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict
from typing import Any, List, OrderedDict, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -78,7 +78,9 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"


def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
Expand Down
Loading

0 comments on commit 12e1e8a

Please sign in to comment.