Skip to content

Commit

Permalink
refactor memroy snapshot (mosaicml#2960)
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 authored Feb 3, 2024
1 parent 8b9e18e commit 7b9c42e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
6 changes: 3 additions & 3 deletions composer/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def __init__(
else:
self.remote_path_in_bucket = None

if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore
# memory snapshot is only supported in torch v2.1.0-rc1 or higher
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore
# MemorySnapshot is only supported in torch v2.1.0-rc1 or higher
self._enabled = True
else:
self._enabled = False
log.warning('Memory snapshot is supported after PyTorch 2.1.0. Skipping memory snapshot callback.')
warnings.warn('Memory snapshot is supported after PyTorch 2.1.0. Skipping memory snapshot callback.')

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
Expand Down
21 changes: 7 additions & 14 deletions tests/callbacks/test_memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,13 @@
from composer.callbacks import MemorySnapshot
from composer.loggers import LoggerDestination
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel, device
from tests.common import RandomClassificationDataset, SimpleModel


@device('cpu', 'gpu')
def test_memory_snapshot_warnings_on_cpu_models(device: str):
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
# memory snapshot is supported after PyTorch 2.1.0.
return
# Error if the user sets device=cpu even when cuda is available
del device # unused. always using cpu
with pytest.warns(UserWarning, match='The memory snapshot only works on CUDA devices'):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_memory_snapshot_warnings_on_cpu_models():
with pytest.warns(UserWarning):
Trainer(
model=SimpleModel(),
callbacks=MemorySnapshot(),
Expand All @@ -44,16 +40,13 @@ def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Pa

@pytest.mark.gpu
@pytest.mark.parametrize('interval', ['1ba'])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_memory_snapshot(interval: str):
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
# memory snapshot is supported after PyTorch 2.1.0.
return
# Construct the callbacks
skip_batches = 0
memory_snapshot = MemorySnapshot(skip_batches=skip_batches, interval=interval)

simple_model = SimpleModel()

file_tracker_destination = FileUploaderTracker()

# Construct the trainer and train
Expand Down

0 comments on commit 7b9c42e

Please sign in to comment.