forked from mosaicml/composer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add OOM observer with memory visualizations (mosaicml#2958)
* add oomobserver * update docstring * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * use pyskip * call trainer fit * fix ci * Update composer/callbacks/oom_observer.py Co-authored-by: Charles Tang <[email protected]> * addresss comments * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * add test wiht snapshot * update doc * fix typo * use log info * fix format * fix format * fix ci * fix cpu test * fix ci * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * update test * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * Update tests/callbacks/test_oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * Update composer/callbacks/oom_observer.py Co-authored-by: Mihir Patel <[email protected]> * use warnings * add pytest filter user warnings in cpu callback tests * fix typo --------- Co-authored-by: Mihir Patel <[email protected]> Co-authored-by: Charles Tang <[email protected]>
- Loading branch information
1 parent
e4ee99e
commit 21bc3db
Showing
10 changed files
with
287 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Generate a memory snapshot during an OutOfMemory exception.""" | ||
|
||
import logging | ||
import os | ||
import pickle | ||
import warnings | ||
from typing import Optional | ||
|
||
import torch.cuda | ||
from packaging import version | ||
|
||
from composer import State | ||
from composer.core import Callback, State | ||
from composer.loggers import Logger | ||
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
__all__ = ['OOMObserver'] | ||
|
||
|
||
class OOMObserver(Callback): | ||
"""Generate visualizations of the state of allocated memory during an OutOfMemory exception. | ||
This callback registers an observer with the allocator that will be called everytime it is about to raise an OutOfMemoryError before any memory has been release while unwinding the exception. OOMObserver is attached to the Trainer at init stage. The visualizations include a snapshot of the memory state, a trace plot, a segment plot, a segment flamegraph, and a memory flamegraph. | ||
Example: | ||
.. doctest:: | ||
>>> from composer import Trainer | ||
>>> from composer.callbacks import OOMObserver | ||
>>> # constructing trainer object with this callback | ||
>>> trainer = Trainer( | ||
... model=model, | ||
... train_dataloader=train_dataloader, | ||
... eval_dataloader=eval_dataloader, | ||
... optimizers=optimizer, | ||
... max_duration="1ep", | ||
... callbacks=[OOMObserver()], | ||
... ) | ||
.. note:: | ||
OOMObserver is only supported for GPU devices. | ||
Args: | ||
max_entries (int, optional): Maximum number of memory alloc/free events to record. Defaults to 100000. | ||
folder (str, optional): A format string describing the folder containing the memory visualization files. | ||
Defaults to ``'{{run_name}}/torch_traces'``. | ||
filename (str, optional): A format string describing the prefix used to name the memory visualization files. | ||
Defaults to ``'rank{{rank}}_oom'``. | ||
remote_file_name (str, optional): A format string describing the prefix for the memory visualization remote file name. | ||
Defaults to ``'{{run_name}}/oom_traces/rank{{rank}}_oom'``. | ||
Whenever a trace file is saved, it is also uploaded as a file according to this format string. | ||
The same format variables as for ``filename`` are available. | ||
.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading. | ||
Leading slashes (``'/'``) will be stripped. | ||
To disable uploading trace files, set this parameter to ``None``. | ||
overwrite (bool, optional): Whether to override existing memory snapshots. Defaults to False. | ||
If False, then the trace folder as determined by ``folder`` must be empty. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
max_entries: int = 100000, | ||
folder: str = '{run_name}/torch_traces', | ||
filename: str = 'rank{rank}_oom', | ||
remote_file_name: Optional[str] = '{run_name}/oom_traces/rank{rank}_oom', | ||
overwrite: bool = False, | ||
) -> None: | ||
self.max_entries = max_entries | ||
self.folder = folder | ||
self.folder_name = None | ||
self.filename = filename | ||
self.remote_file_name = remote_file_name | ||
self.overwrite = overwrite | ||
if remote_file_name: | ||
self.remote_file_name = remote_file_name | ||
_, _, self.remote_path_in_bucket = parse_uri(remote_file_name) | ||
else: | ||
self.remote_path_in_bucket = None | ||
|
||
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore | ||
# OOMObserver is only supported in torch v2.1.0 or higher | ||
self._enabled = True | ||
else: | ||
self._enabled = False | ||
warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.') | ||
|
||
def init(self, state: State, logger: Logger) -> None: | ||
if not self._enabled: | ||
return | ||
# Not relying on `torch.cuda.is_available()` since the model could be on CPU. | ||
model_device = next(state.model.parameters()).device | ||
|
||
if model_device.type not in ('cuda', 'meta'): | ||
warnings.warn( | ||
f'OOMObserver only works on CUDA devices, but the model is on {model_device.type}. Disabling OOMObserver.' | ||
) | ||
self._enabled = False | ||
else: | ||
self.folder_name = format_name_with_dist(self.folder, state.run_name) | ||
os.makedirs(self.folder_name, exist_ok=True) | ||
if not self.overwrite: | ||
ensure_folder_is_empty(self.folder_name) | ||
|
||
def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int): | ||
# Snapshot right after an OOM happened | ||
log.warning('Out Of Memory (OOM) observed') | ||
|
||
assert self.filename | ||
assert self.folder_name, 'folder_name must be set in init' | ||
filename = os.path.join( | ||
self.folder_name, | ||
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp)) | ||
|
||
try: | ||
snapshot_file = filename + '_snapshot.pickle' | ||
trace_plot_file = filename + '_trace_plot.html' | ||
segment_plot_file = filename + '_segment_plot.html' | ||
segment_flamegraph_file = filename + '_segment_flamegraph.svg' | ||
memory_flamegraph_file = filename + '_memory_flamegraph.svg' | ||
log.info(f'Dumping OOMObserver visualizations') | ||
|
||
snapshot = torch.cuda.memory._snapshot() | ||
# No data was recorded - avoids a `ValueError` in `trace_plot` | ||
if all(len(t) == 0 for t in snapshot['device_traces']): | ||
log.info(f'No allocation is recorded in memory snapshot)') | ||
return | ||
|
||
with open(snapshot_file, 'wb') as fd: | ||
pickle.dump(snapshot, fd) | ||
|
||
with open(trace_plot_file, 'w+') as fd: | ||
fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore | ||
|
||
with open(segment_plot_file, 'w+') as fd: | ||
fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore | ||
|
||
with open(segment_flamegraph_file, 'w+') as fd: | ||
fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore | ||
|
||
with open(memory_flamegraph_file, 'w+') as fd: | ||
fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore | ||
|
||
log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM') | ||
|
||
if self.remote_path_in_bucket is not None: | ||
for f in [ | ||
snapshot_file, trace_plot_file, segment_plot_file, segment_flamegraph_file, | ||
memory_flamegraph_file | ||
]: | ||
remote_file_name = (self.remote_path_in_bucket + os.path.basename(f)).lstrip('/') | ||
log.info(f'Uploading memory visualization to remote: {remote_file_name} from {f}') | ||
try: | ||
logger.upload_file(remote_file_name=remote_file_name, file_path=f, overwrite=self.overwrite) | ||
except FileExistsError as e: | ||
raise FileExistsError( | ||
f'Uploading memory visualizations failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite memory visualizations with Trainer, set save_overwrite to True.' | ||
) from e | ||
|
||
except Exception as e: | ||
log.error(f'Failed to capture memory snapshot {e}') | ||
|
||
if self._enabled: | ||
torch.cuda.memory._record_memory_history( | ||
True, # type: ignore | ||
trace_alloc_max_entries=self.max_entries, | ||
trace_alloc_record_context=True) | ||
torch._C._cuda_attach_out_of_memory_observer(oom_observer) # type: ignore | ||
log.info('OOMObserver is enabled and registered') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pathlib | ||
|
||
import pytest | ||
import torch | ||
from packaging import version | ||
from torch.utils.data import DataLoader | ||
|
||
from composer import State, Trainer | ||
from composer.callbacks import MemorySnapshot, OOMObserver | ||
from composer.loggers import LoggerDestination | ||
from composer.trainer import Trainer | ||
from tests.common import RandomClassificationDataset, SimpleModel | ||
|
||
|
||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), | ||
reason='OOM Observer requires PyTorch 2.1 or higher') | ||
def test_oom_observer_warnings_on_cpu_models(): | ||
ob = OOMObserver() | ||
with pytest.warns(UserWarning): | ||
Trainer( | ||
model=SimpleModel(), | ||
callbacks=ob, | ||
train_dataloader=DataLoader(RandomClassificationDataset()), | ||
max_duration='1ba', | ||
device='cpu', | ||
) | ||
assert ob._enabled is False | ||
|
||
|
||
class FileUploaderTracker(LoggerDestination): | ||
|
||
def __init__(self) -> None: | ||
self.uploaded_files = [] | ||
|
||
def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool): | ||
del state, overwrite # unused | ||
self.uploaded_files.append((remote_file_name, file_path)) | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), | ||
reason='OOM Observer requires PyTorch 2.1 or higher') | ||
def test_oom_observer(): | ||
# Construct the callbacks | ||
oom_observer = OOMObserver() | ||
simple_model = SimpleModel() | ||
file_tracker_destination = FileUploaderTracker() | ||
|
||
with pytest.raises(torch.cuda.OutOfMemoryError): | ||
trainer = Trainer( | ||
model=simple_model, | ||
loggers=file_tracker_destination, | ||
callbacks=oom_observer, | ||
train_dataloader=DataLoader(RandomClassificationDataset()), | ||
max_duration='2ba', | ||
) | ||
|
||
# trigger OOM | ||
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda') | ||
|
||
trainer.fit() | ||
|
||
assert len(file_tracker_destination.uploaded_files) == 5 | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), | ||
reason='OOM Observer requires PyTorch 2.1 or higher') | ||
def test_oom_observer_with_memory_snapshot(): | ||
# Construct the callbacks | ||
oom_observer = OOMObserver() | ||
memory_snapshot = MemorySnapshot(skip_batches=0, interval='1ba') | ||
simple_model = SimpleModel() | ||
file_tracker_destination = FileUploaderTracker() | ||
|
||
trainer = Trainer( | ||
model=simple_model, | ||
loggers=file_tracker_destination, | ||
callbacks=[oom_observer, memory_snapshot], | ||
train_dataloader=DataLoader(RandomClassificationDataset()), | ||
max_duration='2ba', | ||
) | ||
|
||
trainer.fit() | ||
assert len(file_tracker_destination.uploaded_files) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.