diff --git a/test/test_data.py b/test/test_data.py index 227bf22..a5c7853 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -9,6 +9,16 @@ from torch.utils.data import DataLoader, IterableDataset +class _Unzip(IterableDataset): + def __init__(self, ds): + self.ds = ds + + def __iter__(self): + for chunk in self.ds: + arr = chunk.to_array().values[0] + yield from arr + + class TestSharding: def test_types(self): indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) @@ -94,15 +104,6 @@ def test_shuffled(self): assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1)) def test_XrShardedDataset_multiprocessing(self): - class _Unzip(IterableDataset): - def __init__(self, ds): - self.ds = ds - - def __iter__(self): - for chunk in self.ds: - arr = chunk.to_array().values[0] - yield from arr - xr_ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() # Simple case: 2 workers, world_size=1