Skip to content

Commit

Permalink
support asyncio for 3d
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 26, 2024
1 parent ab856fd commit 29eb97e
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 49 deletions.
4 changes: 2 additions & 2 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_save_state_dict_shards,
async_move_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
Expand Down Expand Up @@ -189,7 +189,7 @@ def save_sharded_model(

if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
Expand Down
Loading

0 comments on commit 29eb97e

Please sign in to comment.