diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 26b135217f..b5832740a6 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -592,6 +592,7 @@ def get_sampler( shuffle: bool = False, num_replicas: Optional[int] = None, rank: Optional[int] = None, + seed: int = 0, ): """Constructs a :class:`~torch.utils.data.distributed.DistributedSampler` for a dataset. @@ -620,6 +621,7 @@ def get_sampler( shuffle=shuffle, num_replicas=get_world_size() if num_replicas is None else num_replicas, rank=get_global_rank() if rank is None else rank, + seed=seed, )