Skip to content

Commit

Permalink
Use fsspec for multihost checkpoint (#6818)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Mar 26, 2024
1 parent 73c31db commit 1ad6bb4
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

from torch.distributed.checkpoint._fsspec_filesystem import *
from torch.distributed.checkpoint.default_planner import (
create_default_local_save_plan,
create_default_global_save_plan,
Expand Down Expand Up @@ -75,6 +76,8 @@ def _save_and_restore(self,
model_out,
save_planner=None,
load_planner=None,
storage_writer_cls=dist_cp.FileSystemWriter,
storage_reader_cls=dist_cp.FileSystemReader,
is_sharded_cpu_state_dict=False,
chkpt_path=None):
"""
Expand All @@ -91,8 +94,9 @@ def _save_and_restore(self,
model_out_state_dict = model_out.state_dict()
dist_cp.save(
state_dict=model_in_state_dict,
storage_writer=dist_cp.FileSystemWriter(
storage_writer=storage_writer_cls(
chkpt_path,
sync_files=False,
per_thread_copy_ahead=0,
),
planner=save_planner,
Expand All @@ -103,7 +107,7 @@ def _save_and_restore(self,

dist_cp.load(
state_dict=model_out_state_dict,
storage_reader=dist_cp.FileSystemReader(chkpt_path),
storage_reader=storage_reader_cls(chkpt_path),
planner=load_planner,
)
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
Expand Down Expand Up @@ -156,6 +160,8 @@ def test_multihost_checkpoint(self):
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner(),
storage_writer_cls=FsspecWriter,
storage_reader_cls=FsspecReader,
chkpt_path=os.environ['CHKPT_PATH'])

# Destroy the CPU process group after the test
Expand Down

0 comments on commit 1ad6bb4

Please sign in to comment.