Skip to content

Commit

Permalink
Add GKE support and various usability improvements in CheckpointManag…
Browse files Browse the repository at this point in the history
…er (#5770)

* Add GKE support and various usability improvements in CheckpointManager

* Bug fix for async checkpointing fully sharded state dicts
  • Loading branch information
jonb377 authored Nov 14, 2023
1 parent b3ed82c commit 5710a83
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 66 deletions.
21 changes: 18 additions & 3 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@ def test_save_state_dict_with_cpu_shards(self):
self.assertTrue(
isinstance(planner.sharded_state_dict['fc1.weight'], _CpuShards))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for sharded test")
def test_cpu_state_dict_flattening(self):
# In the case of a nested state_dict with fully sharded parameters,
# _CpuShards should be treated as terminal nodes.
t = torch.randn(128, 128).to(xm.xla_device())
mesh = self._get_mesh((self.n_devices, 1))
xs.mark_sharding(t, mesh, (0, 1))
state_dict = _sharded_cpu_state_dict({'model': {'weight': t}})
planner = SPMDSavePlanner()
planner.set_up_planner(state_dict, True)
# model.weight should be flattened and tracked in the sharded state dict.
self.assertCountEqual(planner.sharded_state_dict, ["model.weight"])

def test_local_save_plan(self):

def _write_item_assertions(plan, n_devices, parameter_count):
Expand Down Expand Up @@ -433,13 +447,14 @@ def test_manager_async(self, tmpdir):

# Patch the manager's save method to block until this thread signals.
cond = threading.Condition()
old_save = chkpt_mgr.save
old_save = chkpt_mgr._save

def patched_save(*args, **kwargs):
cond.wait()
with cond:
cond.wait()
old_save(*args, **kwargs)

with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save):
with unittest.mock.patch.object(chkpt_mgr, '_save', patched_save):
chkpt_mgr.save_async(10, state_dict)

# No new steps should be tracked immediately after calling save_async
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pathlib
import re
import socket
from typing import NamedTuple, Optional, List
from typing_extensions import TypedDict
import requests
Expand Down Expand Up @@ -299,10 +300,12 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str:
return worker_ips[master_worker_id]


def _spmd_find_master_ip(current_worker_ip: str) -> str:
def _spmd_find_master_ip(current_worker_hostname: str) -> str:
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
# Translate the hostname to an IP address, e.g. for TPUs on GKE.
current_worker_ip = socket.gethostbyname(current_worker_hostname)
ip_int = int(ip_address(current_worker_ip))
n_dev = xr.global_runtime_device_count()
local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices())
Expand Down
7 changes: 6 additions & 1 deletion torch_xla/experimental/distributed_checkpoint/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]


# TODO(jonbolin): Logic here is modified from the upstream to enable async
# checkpointing. If the state_dict is comprised entirely of _CpuShards,
# flatten_state_dict will not actually flatten the dict.
# Once we can represent XLAShardedTensor on CPU, either directly or through
# DistributedTensor, we can reuse the upstream logic.
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
return isinstance(value, torch.Tensor)
return isinstance(value, torch.Tensor) or isinstance(value, _CpuShards)


def _traverse_state_dict(
Expand Down
130 changes: 69 additions & 61 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import logging
import os
import pickle
import queue
import threading
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc
import traceback
Expand All @@ -16,6 +16,7 @@
from collections import deque
from fsspec.core import url_to_fs
from os.path import basename
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Deque, List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from ._helpers import _sharded_cpu_state_dict
Expand Down Expand Up @@ -81,7 +82,7 @@ class CheckpointManager:
step_period, as would be the case in auto checkpointing.
This class is inspired by Orbax's CheckpointManager, which can be found here:
https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py
https://github.com/google/orbax/blob/efc079c/checkpoint/orbax/checkpoint/checkpoint_manager.py
"""

# The base path to write checkpoints to. Each checkpoint taken by the manager
Expand All @@ -102,7 +103,7 @@ def __init__(self,
path: str,
save_interval: int,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1,
max_pending_async: Optional[int] = 1,
process_group: dist.ProcessGroup = None,
chkpt_on_preemption: bool = True):
"""
Expand All @@ -116,11 +117,11 @@ def __init__(self,
CheckpointManager. When a new checkpoint will be taken, the
checkpoint for the lowest tracked step will be deleted.
Default: 0, indicating no upper bound on the number of checkpoints.
async_queue_size: The size of the execution queue which processes async
checkpoints. This should be a small value to ensure training doesn't
max_pending_async: The maximum number of async checkpoints which can be
pending. This should be a small value to ensure training doesn't
get too far ahead of the last finished checkpoint, but increasing
the value to 2 can unblock training when there are transient
network issues which slow down the active checkpoint.
the value can unblock training when there are transient issues which
slow down the active checkpoint.
Default: 1, which only allows a single async checkpoint to be
pending at a time.
process_group: The process group to use when coordinating the checkpoint.
Expand All @@ -132,31 +133,33 @@ def __init__(self,
"""
assert dist.is_initialized(), "A process group is required."
assert save_interval > 0, "save_interval must be positive"
assert async_queue_size > 0, "async_queue_size must be positive"
assert max_pending_async > 0, "max_pending_async must be positive"
assert max_to_keep >= 0, "max_to_keep must be non-negative"

self.base_path = path
self.base_path = os.path.join(path, '') # Ensure the base path ends in '/'
self.save_interval = save_interval
self.max_to_keep = max_to_keep
self.chkpt_on_preemption = chkpt_on_preemption

self._tracked_chkpts = self._load_tracked_chkpts()
self._async_queue = queue.Queue(maxsize=async_queue_size)
self._alive = threading.Event()
self._alive.set()
self._chkpt_thread = threading.Thread(
target=self._async_worker, daemon=True)
self._chkpt_thread.start()

# Create a new group if none is provided
# TODO(jonbolin): Verify subgroup on GPU backend
self.pg = process_group or dist.new_group()

# Thread pool to run the async checkpoints. `_async_sem` is used to guard
# the number of pending checkpoints, and `_async_futures` tracks all
# futures returned by the pool.
self._async_worker_pool = ThreadPoolExecutor(max_workers=1)
self._async_sem = threading.Semaphore(max_pending_async)
self._async_futures = []
# Mutex to ensure only a single thread can write a checkpoint at a time.
self._save_mutex = threading.Lock()

self._tracked_chkpts = self._load_tracked_chkpts()

if self.chkpt_on_preemption:
# Initialize the distributed runtime for preemption detection
master_ip = xr.get_master_ip()
torch_xla._XLAC._ensure_xla_coordinator_initialized(
xr.process_index(), xr.process_count(), master_ip)
xr.process_index(), xr.process_count(), xr.get_master_ip())
torch_xla._XLAC._activate_preemption_sync_manager()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
Expand All @@ -166,36 +169,20 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
all_chkpts = []
invalid_paths = []
fs, raw_path = url_to_fs(self.base_path)
for path in fs.ls(raw_path, detail=False):
try:
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f:
all_chkpts.append(pickle.load(f))
except:
invalid_paths.append(path)
if not fs.exists(raw_path):
fs.mkdir(raw_path)
else:
for path in fs.ls(raw_path, detail=False):
try:
with fs.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f:
all_chkpts.append(pickle.load(f))
except:
invalid_paths.append(path)

if invalid_paths:
logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}')
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def __del__(self):
self._alive.clear()
# Send a sentinel value to tell the worker to exit, and wait for pending
# checkpoints to complete.
self._async_queue.put(None)
self._chkpt_thread.join()

def _async_worker(self):
while self._alive.is_set():
try:
item = self._async_queue.get()
if item:
step, state_dict = item
self.save(step, state_dict, force=True)
except:
traceback.print_exc()
finally:
self._async_queue.task_done()

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

Expand All @@ -215,6 +202,35 @@ def _release_oldest_checkpoints(self):
oldest_chkpt = self._tracked_chkpts.popleft()
self._delete_chkpt_at_step(oldest_chkpt.step)

def _wait_for_data(self):
xm.mark_step()
xm.wait_device_ops()

def _save(self, step, state_dict):
"""
The actual checkpointing logic, which is shared between async and
synchronous checkpointing.
The caller must ensure that data is accessible within the state_dict before
calling, which can be achieved with `self._wait_for_data`.
"""
with self._save_mutex:
path = self._get_path(step)
# Delete any existing checkpoint at the current step.
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
process_group=self.pg,
)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
self._tracked_chkpts.append(metadata)
if dist.get_rank(self.pg) == 0:
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
pickle.dump(metadata, f)
self._release_oldest_checkpoints()

def should_save(self, step: int) -> bool:
"""
Returns true if a checkpoint should be saved for the current step. A
Expand Down Expand Up @@ -247,20 +263,8 @@ def save(self,
True if a checkpoint was taken and False otherwise.
"""
if self.should_save(step) or force:
path = self._get_path(step)
# Delete any existing checkpoint at the current step.
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
process_group=self.pg,
)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
pickle.dump(metadata, f)
self._tracked_chkpts.append(metadata)
self._release_oldest_checkpoints()
self._wait_for_data()
self._save(step, state_dict)
return True
return False

Expand Down Expand Up @@ -288,9 +292,13 @@ def save_async(self,
True if a checkpoint was taken and False otherwise.
"""
if self.should_save(step) or force:
self._wait_for_data()
# Move the state_dict to CPU
cpu_state_dict = _sharded_cpu_state_dict(state_dict)
self._async_queue.put((step, cpu_state_dict))
self._async_sem.acquire()
future = self._async_worker_pool.submit(self._save, step, cpu_state_dict)
future.add_done_callback(lambda _: self._async_sem.release())
self._async_futures.append(future)
return True
return False

Expand Down Expand Up @@ -322,8 +330,8 @@ def all_steps(self) -> List[int]:
return sorted(x.step for x in self._tracked_chkpts)

def join(self):
""" Wait for all pending async checkpoints to complete. """
self._async_queue.join()
""" Wait for any pending async checkpoints to complete. """
wait(self._async_futures)

def reached_preemption(self, step: int) -> bool:
""" Returns True if a preemption has been detected at the given step. """
Expand Down

0 comments on commit 5710a83

Please sign in to comment.