Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Feb 22, 2024
1 parent 3a263e9 commit 1192ec5
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/callbacks/test_memory_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from composer.callbacks import MemoryMonitor
from composer.loggers import InMemoryLogger
from composer.trainer import Trainer
from composer.utils import dist, get_device
from tests.common import RandomClassificationDataset, SimpleModel


Expand Down Expand Up @@ -38,3 +41,37 @@ def test_memory_monitor_gpu():
num_memory_monitor_calls = len(in_memory_logger.data['memory/peak_allocated_mem'])

assert num_memory_monitor_calls == int(trainer.state.timestamp.batch)


@pytest.mark.gpu
@pytest.mark.world_size(2)
def test_dist_memory_monitor_gpu():
dist.initialize_dist(get_device(None))

# Construct the trainer
memory_monitor = MemoryMonitor(dist_aggregate_batch_interval=1)
in_memory_logger = InMemoryLogger()

# Add extra memory useage to rank 1
numel = 1 << 30 # about 1B elements in 32 bits is about 4GB
expected_extra_mem_usage_gb = 4 * numel / 1e9
if dist.get_local_rank() == 1:
tmp_tensor = torch.randn(numel, device='cuda')

dataset = RandomClassificationDataset()
trainer = Trainer(
model=SimpleModel(),
callbacks=memory_monitor,
loggers=in_memory_logger,
train_dataloader=DataLoader(dataset=dataset, sampler=DistributedSampler(dataset=dataset)),
max_duration='2ba',
)
trainer.fit()

peak_allocated_mem = in_memory_logger.data['memory/peak_allocated_mem'][-1][-1]
peak_allocated_mem_max = in_memory_logger.data['memory/peak_allocated_mem_max'][-1][-1]

gb_buffer = 0.5
extra_mem_gb = expected_extra_mem_usage_gb - gb_buffer
if dist.get_local_rank() == 0:
assert peak_allocated_mem_max - extra_mem_gb >= peak_allocated_mem

0 comments on commit 1192ec5

Please sign in to comment.