Skip to content

Commit

Permalink
Switch mosaicml logger to use futures to enable better error handling (
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck authored Nov 14, 2023
1 parent 834838d commit 0f07bf9
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 18 deletions.
9 changes: 9 additions & 0 deletions composer/algorithms/seq_length_warmup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def training_loop(model, train_loader):
<!--pytest.mark.gpu-->
<!--
```python
import os
previous_platform_env = os.environ["MOSAICML_PLATFORM"]
os.environ["MOSAICML_PLATFORM"] = "false"
from tests.common.models import configure_tiny_bert_hf_model
from tests.common.datasets import dummy_bert_lm_dataloader
Expand All @@ -64,6 +67,12 @@ trainer = Trainer(model=model,

trainer.fit()
```
<!--pytest-codeblocks:cont-->
<!--
```python
os.environ["MOSAICML_PLATFORM"] = previous_platform_env
```
-->

### Implementation Details

Expand Down
35 changes: 22 additions & 13 deletions composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import time
import warnings
from concurrent.futures import wait
from functools import reduce
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand Down Expand Up @@ -57,22 +58,25 @@ class MosaicMLLogger(LoggerDestination):
Example 2: ``ignore_keys = ["wall_clock/*"]`` would ignore all wall clock metrics.
(default: ``None``)
ignore_exceptions: Flag to disable logging exceptions. Defaults to False.
"""

def __init__(
self,
log_interval: int = 60,
ignore_keys: Optional[List[str]] = None,
ignore_exceptions: bool = False,
) -> None:
self.log_interval = log_interval
self.ignore_keys = ignore_keys
self.ignore_exceptions = ignore_exceptions
self._enabled = dist.get_global_rank() == 0
if self._enabled:
self.allowed_fails_left = 3
self.time_last_logged = 0
self.train_dataloader_len = None
self.time_failed_count_adjusted = 0
self.buffered_metadata: Dict[str, Any] = {}
self._futures = []

self.run_name = os.environ.get(RUN_NAME_ENV_VAR)
if self.run_name is not None:
log.info(f'Logging to mosaic run {self.run_name}')
Expand Down Expand Up @@ -140,20 +144,25 @@ def _flush_metadata(self, force_flush: bool = False) -> None:
"""Flush buffered metadata to MosaicML if enough time has passed since last flush."""
if self._enabled and (time.time() - self.time_last_logged > self.log_interval or force_flush):
try:
mcli.update_run_metadata(self.run_name, self.buffered_metadata)
f = mcli.update_run_metadata(self.run_name, self.buffered_metadata, future=True, protect=True)
self.buffered_metadata = {}
self.time_last_logged = time.time()
# If we have not failed in the last hour, increase the allowed fails. This increases
# robustness to transient network issues.
if time.time() - self.time_failed_count_adjusted > 3600 and self.allowed_fails_left < 3:
self.allowed_fails_left += 1
self.time_failed_count_adjusted = time.time()
except Exception as e:
log.error(f'Failed to log metadata to Mosaic with error: {e}')
self.allowed_fails_left -= 1
self.time_failed_count_adjusted = time.time()
if self.allowed_fails_left <= 0:
self._futures.append(f)
done, incomplete = wait(self._futures, timeout=0.01)
log.info(f'Logged {len(done)} metadata to MosaicML, waiting on {len(incomplete)}')
# Raise any exceptions
for f in done:
if f.exception() is not None:
raise f.exception() # type: ignore
self._futures = list(incomplete)
except Exception:
log.exception('Failed to log metadata to Mosaic') # Prints out full traceback
if self.ignore_exceptions:
log.info('Ignoring exception and disabling MosaicMLLogger.')
self._enabled = False
else:
log.info('Raising exception. To ignore exceptions, set ignore_exceptions=True.')
raise

def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]:
"""Calculates training progress metrics.
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ and also saves them to the file

os.environ["WANDB_MODE"] = "disabled"
os.environ["COMET_API_KEY"] = "<comet_api_key>"
os.environ["MLFLOW_TRACKING_URI"] = ""

.. testcode::
:skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def package_files(prefix: str, directory: str, extension: str):
'py-cpuinfo>=8.0.0,<10',
'packaging>=21.3.0,<23',
'importlib-metadata>=5.0.0,<7',
'mosaicml-cli>=0.5.8,<0.6',
'mosaicml-cli>=0.5.25,<0.6',
]
extra_deps = {}

Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures/autouse_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import pathlib
from concurrent.futures import Future

import mcli
import pytest
Expand Down Expand Up @@ -118,7 +119,9 @@ def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
@pytest.fixture(autouse=True)
def mapi_fixture(monkeypatch):
# Composer auto-adds mosaicml logger when running on platform. Disable logging for tests.
mock_update = lambda *args, **kwargs: None
future_obj = Future()
future_obj.set_result(None)
mock_update = lambda *args, **kwargs: future_obj
monkeypatch.setattr(mcli, 'update_run_metadata', mock_update)


Expand Down
47 changes: 44 additions & 3 deletions tests/loggers/test_mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import json
from concurrent.futures import Future
from typing import Type
from unittest.mock import MagicMock

Expand All @@ -23,10 +24,26 @@

class MockMAPI:

def __init__(self):
def __init__(self, simulate_exception: bool = False):
self.run_metadata = {}

def update_run_metadata(self, run_name, new_metadata):
self.simulate_exception = simulate_exception

def update_run_metadata(self, run_name, new_metadata, future=False, protect=True):
if future:
# Simulate asynchronous behavior using Future
future_obj = Future()
try:
self._update_metadata(run_name, new_metadata)
future_obj.set_result(None) # Set a result to indicate completion
except Exception as e:
future_obj.set_exception(e) # Set an exception if something goes wrong
return future_obj
else:
self._update_metadata(run_name, new_metadata)

def _update_metadata(self, run_name, new_metadata):
if self.simulate_exception:
raise RuntimeError('Simulated exception')
if run_name not in self.run_metadata:
self.run_metadata[run_name] = {}
for k, v in new_metadata.items():
Expand Down Expand Up @@ -94,6 +111,30 @@ def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callba
assert len(mock_mapi.run_metadata.keys()) == 0


@world_size(1, 2)
@pytest.mark.parametrize('ignore_exceptions', [True, False])
def test_logged_data_exception_handling(monkeypatch, world_size: int, ignore_exceptions: bool):
"""Test that exceptions in MAPI are raised properly."""
mock_mapi = MockMAPI(simulate_exception=True)
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
run_name = 'small_chungus'
monkeypatch.setenv('RUN_NAME', run_name)

logger = MosaicMLLogger(ignore_exceptions=ignore_exceptions)
if dist.get_global_rank() != 0:
assert logger._enabled is False
logger._flush_metadata(force_flush=True)
assert logger._enabled is False
elif ignore_exceptions:
assert logger._enabled is True
logger._flush_metadata(force_flush=True)
assert logger._enabled is False
else:
with pytest.raises(RuntimeError, match='Simulated exception'):
assert logger._enabled is True
logger._flush_metadata(force_flush=True)


def test_metric_partial_filtering(monkeypatch):
mock_mapi = MockMAPI()
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
Expand Down

0 comments on commit 0f07bf9

Please sign in to comment.