Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[checkpointio] support async save for general/zero #6141

Merged
merged 6 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
DISABLE_URING=1 pip install -v --no-cache-dir .

- name: Store TensorNVMe Cache
run: |
Expand Down
8 changes: 7 additions & 1 deletion colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def save_model(
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""Save model to checkpoint.

Expand All @@ -324,6 +325,7 @@ def save_model(
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
"""
self.checkpoint_io.save_model(
model,
Expand All @@ -333,6 +335,7 @@ def save_model(
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
use_async=use_async,
)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
Expand All @@ -356,6 +359,7 @@ def save_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
) -> None:
"""
Save optimizer to checkpoint.
Expand All @@ -371,7 +375,9 @@ def save_optimizer(
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
self.checkpoint_io.save_optimizer(
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.
Expand Down
69 changes: 47 additions & 22 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ def __init__(self) -> None:
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
Expand All @@ -74,7 +81,10 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Expand All @@ -84,7 +94,9 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool =
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
Expand Down Expand Up @@ -112,6 +124,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save sharded model.
Expand All @@ -130,27 +143,33 @@ def save_sharded_model(

# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)

def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):
Expand All @@ -161,7 +180,13 @@ def load_sharded_model(
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: GeminiOptimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer state dict to checkpoint folder.
Expand Down
74 changes: 64 additions & 10 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
Expand Down Expand Up @@ -113,7 +114,9 @@ def _hook_context(self):


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
):
"""Save optimizer to checkpoint but only on master process.

Args:
Expand All @@ -125,9 +128,34 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat

checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)

def save_sharded_optimizer(
self,
Expand All @@ -136,6 +164,7 @@ def save_sharded_optimizer(
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
Expand All @@ -161,10 +190,16 @@ def save_sharded_optimizer(
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)

Expand All @@ -184,7 +219,18 @@ def save_sharded_optimizer(

checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
Expand Down Expand Up @@ -223,7 +269,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat

state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
Expand Down Expand Up @@ -259,10 +310,12 @@ def load_sharded_model(
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)

def save_sharded_model(
self,
Expand All @@ -272,11 +325,12 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
)

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
Expand Down
26 changes: 21 additions & 5 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Expand All @@ -48,7 +52,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str)
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
Expand All @@ -71,14 +77,21 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model.unwrap(),
checkpoint_path,
gather_dtensor,
prefix,
max_shard_size,
use_safetensors,
use_async=use_async,
)

def load_sharded_model(
Expand All @@ -102,13 +115,16 @@ def save_sharded_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)

def load_sharded_optimizer(
self,
Expand Down
Loading
Loading