From 8a42d1b89b70d6030bf9c1c36bf37f2972af78b7 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 20 Jul 2023 15:28:25 +0000 Subject: [PATCH 1/7] refactor datapipe --- pvnet_summation/data/datamodule.py | 180 ++++++++++++++++++++++------- 1 file changed, 137 insertions(+), 43 deletions(-) diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index d4f3a8b..ec5eabf 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -7,6 +7,7 @@ from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.load import OpenGSP from ocf_datapipes.training.pvnet import normalize_gsp +from torchdata.datapipes.iter import Zipper from pvnet.data.datamodule import ( copy_batch_to_device, @@ -18,44 +19,100 @@ class GetNationalPVLive(IterDataPipe): - def __init__(self, gsp_data, sample_datapipe, return_times=False): + """Select national output targets for given times""" + def __init__(self, gsp_data, times_datapipe): + """Select national output targets for given times + + Args: + gsp_data: xarray Dataarray of the national outputs + times_datapipe: IterDataPipe yeilding arrays of target times. + """ self.gsp_data = gsp_data - self.sample_datapipe = sample_datapipe - self.return_times = return_times + self.times_datapipe = times_datapipe def __iter__(self): + gsp_data = self.gsp_data - for sample in self.sample_datapipe: - # Times for each GSP in the sample batch should be the same - take first - id0 = sample[BatchKey.gsp_t0_idx] - times = sample[BatchKey.gsp_time_utc][0, id0+1:] + for times in self.times_datapipe: national_outputs = torch.as_tensor( gsp_data.sel(time_utc=times.cpu().numpy().astype("datetime64[s]")).values ) - - if self.return_times: - yield national_outputs, times - else: - yield national_outputs + yield national_outputs + +class GetBatchTime(IterDataPipe): + """Extract the valid times from the concurrent sample batch""" + + def __init__(self, sample_datapipe): + """Extract the valid times from the concurrent sample batch + + Args: + sample_datapipe: IterDataPipe yeilding concurrent sample batches + """ + self.sample_datapipe = sample_datapipe + + def __iter__(self): + for sample in self.sample_datapipe: + # Times for each GSP in the sample batch should be the same - take first + id0 = sample[BatchKey.gsp_t0_idx] + times = sample[BatchKey.gsp_time_utc][0, id0+1:] + yield times -class ReorganiseBatch(IterDataPipe): - """Reoragnise batches for pvnet_summation""" + +class PivotDictList(IterDataPipe): + """Convert list of dicts to dict of lists""" + def __init__(self, source_datapipe): - """Reoragnise batches for pvnet_summation + """Convert list of dicts to dict of lists Args: - source_datapipe: Zipped datapipe of list[tuple(NumpyBatch, national_outputs)] + source_datapipe: """ self.source_datapipe = source_datapipe def __iter__(self): - for batch in self.source_datapipe: - yield dict( - pvnet_inputs = [sample[0] for sample in batch], - national_targets = torch.stack([sample[1] for sample in batch]), - times = torch.stack([sample[2] for sample in batch]), - ) + for list_of_dicts in self.source_datapipe: + keys = list_of_dicts[0].keys() + batch_dict = {k: [d[k] for d in list_of_dicts] for k in keys} + yield batch_dict + + +class DictApply(IterDataPipe): + """Apply functions to elements of a dictionary and return processed dictionary.""" + + def __init__(self, source_datapipe, **transforms): + """Apply functions to elements of a dictionary and return processed dictionary. + + Args: + source_datapipe: Datapipe which yields dicts + **transforms: key-function pairs + """ + self.source_datapipe = source_datapipe + self.transforms = transforms + + def __iter__(self): + for d in self.source_datapipe: + for key, function in self.transforms.items(): + d[key] = function(d[key]) + yield d + + +class ZipperDict(IterDataPipe): + """Yield samples from multiple datapipes as a dict""" + + def __init__(self, **datapipes): + """Yield samples from multiple datapipes as a dict. + + Args: + **datapipes: Named datapipes + """ + self.keys = list(datapipes.keys()) + self.source_datapipes = Zipper(*[datapipes[key] for key in self.keys]) + + def __iter__(self): + for outputs in self.source_datapipes: + yield {key: value for key, value in zip(self.keys, outputs)} + class DataModule(LightningDataModule): """Datamodule for training pvnet_summation.""" @@ -88,45 +145,82 @@ def __init__( worker_prefetch_cnt=prefetch_factor, ) - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - data_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) - if shuffle: - data_pipeline = data_pipeline.shuffle(buffer_size=1000) + def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False): - data_pipeline = data_pipeline.sharding_filter().map(torch.load) + # Load presaved concurrent sample batches + file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) - # Add the national target - data_pipeline, dp = data_pipeline.fork(2, buffer_size=5) + if shuffle: + file_pipeline = file_pipeline.shuffle(buffer_size=1000) + if add_filename: + file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5) - gsp_datapipe = OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp) - gsp_data = next(iter(gsp_datapipe)).sel(gsp_id=0).compute() + sample_pipeline = file_pipeline.sharding_filter().map(torch.load) - national_targets_datapipe, times_datapipe = ( - GetNationalPVLive(gsp_data, dp, return_times=True).unzip(sequence_length=2) + # Find national outout simultaneous to concurrent samples + gsp_data = ( + next(iter( + OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path) + .map(normalize_gsp) + )) + .sel(gsp_id=0) + .compute() ) - data_pipeline = data_pipeline.zip(national_targets_datapipe, times_datapipe) - data_pipeline = ReorganiseBatch(data_pipeline.batch(self.batch_size)) + sample_pipeline, dp = sample_pipeline.fork(2, buffer_size=5) + + times_datapipe, dp = GetBatchTime(dp).fork(2, buffer_size=5) + + national_targets_datapipe = GetNationalPVLive(gsp_data, dp) + + # Compile the samples + if add_filename: + data_pipeline = ZipperDict( + pvnet_inputs = sample_pipeline, + national_targets = national_targets_datapipe, + times = times_datapipe, + filepath = file_pipeline_copy, + ) + else: + data_pipeline = ZipperDict( + pvnet_inputs = sample_pipeline, + national_targets = national_targets_datapipe, + times = times_datapipe, + ) + + if self.batch_size is not None: + + data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size)) + data_pipeline = DictApply( + data_pipeline, + national_targets=torch.stack, + times=torch.stack, + ) return data_pipeline - def train_dataloader(self): + def train_dataloader(self, shuffle=True, add_filename=False): """Construct train dataloader""" - datapipe = self._get_premade_batches_datapipe("train", shuffle=True) + datapipe = self._get_premade_batches_datapipe( + "train", + shuffle=shuffle, + add_filename=add_filename + ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) - def val_dataloader(self): + def val_dataloader(self, shuffle=False, add_filename=False): """Construct val dataloader""" - datapipe = self._get_premade_batches_datapipe("val") - + datapipe = self._get_premade_batches_datapipe( + "val", + shuffle=shuffle, + add_filename=add_filename + ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) def test_dataloader(self): """Construct test dataloader""" - datapipe = self._get_premade_batches_datapipe("test") + raise NotImplementedError - rs = MultiProcessingReadingService(**self.readingservice_config) - return DataLoader2(datapipe, reading_service=rs) From 1493fc4519df6244bf48b5dbe2fdca174ae621d5 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 20 Jul 2023 16:11:20 +0000 Subject: [PATCH 2/7] add presaving routine --- configs/config.yaml | 4 +++ pvnet_summation/training.py | 53 ++++++++++++++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index fbbe429..08c8fb4 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -9,6 +9,10 @@ defaults: - logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) - hydra: default.yaml +# Whether to loop through the PVNet outputs and save them out before training +presave_pvnet_outputs: True + + # enable color logging # - override hydra/hydra_logging: colorlog # - override hydra/job_logging: colorlog diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index f7f5324..51ef5cf 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -15,9 +15,12 @@ from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm from pvnet import utils +from pvnet_summation.data.datamodule import PVNetPresavedDataModule + log = utils.get_logger(__name__) torch.set_default_dtype(torch.float32) @@ -64,6 +67,51 @@ def train(config: DictConfig) -> Optional[float]: # Init lightning model log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate(config.model) + + # Presave batches + if config.get("presave_pvnet_outputs", False): + + # Set batch size to None so batching is skipped + datamodule.batch_size = None + + save_dir = ( + f"{config.datamodule.batch_dir}/" + f"{config.model.model_name}/" + f"{config.model.model_version}" + ) + + log.info(f"Saving PVNet outputs to {save_dir}") + + os.makedirs(f"{save_dir}/train") + os.makedirs(f"{save_dir}/val") + + for dataloader_func, split in [ + (datamodule.train_dataloader, "train"), + (datamodule.val_dataloader, "val") + ]: + log.info(f"Saving {split} outputs") + dataloader = dataloader_func(shuffle=False, add_filename=True) + + for concurrent_sample_dict in tqdm(dataloader): + # Run though model and remove + pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0] + del concurrent_sample_dict["pvnet_inputs"] + concurrent_sample_dict["pvnet_outputs"] = pvnet_out + + # Save pvnet prediction sample + filepath = concurrent_sample_dict.pop("filepath") + sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir) + sample_path = f"{save_dir}{sample_rel_path}" + torch.save(concurrent_sample_dict, sample_path) + + + + datamodule = PVNetPresavedDataModule( + batch_dir=save_dir, + batch_size=config.datamodule.batch_size, + num_workers=config.datamodule.num_workers, + prefetch_factor=config.datamodule.prefetch_factor + ) # Init lightning loggers loggers: list[Logger] = [] @@ -104,6 +152,7 @@ def train(config: DictConfig) -> Optional[float]: OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml") break + trainer: Trainer = hydra.utils.instantiate( config.trainer, logger=loggers, @@ -114,10 +163,6 @@ def train(config: DictConfig) -> Optional[float]: # Train the model completely trainer.fit(model=model, datamodule=datamodule) - if config.test_after_training: - # Evaluate model on test set, using the best model achieved during training - log.info("Starting testing!") - trainer.test(model=model, datamodule=datamodule, ckpt_path="best") # Make sure everything closed properly log.info("Finalizing!") From 3b801ee0ea3bb98fb27742f56523db1ac15db5a8 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 20 Jul 2023 16:11:51 +0000 Subject: [PATCH 3/7] add pvnet output datamodule --- pvnet_summation/data/datamodule.py | 87 ++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 4 deletions(-) diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index ec5eabf..64ba4e8 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -153,7 +153,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) if add_filename: - file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5) + file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=50) sample_pipeline = file_pipeline.sharding_filter().map(torch.load) @@ -167,11 +167,13 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals .compute() ) - sample_pipeline, dp = sample_pipeline.fork(2, buffer_size=5) + sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5) - times_datapipe, dp = GetBatchTime(dp).fork(2, buffer_size=5) + times_datapipe, times_datapipe_copy = ( + GetBatchTime(sample_pipeline_copy).fork(2, buffer_size=5) + ) - national_targets_datapipe = GetNationalPVLive(gsp_data, dp) + national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy) # Compile the samples if add_filename: @@ -198,6 +200,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals ) return data_pipeline + def train_dataloader(self, shuffle=True, add_filename=False): """Construct train dataloader""" @@ -210,6 +213,7 @@ def train_dataloader(self, shuffle=True, add_filename=False): rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) + def val_dataloader(self, shuffle=False, add_filename=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( @@ -220,6 +224,81 @@ def val_dataloader(self, shuffle=False, add_filename=False): rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) + + def test_dataloader(self): + """Construct test dataloader""" + raise NotImplementedError + + +class PVNetPresavedDataModule(LightningDataModule): + """Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.""" + + def __init__( + self, + batch_dir: str, + batch_size=16, + num_workers=0, + prefetch_factor=2, + ): + """Datamodule for loading pre-saved PVNet predictions to train pvnet_summation. + + Args: + batch_dir: Path to the directory of pre-saved batches. + batch_size: Batch size. + num_workers: Number of workers to use in multiprocess batch loading. + prefetch_factor: Number of data will be prefetched at the end of each worker process. + """ + super().__init__() + self.batch_size = batch_size + self.batch_dir = batch_dir + + self.readingservice_config = dict( + num_workers=num_workers, + multiprocessing_context="spawn", + worker_prefetch_cnt=prefetch_factor, + ) + + def _get_premade_batches_datapipe(self, subdir, shuffle=False): + + # Load presaved concurrent sample batches + file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) + + if shuffle: + file_pipeline = file_pipeline.shuffle(buffer_size=1000) + + sample_pipeline = file_pipeline.sharding_filter().map(torch.load) + + if self.batch_size is not None: + + data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size)) + data_pipeline = DictApply( + data_pipeline, + pvnet_outputs=torch.stack, + national_targets=torch.stack, + times=torch.stack, + ) + + return data_pipeline + + def train_dataloader(self, shuffle=True): + """Construct train dataloader""" + datapipe = self._get_premade_batches_datapipe( + "train", + shuffle=shuffle, + ) + + rs = MultiProcessingReadingService(**self.readingservice_config) + return DataLoader2(datapipe, reading_service=rs) + + def val_dataloader(self, shuffle=False): + """Construct val dataloader""" + datapipe = self._get_premade_batches_datapipe( + "val", + shuffle=shuffle, + ) + rs = MultiProcessingReadingService(**self.readingservice_config) + return DataLoader2(datapipe, reading_service=rs) + def test_dataloader(self): """Construct test dataloader""" raise NotImplementedError From 60113597d511dfcd0c7c3d78e1f766cacf568e0c Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 20 Jul 2023 16:13:15 +0000 Subject: [PATCH 4/7] add support for different inputs - either raw batches of the output of PVNet --- pvnet_summation/models/base_model.py | 9 ++++++--- pvnet_summation/models/model.py | 9 +++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index 0d10862..dec091f 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -135,7 +135,8 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times): def training_step(self, batch, batch_idx): """Run training step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -152,7 +153,8 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch: dict, batch_idx): """Run validation step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -201,7 +203,8 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] losses = self._calculate_common_losses(y, y_hat) diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index b978ef4..901621e 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -66,8 +66,13 @@ def __init__( def forward(self, x): - """Run central model forward""" - pvnet_out = self.predict_pvnet_batch(x) + """Run model forward""" + + if "pvnet_outputs" in x: + pvnet_out = x["pvnet_outputs"] + else: + pvnet_out = self.predict_pvnet_batch(x['pvnet_inputs']) + pvnet_out = torch.flatten(pvnet_out, start_dim=1) out = self.model(pvnet_out) From 1f877d0b89c363973ccf97f9390dd9c8af4eb7f3 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 20 Jul 2023 16:31:55 +0000 Subject: [PATCH 5/7] fix fork buffer bug --- pvnet_summation/data/datamodule.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index 64ba4e8..23279c0 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -152,10 +152,13 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) + + file_pipeline = file_pipeline.sharding_filter() + if add_filename: - file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=50) + file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5) - sample_pipeline = file_pipeline.sharding_filter().map(torch.load) + sample_pipeline = file_pipeline.map(torch.load) # Find national outout simultaneous to concurrent samples gsp_data = ( From cc4707dfc976a1dd7d2f68c42fcdbfdd64714758 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Fri, 21 Jul 2023 09:14:30 +0000 Subject: [PATCH 6/7] fix dataloader bug and make saved outputs reusable --- pvnet_summation/data/datamodule.py | 8 ++--- pvnet_summation/training.py | 55 ++++++++++++++++++------------ 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index 23279c0..b91d0c4 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -273,15 +273,15 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): if self.batch_size is not None: - data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size)) - data_pipeline = DictApply( - data_pipeline, + batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size)) + batch_pipeline = DictApply( + batch_pipeline, pvnet_outputs=torch.stack, national_targets=torch.stack, times=torch.stack, ) - return data_pipeline + return batch_pipeline def train_dataloader(self, shuffle=True): """Construct train dataloader""" diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index 51ef5cf..5e65549 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -71,8 +71,7 @@ def train(config: DictConfig) -> Optional[float]: # Presave batches if config.get("presave_pvnet_outputs", False): - # Set batch size to None so batching is skipped - datamodule.batch_size = None + save_dir = ( f"{config.datamodule.batch_dir}/" @@ -80,29 +79,41 @@ def train(config: DictConfig) -> Optional[float]: f"{config.model.model_version}" ) - log.info(f"Saving PVNet outputs to {save_dir}") - os.makedirs(f"{save_dir}/train") - os.makedirs(f"{save_dir}/val") - for dataloader_func, split in [ - (datamodule.train_dataloader, "train"), - (datamodule.val_dataloader, "val") - ]: - log.info(f"Saving {split} outputs") - dataloader = dataloader_func(shuffle=False, add_filename=True) + if os.path.isdir(save_dir): + log.info( + f"PVNet output directory already exists: {save_dir}\n" + "Skipping saving new outputs. The existing saved outputs will be loaded." + ) + + else: + log.info(f"Saving PVNet outputs to {save_dir}") + + os.makedirs(f"{save_dir}/train") + os.makedirs(f"{save_dir}/val") - for concurrent_sample_dict in tqdm(dataloader): - # Run though model and remove - pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0] - del concurrent_sample_dict["pvnet_inputs"] - concurrent_sample_dict["pvnet_outputs"] = pvnet_out - - # Save pvnet prediction sample - filepath = concurrent_sample_dict.pop("filepath") - sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir) - sample_path = f"{save_dir}{sample_rel_path}" - torch.save(concurrent_sample_dict, sample_path) + # Set batch size to None so batching is skipped + datamodule.batch_size = None + + for dataloader_func, split in [ + (datamodule.train_dataloader, "train"), + (datamodule.val_dataloader, "val") + ]: + log.info(f"Saving {split} outputs") + dataloader = dataloader_func(shuffle=False, add_filename=True) + + for concurrent_sample_dict in tqdm(dataloader): + # Run though model and remove + pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0] + del concurrent_sample_dict["pvnet_inputs"] + concurrent_sample_dict["pvnet_outputs"] = pvnet_out + + # Save pvnet prediction sample + filepath = concurrent_sample_dict.pop("filepath") + sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir) + sample_path = f"{save_dir}{sample_rel_path}" + torch.save(concurrent_sample_dict, sample_path) From 2143caa3382655319690914f4a506853bce9ceb7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 10:00:31 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- configs/config.yaml | 4 +- configs/datamodule/default.yaml | 2 +- configs/model/default.yaml | 1 - pvnet_summation/__init__.py | 2 +- pvnet_summation/data/datamodule.py | 157 ++++++++++++--------------- pvnet_summation/models/base_model.py | 40 +++---- pvnet_summation/models/model.py | 39 +++---- pvnet_summation/training.py | 40 +++---- pvnet_summation/utils.py | 7 +- requirements.txt | 2 +- run.py | 5 +- tests/conftest.py | 39 +++---- tests/data/test_datamodule.py | 32 +++--- tests/test_end2end.py | 2 +- 14 files changed, 159 insertions(+), 213 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index 08c8fb4..0cb36a1 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -10,8 +10,8 @@ defaults: - hydra: default.yaml # Whether to loop through the PVNet outputs and save them out before training -presave_pvnet_outputs: True - +presave_pvnet_outputs: + True # enable color logging # - override hydra/hydra_logging: colorlog diff --git a/configs/datamodule/default.yaml b/configs/datamodule/default.yaml index 3ed7921..4e9f6af 100644 --- a/configs/datamodule/default.yaml +++ b/configs/datamodule/default.yaml @@ -3,4 +3,4 @@ batch_dir: "/mnt/disks/bigbatches/concurrent_batches_v3.6_-60mins" gsp_zarr_path: "/mnt/disks/nwp/pv_gsp.zarr" batch_size: 8 num_workers: 20 -prefetch_factor: 2 \ No newline at end of file +prefetch_factor: 2 diff --git a/configs/model/default.yaml b/configs/model/default.yaml index 0418c70..d481aaf 100644 --- a/configs/model/default.yaml +++ b/configs/model/default.yaml @@ -18,7 +18,6 @@ output_network_kwargs: res_block_layers: 2 dropout_frac: 0.0 - # Foreast and time settings forecast_minutes: 480 diff --git a/pvnet_summation/__init__.py b/pvnet_summation/__init__.py index ed2c3c5..ed53582 100644 --- a/pvnet_summation/__init__.py +++ b/pvnet_summation/__init__.py @@ -1 +1 @@ -"""PVNet_summation""" \ No newline at end of file +"""PVNet_summation""" diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index b91d0c4..029a952 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -2,36 +2,30 @@ import torch from lightning.pytorch import LightningDataModule -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import FileLister, IterDataPipe -from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.load import OpenGSP from ocf_datapipes.training.pvnet import normalize_gsp -from torchdata.datapipes.iter import Zipper +from ocf_datapipes.utils.consts import BatchKey +from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService +from torchdata.datapipes.iter import FileLister, IterDataPipe, Zipper -from pvnet.data.datamodule import ( - copy_batch_to_device, - batch_to_tensor, - split_batches, -) # https://github.com/pytorch/pytorch/issues/973 -torch.multiprocessing.set_sharing_strategy('file_system') +torch.multiprocessing.set_sharing_strategy("file_system") class GetNationalPVLive(IterDataPipe): """Select national output targets for given times""" + def __init__(self, gsp_data, times_datapipe): """Select national output targets for given times - + Args: gsp_data: xarray Dataarray of the national outputs times_datapipe: IterDataPipe yeilding arrays of target times. """ self.gsp_data = gsp_data self.times_datapipe = times_datapipe - + def __iter__(self): - gsp_data = self.gsp_data for times in self.times_datapipe: national_outputs = torch.as_tensor( @@ -42,54 +36,54 @@ def __iter__(self): class GetBatchTime(IterDataPipe): """Extract the valid times from the concurrent sample batch""" - + def __init__(self, sample_datapipe): """Extract the valid times from the concurrent sample batch - + Args: sample_datapipe: IterDataPipe yeilding concurrent sample batches """ self.sample_datapipe = sample_datapipe - + def __iter__(self): for sample in self.sample_datapipe: - # Times for each GSP in the sample batch should be the same - take first + # Times for each GSP in the sample batch should be the same - take first id0 = sample[BatchKey.gsp_t0_idx] - times = sample[BatchKey.gsp_time_utc][0, id0+1:] + times = sample[BatchKey.gsp_time_utc][0, id0 + 1 :] yield times - + class PivotDictList(IterDataPipe): """Convert list of dicts to dict of lists""" - + def __init__(self, source_datapipe): """Convert list of dicts to dict of lists - + Args: - source_datapipe: + source_datapipe: """ self.source_datapipe = source_datapipe - + def __iter__(self): for list_of_dicts in self.source_datapipe: keys = list_of_dicts[0].keys() batch_dict = {k: [d[k] for d in list_of_dicts] for k in keys} yield batch_dict - - + + class DictApply(IterDataPipe): """Apply functions to elements of a dictionary and return processed dictionary.""" - + def __init__(self, source_datapipe, **transforms): """Apply functions to elements of a dictionary and return processed dictionary. - + Args: source_datapipe: Datapipe which yields dicts **transforms: key-function pairs """ self.source_datapipe = source_datapipe self.transforms = transforms - + def __iter__(self): for d in self.source_datapipe: for key, function in self.transforms.items(): @@ -99,21 +93,21 @@ def __iter__(self): class ZipperDict(IterDataPipe): """Yield samples from multiple datapipes as a dict""" - + def __init__(self, **datapipes): """Yield samples from multiple datapipes as a dict. - + Args: **datapipes: Named datapipes """ self.keys = list(datapipes.keys()) self.source_datapipes = Zipper(*[datapipes[key] for key in self.keys]) - + def __iter__(self): for outputs in self.source_datapipes: yield {key: value for key, value in zip(self.keys, outputs)} - + class DataModule(LightningDataModule): """Datamodule for training pvnet_summation.""" @@ -144,95 +138,83 @@ def __init__( multiprocessing_context="spawn", worker_prefetch_cnt=prefetch_factor, ) - + def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False): - # Load presaved concurrent sample batches file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) - + if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) - + file_pipeline = file_pipeline.sharding_filter() - + if add_filename: file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5) - + sample_pipeline = file_pipeline.map(torch.load) - + # Find national outout simultaneous to concurrent samples gsp_data = ( - next(iter( - OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path) - .map(normalize_gsp) - )) + next(iter(OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp))) .sel(gsp_id=0) .compute() ) - + sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5) - - times_datapipe, times_datapipe_copy = ( - GetBatchTime(sample_pipeline_copy).fork(2, buffer_size=5) + + times_datapipe, times_datapipe_copy = GetBatchTime(sample_pipeline_copy).fork( + 2, buffer_size=5 ) - + national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy) - + # Compile the samples if add_filename: data_pipeline = ZipperDict( - pvnet_inputs = sample_pipeline, - national_targets = national_targets_datapipe, - times = times_datapipe, - filepath = file_pipeline_copy, + pvnet_inputs=sample_pipeline, + national_targets=national_targets_datapipe, + times=times_datapipe, + filepath=file_pipeline_copy, ) else: data_pipeline = ZipperDict( - pvnet_inputs = sample_pipeline, - national_targets = national_targets_datapipe, - times = times_datapipe, - ) - + pvnet_inputs=sample_pipeline, + national_targets=national_targets_datapipe, + times=times_datapipe, + ) + if self.batch_size is not None: - data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size)) data_pipeline = DictApply( - data_pipeline, - national_targets=torch.stack, + data_pipeline, + national_targets=torch.stack, times=torch.stack, ) - + return data_pipeline - def train_dataloader(self, shuffle=True, add_filename=False): """Construct train dataloader""" datapipe = self._get_premade_batches_datapipe( - "train", - shuffle=shuffle, - add_filename=add_filename + "train", shuffle=shuffle, add_filename=add_filename ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) - def val_dataloader(self, shuffle=False, add_filename=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( - "val", - shuffle=shuffle, - add_filename=add_filename - ) + "val", shuffle=shuffle, add_filename=add_filename + ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) - def test_dataloader(self): """Construct test dataloader""" raise NotImplementedError - - + + class PVNetPresavedDataModule(LightningDataModule): """Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.""" @@ -260,34 +242,32 @@ def __init__( multiprocessing_context="spawn", worker_prefetch_cnt=prefetch_factor, ) - + def _get_premade_batches_datapipe(self, subdir, shuffle=False): - # Load presaved concurrent sample batches file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) - + if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) - - sample_pipeline = file_pipeline.sharding_filter().map(torch.load) - + + sample_pipeline = file_pipeline.sharding_filter().map(torch.load) + if self.batch_size is not None: - batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size)) batch_pipeline = DictApply( batch_pipeline, pvnet_outputs=torch.stack, - national_targets=torch.stack, + national_targets=torch.stack, times=torch.stack, ) - + return batch_pipeline def train_dataloader(self, shuffle=True): """Construct train dataloader""" datapipe = self._get_premade_batches_datapipe( - "train", - shuffle=shuffle, + "train", + shuffle=shuffle, ) rs = MultiProcessingReadingService(**self.readingservice_config) @@ -296,13 +276,12 @@ def train_dataloader(self, shuffle=True): def val_dataloader(self, shuffle=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( - "val", - shuffle=shuffle, - ) + "val", + shuffle=shuffle, + ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) def test_dataloader(self): """Construct test dataloader""" raise NotImplementedError - diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index dec091f..16debf4 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -1,28 +1,22 @@ """Base model for all PVNet submodels""" -import json import logging -import os -from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional -import hydra +import lightning.pytorch as pl import torch import wandb from nowcasting_utils.models.loss import WeightedLosses -import lightning.pytorch as pl -from torch import nn -from pvnet.models.base_model import PVNetModelHubMixin, BaseModel as PVNetBaseModel - -from pvnet.models.utils import ( +from pvnet.models.base_model import BaseModel as PVNetBaseModel +from pvnet.models.base_model import PVNetModelHubMixin +from pvnet.models.utils import ( MetricAccumulator, PredAccumulator, ) - - from pvnet.optimizers import AbstractOptimizer + from pvnet_summation.utils import plot_forecasts -#from pvnet.models.base_model import BaseModel as PVNetBaseModel +# from pvnet.models.base_model import BaseModel as PVNetBaseModel logger = logging.getLogger(__name__) @@ -54,7 +48,7 @@ def __init__( None the output is a single value. """ pl.LightningModule.__init__(self) - PVNetModelHubMixin.__init__(self) + PVNetModelHubMixin.__init__(self) self._optimizer = optimizer @@ -74,20 +68,20 @@ def __init__( self._accumulated_y = PredAccumulator() self._accumulated_y_hat = PredAccumulator() self._accumulated_times = PredAccumulator() - + self.pvnet_model = PVNetBaseModel.from_pretrained( model_name, revision=model_version, ) self.pvnet_model.requires_grad_(False) - + def predict_pvnet_batch(self, batch): gsp_batches = [] for sample in batch: preds = self.pvnet_model(sample) gsp_batches += [preds] return torch.stack(gsp_batches) - + @property def pvnet_output_shape(self): if self.pvnet_model.use_quantile_regression: @@ -95,7 +89,6 @@ def pvnet_output_shape(self): else: return (317, self.pvnet_model.forecast_len_30) - def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times): """Internal function to accumulate training batches and log results. @@ -135,7 +128,7 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times): def training_step(self, batch, batch_idx): """Run training step""" - + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -150,10 +143,10 @@ def training_step(self, batch, batch_idx): else: opt_target = losses["MAE/train"] return opt_target - + def validation_step(self, batch: dict, batch_idx): """Run validation step""" - + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -203,7 +196,7 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" - + y_hat = self.forward(batch) y = batch["national_targets"] @@ -220,10 +213,9 @@ def test_step(self, batch, batch_idx): return logged_losses - def configure_optimizers(self): """Configure the optimizers using learning rate found with LR finder if used""" if self.lr is not None: # Use learning rate found by learning rate finder callback self._optimizer.lr = self.lr - return self._optimizer(self.parameters()) \ No newline at end of file + return self._optimizer(self.parameters()) diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index 901621e..9214508 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -3,21 +3,17 @@ from typing import Optional import numpy as np -import torch - import pvnet -from pvnet_summation.models.base_model import BaseModel -from pvnet.optimizers import AbstractOptimizer -from pvnet.models.multimodal.linear_networks.networks import DefaultFCNet +import torch from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork +from pvnet.models.multimodal.linear_networks.networks import DefaultFCNet +from pvnet.optimizers import AbstractOptimizer - +from pvnet_summation.models.base_model import BaseModel class Model(BaseModel): - """Neural network which combines GSP predictions from PVNet - - """ + """Neural network which combines GSP predictions from PVNet""" name = "pvnet_summation_model" @@ -29,8 +25,7 @@ def __init__( output_quantiles: Optional[list[float]] = None, output_network: AbstractLinearNetwork = DefaultFCNet, output_network_kwargs: dict = dict(), - optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), - + optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), ): """Neural network which combines GSP predictions from PVNet @@ -46,16 +41,10 @@ def __init__( optimizer (AbstractOptimizer): Optimizer """ - super().__init__( - forecast_minutes, - model_name, - model_version, - optimizer, - output_quantiles - ) + super().__init__(forecast_minutes, model_name, model_version, optimizer, output_quantiles) in_features = np.product(self.pvnet_output_shape) - + self.model = output_network( in_features=in_features, out_features=self.num_output_features, @@ -64,21 +53,19 @@ def __init__( self.save_hyperparameters() - def forward(self, x): """Run model forward""" - + if "pvnet_outputs" in x: pvnet_out = x["pvnet_outputs"] else: - pvnet_out = self.predict_pvnet_batch(x['pvnet_inputs']) - + pvnet_out = self.predict_pvnet_batch(x["pvnet_inputs"]) + pvnet_out = torch.flatten(pvnet_out, start_dim=1) out = self.model(pvnet_out) - + if self.use_quantile_regression: # Shape: batch_size, seq_length * num_quantiles out = out.reshape(out.shape[0], self.forecast_len_30, len(self.output_quantiles)) - - return out + return out diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index 5e65549..8388be7 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -15,9 +15,8 @@ from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf -from tqdm import tqdm - from pvnet import utils +from tqdm import tqdm from pvnet_summation.data.datamodule import PVNetPresavedDataModule @@ -67,45 +66,42 @@ def train(config: DictConfig) -> Optional[float]: # Init lightning model log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate(config.model) - + # Presave batches if config.get("presave_pvnet_outputs", False): - - - save_dir = ( f"{config.datamodule.batch_dir}/" f"{config.model.model_name}/" f"{config.model.model_version}" ) - - - + if os.path.isdir(save_dir): log.info( f"PVNet output directory already exists: {save_dir}\n" "Skipping saving new outputs. The existing saved outputs will be loaded." ) - + else: log.info(f"Saving PVNet outputs to {save_dir}") - - os.makedirs(f"{save_dir}/train") + + os.makedirs(f"{save_dir}/train") os.makedirs(f"{save_dir}/val") - - # Set batch size to None so batching is skipped + + # Set batch size to None so batching is skipped datamodule.batch_size = None for dataloader_func, split in [ - (datamodule.train_dataloader, "train"), - (datamodule.val_dataloader, "val") + (datamodule.train_dataloader, "train"), + (datamodule.val_dataloader, "val"), ]: log.info(f"Saving {split} outputs") dataloader = dataloader_func(shuffle=False, add_filename=True) for concurrent_sample_dict in tqdm(dataloader): # Run though model and remove - pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0] + pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[ + 0 + ] del concurrent_sample_dict["pvnet_inputs"] concurrent_sample_dict["pvnet_outputs"] = pvnet_out @@ -114,14 +110,12 @@ def train(config: DictConfig) -> Optional[float]: sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir) sample_path = f"{save_dir}{sample_rel_path}" torch.save(concurrent_sample_dict, sample_path) - - - + datamodule = PVNetPresavedDataModule( batch_dir=save_dir, - batch_size=config.datamodule.batch_size, + batch_size=config.datamodule.batch_size, num_workers=config.datamodule.num_workers, - prefetch_factor=config.datamodule.prefetch_factor + prefetch_factor=config.datamodule.prefetch_factor, ) # Init lightning loggers @@ -163,7 +157,6 @@ def train(config: DictConfig) -> Optional[float]: OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml") break - trainer: Trainer = hydra.utils.instantiate( config.trainer, logger=loggers, @@ -174,7 +167,6 @@ def train(config: DictConfig) -> Optional[float]: # Train the model completely trainer.fit(model=model, datamodule=datamodule) - # Make sure everything closed properly log.info("Finalizing!") utils.finish( diff --git a/pvnet_summation/utils.py b/pvnet_summation/utils.py index 711dc3e..e9c7b79 100644 --- a/pvnet_summation/utils.py +++ b/pvnet_summation/utils.py @@ -8,7 +8,6 @@ def plot_forecasts(y, y_hat, times, batch_idx=None, quantiles=None): """Plot a batch of data and the forecast from that batch""" - times_utc = times.cpu().numpy().squeeze().astype("datetime64[s]") times_utc = [pd.to_datetime(t) for t in times_utc] y = y.cpu().numpy() @@ -25,9 +24,7 @@ def plot_forecasts(y, y_hat, times, batch_idx=None, quantiles=None): ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$") if quantiles is None: - ax.plot( - times_utc[i], y_hat[i], marker=".", color="r", label=r"$\hat{y}$" - ) + ax.plot(times_utc[i], y_hat[i], marker=".", color="r", label=r"$\hat{y}$") else: cm = pylab.get_cmap("twilight") for nq, q in enumerate(quantiles): @@ -57,4 +54,4 @@ def plot_forecasts(y, y_hat, times, batch_idx=None, quantiles=None): plt.suptitle(title) plt.tight_layout() - return fig \ No newline at end of file + return fig diff --git a/requirements.txt b/requirements.txt index d80d384..bf8a5e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,4 @@ tqdm rich omegaconf hydra-core -python-dotenv \ No newline at end of file +python-dotenv diff --git a/run.py b/run.py index 0c2b3c4..dbcceb2 100644 --- a/run.py +++ b/run.py @@ -12,8 +12,8 @@ pass import logging -import sys import os +import sys # Tired of seeing these warnings import warnings @@ -34,9 +34,10 @@ def main(config: DictConfig): """Runs training""" # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 - from pvnet_summation.training import train from pvnet.utils import extras, print_config + from pvnet_summation.training import train + # A couple of optional utilities: # - disabling python warnings # - easier access to debug mode diff --git a/tests/conftest.py b/tests/conftest.py index 1ffb578..dfdffa7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,44 +17,41 @@ from pvnet_summation.data.datamodule import DataModule - @pytest.fixture() def sample_data(): - # Copy small batches to fake 317 GSPs in each with tempfile.TemporaryDirectory() as tmpdirname: os.makedirs(f"{tmpdirname}/train") os.makedirs(f"{tmpdirname}/val") - + # Grab times from batch to make national output zarr times = [] - + file_n = 0 for file in glob.glob("tests/data/sample_batches/train/*.pt"): - batch = torch.load(file) - + this_batch = {} for i in range(batch[BatchKey.gsp_time_utc].shape[0]): - # Duplicate sample to fake 317 GSPs + # Duplicate sample to fake 317 GSPs for key in batch.keys(): if isinstance(batch[key], torch.Tensor): n_dims = len(batch[key].shape) - repeats = (317,) + tuple(1 for dim in range(n_dims-1)) - this_batch[key] = batch[key][i:i+1].repeat(repeats)[:317] + repeats = (317,) + tuple(1 for dim in range(n_dims - 1)) + this_batch[key] = batch[key][i : i + 1].repeat(repeats)[:317] else: this_batch[key] = batch[key] - + # Save fopr both train and val torch.save(this_batch, f"{tmpdirname}/train/{file_n:06}.pt") torch.save(this_batch, f"{tmpdirname}/val/{file_n:06}.pt") - + file_n += 1 times += [batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")] - + times = np.unique(np.sort(np.concatenate(times))) - + da_output = xr.DataArray( data=np.random.uniform(size=(len(times), 1)), dims=["datetime_gmt", "gsp_id"], @@ -63,7 +60,7 @@ def sample_data(): gsp_id=[0], ), ) - + da_cap = xr.DataArray( data=np.ones((len(times), 1)), dims=["datetime_gmt", "gsp_id"], @@ -72,7 +69,7 @@ def sample_data(): gsp_id=[0], ), ) - + ds = xr.Dataset( data_vars=dict( generation_mw=da_output, @@ -80,9 +77,9 @@ def sample_data(): capacity_mwp=da_cap, ), ) - + ds.to_zarr(f"{tmpdirname}/gsp.zarr") - + yield tmpdirname, f"{tmpdirname}/gsp.zarr" @@ -97,7 +94,7 @@ def sample_datamodule(sample_data): num_workers=0, prefetch_factor=2, ) - + return dm @@ -111,8 +108,8 @@ def sample_batch(sample_datamodule): def model_kwargs(): kwargs = dict( forecast_minutes=480, - model_name= "openclimatefix/pvnet_v2", - model_version= "898630f3f8cd4e8506525d813dd61c6d8de86144", + model_name="openclimatefix/pvnet_v2", + model_version="898630f3f8cd4e8506525d813dd61c6d8de86144", ) return kwargs @@ -126,4 +123,4 @@ def model(model_kwargs): @pytest.fixture() def quantile_model(model_kwargs): model = Model(output_quantiles=[0.1, 0.5, 0.9], **model_kwargs) - return model \ No newline at end of file + return model diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index b23c3bc..9aa2c27 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -13,6 +13,7 @@ def test_init(sample_data): prefetch_factor=2, ) + def test_iter(sample_data): batch_dir, gsp_zarr_dir = sample_data @@ -23,20 +24,21 @@ def test_iter(sample_data): num_workers=0, prefetch_factor=2, ) - + batch = next(iter(dm.train_dataloader())) - + # batch size is 2 - assert len(batch['pvnet_inputs'])==2 - + assert len(batch["pvnet_inputs"]) == 2 + # 317 GSPs in each sample # 21 timestamps for each GSP from -120 mins to +480 mins - assert batch['pvnet_inputs'][0][BatchKey.gsp_time_utc].shape==(317,21) - - assert batch['times'].shape==(2, 16) - - assert batch['national_targets'].shape==(2, 16) - + assert batch["pvnet_inputs"][0][BatchKey.gsp_time_utc].shape == (317, 21) + + assert batch["times"].shape == (2, 16) + + assert batch["national_targets"].shape == (2, 16) + + def test_iter_multiprocessing(sample_data): batch_dir, gsp_zarr_dir = sample_data @@ -47,15 +49,15 @@ def test_iter_multiprocessing(sample_data): num_workers=2, prefetch_factor=2, ) - + for batch in dm.train_dataloader(): # batch size is 2 - assert len(batch['pvnet_inputs'])==2 + assert len(batch["pvnet_inputs"]) == 2 # 317 GSPs in each sample # 21 timestamps for each GSP from -120 mins to +480 mins - assert batch['pvnet_inputs'][0][BatchKey.gsp_time_utc].shape==(317,21) + assert batch["pvnet_inputs"][0][BatchKey.gsp_time_utc].shape == (317, 21) - assert batch['times'].shape==(2, 16) + assert batch["times"].shape == (2, 16) - assert batch['national_targets'].shape==(2, 16) \ No newline at end of file + assert batch["national_targets"].shape == (2, 16) diff --git a/tests/test_end2end.py b/tests/test_end2end.py index 856f883..b054d01 100644 --- a/tests/test_end2end.py +++ b/tests/test_end2end.py @@ -3,4 +3,4 @@ def test_model_trainer_fit(model, sample_datamodule): trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True) - trainer.fit(model=model, datamodule=sample_datamodule) \ No newline at end of file + trainer.fit(model=model, datamodule=sample_datamodule)