Skip to content

Commit

Permalink
feat: shard_sequence() and ShardedSequenceDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed May 10, 2024
1 parent 85e8443 commit 3203f25
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion dmlcloud/util/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from typing import Iterable, Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -54,6 +54,18 @@ def chunk_and_shard_indices(
return chunks


def shard_sequence(
sequence: Sequence,
rank: int,
world_size: int,
shuffle: bool = False,
even_shards: bool = True,
seed: int = 0,
):
indices = shard_indices(len(sequence), rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed)
return [sequence[i] for i in indices]


def sharded_xr_dataset(
ds: xr.Dataset | xr.DataArray,
dim: str,
Expand Down Expand Up @@ -94,6 +106,40 @@ def sharded_xr_dataset(
yield chunk


class ShardedSequenceDataset(IterableDataset):

def __init__(
self,
sequence: Sequence,
shuffle: bool = False,
even_shards: bool = True,
seed: int = 0,
rank: int | None = None,
world_size: int | None = None,
):
self.sequence = sequence
self.shuffle = shuffle
self.even_shards = even_shards
self.seed = seed
self.rank = rank if rank is not None else dist.get_rank()
self.world_size = world_size if world_size is not None else dist.get_world_size()
self.epoch = 0

def set_epoch(self, epoch: int):
self.epoch = epoch

def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
rank = self.rank
world_size = self.world_size
else:
rank = self.rank * worker_info.num_workers + worker_info.id
world_size = self.world_size * worker_info.num_workers
shards = shard_sequence(self.sequence, rank, world_size, shuffle=self.shuffle, even_shards=self.even_shards, seed=self.seed + self.epoch)
return iter(shards)


class ShardedXrDataset(IterableDataset):
def __init__(
self,
Expand Down

0 comments on commit 3203f25

Please sign in to comment.