Skip to content

Commit

Permalink
Remove datapipe tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukhil Patel authored and Sukhil Patel committed Dec 19, 2024
1 parent 0c5a97b commit a75ee76
Showing 1 changed file with 2 additions and 57 deletions.
59 changes: 2 additions & 57 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import pytest
from pvnet.data.datamodule import DataModule
from pvnet.data.wind_datamodule import WindDataModule
from pvnet.data.pv_site_datamodule import PVSiteDataModule
import os
from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey


def test_init():
Expand All @@ -17,58 +12,6 @@ def test_init():
val_period=[None, None],
)


@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
def test_wind_init():
dm = WindDataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/data/sample_batches",
)


@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
def test_wind_init_with_nwp_filter():
dm = WindDataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/test_data/sample_wind_batches",
nwp_channels={"ecmwf": ["t2m", "v200"]},
)
dataloader = iter(dm.train_dataloader())

batch = next(dataloader)
batch_channels = batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp_channel_names]
print(batch_channels)
for v in ["t2m", "v200"]:
assert v in batch_channels
assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2


@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
def test_pv_site_init():
dm = PVSiteDataModule(
configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml",
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir=None,
)


def test_iter():
dm = DataModule(
configuration=None,
Expand Down Expand Up @@ -104,3 +47,5 @@ def test_iter_multiprocessing():

# Make sure we've served 2 batches
assert served_batches == 2

# TODO add test cases with some netcdfs premade samples

0 comments on commit a75ee76

Please sign in to comment.