Skip to content

Commit

Permalink
3d-asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 19, 2024
1 parent 12e1e8a commit 2821dc4
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 150 deletions.
172 changes: 134 additions & 38 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import flatten_dict, load_flat, move_and_save

from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
Expand Down Expand Up @@ -224,7 +225,18 @@ def save_sharded_model(
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
Expand Down Expand Up @@ -259,24 +271,29 @@ def save_sharded_model(
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)

if control_saving:
assert (
Expand Down Expand Up @@ -455,19 +472,33 @@ def save_sharded_optimizer(
tp_group=self.tp_group,
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0

if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)

if control_saving:
# Store param groups.
Expand Down Expand Up @@ -502,14 +533,28 @@ def save_sharded_optimizer(
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)

total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)

if control_saving:
assert (
Expand Down Expand Up @@ -620,9 +665,11 @@ def _get_param_id_from_optimizer_param(
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue

file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
if file_path.endswith(".safetensors"):
state_dict = load_flat(file_path)
else:
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)

Expand Down Expand Up @@ -672,7 +719,16 @@ def save_unsharded_model(
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
from tensornvme.async_file_io import AsyncFileWriter

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
move_and_save(f_writer, state_dict=state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)])
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
Expand All @@ -688,13 +744,13 @@ def save_unsharded_model(
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
move_and_save(
writer, state_dict=complete_state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)]
)
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

Expand All @@ -720,6 +776,7 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
# Load from checkpoint. Since the logic of breaking parameter shards along tp degree
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
# model.load_state_dict can be directly called.

state_dict = load_state_dict(checkpoint)
model.load_state_dict(state_dict, strict=strict)

Expand Down Expand Up @@ -778,7 +835,25 @@ def save_unsharded_optimizer(
]
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

flatten_state_dict = flatten_dict(state_dict["state"])
if use_async and id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
move_and_save(
f_writer,
state_dict=flatten_state_dict,
metadata={"param_groups": state_dict["param_groups"]},
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
Expand All @@ -794,7 +869,25 @@ def save_unsharded_optimizer(
state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

flatten_state_dict = flatten_dict(state_dict["state"])
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
move_and_save(
f_writer,
state_dict=flatten_state_dict,
metadata={"param_groups": state_dict["param_groups"]},
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Expand All @@ -820,7 +913,10 @@ def _get_param_id_from_optimizer_param(
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"

# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
if checkpoint.endswith(".safetensors"):
state_dict = load_flat(checkpoint)
else:
state_dict = load_state_dict(checkpoint)

# Load param_groups.
updated_groups = []
Expand Down
9 changes: 5 additions & 4 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import move_and_save
from colossalai.utils.safetensors import flatten_dict, move_and_save

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -314,14 +314,15 @@ def async_save_state_dict_shards(
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
writers.append(writer)

flatten_dicts = flatten_dict(shard)
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in flatten_dicts.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(shard)
sub_pinned_state_dict = create_pinned_state_dict(flatten_dicts)
returned_state_dict.update(sub_pinned_state_dict)

# Only save on master rank.
move_and_save(writer, shard, sub_pinned_state_dict)
move_and_save(writer, state_dict=flatten_dicts, state_dict_pinned=sub_pinned_state_dict)
shard_filenames.append(shard_file)
del shard

Expand Down
4 changes: 3 additions & 1 deletion colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def prepare(
data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
if metadata is not None:
print("metadata", metadata)
assert isinstance(metadata, dict)
for k, v in metadata.items():
metadata[k] = json.dumps(v)
Expand Down Expand Up @@ -134,9 +135,10 @@ def save_nested(
def move_and_save(
f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor],
metadata: Optional[Dict[str, str]] = None,
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset

f_writer.write(n.to_bytes(8, byteorder="little"))
Expand Down
Loading

0 comments on commit 2821dc4

Please sign in to comment.