Skip to content

Commit

Permalink
Remove old datapipe tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukhil Patel authored and Sukhil Patel committed Dec 19, 2024
1 parent 8e7e0d0 commit ff5bbf0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
31 changes: 15 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import pvnet
from pvnet.data.datamodule import DataModule
from pvnet.data.wind_datamodule import WindDataModule

import pvnet.models.multimodal.encoders.encoders3d
import pvnet.models.multimodal.linear_networks.networks
Expand Down Expand Up @@ -158,21 +157,21 @@ def sample_pv_batch():
# old batches. For now we use the old batches to test the site encoder models
return torch.load("tests/test_data/presaved_batches/train/000000.pt")


@pytest.fixture()
def sample_wind_batch():
dm = WindDataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/test_data/sample_wind_batches",
)
batch = next(iter(dm.train_dataloader()))
return batch
# TODO update this test once we add the loading logic for the Site dataset
# @pytest.fixture()
# def sample_wind_batch():
# dm = WindDataModule(
# configuration=None,
# batch_size=2,
# num_workers=0,
# prefetch_factor=None,
# train_period=[None, None],
# val_period=[None, None],
# test_period=[None, None],
# batch_dir="tests/test_data/sample_wind_batches",
# )
# batch = next(iter(dm.train_dataloader()))
# return batch


@pytest.fixture()
Expand Down
16 changes: 8 additions & 8 deletions tests/models/multimodal/site_encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwar
batch_size=8,
)


def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
_test_model_forward(
sample_wind_batch,
SingleAttentionNetwork,
site_encoder_sensor_model_kwargs,
batch_size=2,
)
# TODO once we have updated the sample batches for sites include this test
# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
# _test_model_forward(
# sample_wind_batch,
# SingleAttentionNetwork,
# site_encoder_sensor_model_kwargs,
# batch_size=2,
# )


# Test model backward on all models
Expand Down

0 comments on commit ff5bbf0

Please sign in to comment.