Skip to content

Commit

Permalink
Move hooks and fsdp modules onto state rather than trainer (#3522)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackZ-db authored Aug 6, 2024
1 parent 3aa266f commit 2fdfd12
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 22 deletions.
35 changes: 28 additions & 7 deletions composer/checkpoint/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.distributed.checkpoint as DCP
Expand Down Expand Up @@ -139,7 +139,18 @@ def load_checkpoint(
assert model is not None
assert model_child_path is not None
model_load_path = os.path.join(load_path, model_child_path)
load_model_checkpoint(model, load_path=model_load_path, load_options=load_options)
if state is not None:
state.automicrobatch_fsdp_hook_handles, state.fsdp_modules = load_model_checkpoint(
model,
load_path=model_load_path,
load_options=load_options,
)
else:
load_model_checkpoint(
model,
load_path=model_load_path,
load_options=load_options,
)

if load_options.load_optimizer:
assert optim_child_path is not None
Expand All @@ -159,7 +170,7 @@ def load_model_checkpoint(
load_path: Optional[str] = None,
load_options: Optional[Union[CheckpointLoadOptions, Dict]] = None,
seed: int = 42,
):
) -> Tuple[list, dict]:
"""Load a a model checkpoint from the specified path into the model.
Args:
Expand All @@ -178,10 +189,13 @@ def load_model_checkpoint(
if load_options.include_keys is not None or load_options.ignore_keys is not None:
load_options.strict = False

automicrobatch_fsdp_hook_handles = []
fsdp_modules = {}

if load_options.sharded_checkpoint:
if not _is_model_fsdp(model):
if load_options.shard_as_needed_during_load:
_shard_with_fsdp(
automicrobatch_fsdp_hook_handles, fsdp_modules = _shard_with_fsdp(
model,
fsdp_config=load_options.fsdp_config,
precision=load_options.precision,
Expand All @@ -205,7 +219,13 @@ def load_model_checkpoint(
load_options.fsdp_config.update({'sync_module_states': True})
else:
load_options.fsdp_config.sync_module_states = True
_shard_with_fsdp(model, fsdp_config=load_options.fsdp_config, precision=load_options.precision, seed=seed)
automicrobatch_fsdp_hook_handles, fsdp_modules = _shard_with_fsdp(
model,
fsdp_config=load_options.fsdp_config,
precision=load_options.precision,
seed=seed,
)
return automicrobatch_fsdp_hook_handles, fsdp_modules


def _shard_with_fsdp(
Expand All @@ -214,18 +234,19 @@ def _shard_with_fsdp(
fsdp_config: Optional[Union[FSDPConfig, dict]] = None,
precision: Optional[str] = None,
seed: int = 42,
):
) -> Tuple[list, dict]:
if fsdp_config is None:
fsdp_config = FSDPConfig()
if isinstance(fsdp_config, dict):
fsdp_config = FSDPConfig(**fsdp_config)
with reproducibility.seed_context(seed):
prepare_fsdp_module(
automicrobatch_fsdp_hook_handles, fsdp_modules = prepare_fsdp_module(
model,
optimizers=optimizer,
fsdp_config=fsdp_config,
precision=precision,
)
return automicrobatch_fsdp_hook_handles, fsdp_modules


def _load_sharded_model_checkpoint(
Expand Down
5 changes: 4 additions & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,9 @@ def __init__(
self.fsdp_config = parallelism_config.fsdp if parallelism_config is not None else None
self.tp_config = parallelism_config.tp if parallelism_config is not None else None

self.automicrobatch_fsdp_hook_handles = []
self.fsdp_modules = {}

self._validate_parallelism_configs()

self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config)
Expand Down Expand Up @@ -1387,7 +1390,7 @@ def load_model_state(
with reproducibility.seed_context(self.rank_zero_seed):
from composer.distributed import prepare_fsdp_module

prepare_fsdp_module(
self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module(
self.model,
self.optimizers,
self.fsdp_config,
Expand Down
34 changes: 20 additions & 14 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,6 @@ def __init__(
self.cumulative_alloc_retries = 0
self.num_consecutive_thrashes = 0
self.num_consecutive_non_OOM_batches = 0
self.automicrobatch_fsdp_hook_handles = []

if auto_microbatching and profiler:
raise ValueError(
Expand Down Expand Up @@ -1766,7 +1765,7 @@ def __init__(
if self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and not self.state.load_monolith_rank0_only:
# Init with globally fixed seed so all HSDP replicas have the same initial weights
with reproducibility.seed_context(self.state.rank_zero_seed):
self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module(
self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module(
model,
optimizers,
self.state.fsdp_config,
Expand Down Expand Up @@ -1937,7 +1936,7 @@ def __init__(
):
# Init with globally fixed seed so all HSDP replicas have the same initial weights
with reproducibility.seed_context(self.state.rank_zero_seed):
self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module(
self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module(
model,
optimizers,
self.state.fsdp_config,
Expand Down Expand Up @@ -2917,8 +2916,11 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]:
all_ranks_finished = all_ranks_finished_tensor.item() == 1
if found_cuda_oom == 1:
# Readd sync hooks if they were previously turned off
if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0:
self.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.fsdp_modules, sync_hook)
if self.state.fsdp_enabled and len(self.state.automicrobatch_fsdp_hook_handles) == 0:
self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(
self.state.fsdp_modules,
sync_hook,
)
_adjust_device_train_microbatch_size(self.state)
self.num_consecutive_thrashes = 0
self.num_consecutive_non_OOM_batches = 0
Expand All @@ -2934,8 +2936,11 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]:
)
if self.num_consecutive_thrashes >= 2:
# Readd sync hooks if they were previously turned off
if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0:
self.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.fsdp_modules, sync_hook)
if self.state.fsdp_enabled and len(self.state.automicrobatch_fsdp_hook_handles) == 0:
self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(
self.state.fsdp_modules,
sync_hook,
)
_adjust_device_train_microbatch_size(self.state)
self.num_consecutive_thrashes = 0
continue
Expand All @@ -2949,12 +2954,12 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]:
)
self.num_consecutive_non_OOM_batches += 1
if self.state.fsdp_enabled and len(
self.automicrobatch_fsdp_hook_handles,
self.state.automicrobatch_fsdp_hook_handles,
) > 0 and self.num_consecutive_non_OOM_batches >= 3:
patch_unshard_for_automicrobatching(auto_microbatch_size_found=True)
for handle in self.automicrobatch_fsdp_hook_handles:
for handle in self.state.automicrobatch_fsdp_hook_handles:
handle.remove()
self.automicrobatch_fsdp_hook_handles.clear()
self.state.automicrobatch_fsdp_hook_handles.clear()
if torch.cuda.is_available():
memory_stats = torch.cuda.memory_stats()
self.cumulative_alloc_retries = memory_stats['num_alloc_retries']
Expand Down Expand Up @@ -3753,10 +3758,11 @@ def _eval_loop(
self.state.dataloader_len = original_num_batches

# If training occurs after evaluation, readd hooks in case of memory spike
sync_hook = _create_sync_hook(self.state)
if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0:
self.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.fsdp_modules, sync_hook)
self.num_consecutive_non_OOM_batches = 0
if self.state.auto_microbatching:
sync_hook = _create_sync_hook(self.state)
if self.state.fsdp_enabled and len(self.state.automicrobatch_fsdp_hook_handles) == 0:
self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.state.fsdp_modules, sync_hook)
self.num_consecutive_non_OOM_batches = 0

def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool:
"""Determines based on precision when to use grad scaling.
Expand Down
98 changes: 98 additions & 0 deletions tests/checkpoint/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

import pytest
from torch.utils.data import DataLoader

from composer.checkpoint.load import (
load_checkpoint,
Expand All @@ -26,8 +27,12 @@
get_optim_state_dict,
get_resumption_state_dict,
)
from composer.trainer import Trainer
from composer.utils import dist
from tests.checkpoint.helpers import init_model, init_model_and_optimizer, init_state
from tests.common import (
RandomClassificationDataset,
)
from tests.common.compare import deep_compare


Expand Down Expand Up @@ -333,3 +338,96 @@ def test_load_checkpoint(
deep_compare(original_model_state_dict, new_state_dict)
deep_compare(original_optim_state_dict, new_optim_state_dict)
deep_compare(original_resumption_state, new_resumption_state, ignore_keys=['rng', 'run_name'])


@pytest.mark.gpu
@pytest.mark.parametrize(
'world_size,sharded_model,sharded_checkpoint,shard_as_needed_during_load',
[
# Loading an unsharded checkpoint into an unsharded model on a single GPU (not sharding after)
pytest.param(1, False, False, False, marks=pytest.mark.world_size(1)),
# Loading a sharded checkpoint into a sharded model in distributed setting
pytest.param(2, True, True, False, marks=pytest.mark.world_size(2)),
# Loading a sharded checkpoint into an unsharded model (sharding it before load)
pytest.param(2, False, True, True, marks=pytest.mark.world_size(2)),
# Loading an unsharded checkpoint into an unsharded model and sharding it after.
pytest.param(2, False, False, True, marks=pytest.mark.world_size(2)),
# The other three permutations of the above tests are:
# 2 gpu, Sharded model, sharded checkpoint, with additional sharding -> no need to shard already sharded model
# 2 gpu, Sharded model, unsharded checkpoint, with additional sharding -> no need to shard already sharded model
# 2 gpu, Unsharded model, unsharded checkpoint, without additional sharding -> no need to try this on 2 gpus
],
)
def test_load_model_checkpoint_and_eval(
world_size: int,
tmp_path: Path,
sharded_model: bool,
sharded_checkpoint: bool,
shard_as_needed_during_load: bool,
):
if sharded_model and not sharded_checkpoint:
pytest.xfail(
'Loading an unsharded checkpoint into a sharded model is not supported and causes OOMs when running with these tests',
)
# Ensure all ranks use the same path
destination_dir = os.path.join(tmp_path, str(uuid.uuid4())[:8])
destination_dir = dist.all_gather_object(destination_dir)[0]

# Save a model checkpoint
model, _ = init_model(use_composer_model=True, use_fsdp=sharded_checkpoint, device='cuda')
save_path = os.path.join(destination_dir, 'model.pt') if not sharded_checkpoint else destination_dir
saved_path = save_model_to_disk(model, save_path, sharded_checkpoint=sharded_checkpoint)

# Get the original model's state dict
original_state_dict = get_model_state_dict(model, sharded_state_dict=False)
# Load the model checkpoint
new_model, _ = init_model(use_composer_model=True, use_fsdp=sharded_model, device='cuda')
if saved_path is not None:
load_path = saved_path if not sharded_checkpoint else str(Path(saved_path).parent)
else:
load_path = ''

if not sharded_model and sharded_checkpoint and not shard_as_needed_during_load:
context_manager = pytest.raises(ValueError)
else:
context_manager = contextlib.nullcontext()

with context_manager:
load_model_checkpoint(
new_model,
load_path=load_path,
load_options=dict(
sharded_checkpoint=sharded_checkpoint,
shard_as_needed_during_load=shard_as_needed_during_load,
),
)
# Check if model is sharded when it should be
if shard_as_needed_during_load:
assert _is_model_fsdp(new_model), 'Model should be sharded after load'

# Get the new model's state dict
new_state_dict = get_model_state_dict(new_model, sharded_state_dict=False)

if dist.get_global_rank() == 0:
deep_compare(original_state_dict, new_state_dict)

dataset = RandomClassificationDataset(
shape=(8,),
size=100,
num_classes=3,
)

trainer = Trainer(
eval_dataloader=DataLoader(
dataset=dataset,
sampler=dist.get_sampler(dataset),
),
model=new_model, # type: ignore
)

# Evaluate the model
trainer.eval()

0 comments on commit 2fdfd12

Please sign in to comment.