diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index af33587..ac8b2c2 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -1,6 +1,7 @@ from typing import Iterable import numpy as np +import torch import torch.distributed as dist import xarray as xr from torch.utils.data import get_worker_info, IterableDataset @@ -104,3 +105,36 @@ def __iter__(self): world_size, self.load, ) + + +def interleave_batches( + iterable: Iterable[torch.Tensor], num_batches: int, pin_memory: bool = False +) -> Iterable[torch.Tensor]: + """ + Interleaves batches from an iterable of batches. + Important: Returned batches must be used immediately or copied to avoid overwriting. + """ + + batches = [] + memory = None + batch_size = None + slice_size = None + for batch in iterable: + if memory is None: + batch_size = batch.shape[0] + slice_size = batch_size // num_batches + if batch_size % num_batches != 0: + raise ValueError(f'Batch dimension ({batch_size}) must be divisible by num_batches={num_batches}') + memory = torch.empty( + (num_batches, *batch.shape), dtype=batch.dtype, device=batch.device, pin_memory=pin_memory + ) + + batches.append(batch) + + if len(batches) == num_batches: + for i in range(num_batches): + for j in range(num_batches): + memory[i, j * slice_size : (j + 1) * slice_size] = batches[j][i * slice_size : (i + 1) * slice_size] + batches = [] + for i in range(num_batches): + yield memory[i] diff --git a/test/test_data.py b/test/test_data.py index d94fedd..227bf22 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -2,8 +2,9 @@ import numpy as np import pytest +import torch import xarray as xr -from dmlcloud.util.data import shard_indices, sharded_xr_dataset, ShardedXrDataset +from dmlcloud.util.data import interleave_batches, shard_indices, sharded_xr_dataset, ShardedXrDataset from numpy.testing import assert_array_equal from torch.utils.data import DataLoader, IterableDataset @@ -324,5 +325,21 @@ def test_XrShardedDataset_length(self): assert len(torch_ds4) == 1 +class TestInterleaveBatches: + def test_basic(self): + batches = [ + torch.arange(0, 8), + torch.arange(8, 16), + torch.arange(16, 24), + torch.arange(24, 32), + ] + interleaved_batches = list(t.clone() for t in interleave_batches(batches, num_batches=2)) + assert len(interleaved_batches) == 4 + assert {t.item() for t in interleaved_batches[0]} == {0, 1, 2, 3, 8, 9, 10, 11} + assert {t.item() for t in interleaved_batches[1]} == {4, 5, 6, 7, 12, 13, 14, 15} + assert {t.item() for t in interleaved_batches[2]} == {16, 17, 18, 19, 24, 25, 26, 27} + assert {t.item() for t in interleaved_batches[3]} == {20, 21, 22, 23, 28, 29, 30, 31} + + if __name__ == '__main__': sys.exit(pytest.main([__file__]))