From 21bc3db87b2b4d932f45d8a98d301bdd43a5c4d8 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Fri, 2 Feb 2024 14:03:37 -0800 Subject: [PATCH] Add OOM observer with memory visualizations (#2958) * add oomobserver * update docstring * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel * use pyskip * call trainer fit * fix ci * Update composer/callbacks/oom_observer.py Co-authored-by: Charles Tang * addresss comments * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel * add test wiht snapshot * update doc * fix typo * use log info * fix format * fix format * fix ci * fix cpu test * fix ci * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel * update test * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel * use warnings * add pytest filter user warnings in cpu callback tests * fix typo --------- Co-authored-by: Mihir Patel Co-authored-by: Charles Tang --- composer/callbacks/__init__.py | 2 + composer/callbacks/oom_observer.py | 178 ++++++++++++++++++ composer/trainer/trainer.py | 6 +- docs/source/trainer/callbacks.rst | 1 + tests/callbacks/callback_settings.py | 8 +- tests/callbacks/test_callbacks.py | 6 + .../test_loggers_across_callbacks.py | 1 + tests/callbacks/test_oom_observer.py | 88 +++++++++ tests/loggers/test_mosaicml_logger.py | 1 + tests/loggers/test_wandb_logger.py | 1 + 10 files changed, 287 insertions(+), 5 deletions(-) create mode 100644 composer/callbacks/oom_observer.py create mode 100644 tests/callbacks/test_oom_observer.py diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py index ee1ca0ae18..16a50a31a9 100644 --- a/composer/callbacks/__init__.py +++ b/composer/callbacks/__init__.py @@ -18,6 +18,7 @@ from composer.callbacks.memory_snapshot import MemorySnapshot from composer.callbacks.mlperf import MLPerfCallback from composer.callbacks.nan_monitor import NaNMonitor +from composer.callbacks.oom_observer import OOMObserver from composer.callbacks.optimizer_monitor import OptimizerMonitor from composer.callbacks.runtime_estimator import RuntimeEstimator from composer.callbacks.speed_monitor import SpeedMonitor @@ -42,4 +43,5 @@ 'Generate', 'FreeOutputs', 'MemorySnapshot', + 'OOMObserver', ] diff --git a/composer/callbacks/oom_observer.py b/composer/callbacks/oom_observer.py new file mode 100644 index 0000000000..87d818df1a --- /dev/null +++ b/composer/callbacks/oom_observer.py @@ -0,0 +1,178 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a memory snapshot during an OutOfMemory exception.""" + +import logging +import os +import pickle +import warnings +from typing import Optional + +import torch.cuda +from packaging import version + +from composer import State +from composer.core import Callback, State +from composer.loggers import Logger +from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri + +log = logging.getLogger(__name__) + +__all__ = ['OOMObserver'] + + +class OOMObserver(Callback): + """Generate visualizations of the state of allocated memory during an OutOfMemory exception. + + This callback registers an observer with the allocator that will be called everytime it is about to raise an OutOfMemoryError before any memory has been release while unwinding the exception. OOMObserver is attached to the Trainer at init stage. The visualizations include a snapshot of the memory state, a trace plot, a segment plot, a segment flamegraph, and a memory flamegraph. + + Example: + .. doctest:: + + >>> from composer import Trainer + >>> from composer.callbacks import OOMObserver + >>> # constructing trainer object with this callback + >>> trainer = Trainer( + ... model=model, + ... train_dataloader=train_dataloader, + ... eval_dataloader=eval_dataloader, + ... optimizers=optimizer, + ... max_duration="1ep", + ... callbacks=[OOMObserver()], + ... ) + + .. note:: + OOMObserver is only supported for GPU devices. + + Args: + max_entries (int, optional): Maximum number of memory alloc/free events to record. Defaults to 100000. + folder (str, optional): A format string describing the folder containing the memory visualization files. + Defaults to ``'{{run_name}}/torch_traces'``. + filename (str, optional): A format string describing the prefix used to name the memory visualization files. + Defaults to ``'rank{{rank}}_oom'``. + remote_file_name (str, optional): A format string describing the prefix for the memory visualization remote file name. + Defaults to ``'{{run_name}}/oom_traces/rank{{rank}}_oom'``. + + Whenever a trace file is saved, it is also uploaded as a file according to this format string. + The same format variables as for ``filename`` are available. + + .. seealso:: :doc:`Uploading Files` for notes for file uploading. + + Leading slashes (``'/'``) will be stripped. + + To disable uploading trace files, set this parameter to ``None``. + overwrite (bool, optional): Whether to override existing memory snapshots. Defaults to False. + + If False, then the trace folder as determined by ``folder`` must be empty. + """ + + def __init__( + self, + max_entries: int = 100000, + folder: str = '{run_name}/torch_traces', + filename: str = 'rank{rank}_oom', + remote_file_name: Optional[str] = '{run_name}/oom_traces/rank{rank}_oom', + overwrite: bool = False, + ) -> None: + self.max_entries = max_entries + self.folder = folder + self.folder_name = None + self.filename = filename + self.remote_file_name = remote_file_name + self.overwrite = overwrite + if remote_file_name: + self.remote_file_name = remote_file_name + _, _, self.remote_path_in_bucket = parse_uri(remote_file_name) + else: + self.remote_path_in_bucket = None + + if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore + # OOMObserver is only supported in torch v2.1.0 or higher + self._enabled = True + else: + self._enabled = False + warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.') + + def init(self, state: State, logger: Logger) -> None: + if not self._enabled: + return + # Not relying on `torch.cuda.is_available()` since the model could be on CPU. + model_device = next(state.model.parameters()).device + + if model_device.type not in ('cuda', 'meta'): + warnings.warn( + f'OOMObserver only works on CUDA devices, but the model is on {model_device.type}. Disabling OOMObserver.' + ) + self._enabled = False + else: + self.folder_name = format_name_with_dist(self.folder, state.run_name) + os.makedirs(self.folder_name, exist_ok=True) + if not self.overwrite: + ensure_folder_is_empty(self.folder_name) + + def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int): + # Snapshot right after an OOM happened + log.warning('Out Of Memory (OOM) observed') + + assert self.filename + assert self.folder_name, 'folder_name must be set in init' + filename = os.path.join( + self.folder_name, + format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp)) + + try: + snapshot_file = filename + '_snapshot.pickle' + trace_plot_file = filename + '_trace_plot.html' + segment_plot_file = filename + '_segment_plot.html' + segment_flamegraph_file = filename + '_segment_flamegraph.svg' + memory_flamegraph_file = filename + '_memory_flamegraph.svg' + log.info(f'Dumping OOMObserver visualizations') + + snapshot = torch.cuda.memory._snapshot() + # No data was recorded - avoids a `ValueError` in `trace_plot` + if all(len(t) == 0 for t in snapshot['device_traces']): + log.info(f'No allocation is recorded in memory snapshot)') + return + + with open(snapshot_file, 'wb') as fd: + pickle.dump(snapshot, fd) + + with open(trace_plot_file, 'w+') as fd: + fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore + + with open(segment_plot_file, 'w+') as fd: + fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore + + with open(segment_flamegraph_file, 'w+') as fd: + fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore + + with open(memory_flamegraph_file, 'w+') as fd: + fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore + + log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM') + + if self.remote_path_in_bucket is not None: + for f in [ + snapshot_file, trace_plot_file, segment_plot_file, segment_flamegraph_file, + memory_flamegraph_file + ]: + remote_file_name = (self.remote_path_in_bucket + os.path.basename(f)).lstrip('/') + log.info(f'Uploading memory visualization to remote: {remote_file_name} from {f}') + try: + logger.upload_file(remote_file_name=remote_file_name, file_path=f, overwrite=self.overwrite) + except FileExistsError as e: + raise FileExistsError( + f'Uploading memory visualizations failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite memory visualizations with Trainer, set save_overwrite to True.' + ) from e + + except Exception as e: + log.error(f'Failed to capture memory snapshot {e}') + + if self._enabled: + torch.cuda.memory._record_memory_history( + True, # type: ignore + trace_alloc_max_entries=self.max_entries, + trace_alloc_record_context=True) + torch._C._cuda_attach_out_of_memory_observer(oom_observer) # type: ignore + log.info('OOMObserver is enabled and registered') diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 99dd2d0437..5dd97bda64 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -36,7 +36,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchmetrics import Metric -from composer.callbacks import CheckpointSaver, MemorySnapshot, OptimizerMonitor +from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching) @@ -1072,9 +1072,9 @@ def __init__( loggers.append(remote_ud) self.state.profiler.bind_to_state(self.state) - # MemorySnapshot + # MemorySnapshot, OOMObserver for cb in self.state.callbacks: - if isinstance(cb, MemorySnapshot): + if isinstance(cb, MemorySnapshot) or isinstance(cb, OOMObserver): if cb.remote_file_name: remote_ud = maybe_create_remote_uploader_downloader_from_uri(uri=cb.remote_file_name, loggers=loggers) diff --git a/docs/source/trainer/callbacks.rst b/docs/source/trainer/callbacks.rst index 7210ef72e8..9f6f26a9dd 100644 --- a/docs/source/trainer/callbacks.rst +++ b/docs/source/trainer/callbacks.rst @@ -51,6 +51,7 @@ components of training. ~optimizer_monitor.OptimizerMonitor ~memory_monitor.MemoryMonitor ~memory_snapshot.MemorySnapshot + ~oom_observer.OOMObserver ~nan_monitor.NaNMonitor ~image_visualizer.ImageVisualizer ~mlperf.MLPerfCallback diff --git a/tests/callbacks/callback_settings.py b/tests/callbacks/callback_settings.py index 492b5988be..55ad1c641c 100644 --- a/tests/callbacks/callback_settings.py +++ b/tests/callbacks/callback_settings.py @@ -12,8 +12,8 @@ import composer.profiler from composer import Callback from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, ImageVisualizer, - MemoryMonitor, MemorySnapshot, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor, - ThresholdStopper) + MemoryMonitor, MemorySnapshot, MLPerfCallback, OOMObserver, SpeedMonitor, + SystemMetricsMonitor, ThresholdStopper) from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger, RemoteUploaderDownloader, TensorboardLogger, WandBLogger) from composer.models.base import ComposerModel @@ -132,6 +132,10 @@ pytest.mark.filterwarnings( r'ignore:The memory snapshot only works on CUDA devices, but the model is on cpu:UserWarning') ], + OOMObserver: [ + pytest.mark.filterwarnings( + r'ignore:The oom observer only works on CUDA devices, but the model is on cpu:UserWarning') + ], MLPerfCallback: [pytest.mark.skipif(not _MLPERF_INSTALLED, reason='MLPerf is optional')], WandBLogger: [ pytest.mark.filterwarnings(r'ignore:unclosed file:ResourceWarning'), diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 695be08c55..f0ddbe43cc 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -43,12 +43,14 @@ class TestCallbacks: def setup_class(cls): pytest.importorskip('wandb', reason='WandB is optional.') + @pytest.mark.filterwarnings('ignore::UserWarning') def test_callback_is_constructable(self, cb_cls: Type[Callback]): cb_kwargs = get_cb_kwargs(cb_cls) cb = cb_cls(**cb_kwargs) assert isinstance(cb_cls, type) assert isinstance(cb, cb_cls) + @pytest.mark.filterwarnings('ignore::UserWarning') def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: State): """Test that callbacks do not crash when Event.FIT_START and Event.FIT_END is called multiple times.""" cb_kwargs = get_cb_kwargs(cb_cls) @@ -69,6 +71,7 @@ def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: S engine.run_event(Event.FIT_START) engine.run_event(Event.FIT_END) + @pytest.mark.filterwarnings('ignore::UserWarning') def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State): """Test that callbacks do not crash when .close() and .post_close() are called multiple times.""" cb_kwargs = get_cb_kwargs(cb_cls) @@ -85,6 +88,7 @@ def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State): engine.close() engine.close() + @pytest.mark.filterwarnings('ignore::UserWarning') def test_multiple_init_and_close(self, cb_cls: Type[Callback], dummy_state: State): """Test that callbacks do not crash when INIT/.close()/.post_close() are called multiple times in that order.""" cb_kwargs = get_cb_kwargs(cb_cls) @@ -136,6 +140,7 @@ def _get_trainer(self, cb: Callback, device_train_microbatch_size: int): torch_prof_memory_filename=None), ) + @pytest.mark.filterwarnings('ignore::UserWarning') def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool): del _remote # unused. `_remote` must be passed through to parameterize the test markers. cb_kwargs = get_cb_kwargs(cb_cls) @@ -143,6 +148,7 @@ def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int, trainer = self._get_trainer(cb, device_train_microbatch_size) trainer.fit() + @pytest.mark.filterwarnings('ignore::UserWarning') def test_trains_multiple_calls(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool): """ Tests that training with multiple fits complete. diff --git a/tests/callbacks/test_loggers_across_callbacks.py b/tests/callbacks/test_loggers_across_callbacks.py index 92363e7aa5..1c58babf0b 100644 --- a/tests/callbacks/test_loggers_across_callbacks.py +++ b/tests/callbacks/test_loggers_across_callbacks.py @@ -15,6 +15,7 @@ @pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True)) @pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True)) +@pytest.mark.filterwarnings('ignore::UserWarning') def test_loggers_on_callbacks(logger_cls: Type[LoggerDestination], callback_cls: Type[Callback]): if logger_cls in [ProgressBarLogger, ConsoleLogger, SlackLogger]: pytest.skip() diff --git a/tests/callbacks/test_oom_observer.py b/tests/callbacks/test_oom_observer.py new file mode 100644 index 0000000000..5fbb5bd8a3 --- /dev/null +++ b/tests/callbacks/test_oom_observer.py @@ -0,0 +1,88 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import pathlib + +import pytest +import torch +from packaging import version +from torch.utils.data import DataLoader + +from composer import State, Trainer +from composer.callbacks import MemorySnapshot, OOMObserver +from composer.loggers import LoggerDestination +from composer.trainer import Trainer +from tests.common import RandomClassificationDataset, SimpleModel + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), + reason='OOM Observer requires PyTorch 2.1 or higher') +def test_oom_observer_warnings_on_cpu_models(): + ob = OOMObserver() + with pytest.warns(UserWarning): + Trainer( + model=SimpleModel(), + callbacks=ob, + train_dataloader=DataLoader(RandomClassificationDataset()), + max_duration='1ba', + device='cpu', + ) + assert ob._enabled is False + + +class FileUploaderTracker(LoggerDestination): + + def __init__(self) -> None: + self.uploaded_files = [] + + def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool): + del state, overwrite # unused + self.uploaded_files.append((remote_file_name, file_path)) + + +@pytest.mark.gpu +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), + reason='OOM Observer requires PyTorch 2.1 or higher') +def test_oom_observer(): + # Construct the callbacks + oom_observer = OOMObserver() + simple_model = SimpleModel() + file_tracker_destination = FileUploaderTracker() + + with pytest.raises(torch.cuda.OutOfMemoryError): + trainer = Trainer( + model=simple_model, + loggers=file_tracker_destination, + callbacks=oom_observer, + train_dataloader=DataLoader(RandomClassificationDataset()), + max_duration='2ba', + ) + + # trigger OOM + torch.empty(1024 * 1024 * 1024 * 1024, device='cuda') + + trainer.fit() + + assert len(file_tracker_destination.uploaded_files) == 5 + + +@pytest.mark.gpu +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), + reason='OOM Observer requires PyTorch 2.1 or higher') +def test_oom_observer_with_memory_snapshot(): + # Construct the callbacks + oom_observer = OOMObserver() + memory_snapshot = MemorySnapshot(skip_batches=0, interval='1ba') + simple_model = SimpleModel() + file_tracker_destination = FileUploaderTracker() + + trainer = Trainer( + model=simple_model, + loggers=file_tracker_destination, + callbacks=[oom_observer, memory_snapshot], + train_dataloader=DataLoader(RandomClassificationDataset()), + max_duration='2ba', + ) + + trainer.fit() + assert len(file_tracker_destination.uploaded_files) == 1 diff --git a/tests/loggers/test_mosaicml_logger.py b/tests/loggers/test_mosaicml_logger.py index ba35beac8b..0834e3dbf0 100644 --- a/tests/loggers/test_mosaicml_logger.py +++ b/tests/loggers/test_mosaicml_logger.py @@ -85,6 +85,7 @@ def test_format_data_to_json_serializable(): @pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True)) @world_size(1, 2) +@pytest.mark.filterwarnings('ignore::UserWarning') def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callback], world_size): """Test that all logged data is json serializable, which is a requirement to use MAPI.""" diff --git a/tests/loggers/test_wandb_logger.py b/tests/loggers/test_wandb_logger.py index 1ccfc5e53a..c9cfe0fc6c 100644 --- a/tests/loggers/test_wandb_logger.py +++ b/tests/loggers/test_wandb_logger.py @@ -247,6 +247,7 @@ def test_wandb_log_metrics(test_wandb_logger): @pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True)) +@pytest.mark.filterwarnings('ignore::UserWarning') def test_logged_data_is_json_serializable(callback_cls: Type[Callback]): """Test that all logged data is json serializable, which is a requirement to use wandb.""" pytest.importorskip('wandb', reason='wandb is optional')