diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index c8a92a3..799a3e2 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -5,8 +5,9 @@ from ocf_datapipes.load import OpenGSP from ocf_datapipes.training.pvnet import normalize_gsp from ocf_datapipes.utils.consts import BatchKey -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import FileLister, IterDataPipe, Zipper +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.iter import FileLister, Zipper # https://github.com/pytorch/pytorch/issues/973 torch.multiprocessing.set_sharing_strategy("file_system") @@ -127,7 +128,7 @@ def __init__( gsp_zarr_path: str, batch_size=16, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, ): """Datamodule for training pvnet_summation. @@ -143,10 +144,19 @@ def __init__( self.batch_size = batch_size self.batch_dir = batch_dir - self.readingservice_config = dict( + self._common_dataloader_kwargs = dict( + shuffle=False, # shuffled in datapipe step + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, num_workers=num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=prefetch_factor, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=prefetch_factor, + persistent_workers=False, ) def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False): @@ -218,17 +228,14 @@ def train_dataloader(self, shuffle=True, add_filename=False): datapipe = self._get_premade_batches_datapipe( "train", shuffle=shuffle, add_filename=add_filename ) - - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) def val_dataloader(self, shuffle=False, add_filename=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( "val", shuffle=shuffle, add_filename=add_filename ) - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader""" @@ -243,7 +250,7 @@ def __init__( batch_dir: str, batch_size=16, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, ): """Datamodule for loading pre-saved PVNet predictions to train pvnet_summation. @@ -257,10 +264,19 @@ def __init__( self.batch_size = batch_size self.batch_dir = batch_dir - self.readingservice_config = dict( + self._common_dataloader_kwargs = dict( + shuffle=False, # shuffled in datapipe step + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, num_workers=num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=prefetch_factor, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=prefetch_factor, + persistent_workers=False, ) def _get_premade_batches_datapipe(self, subdir, shuffle=False): @@ -290,9 +306,7 @@ def train_dataloader(self, shuffle=True): "train", shuffle=shuffle, ) - - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) def val_dataloader(self, shuffle=False): """Construct val dataloader""" @@ -300,8 +314,7 @@ def val_dataloader(self, shuffle=False): "val", shuffle=shuffle, ) - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) + return DataLoader(datapipe, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader""" diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index c467fcd..dc09875 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -267,4 +267,4 @@ def configure_optimizers(self): if self.lr is not None: # Use learning rate found by learning rate finder callback self._optimizer.lr = self.lr - return self._optimizer(self.parameters()) + return self._optimizer(self) diff --git a/requirements.txt b/requirements.txt index 65dc78b..3b2fad3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ nowcasting_utils ocf_datapipes>=1.2.44 -pvnet>=2.1.16 +pvnet>=2.4.0 ocf_ml_metrics numpy pandas @@ -8,7 +8,7 @@ matplotlib xarray ipykernel h5netcdf -torch>=2.0 +torch>=2.1.1 lightning>=2.0.1 torchdata pytest diff --git a/tests/conftest.py b/tests/conftest.py index a3343e7..f22a0f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -92,7 +92,7 @@ def sample_datamodule(sample_data): gsp_zarr_path=gsp_zarr_dir, batch_size=2, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, ) return dm @@ -106,9 +106,10 @@ def sample_batch(sample_datamodule): @pytest.fixture() def model_kwargs(): + # These kwargs define the pvnet model which the summation model uses kwargs = dict( model_name="openclimatefix/pvnet_v2", - model_version="898630f3f8cd4e8506525d813dd61c6d8de86144", + model_version="805ca9b2ee3120592b0b70b7c75a454e2b4e4bec", ) return kwargs diff --git a/tests/data/sample_batches/train/000000.pt b/tests/data/sample_batches/train/000000.pt index e904fc8..6e99981 100644 Binary files a/tests/data/sample_batches/train/000000.pt and b/tests/data/sample_batches/train/000000.pt differ diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 9aa2c27..540133d 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -10,7 +10,7 @@ def test_init(sample_data): gsp_zarr_path=gsp_zarr_dir, batch_size=2, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, ) @@ -22,7 +22,7 @@ def test_iter(sample_data): gsp_zarr_path=gsp_zarr_dir, batch_size=2, num_workers=0, - prefetch_factor=2, + prefetch_factor=None, ) batch = next(iter(dm.train_dataloader()))