Skip to content

Commit

Permalink
Fetching arguments for FSDP (#2710)
Browse files Browse the repository at this point in the history
* args for fetching

* add a unit test

* test

* fix test name

* switch to world size 2

* fix tests

* gate

* warnings

* lint

* typo

* simplify

* remoeve comment

* wrap

* rerun tests
  • Loading branch information
mvpatel2000 authored Nov 14, 2023
1 parent 0f07bf9 commit c6a6216
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
20 changes: 20 additions & 0 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]):
fsdp_config.setdefault('activation_checkpointing_reentrant', True)
fsdp_config.setdefault('activation_cpu_offload', False)
fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST')
fsdp_config.setdefault('backward_prefetch_limit', 1)
fsdp_config.setdefault('cpu_offload', False)
fsdp_config.setdefault('flatten_parameters', True)
fsdp_config.setdefault('forward_prefetch', False)
fsdp_config.setdefault('forward_prefetch_limit', 1)
fsdp_config.setdefault('ignored_modules', None)
fsdp_config.setdefault('keep_low_precision_grads', False)
fsdp_config.setdefault('limit_all_gathers', True)
Expand Down Expand Up @@ -508,6 +510,24 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para
**kwargs,
)

if hasattr(fsdp_obj, '_exec_order_data'):
if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'):
fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit']
else:
warnings.warn('FSDP._exec_order_data does not have attribute _forward_prefetch_limit '
'which is unexpected and will result in `forward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')
if hasattr(fsdp_obj._exec_order_data, '_backward_prefetch_limit'):
fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config['backward_prefetch_limit']
else:
warnings.warn('FSDP._exec_order_data does not have attribute _backward_prefetch_limit '
'which is unexpected and will result in `backward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')
else:
warnings.warn('FSDP does not have attribute _exec_order_data which is unexpected and will '
'result in `forward_prefetch_limit` and `backward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')

# Activation Checkpointing
if activation_checkpointing or activation_cpu_offload:
if not activation_checkpointing_reentrant:
Expand Down
42 changes: 37 additions & 5 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
from composer.models import ComposerClassifier
from composer.trainer.trainer import Trainer
from composer.utils import dist
from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel
from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel,
world_size)


@pytest.mark.parametrize('model', [SimpleWeightTiedModel, EmbeddedWeightTiedModel])
@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE'])
@pytest.mark.parametrize('device', ['cpu', 'meta'])
@pytest.mark.parametrize('reentrant', [True, False])
@pytest.mark.filterwarnings('ignore::UserWarning')
@world_size(2)
@pytest.mark.gpu
@pytest.mark.filterwarnings('ignore:The passed in model appears to have tied weights.*:UserWarning')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, device: str, reentrant: bool):
def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, reentrant: bool, world_size: int,
device: str):
"""test FSDP device initialization for a simple model with weight tying and a model where two modules
from separate submodules have weight tying applied. This test also covers both 'cpu' and
'meta' devices. This is because 'meta' will result in deferred initialization until FSDP is initialized
Expand Down Expand Up @@ -62,15 +65,16 @@ def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision:
@pytest.mark.parametrize('model', [SimpleModel])
@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE'])
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', device: str = 'meta'):
def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', world_size: int):
"""
This test is intended to test FSDP for meta initialization when there are attributes
that are `None` and ensure we don't raise nasty UserWarnings.
"""
num_classes = 2
model = model(num_features=1, num_classes=num_classes, device=device, bias=False)
model = model(num_features=1, num_classes=num_classes, device='meta', bias=False)
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
Expand All @@ -85,3 +89,31 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio
},
max_duration='3ba',
)


@pytest.mark.parametrize('forward_prefetch_limit', [1, 2])
@pytest.mark.parametrize('backward_prefetch_limit', [1, 2])
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limit: int, world_size: int):
model = SimpleModel()
model.fc1._fsdp_wrap = True
model.fc2._fsdp_wrap = True
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

trainer = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
fsdp_config={
'forward_prefetch_limit': forward_prefetch_limit,
'backward_prefetch_limit': backward_prefetch_limit,
},
max_duration='3ba',
)

trainer.fit()

0 comments on commit c6a6216

Please sign in to comment.