From 23e6a2ee46ff6e4e863f064b6d265b01807d666f Mon Sep 17 00:00:00 2001 From: janEbert Date: Thu, 28 Nov 2024 03:49:50 +0000 Subject: [PATCH] Expose `DistributedSampler` RNG seed argument (#3724) --- composer/utils/dist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 26b135217f..b5832740a6 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -592,6 +592,7 @@ def get_sampler( shuffle: bool = False, num_replicas: Optional[int] = None, rank: Optional[int] = None, + seed: int = 0, ): """Constructs a :class:`~torch.utils.data.distributed.DistributedSampler` for a dataset. @@ -620,6 +621,7 @@ def get_sampler( shuffle=shuffle, num_replicas=get_world_size() if num_replicas is None else num_replicas, rank=get_global_rank() if rank is None else rank, + seed=seed, )