diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 2cddc68c29a..a78057210ab 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -16,6 +16,7 @@ import torch_xla.runtime as xr import torch_xla.distributed.spmd as xs +from torch.distributed.checkpoint._fsspec_filesystem import * from torch.distributed.checkpoint.default_planner import ( create_default_local_save_plan, create_default_global_save_plan, @@ -75,6 +76,8 @@ def _save_and_restore(self, model_out, save_planner=None, load_planner=None, + storage_writer_cls=dist_cp.FileSystemWriter, + storage_reader_cls=dist_cp.FileSystemReader, is_sharded_cpu_state_dict=False, chkpt_path=None): """ @@ -91,8 +94,9 @@ def _save_and_restore(self, model_out_state_dict = model_out.state_dict() dist_cp.save( state_dict=model_in_state_dict, - storage_writer=dist_cp.FileSystemWriter( + storage_writer=storage_writer_cls( chkpt_path, + sync_files=False, per_thread_copy_ahead=0, ), planner=save_planner, @@ -103,7 +107,7 @@ def _save_and_restore(self, dist_cp.load( state_dict=model_out_state_dict, - storage_reader=dist_cp.FileSystemReader(chkpt_path), + storage_reader=storage_reader_cls(chkpt_path), planner=load_planner, ) for p1, p2 in zip(model_in.parameters(), model_out.parameters()): @@ -156,6 +160,8 @@ def test_multihost_checkpoint(self): model2, save_planner=SPMDSavePlanner(), load_planner=SPMDLoadPlanner(), + storage_writer_cls=FsspecWriter, + storage_reader_cls=FsspecReader, chkpt_path=os.environ['CHKPT_PATH']) # Destroy the CPU process group after the test