Skip to content

Commit

Permalink
feat: interleave_batches
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 20, 2024
1 parent ffc1a5c commit e124ee8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
34 changes: 34 additions & 0 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Iterable

import numpy as np
import torch
import torch.distributed as dist
import xarray as xr
from torch.utils.data import get_worker_info, IterableDataset
Expand Down Expand Up @@ -104,3 +105,36 @@ def __iter__(self):
world_size,
self.load,
)


def interleave_batches(
iterable: Iterable[torch.Tensor], num_batches: int, pin_memory: bool = False
) -> Iterable[torch.Tensor]:
"""
Interleaves batches from an iterable of batches.
Important: Returned batches must be used immediately or copied to avoid overwriting.
"""

batches = []
memory = None
batch_size = None
slice_size = None
for batch in iterable:
if memory is None:
batch_size = batch.shape[0]
slice_size = batch_size // num_batches
if batch_size % num_batches != 0:
raise ValueError(f'Batch dimension ({batch_size}) must be divisible by num_batches={num_batches}')
memory = torch.empty(
(num_batches, *batch.shape), dtype=batch.dtype, device=batch.device, pin_memory=pin_memory
)

batches.append(batch)

if len(batches) == num_batches:
for i in range(num_batches):
for j in range(num_batches):
memory[i, j * slice_size : (j + 1) * slice_size] = batches[j][i * slice_size : (i + 1) * slice_size]
batches = []
for i in range(num_batches):
yield memory[i]
19 changes: 18 additions & 1 deletion test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np
import pytest
import torch
import xarray as xr
from dmlcloud.util.data import shard_indices, sharded_xr_dataset, ShardedXrDataset
from dmlcloud.util.data import interleave_batches, shard_indices, sharded_xr_dataset, ShardedXrDataset
from numpy.testing import assert_array_equal
from torch.utils.data import DataLoader, IterableDataset

Expand Down Expand Up @@ -324,5 +325,21 @@ def test_XrShardedDataset_length(self):
assert len(torch_ds4) == 1


class TestInterleaveBatches:
def test_basic(self):
batches = [
torch.arange(0, 8),
torch.arange(8, 16),
torch.arange(16, 24),
torch.arange(24, 32),
]
interleaved_batches = list(t.clone() for t in interleave_batches(batches, num_batches=2))
assert len(interleaved_batches) == 4
assert {t.item() for t in interleaved_batches[0]} == {0, 1, 2, 3, 8, 9, 10, 11}
assert {t.item() for t in interleaved_batches[1]} == {4, 5, 6, 7, 12, 13, 14, 15}
assert {t.item() for t in interleaved_batches[2]} == {16, 17, 18, 19, 24, 25, 26, 27}
assert {t.item() for t in interleaved_batches[3]} == {20, 21, 22, 23, 28, 29, 30, 31}


if __name__ == '__main__':
sys.exit(pytest.main([__file__]))

0 comments on commit e124ee8

Please sign in to comment.