Skip to content

Commit

Permalink
feat(tmp): interleave_dict_batches
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Apr 5, 2024
1 parent 16de972 commit 2a3eeb0
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,43 @@ def interleave_batches(
batches = []
for i in range(num_batches):
yield memory[i]


def interleave_dict_batches(
iterable: Iterable[torch.Tensor], num_batches: int, pin_memory: bool = False
) -> Iterable[torch.Tensor]:
"""
Interleaves batches from an iterable of batches.
Important: Returned batches must be used immediately or copied to avoid overwriting.
"""
if num_batches < 1:
raise ValueError('num_batches must be greater than 0')

if num_batches == 1:
yield from iterable

batches = []
memory = {}
slice_size = {}
for batch in iterable:
if not memory:
for k, tensor in batch.items():
batch_size = tensor.shape[0]
if batch_size % num_batches != 0:
raise ValueError(f'Batch dimension ({batch_size}) must be divisible by num_batches={num_batches}')
slice_size[k] = batch_size // num_batches
memory[k] = torch.empty(
(num_batches, *tensor.shape), dtype=tensor.dtype, device=tensor.device, pin_memory=pin_memory
)

batches.append(batch)

if len(batches) == num_batches:
for k in memory:
for i in range(num_batches):
for j in range(num_batches):
source = batches[j][k][i * slice_size[k] : (i + 1) * slice_size[k]]
memory[k][i, j * slice_size[k] : (j + 1) * slice_size[k]] = source
batches = []
for i in range(num_batches):
yield {k: memory[k][i] for k in memory.keys()}

0 comments on commit 2a3eeb0

Please sign in to comment.