Skip to content

Commit

Permalink
[fix auto-microbatch] FSDP reshard and cleanup after OOM to fix the c…
Browse files Browse the repository at this point in the history
…uda memory leak (#3030)

* reshard and cleanup

* format

* fix

* cleanup unit test

* comments

* more test

* fix the warning

* add numerical correctness test

* Apply suggestions from code review

Co-authored-by: Mihir Patel <[email protected]>

* lint

* fix test warnning

* revert irrelevant change

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
bigning and mvpatel2000 authored Feb 22, 2024
1 parent a606314 commit 2133c17
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 2 deletions.
18 changes: 18 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import torch.utils.data
from torch._dynamo import OptimizedModule
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -232,6 +234,21 @@ def _is_cuda_oom(e: RuntimeError):
return False


def _fsdp_reshard_and_cleanup(model: torch.nn.Module):
"""Manually reshard and clean up FSDP model.
When an exception like OOM happens, _post_backward_final_callback, which
is registered as a backward callback, will not run. We manually call it to cleanup
loose memory.
"""
for __, module in model.named_modules():
if isinstance(module, FullyShardedDataParallel):
if module.check_is_root():
# Only call _post_backward_final_callback on root module. It will
# traverse and reshard all FSDP sub-modules
_post_backward_final_callback(module, module)


def _adjust_device_train_microbatch_size(state: State):
"""Adjust device_train_microbatch_size if we encounter OOM.
Expand Down Expand Up @@ -259,6 +276,7 @@ def _adjust_device_train_microbatch_size(state: State):
optimizer.zero_grad(set_to_none=True)
if state.scaler is not None:
state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
_fsdp_reshard_and_cleanup(state.model)
torch.cuda.empty_cache()


Expand Down
113 changes: 111 additions & 2 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import DataLoader

from composer.models import ComposerClassifier, ComposerModel
from composer.trainer.trainer import Trainer
from composer.trainer.trainer import Trainer, _fsdp_reshard_and_cleanup
from composer.utils import dist
from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel,
world_size)
Expand Down Expand Up @@ -232,10 +232,11 @@ def __init__(self, num_features: int = 128, device: str = 'cuda'):
super().__init__()
self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.fc1(x)
x = torch.nn.ReLU(x)
x = self.relu(x)
x = self.fc2(x)
return x

Expand Down Expand Up @@ -282,3 +283,111 @@ def test_fsdp_act_ckpt_offload(
assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper)
else:
assert not isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper)


@pytest.mark.gpu
@world_size(2)
def test_fsdp_reshard_after_oom(world_size: int):
model = SimpleMLP(num_features=128)
model.relu._fsdp_wrap = False # pyright: ignore[reportGeneralTypeIssues]

def oom_hook(*args):
raise RuntimeError('CUDA out of memory.')

model.fc2.register_full_backward_hook(oom_hook)

trainer = Trainer(
model=model,
fsdp_config={},
max_duration='3ba',
)
fsdp_model = trainer.state.model

x = torch.rand([2, 128])
output = fsdp_model(x)
with pytest.raises(Exception):
# Backward triggers the fake OOM exception,
# which prevents fsdp reshard and cleanup
torch.sum(output).backward()

fc2_flat_param = fsdp_model.fc2._flat_param

# Without cleanup, model.fc2.flat_params is still in unshard state
# the full param is not freed
assert fc2_flat_param.data_ptr() != fc2_flat_param._local_shard.data_ptr()
assert fc2_flat_param._full_param_padded.numel() > 0

_fsdp_reshard_and_cleanup(fsdp_model)
assert fc2_flat_param.data_ptr() == fc2_flat_param._local_shard.data_ptr()
assert fc2_flat_param._full_param_padded._typed_storage()._size() == 0


@pytest.mark.gpu
@world_size(2)
def test_fsdp_same_state_after_oom_reshard(world_size: int):
# Test numerical correctness after continuing to train with smaller batch size after OOM.
model = SimpleMLP(num_features=2)
model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
model.relu._fsdp_wrap = False # pyright: ignore[reportGeneralTypeIssues]
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

trainer = Trainer(
model=model,
fsdp_config={},
dist_timeout=20,
optimizers=optimizer,
seed=1,
)
fsdp_model = trainer.state.model

state_dict = fsdp_model.state_dict()

oom_model = SimpleMLP(num_features=2)
oom_model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
oom_model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
oom_model.relu._fsdp_wrap = False # pyright: ignore[reportGeneralTypeIssues]
oom_model_optimizer = torch.optim.SGD(oom_model.parameters(), lr=0.1)

def oom_hook(module, grad_input, grad_ouput):
if grad_ouput[0].shape[0] >= 4:
raise RuntimeError('CUDA out of memory.')

oom_handle = oom_model.fc2.register_full_backward_hook(oom_hook)
oom_trainer = Trainer(
model=oom_model,
fsdp_config={},
dist_timeout=20,
optimizers=oom_model_optimizer,
seed=1,
)

fsdp_oom_model = oom_trainer.state.model
fsdp_oom_model.load_state_dict(state_dict)

x = torch.rand([4, 2])

# Run fwd + bwd + optimizer on normal model
output_0 = fsdp_model(x)
torch.sum(output_0).backward()
optimizer.step()

# Run fwd + bwd + optimizer on OOM model
output = fsdp_oom_model(x)
with pytest.raises(Exception):
torch.sum(output).backward()
# Cleanup after OOM
_fsdp_reshard_and_cleanup(fsdp_oom_model)
oom_model_optimizer.zero_grad(set_to_none=True)

oom_handle.remove()
output = fsdp_oom_model(x)
torch.sum(output).backward()
oom_model_optimizer.step()

# Run another fwd on both model and check
# if output is the same
output_1 = fsdp_model(x)
output_2 = fsdp_oom_model(x)

assert torch.equal(output_1, output_2)

0 comments on commit 2133c17

Please sign in to comment.