diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index ec5eabf..64ba4e8 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -153,7 +153,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) if add_filename: - file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5) + file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=50) sample_pipeline = file_pipeline.sharding_filter().map(torch.load) @@ -167,11 +167,13 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals .compute() ) - sample_pipeline, dp = sample_pipeline.fork(2, buffer_size=5) + sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5) - times_datapipe, dp = GetBatchTime(dp).fork(2, buffer_size=5) + times_datapipe, times_datapipe_copy = ( + GetBatchTime(sample_pipeline_copy).fork(2, buffer_size=5) + ) - national_targets_datapipe = GetNationalPVLive(gsp_data, dp) + national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy) # Compile the samples if add_filename: @@ -198,6 +200,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals ) return data_pipeline + def train_dataloader(self, shuffle=True, add_filename=False): """Construct train dataloader""" @@ -210,6 +213,7 @@ def train_dataloader(self, shuffle=True, add_filename=False): rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) + def val_dataloader(self, shuffle=False, add_filename=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( @@ -220,6 +224,81 @@ def val_dataloader(self, shuffle=False, add_filename=False): rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) + + def test_dataloader(self): + """Construct test dataloader""" + raise NotImplementedError + + +class PVNetPresavedDataModule(LightningDataModule): + """Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.""" + + def __init__( + self, + batch_dir: str, + batch_size=16, + num_workers=0, + prefetch_factor=2, + ): + """Datamodule for loading pre-saved PVNet predictions to train pvnet_summation. + + Args: + batch_dir: Path to the directory of pre-saved batches. + batch_size: Batch size. + num_workers: Number of workers to use in multiprocess batch loading. + prefetch_factor: Number of data will be prefetched at the end of each worker process. + """ + super().__init__() + self.batch_size = batch_size + self.batch_dir = batch_dir + + self.readingservice_config = dict( + num_workers=num_workers, + multiprocessing_context="spawn", + worker_prefetch_cnt=prefetch_factor, + ) + + def _get_premade_batches_datapipe(self, subdir, shuffle=False): + + # Load presaved concurrent sample batches + file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) + + if shuffle: + file_pipeline = file_pipeline.shuffle(buffer_size=1000) + + sample_pipeline = file_pipeline.sharding_filter().map(torch.load) + + if self.batch_size is not None: + + data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size)) + data_pipeline = DictApply( + data_pipeline, + pvnet_outputs=torch.stack, + national_targets=torch.stack, + times=torch.stack, + ) + + return data_pipeline + + def train_dataloader(self, shuffle=True): + """Construct train dataloader""" + datapipe = self._get_premade_batches_datapipe( + "train", + shuffle=shuffle, + ) + + rs = MultiProcessingReadingService(**self.readingservice_config) + return DataLoader2(datapipe, reading_service=rs) + + def val_dataloader(self, shuffle=False): + """Construct val dataloader""" + datapipe = self._get_premade_batches_datapipe( + "val", + shuffle=shuffle, + ) + rs = MultiProcessingReadingService(**self.readingservice_config) + return DataLoader2(datapipe, reading_service=rs) + def test_dataloader(self): """Construct test dataloader""" raise NotImplementedError