From 8784045af71d4e6c8f2907faf0c4601b4d492aa9 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Wed, 20 Mar 2024 15:36:06 +0100 Subject: [PATCH] fix: data tests on macos --- test/test_data.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/test/test_data.py b/test/test_data.py index 227bf22..a5c7853 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -9,6 +9,16 @@ from torch.utils.data import DataLoader, IterableDataset +class _Unzip(IterableDataset): + def __init__(self, ds): + self.ds = ds + + def __iter__(self): + for chunk in self.ds: + arr = chunk.to_array().values[0] + yield from arr + + class TestSharding: def test_types(self): indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) @@ -94,15 +104,6 @@ def test_shuffled(self): assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1)) def test_XrShardedDataset_multiprocessing(self): - class _Unzip(IterableDataset): - def __init__(self, ds): - self.ds = ds - - def __iter__(self): - for chunk in self.ds: - arr = chunk.to_array().values[0] - yield from arr - xr_ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() # Simple case: 2 workers, world_size=1