From 85c66f5fb1d2d1c032672710316d2b76c0513bd8 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 31 Jan 2024 19:00:55 -0300 Subject: [PATCH] Fix distributed checkpoint errors when PyTorch CUDA is enabled. (#6421) --- test/spmd/test_xla_distributed_checkpoint.py | 5 ++++- torch_xla/experimental/distributed_checkpoint/manager.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 535dce59630..d8d19ad7c9e 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -93,7 +93,10 @@ def _save_and_restore(self, model_out_state_dict = model_out.state_dict() dist_cp.save_state_dict( state_dict=model_in_state_dict, - storage_writer=dist_cp.FileSystemWriter(chkpt_path), + storage_writer=dist_cp.FileSystemWriter( + chkpt_path, + per_thread_copy_ahead=0, + ), planner=save_planner, no_dist=no_dist, ) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 89bb20f5076..13b6abfacfc 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -220,7 +220,10 @@ def _save(self, step, state_dict): self._delete_chkpt_at_step(step) dist_cp.save_state_dict( state_dict=state_dict, - storage_writer=FsspecWriter(path), + storage_writer=FsspecWriter( + path, + per_thread_copy_ahead=0, + ), planner=xc.SPMDSavePlanner(), process_group=self.pg, )