Skip to content

Commit

Permalink
Address deprecation in torch.distributed.checkpoint (#6786)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Mar 21, 2024
1 parent c923e8f commit d987775
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 17 deletions.
6 changes: 3 additions & 3 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ train_loader = pl.MpDeviceLoader(

PyTorch/XLA SPMD is compatible with the [torch.distributed.checkpoint](https://pytorch.org/docs/stable/distributed.checkpoint.html) library through a dedicated `Planner` instance. Users are able to synchronously save and load checkpoints through this common interface.

The SPMDSavePlanner and SPMDLoadPlanner ([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) classes enable the `save_state_dict` and `load_state_dict` functions to operate directly on the shards of an `XLAShardedTensor`, enabling all of the benefits of distributed checkpointing in SPMD training.
The SPMDSavePlanner and SPMDLoadPlanner ([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) classes enable the `save` and `load` functions to operate directly on the shards of an `XLAShardedTensor`, enabling all of the benefits of distributed checkpointing in SPMD training.

Here is a demonstration of the synchronous distributed checkpointing API:

Expand All @@ -249,7 +249,7 @@ state_dict = {
"optim": optim.state_dict(),
}

dist_cp.save_state_dict(
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=xc.SPMDSavePlanner(),
Expand All @@ -262,7 +262,7 @@ state_dict = {
"model": model.state_dict(),
}

dist_cp.load_state_dict(
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=xc.SPMDLoadPlanner(),
Expand Down
16 changes: 5 additions & 11 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def _save_and_restore(self,
save_planner=None,
load_planner=None,
is_sharded_cpu_state_dict=False,
no_dist=True,
chkpt_path=None):
"""
Checkpoint model_in using the provided save_planner and load into model_out
Expand All @@ -90,24 +89,22 @@ def _save_and_restore(self,
if is_sharded_cpu_state_dict:
model_in_state_dict = _sharded_cpu_state_dict(model_in_state_dict)
model_out_state_dict = model_out.state_dict()
dist_cp.save_state_dict(
dist_cp.save(
state_dict=model_in_state_dict,
storage_writer=dist_cp.FileSystemWriter(
chkpt_path,
per_thread_copy_ahead=0,
),
planner=save_planner,
no_dist=no_dist,
)
# Load the checkpoint using the provided load planner
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertFalse(torch.allclose(p1.cpu(), p2.cpu()))

dist_cp.load_state_dict(
dist_cp.load(
state_dict=model_out_state_dict,
storage_reader=dist_cp.FileSystemReader(chkpt_path),
planner=load_planner,
no_dist=no_dist,
)
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertTrue(torch.allclose(p1.cpu(), p2.cpu()))
Expand Down Expand Up @@ -142,15 +139,13 @@ def test_resharding_different_device_mesh(self):
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipUnless(
{'CHKPT_PATH', 'MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE'
} <= os.environ.keys(),
'CHKPT_PATH and distributed config must be set for multihost checkpoint')
@unittest.skipUnless('CHKPT_PATH' in os.environ,
'CHKPT_PATH must be set for multihost checkpoint')
def test_multihost_checkpoint(self):
torch.manual_seed(42)

# Initialize the default CPU process group from the environment.
dist.init_process_group()
dist.init_process_group(backend='gloo', init_method='xla://')

model1 = self._get_sharded_model(mesh_shape=(1, self.n_devices))
model2 = self._get_sharded_model(mesh_shape=(self.n_devices, 1))
Expand All @@ -160,7 +155,6 @@ def test_multihost_checkpoint(self):
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner(),
no_dist=False,
chkpt_path=os.environ['CHKPT_PATH'])

# Destroy the CPU process group after the test
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _save(self, step, state_dict):
path = self._get_path(step)
# Delete any existing checkpoint at the current step.
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
dist_cp.save(
state_dict=state_dict,
storage_writer=FsspecWriter(
path,
Expand All @@ -244,7 +244,7 @@ def should_save(self, step: int) -> bool:
"""
preemption_detected = False
if self.chkpt_on_preemption and self.reached_preemption(step):
logging.warn(
logging.warning(
f"Preemption sync point reached at step {step}. Triggering a checkpoint."
)
preemption_detected = True
Expand Down Expand Up @@ -319,7 +319,7 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
tracked_steps = set(x.step for x in self._tracked_chkpts)
assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}'
path = self._get_path(step)
dist_cp.load_state_dict(
dist_cp.load(
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
Expand Down

0 comments on commit d987775

Please sign in to comment.