Skip to content

Commit

Permalink
Merge branch 'pytorch_datapipes' of https://github.com/openclimatefix…
Browse files Browse the repository at this point in the history
…/PVNet into pytorch_datapipes
  • Loading branch information
dfulu committed Nov 21, 2023
2 parents c7585c7 + f8ba1f7 commit 3839fe8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 30 deletions.
19 changes: 8 additions & 11 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch.utils.data.datapipes._decorator import functional_datapipe

from lightning.pytorch import LightningDataModule

from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
Expand Down Expand Up @@ -121,20 +120,20 @@ def __init__(
self.test_period = [
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period
]

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
batch_sampler=None,
num_workers=num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False
persistent_workers=False,
)

def _get_datapipe(self, start_time, end_time):
Expand Down Expand Up @@ -195,12 +194,10 @@ def val_dataloader(self):
datapipe = self._get_datapipe(*self.val_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)


def test_dataloader(self):
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")
else:
datapipe = self._get_datapipe(*self.test_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

23 changes: 11 additions & 12 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader
def main(config: DictConfig):
"Constructs and saves validation and training batches."
config_dm = config.datamodule

print_config(config, resolve=False)

# Set up directory
Expand All @@ -100,20 +100,20 @@ def main(config: DictConfig):

dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=config_dm.num_workers,
batch_sampler=None,
num_workers=config_dm.num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=config_dm.prefetch_factor,
persistent_workers=False
persistent_workers=False,
)
if config.num_val_batches>0:

if config.num_val_batches > 0:
print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
Expand All @@ -128,9 +128,8 @@ def main(config: DictConfig):
num_batches=config.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
)

if config.num_train_batches>0:

if config.num_train_batches > 0:
print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
Expand All @@ -150,4 +149,4 @@ def main(config: DictConfig):


if __name__ == "__main__":
main()
main()
13 changes: 6 additions & 7 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,19 @@ def main(config: DictConfig):
os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config.batch_output_dir}/val")


dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=config_dm.num_workers,
batch_sampler=None,
num_workers=config_dm.num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=config_dm.prefetch_factor,
persistent_workers=False
persistent_workers=False,
)

print("----- Saving val batches -----")
Expand Down

0 comments on commit 3839fe8

Please sign in to comment.