Skip to content

Commit

Permalink
Merge branch 'dev' into mvpatel2000/bump-version-17
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 15, 2023
2 parents 204cd5f + c6a6216 commit 0f6d153
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 27 deletions.
9 changes: 9 additions & 0 deletions composer/algorithms/seq_length_warmup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def training_loop(model, train_loader):
<!--pytest.mark.gpu-->
<!--
```python
import os
previous_platform_env = os.environ["MOSAICML_PLATFORM"]
os.environ["MOSAICML_PLATFORM"] = "false"
from tests.common.models import configure_tiny_bert_hf_model
from tests.common.datasets import dummy_bert_lm_dataloader
Expand All @@ -64,6 +67,12 @@ trainer = Trainer(model=model,

trainer.fit()
```
<!--pytest-codeblocks:cont-->
<!--
```python
os.environ["MOSAICML_PLATFORM"] = previous_platform_env
```
-->

### Implementation Details

Expand Down
2 changes: 1 addition & 1 deletion composer/callbacks/memory_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def init(self, state: State, logger: Logger) -> None:
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
model_device = next(state.model.parameters()).device

if model_device.type != 'cuda':
if model_device.type not in ('cuda', 'meta'):
warnings.warn(f'The memory monitor only works on CUDA devices, but the model is on {model_device.type}.')

def after_train_batch(self, state: State, logger: Logger):
Expand Down
6 changes: 5 additions & 1 deletion composer/callbacks/nan_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Callback for catching loss NaNs."""

from typing import Sequence
from typing import Dict, Sequence

import torch

Expand All @@ -24,5 +24,9 @@ def after_loss(self, state: State, logger: Logger):
for loss in state.loss:
if torch.isnan(loss).any():
raise RuntimeError('Train loss contains a NaN.')
elif isinstance(state.loss, Dict):
for k, v in state.loss.items():
if torch.isnan(v).any():
raise RuntimeError(f'Train loss {k} contains a NaN.')
else:
raise TypeError(f'Loss is of type {type(state.loss)}, but should be a tensor or a sequence of tensors')
4 changes: 2 additions & 2 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class State(Serializable):
before the dataloader is evaluated. The :attr:`~Timestamp.epoch` attribute for this timestamp is always
``0``.
device_train_microbatch_size (int): The size of each train microbatch per device.
loss (torch.Tensor | Sequence[torch.Tensor]): The most recently computed loss.
loss (torch.Tensor | Sequence[torch.Tensor] | Dict[Any, torch.Tensor]): The most recently computed loss.
model (torch.nn.Module): The training model.
.. note::
Expand Down Expand Up @@ -547,7 +547,7 @@ def __init__(

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
Expand Down
35 changes: 22 additions & 13 deletions composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import time
import warnings
from concurrent.futures import wait
from functools import reduce
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand Down Expand Up @@ -57,22 +58,25 @@ class MosaicMLLogger(LoggerDestination):
Example 2: ``ignore_keys = ["wall_clock/*"]`` would ignore all wall clock metrics.
(default: ``None``)
ignore_exceptions: Flag to disable logging exceptions. Defaults to False.
"""

def __init__(
self,
log_interval: int = 60,
ignore_keys: Optional[List[str]] = None,
ignore_exceptions: bool = False,
) -> None:
self.log_interval = log_interval
self.ignore_keys = ignore_keys
self.ignore_exceptions = ignore_exceptions
self._enabled = dist.get_global_rank() == 0
if self._enabled:
self.allowed_fails_left = 3
self.time_last_logged = 0
self.train_dataloader_len = None
self.time_failed_count_adjusted = 0
self.buffered_metadata: Dict[str, Any] = {}
self._futures = []

self.run_name = os.environ.get(RUN_NAME_ENV_VAR)
if self.run_name is not None:
log.info(f'Logging to mosaic run {self.run_name}')
Expand Down Expand Up @@ -140,20 +144,25 @@ def _flush_metadata(self, force_flush: bool = False) -> None:
"""Flush buffered metadata to MosaicML if enough time has passed since last flush."""
if self._enabled and (time.time() - self.time_last_logged > self.log_interval or force_flush):
try:
mcli.update_run_metadata(self.run_name, self.buffered_metadata)
f = mcli.update_run_metadata(self.run_name, self.buffered_metadata, future=True, protect=True)
self.buffered_metadata = {}
self.time_last_logged = time.time()
# If we have not failed in the last hour, increase the allowed fails. This increases
# robustness to transient network issues.
if time.time() - self.time_failed_count_adjusted > 3600 and self.allowed_fails_left < 3:
self.allowed_fails_left += 1
self.time_failed_count_adjusted = time.time()
except Exception as e:
log.error(f'Failed to log metadata to Mosaic with error: {e}')
self.allowed_fails_left -= 1
self.time_failed_count_adjusted = time.time()
if self.allowed_fails_left <= 0:
self._futures.append(f)
done, incomplete = wait(self._futures, timeout=0.01)
log.info(f'Logged {len(done)} metadata to MosaicML, waiting on {len(incomplete)}')
# Raise any exceptions
for f in done:
if f.exception() is not None:
raise f.exception() # type: ignore
self._futures = list(incomplete)
except Exception:
log.exception('Failed to log metadata to Mosaic') # Prints out full traceback
if self.ignore_exceptions:
log.info('Ignoring exception and disabling MosaicMLLogger.')
self._enabled = False
else:
log.info('Raising exception. To ignore exceptions, set ignore_exceptions=True.')
raise

def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]:
"""Calculates training progress metrics.
Expand Down
20 changes: 20 additions & 0 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]):
fsdp_config.setdefault('activation_checkpointing_reentrant', True)
fsdp_config.setdefault('activation_cpu_offload', False)
fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST')
fsdp_config.setdefault('backward_prefetch_limit', 1)
fsdp_config.setdefault('cpu_offload', False)
fsdp_config.setdefault('flatten_parameters', True)
fsdp_config.setdefault('forward_prefetch', False)
fsdp_config.setdefault('forward_prefetch_limit', 1)
fsdp_config.setdefault('ignored_modules', None)
fsdp_config.setdefault('keep_low_precision_grads', False)
fsdp_config.setdefault('limit_all_gathers', True)
Expand Down Expand Up @@ -508,6 +510,24 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para
**kwargs,
)

if hasattr(fsdp_obj, '_exec_order_data'):
if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'):
fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit']
else:
warnings.warn('FSDP._exec_order_data does not have attribute _forward_prefetch_limit '
'which is unexpected and will result in `forward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')
if hasattr(fsdp_obj._exec_order_data, '_backward_prefetch_limit'):
fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config['backward_prefetch_limit']
else:
warnings.warn('FSDP._exec_order_data does not have attribute _backward_prefetch_limit '
'which is unexpected and will result in `backward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')
else:
warnings.warn('FSDP does not have attribute _exec_order_data which is unexpected and will '
'result in `forward_prefetch_limit` and `backward_prefetch_limit` from FSDP '
'config being ignored. Please open an issue to Composer to report this.')

# Activation Checkpointing
if activation_checkpointing or activation_cpu_offload:
if not activation_checkpointing_reentrant:
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ and also saves them to the file

os.environ["WANDB_MODE"] = "disabled"
os.environ["COMET_API_KEY"] = "<comet_api_key>"
os.environ["MLFLOW_TRACKING_URI"] = ""

.. testcode::
:skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def package_files(prefix: str, directory: str, extension: str):
'py-cpuinfo>=8.0.0,<10',
'packaging>=21.3.0,<23',
'importlib-metadata>=5.0.0,<7',
'mosaicml-cli>=0.5.8,<0.6',
'mosaicml-cli>=0.5.25,<0.6',
]
extra_deps = {}

Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures/autouse_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import pathlib
from concurrent.futures import Future

import mcli
import pytest
Expand Down Expand Up @@ -118,7 +119,9 @@ def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
@pytest.fixture(autouse=True)
def mapi_fixture(monkeypatch):
# Composer auto-adds mosaicml logger when running on platform. Disable logging for tests.
mock_update = lambda *args, **kwargs: None
future_obj = Future()
future_obj.set_result(None)
mock_update = lambda *args, **kwargs: future_obj
monkeypatch.setattr(mcli, 'update_run_metadata', mock_update)


Expand Down
47 changes: 44 additions & 3 deletions tests/loggers/test_mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import json
from concurrent.futures import Future
from typing import Type
from unittest.mock import MagicMock

Expand All @@ -23,10 +24,26 @@

class MockMAPI:

def __init__(self):
def __init__(self, simulate_exception: bool = False):
self.run_metadata = {}

def update_run_metadata(self, run_name, new_metadata):
self.simulate_exception = simulate_exception

def update_run_metadata(self, run_name, new_metadata, future=False, protect=True):
if future:
# Simulate asynchronous behavior using Future
future_obj = Future()
try:
self._update_metadata(run_name, new_metadata)
future_obj.set_result(None) # Set a result to indicate completion
except Exception as e:
future_obj.set_exception(e) # Set an exception if something goes wrong
return future_obj
else:
self._update_metadata(run_name, new_metadata)

def _update_metadata(self, run_name, new_metadata):
if self.simulate_exception:
raise RuntimeError('Simulated exception')
if run_name not in self.run_metadata:
self.run_metadata[run_name] = {}
for k, v in new_metadata.items():
Expand Down Expand Up @@ -94,6 +111,30 @@ def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callba
assert len(mock_mapi.run_metadata.keys()) == 0


@world_size(1, 2)
@pytest.mark.parametrize('ignore_exceptions', [True, False])
def test_logged_data_exception_handling(monkeypatch, world_size: int, ignore_exceptions: bool):
"""Test that exceptions in MAPI are raised properly."""
mock_mapi = MockMAPI(simulate_exception=True)
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
run_name = 'small_chungus'
monkeypatch.setenv('RUN_NAME', run_name)

logger = MosaicMLLogger(ignore_exceptions=ignore_exceptions)
if dist.get_global_rank() != 0:
assert logger._enabled is False
logger._flush_metadata(force_flush=True)
assert logger._enabled is False
elif ignore_exceptions:
assert logger._enabled is True
logger._flush_metadata(force_flush=True)
assert logger._enabled is False
else:
with pytest.raises(RuntimeError, match='Simulated exception'):
assert logger._enabled is True
logger._flush_metadata(force_flush=True)


def test_metric_partial_filtering(monkeypatch):
mock_mapi = MockMAPI()
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
Expand Down
42 changes: 37 additions & 5 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
from composer.models import ComposerClassifier
from composer.trainer.trainer import Trainer
from composer.utils import dist
from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel
from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel,
world_size)


@pytest.mark.parametrize('model', [SimpleWeightTiedModel, EmbeddedWeightTiedModel])
@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE'])
@pytest.mark.parametrize('device', ['cpu', 'meta'])
@pytest.mark.parametrize('reentrant', [True, False])
@pytest.mark.filterwarnings('ignore::UserWarning')
@world_size(2)
@pytest.mark.gpu
@pytest.mark.filterwarnings('ignore:The passed in model appears to have tied weights.*:UserWarning')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, device: str, reentrant: bool):
def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, reentrant: bool, world_size: int,
device: str):
"""test FSDP device initialization for a simple model with weight tying and a model where two modules
from separate submodules have weight tying applied. This test also covers both 'cpu' and
'meta' devices. This is because 'meta' will result in deferred initialization until FSDP is initialized
Expand Down Expand Up @@ -62,15 +65,16 @@ def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision:
@pytest.mark.parametrize('model', [SimpleModel])
@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE'])
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', device: str = 'meta'):
def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', world_size: int):
"""
This test is intended to test FSDP for meta initialization when there are attributes
that are `None` and ensure we don't raise nasty UserWarnings.
"""
num_classes = 2
model = model(num_features=1, num_classes=num_classes, device=device, bias=False)
model = model(num_features=1, num_classes=num_classes, device='meta', bias=False)
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
Expand All @@ -85,3 +89,31 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio
},
max_duration='3ba',
)


@pytest.mark.parametrize('forward_prefetch_limit', [1, 2])
@pytest.mark.parametrize('backward_prefetch_limit', [1, 2])
@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limit: int, world_size: int):
model = SimpleModel()
model.fc1._fsdp_wrap = True
model.fc2._fsdp_wrap = True
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

trainer = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
fsdp_config={
'forward_prefetch_limit': forward_prefetch_limit,
'backward_prefetch_limit': backward_prefetch_limit,
},
max_duration='3ba',
)

trainer.fit()

0 comments on commit 0f6d153

Please sign in to comment.