Skip to content

Commit

Permalink
Fix distributed checkpoint errors when PyTorch CUDA is enabled. (#6421)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Jan 31, 2024
1 parent 010b6f0 commit 85c66f5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def _save_and_restore(self,
model_out_state_dict = model_out.state_dict()
dist_cp.save_state_dict(
state_dict=model_in_state_dict,
storage_writer=dist_cp.FileSystemWriter(chkpt_path),
storage_writer=dist_cp.FileSystemWriter(
chkpt_path,
per_thread_copy_ahead=0,
),
planner=save_planner,
no_dist=no_dist,
)
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ def _save(self, step, state_dict):
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=FsspecWriter(path),
storage_writer=FsspecWriter(
path,
per_thread_copy_ahead=0,
),
planner=xc.SPMDSavePlanner(),
process_group=self.pg,
)
Expand Down

0 comments on commit 85c66f5

Please sign in to comment.