From 0f07bf9398e4cf2e05c2bc3d1c261a6c1be195f7 Mon Sep 17 00:00:00 2001 From: Charles Tang Date: Tue, 14 Nov 2023 15:55:21 -0800 Subject: [PATCH] Switch mosaicml logger to use futures to enable better error handling (#2702) --- .../algorithms/seq_length_warmup/README.md | 9 ++++ composer/loggers/mosaicml_logger.py | 35 +++++++++----- docs/source/trainer/logging.rst | 1 + setup.py | 2 +- tests/fixtures/autouse_fixtures.py | 5 +- tests/loggers/test_mosaicml_logger.py | 47 +++++++++++++++++-- 6 files changed, 81 insertions(+), 18 deletions(-) diff --git a/composer/algorithms/seq_length_warmup/README.md b/composer/algorithms/seq_length_warmup/README.md index e26724225c..84aee2c697 100644 --- a/composer/algorithms/seq_length_warmup/README.md +++ b/composer/algorithms/seq_length_warmup/README.md @@ -44,6 +44,9 @@ def training_loop(model, train_loader): + ### Implementation Details diff --git a/composer/loggers/mosaicml_logger.py b/composer/loggers/mosaicml_logger.py index a23bfd9310..d14b3d81c5 100644 --- a/composer/loggers/mosaicml_logger.py +++ b/composer/loggers/mosaicml_logger.py @@ -12,6 +12,7 @@ import os import time import warnings +from concurrent.futures import wait from functools import reduce from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -57,22 +58,25 @@ class MosaicMLLogger(LoggerDestination): Example 2: ``ignore_keys = ["wall_clock/*"]`` would ignore all wall clock metrics. (default: ``None``) + ignore_exceptions: Flag to disable logging exceptions. Defaults to False. """ def __init__( self, log_interval: int = 60, ignore_keys: Optional[List[str]] = None, + ignore_exceptions: bool = False, ) -> None: self.log_interval = log_interval self.ignore_keys = ignore_keys + self.ignore_exceptions = ignore_exceptions self._enabled = dist.get_global_rank() == 0 if self._enabled: - self.allowed_fails_left = 3 self.time_last_logged = 0 self.train_dataloader_len = None - self.time_failed_count_adjusted = 0 self.buffered_metadata: Dict[str, Any] = {} + self._futures = [] + self.run_name = os.environ.get(RUN_NAME_ENV_VAR) if self.run_name is not None: log.info(f'Logging to mosaic run {self.run_name}') @@ -140,20 +144,25 @@ def _flush_metadata(self, force_flush: bool = False) -> None: """Flush buffered metadata to MosaicML if enough time has passed since last flush.""" if self._enabled and (time.time() - self.time_last_logged > self.log_interval or force_flush): try: - mcli.update_run_metadata(self.run_name, self.buffered_metadata) + f = mcli.update_run_metadata(self.run_name, self.buffered_metadata, future=True, protect=True) self.buffered_metadata = {} self.time_last_logged = time.time() - # If we have not failed in the last hour, increase the allowed fails. This increases - # robustness to transient network issues. - if time.time() - self.time_failed_count_adjusted > 3600 and self.allowed_fails_left < 3: - self.allowed_fails_left += 1 - self.time_failed_count_adjusted = time.time() - except Exception as e: - log.error(f'Failed to log metadata to Mosaic with error: {e}') - self.allowed_fails_left -= 1 - self.time_failed_count_adjusted = time.time() - if self.allowed_fails_left <= 0: + self._futures.append(f) + done, incomplete = wait(self._futures, timeout=0.01) + log.info(f'Logged {len(done)} metadata to MosaicML, waiting on {len(incomplete)}') + # Raise any exceptions + for f in done: + if f.exception() is not None: + raise f.exception() # type: ignore + self._futures = list(incomplete) + except Exception: + log.exception('Failed to log metadata to Mosaic') # Prints out full traceback + if self.ignore_exceptions: + log.info('Ignoring exception and disabling MosaicMLLogger.') self._enabled = False + else: + log.info('Raising exception. To ignore exceptions, set ignore_exceptions=True.') + raise def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]: """Calculates training progress metrics. diff --git a/docs/source/trainer/logging.rst b/docs/source/trainer/logging.rst index 48ca8c47b2..7458f89ff6 100644 --- a/docs/source/trainer/logging.rst +++ b/docs/source/trainer/logging.rst @@ -35,6 +35,7 @@ and also saves them to the file os.environ["WANDB_MODE"] = "disabled" os.environ["COMET_API_KEY"] = "" + os.environ["MLFLOW_TRACKING_URI"] = "" .. testcode:: :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED diff --git a/setup.py b/setup.py index 78f8af7746..065104c75b 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def package_files(prefix: str, directory: str, extension: str): 'py-cpuinfo>=8.0.0,<10', 'packaging>=21.3.0,<23', 'importlib-metadata>=5.0.0,<7', - 'mosaicml-cli>=0.5.8,<0.6', + 'mosaicml-cli>=0.5.25,<0.6', ] extra_deps = {} diff --git a/tests/fixtures/autouse_fixtures.py b/tests/fixtures/autouse_fixtures.py index fc836be679..f44a74c363 100644 --- a/tests/fixtures/autouse_fixtures.py +++ b/tests/fixtures/autouse_fixtures.py @@ -5,6 +5,7 @@ import logging import os import pathlib +from concurrent.futures import Future import mcli import pytest @@ -118,7 +119,9 @@ def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch): @pytest.fixture(autouse=True) def mapi_fixture(monkeypatch): # Composer auto-adds mosaicml logger when running on platform. Disable logging for tests. - mock_update = lambda *args, **kwargs: None + future_obj = Future() + future_obj.set_result(None) + mock_update = lambda *args, **kwargs: future_obj monkeypatch.setattr(mcli, 'update_run_metadata', mock_update) diff --git a/tests/loggers/test_mosaicml_logger.py b/tests/loggers/test_mosaicml_logger.py index 1258e47e21..5461829ab5 100644 --- a/tests/loggers/test_mosaicml_logger.py +++ b/tests/loggers/test_mosaicml_logger.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +from concurrent.futures import Future from typing import Type from unittest.mock import MagicMock @@ -23,10 +24,26 @@ class MockMAPI: - def __init__(self): + def __init__(self, simulate_exception: bool = False): self.run_metadata = {} - - def update_run_metadata(self, run_name, new_metadata): + self.simulate_exception = simulate_exception + + def update_run_metadata(self, run_name, new_metadata, future=False, protect=True): + if future: + # Simulate asynchronous behavior using Future + future_obj = Future() + try: + self._update_metadata(run_name, new_metadata) + future_obj.set_result(None) # Set a result to indicate completion + except Exception as e: + future_obj.set_exception(e) # Set an exception if something goes wrong + return future_obj + else: + self._update_metadata(run_name, new_metadata) + + def _update_metadata(self, run_name, new_metadata): + if self.simulate_exception: + raise RuntimeError('Simulated exception') if run_name not in self.run_metadata: self.run_metadata[run_name] = {} for k, v in new_metadata.items(): @@ -94,6 +111,30 @@ def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callba assert len(mock_mapi.run_metadata.keys()) == 0 +@world_size(1, 2) +@pytest.mark.parametrize('ignore_exceptions', [True, False]) +def test_logged_data_exception_handling(monkeypatch, world_size: int, ignore_exceptions: bool): + """Test that exceptions in MAPI are raised properly.""" + mock_mapi = MockMAPI(simulate_exception=True) + monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata) + run_name = 'small_chungus' + monkeypatch.setenv('RUN_NAME', run_name) + + logger = MosaicMLLogger(ignore_exceptions=ignore_exceptions) + if dist.get_global_rank() != 0: + assert logger._enabled is False + logger._flush_metadata(force_flush=True) + assert logger._enabled is False + elif ignore_exceptions: + assert logger._enabled is True + logger._flush_metadata(force_flush=True) + assert logger._enabled is False + else: + with pytest.raises(RuntimeError, match='Simulated exception'): + assert logger._enabled is True + logger._flush_metadata(force_flush=True) + + def test_metric_partial_filtering(monkeypatch): mock_mapi = MockMAPI() monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)