Skip to content

Commit

Permalink
fix: data tests on macos
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 20, 2024
1 parent e124ee8 commit 8784045
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from torch.utils.data import DataLoader, IterableDataset


class _Unzip(IterableDataset):
def __init__(self, ds):
self.ds = ds

def __iter__(self):
for chunk in self.ds:
arr = chunk.to_array().values[0]
yield from arr


class TestSharding:
def test_types(self):
indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False)
Expand Down Expand Up @@ -94,15 +104,6 @@ def test_shuffled(self):
assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1))

def test_XrShardedDataset_multiprocessing(self):
class _Unzip(IterableDataset):
def __init__(self, ds):
self.ds = ds

def __iter__(self):
for chunk in self.ds:
arr = chunk.to_array().values[0]
yield from arr

xr_ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()

# Simple case: 2 workers, world_size=1
Expand Down

0 comments on commit 8784045

Please sign in to comment.