diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 535dce59630..d8d19ad7c9e 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -93,7 +93,10 @@ def _save_and_restore(self, model_out_state_dict = model_out.state_dict() dist_cp.save_state_dict( state_dict=model_in_state_dict, - storage_writer=dist_cp.FileSystemWriter(chkpt_path), + storage_writer=dist_cp.FileSystemWriter( + chkpt_path, + per_thread_copy_ahead=0, + ), planner=save_planner, no_dist=no_dist, ) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 89bb20f5076..13b6abfacfc 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -220,7 +220,10 @@ def _save(self, step, state_dict): self._delete_chkpt_at_step(step) dist_cp.save_state_dict( state_dict=state_dict, - storage_writer=FsspecWriter(path), + storage_writer=FsspecWriter( + path, + per_thread_copy_ahead=0, + ), planner=xc.SPMDSavePlanner(), process_group=self.pg, )