From 5710a83c239d0cf89335b88808828eb80ce7dbe0 Mon Sep 17 00:00:00 2001 From: jonb377 Date: Tue, 14 Nov 2023 15:21:58 +1100 Subject: [PATCH] Add GKE support and various usability improvements in CheckpointManager (#5770) * Add GKE support and various usability improvements in CheckpointManager * Bug fix for async checkpointing fully sharded state dicts --- test/spmd/test_xla_distributed_checkpoint.py | 21 ++- torch_xla/_internal/tpu.py | 5 +- .../distributed_checkpoint/_helpers.py | 7 +- .../distributed_checkpoint/manager.py | 130 ++++++++++-------- 4 files changed, 97 insertions(+), 66 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 910b35f324b..55fd7d9c155 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -273,6 +273,20 @@ def test_save_state_dict_with_cpu_shards(self): self.assertTrue( isinstance(planner.sharded_state_dict['fc1.weight'], _CpuShards)) + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for sharded test") + def test_cpu_state_dict_flattening(self): + # In the case of a nested state_dict with fully sharded parameters, + # _CpuShards should be treated as terminal nodes. + t = torch.randn(128, 128).to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1)) + xs.mark_sharding(t, mesh, (0, 1)) + state_dict = _sharded_cpu_state_dict({'model': {'weight': t}}) + planner = SPMDSavePlanner() + planner.set_up_planner(state_dict, True) + # model.weight should be flattened and tracked in the sharded state dict. + self.assertCountEqual(planner.sharded_state_dict, ["model.weight"]) + def test_local_save_plan(self): def _write_item_assertions(plan, n_devices, parameter_count): @@ -433,13 +447,14 @@ def test_manager_async(self, tmpdir): # Patch the manager's save method to block until this thread signals. cond = threading.Condition() - old_save = chkpt_mgr.save + old_save = chkpt_mgr._save def patched_save(*args, **kwargs): - cond.wait() + with cond: + cond.wait() old_save(*args, **kwargs) - with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save): + with unittest.mock.patch.object(chkpt_mgr, '_save', patched_save): chkpt_mgr.save_async(10, state_dict) # No new steps should be tracked immediately after calling save_async diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 108ca7945a3..385566b1d35 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -5,6 +5,7 @@ import os import pathlib import re +import socket from typing import NamedTuple, Optional, List from typing_extensions import TypedDict import requests @@ -299,10 +300,12 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str: return worker_ips[master_worker_id] -def _spmd_find_master_ip(current_worker_ip: str) -> str: +def _spmd_find_master_ip(current_worker_hostname: str) -> str: import torch_xla.runtime as xr import torch_xla.distributed.spmd as xs from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + # Translate the hostname to an IP address, e.g. for TPUs on GKE. + current_worker_ip = socket.gethostbyname(current_worker_hostname) ip_int = int(ip_address(current_worker_ip)) n_dev = xr.global_runtime_device_count() local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices()) diff --git a/torch_xla/experimental/distributed_checkpoint/_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py index 6ab2da163ac..62c3c6f2ee0 100644 --- a/torch_xla/experimental/distributed_checkpoint/_helpers.py +++ b/torch_xla/experimental/distributed_checkpoint/_helpers.py @@ -34,8 +34,13 @@ CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] +# TODO(jonbolin): Logic here is modified from the upstream to enable async +# checkpointing. If the state_dict is comprised entirely of _CpuShards, +# flatten_state_dict will not actually flatten the dict. +# Once we can represent XLAShardedTensor on CPU, either directly or through +# DistributedTensor, we can reuse the upstream logic. def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: - return isinstance(value, torch.Tensor) + return isinstance(value, torch.Tensor) or isinstance(value, _CpuShards) def _traverse_state_dict( diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 0eaf184910a..89bb20f5076 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -2,11 +2,11 @@ import logging import os import pickle -import queue import threading import torch.distributed as dist import torch.distributed.checkpoint as dist_cp import torch_xla +import torch_xla.core.xla_model as xm import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc import traceback @@ -16,6 +16,7 @@ from collections import deque from fsspec.core import url_to_fs from os.path import basename +from concurrent.futures import ThreadPoolExecutor, wait from typing import Deque, List, Optional, Union from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from ._helpers import _sharded_cpu_state_dict @@ -81,7 +82,7 @@ class CheckpointManager: step_period, as would be the case in auto checkpointing. This class is inspired by Orbax's CheckpointManager, which can be found here: - https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py + https://github.com/google/orbax/blob/efc079c/checkpoint/orbax/checkpoint/checkpoint_manager.py """ # The base path to write checkpoints to. Each checkpoint taken by the manager @@ -102,7 +103,7 @@ def __init__(self, path: str, save_interval: int, max_to_keep: Optional[int] = 0, - async_queue_size: Optional[int] = 1, + max_pending_async: Optional[int] = 1, process_group: dist.ProcessGroup = None, chkpt_on_preemption: bool = True): """ @@ -116,11 +117,11 @@ def __init__(self, CheckpointManager. When a new checkpoint will be taken, the checkpoint for the lowest tracked step will be deleted. Default: 0, indicating no upper bound on the number of checkpoints. - async_queue_size: The size of the execution queue which processes async - checkpoints. This should be a small value to ensure training doesn't + max_pending_async: The maximum number of async checkpoints which can be + pending. This should be a small value to ensure training doesn't get too far ahead of the last finished checkpoint, but increasing - the value to 2 can unblock training when there are transient - network issues which slow down the active checkpoint. + the value can unblock training when there are transient issues which + slow down the active checkpoint. Default: 1, which only allows a single async checkpoint to be pending at a time. process_group: The process group to use when coordinating the checkpoint. @@ -132,31 +133,33 @@ def __init__(self, """ assert dist.is_initialized(), "A process group is required." assert save_interval > 0, "save_interval must be positive" - assert async_queue_size > 0, "async_queue_size must be positive" + assert max_pending_async > 0, "max_pending_async must be positive" assert max_to_keep >= 0, "max_to_keep must be non-negative" - self.base_path = path + self.base_path = os.path.join(path, '') # Ensure the base path ends in '/' self.save_interval = save_interval self.max_to_keep = max_to_keep self.chkpt_on_preemption = chkpt_on_preemption - self._tracked_chkpts = self._load_tracked_chkpts() - self._async_queue = queue.Queue(maxsize=async_queue_size) - self._alive = threading.Event() - self._alive.set() - self._chkpt_thread = threading.Thread( - target=self._async_worker, daemon=True) - self._chkpt_thread.start() - # Create a new group if none is provided # TODO(jonbolin): Verify subgroup on GPU backend self.pg = process_group or dist.new_group() + # Thread pool to run the async checkpoints. `_async_sem` is used to guard + # the number of pending checkpoints, and `_async_futures` tracks all + # futures returned by the pool. + self._async_worker_pool = ThreadPoolExecutor(max_workers=1) + self._async_sem = threading.Semaphore(max_pending_async) + self._async_futures = [] + # Mutex to ensure only a single thread can write a checkpoint at a time. + self._save_mutex = threading.Lock() + + self._tracked_chkpts = self._load_tracked_chkpts() + if self.chkpt_on_preemption: # Initialize the distributed runtime for preemption detection - master_ip = xr.get_master_ip() torch_xla._XLAC._ensure_xla_coordinator_initialized( - xr.process_index(), xr.process_count(), master_ip) + xr.process_index(), xr.process_count(), xr.get_master_ip()) torch_xla._XLAC._activate_preemption_sync_manager() def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: @@ -166,36 +169,20 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: all_chkpts = [] invalid_paths = [] fs, raw_path = url_to_fs(self.base_path) - for path in fs.ls(raw_path, detail=False): - try: - with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f: - all_chkpts.append(pickle.load(f)) - except: - invalid_paths.append(path) + if not fs.exists(raw_path): + fs.mkdir(raw_path) + else: + for path in fs.ls(raw_path, detail=False): + try: + with fs.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f: + all_chkpts.append(pickle.load(f)) + except: + invalid_paths.append(path) if invalid_paths: logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}') return deque(sorted(all_chkpts, key=lambda m: m.ts)) - def __del__(self): - self._alive.clear() - # Send a sentinel value to tell the worker to exit, and wait for pending - # checkpoints to complete. - self._async_queue.put(None) - self._chkpt_thread.join() - - def _async_worker(self): - while self._alive.is_set(): - try: - item = self._async_queue.get() - if item: - step, state_dict = item - self.save(step, state_dict, force=True) - except: - traceback.print_exc() - finally: - self._async_queue.task_done() - def _get_path(self, step: int) -> str: return os.path.join(self.base_path, str(step)) @@ -215,6 +202,35 @@ def _release_oldest_checkpoints(self): oldest_chkpt = self._tracked_chkpts.popleft() self._delete_chkpt_at_step(oldest_chkpt.step) + def _wait_for_data(self): + xm.mark_step() + xm.wait_device_ops() + + def _save(self, step, state_dict): + """ + The actual checkpointing logic, which is shared between async and + synchronous checkpointing. + + The caller must ensure that data is accessible within the state_dict before + calling, which can be achieved with `self._wait_for_data`. + """ + with self._save_mutex: + path = self._get_path(step) + # Delete any existing checkpoint at the current step. + self._delete_chkpt_at_step(step) + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=FsspecWriter(path), + planner=xc.SPMDSavePlanner(), + process_group=self.pg, + ) + metadata = _CheckpointMetadata(step=step, ts=datetime.now()) + self._tracked_chkpts.append(metadata) + if dist.get_rank(self.pg) == 0: + with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f: + pickle.dump(metadata, f) + self._release_oldest_checkpoints() + def should_save(self, step: int) -> bool: """ Returns true if a checkpoint should be saved for the current step. A @@ -247,20 +263,8 @@ def save(self, True if a checkpoint was taken and False otherwise. """ if self.should_save(step) or force: - path = self._get_path(step) - # Delete any existing checkpoint at the current step. - self._delete_chkpt_at_step(step) - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=FsspecWriter(path), - planner=xc.SPMDSavePlanner(), - process_group=self.pg, - ) - metadata = _CheckpointMetadata(step=step, ts=datetime.now()) - with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f: - pickle.dump(metadata, f) - self._tracked_chkpts.append(metadata) - self._release_oldest_checkpoints() + self._wait_for_data() + self._save(step, state_dict) return True return False @@ -288,9 +292,13 @@ def save_async(self, True if a checkpoint was taken and False otherwise. """ if self.should_save(step) or force: + self._wait_for_data() # Move the state_dict to CPU cpu_state_dict = _sharded_cpu_state_dict(state_dict) - self._async_queue.put((step, cpu_state_dict)) + self._async_sem.acquire() + future = self._async_worker_pool.submit(self._save, step, cpu_state_dict) + future.add_done_callback(lambda _: self._async_sem.release()) + self._async_futures.append(future) return True return False @@ -322,8 +330,8 @@ def all_steps(self) -> List[int]: return sorted(x.step for x in self._tracked_chkpts) def join(self): - """ Wait for all pending async checkpoints to complete. """ - self._async_queue.join() + """ Wait for any pending async checkpoints to complete. """ + wait(self._async_futures) def reached_preemption(self, step: int) -> bool: """ Returns True if a preemption has been detected at the given step. """