Skip to content

Commit

Permalink
update to torch datapipes
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Nov 22, 2023
1 parent a26cbe9 commit 800667c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
54 changes: 34 additions & 20 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp
from ocf_datapipes.utils.consts import BatchKey
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import FileLister, IterDataPipe, Zipper
from torch.utils.data import DataLoader
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, Zipper

# https://github.com/pytorch/pytorch/issues/973
torch.multiprocessing.set_sharing_strategy("file_system")
Expand Down Expand Up @@ -127,7 +129,7 @@ def __init__(
gsp_zarr_path: str,
batch_size=16,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
):
"""Datamodule for training pvnet_summation.
Expand All @@ -143,10 +145,19 @@ def __init__(
self.batch_size = batch_size
self.batch_dir = batch_dir

self.readingservice_config = dict(
self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False):
Expand Down Expand Up @@ -218,17 +229,14 @@ def train_dataloader(self, shuffle=True, add_filename=False):
datapipe = self._get_premade_batches_datapipe(
"train", shuffle=shuffle, add_filename=add_filename
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False, add_filename=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val", shuffle=shuffle, add_filename=add_filename
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand All @@ -243,7 +251,7 @@ def __init__(
batch_dir: str,
batch_size=16,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
):
"""Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.
Expand All @@ -257,10 +265,19 @@ def __init__(
self.batch_size = batch_size
self.batch_dir = batch_dir

self.readingservice_config = dict(
self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
Expand Down Expand Up @@ -290,18 +307,15 @@ def train_dataloader(self, shuffle=True):
"train",
shuffle=shuffle,
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def sample_datamodule(sample_data):
gsp_zarr_path=gsp_zarr_dir,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
)

return dm
Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_init(sample_data):
gsp_zarr_path=gsp_zarr_dir,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
)


Expand All @@ -22,7 +22,7 @@ def test_iter(sample_data):
gsp_zarr_path=gsp_zarr_dir,
batch_size=2,
num_workers=0,
prefetch_factor=2,
prefetch_factor=None,
)

batch = next(iter(dm.train_dataloader()))
Expand Down

0 comments on commit 800667c

Please sign in to comment.