From 574e9043ff517933d46aa866ffc227ba38c08fdc Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 2 Apr 2024 15:00:34 +0200 Subject: [PATCH] fix: test regression --- test/test_data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_data.py b/test/test_data.py index 8a5d778..7ac9aad 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -152,9 +152,10 @@ def test_shuffled(self): world_size = 3 chunk_size = 15 - chunks_1 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=True, seed=0)) - chunks_2 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=True, seed=0)) - chunks_3 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=True, seed=0)) + shard = partial(sharded_xr_dataset, ds, 'x', chunk_size, world_size=world_size, shuffle=True, seed=0) + chunks_1 = list(shard(rank=0)) + chunks_2 = list(shard(rank=1)) + chunks_3 = list(shard(rank=2)) assert len(chunks_1) == 2 assert len(chunks_2) == 2