Skip to content

Commit

Permalink
Merge pull request #11 from openclimatefix/pytorch_datapipes
Browse files Browse the repository at this point in the history
Remove torchdata
  • Loading branch information
dfulu authored Nov 22, 2023
2 parents 9158f34 + a6f2048 commit 2db8f05
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 27 deletions.
53 changes: 33 additions & 20 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp
from ocf_datapipes.utils.consts import BatchKey
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import FileLister, IterDataPipe, Zipper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, Zipper

# https://github.com/pytorch/pytorch/issues/973
torch.multiprocessing.set_sharing_strategy("file_system")
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
gsp_zarr_path: str,
batch_size=16,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
):
"""Datamodule for training pvnet_summation.
Expand All @@ -143,10 +144,19 @@ def __init__(
self.batch_size = batch_size
self.batch_dir = batch_dir

self.readingservice_config = dict(
self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False):
Expand Down Expand Up @@ -218,17 +228,14 @@ def train_dataloader(self, shuffle=True, add_filename=False):
datapipe = self._get_premade_batches_datapipe(
"train", shuffle=shuffle, add_filename=add_filename
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False, add_filename=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val", shuffle=shuffle, add_filename=add_filename
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand All @@ -243,7 +250,7 @@ def __init__(
batch_dir: str,
batch_size=16,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
):
"""Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.
Expand All @@ -257,10 +264,19 @@ def __init__(
self.batch_size = batch_size
self.batch_dir = batch_dir

self.readingservice_config = dict(
self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
Expand Down Expand Up @@ -290,18 +306,15 @@ def train_dataloader(self, shuffle=True):
"train",
shuffle=shuffle,
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand Down
2 changes: 1 addition & 1 deletion pvnet_summation/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,4 @@ def configure_optimizers(self):
if self.lr is not None:
# Use learning rate found by learning rate finder callback
self._optimizer.lr = self.lr
return self._optimizer(self.parameters())
return self._optimizer(self)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
nowcasting_utils
ocf_datapipes>=1.2.44
pvnet>=2.1.16
pvnet>=2.4.0
ocf_ml_metrics
numpy
pandas
matplotlib
xarray
ipykernel
h5netcdf
torch>=2.0
torch>=2.1.1
lightning>=2.0.1
torchdata
pytest
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def sample_datamodule(sample_data):
gsp_zarr_path=gsp_zarr_dir,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
)

return dm
Expand All @@ -106,9 +106,10 @@ def sample_batch(sample_datamodule):

@pytest.fixture()
def model_kwargs():
# These kwargs define the pvnet model which the summation model uses
kwargs = dict(
model_name="openclimatefix/pvnet_v2",
model_version="898630f3f8cd4e8506525d813dd61c6d8de86144",
model_version="805ca9b2ee3120592b0b70b7c75a454e2b4e4bec",
)
return kwargs

Expand Down
Binary file modified tests/data/sample_batches/train/000000.pt
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_init(sample_data):
gsp_zarr_path=gsp_zarr_dir,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
)


Expand All @@ -22,7 +22,7 @@ def test_iter(sample_data):
gsp_zarr_path=gsp_zarr_dir,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
)

batch = next(iter(dm.train_dataloader()))
Expand Down

0 comments on commit 2db8f05

Please sign in to comment.