Skip to content

Commit

Permalink
add pvnet output datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jul 20, 2023
1 parent 1493fc4 commit 3b801ee
Showing 1 changed file with 83 additions and 4 deletions.
87 changes: 83 additions & 4 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals
if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)
if add_filename:
file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5)
file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=50)

sample_pipeline = file_pipeline.sharding_filter().map(torch.load)

Expand All @@ -167,11 +167,13 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals
.compute()
)

sample_pipeline, dp = sample_pipeline.fork(2, buffer_size=5)
sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5)

times_datapipe, dp = GetBatchTime(dp).fork(2, buffer_size=5)
times_datapipe, times_datapipe_copy = (
GetBatchTime(sample_pipeline_copy).fork(2, buffer_size=5)
)

national_targets_datapipe = GetNationalPVLive(gsp_data, dp)
national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy)

# Compile the samples
if add_filename:
Expand All @@ -198,6 +200,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals
)

return data_pipeline


def train_dataloader(self, shuffle=True, add_filename=False):
"""Construct train dataloader"""
Expand All @@ -210,6 +213,7 @@ def train_dataloader(self, shuffle=True, add_filename=False):
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)


def val_dataloader(self, shuffle=False, add_filename=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
Expand All @@ -220,6 +224,81 @@ def val_dataloader(self, shuffle=False, add_filename=False):
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)


def test_dataloader(self):
"""Construct test dataloader"""
raise NotImplementedError


class PVNetPresavedDataModule(LightningDataModule):
"""Datamodule for loading pre-saved PVNet predictions to train pvnet_summation."""

def __init__(
self,
batch_dir: str,
batch_size=16,
num_workers=0,
prefetch_factor=2,
):
"""Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.
Args:
batch_dir: Path to the directory of pre-saved batches.
batch_size: Batch size.
num_workers: Number of workers to use in multiprocess batch loading.
prefetch_factor: Number of data will be prefetched at the end of each worker process.
"""
super().__init__()
self.batch_size = batch_size
self.batch_dir = batch_dir

self.readingservice_config = dict(
num_workers=num_workers,
multiprocessing_context="spawn",
worker_prefetch_cnt=prefetch_factor,
)

def _get_premade_batches_datapipe(self, subdir, shuffle=False):

# Load presaved concurrent sample batches
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)

sample_pipeline = file_pipeline.sharding_filter().map(torch.load)

if self.batch_size is not None:

data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size))
data_pipeline = DictApply(
data_pipeline,
pvnet_outputs=torch.stack,
national_targets=torch.stack,
times=torch.stack,
)

return data_pipeline

def train_dataloader(self, shuffle=True):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe(
"train",
shuffle=shuffle,
)

rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def val_dataloader(self, shuffle=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
)
rs = MultiProcessingReadingService(**self.readingservice_config)
return DataLoader2(datapipe, reading_service=rs)

def test_dataloader(self):
"""Construct test dataloader"""
raise NotImplementedError
Expand Down

0 comments on commit 3b801ee

Please sign in to comment.