diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index 23279c0..b91d0c4 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -273,15 +273,15 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): if self.batch_size is not None: - data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size)) - data_pipeline = DictApply( - data_pipeline, + batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size)) + batch_pipeline = DictApply( + batch_pipeline, pvnet_outputs=torch.stack, national_targets=torch.stack, times=torch.stack, ) - return data_pipeline + return batch_pipeline def train_dataloader(self, shuffle=True): """Construct train dataloader""" diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index 51ef5cf..5e65549 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -71,8 +71,7 @@ def train(config: DictConfig) -> Optional[float]: # Presave batches if config.get("presave_pvnet_outputs", False): - # Set batch size to None so batching is skipped - datamodule.batch_size = None + save_dir = ( f"{config.datamodule.batch_dir}/" @@ -80,29 +79,41 @@ def train(config: DictConfig) -> Optional[float]: f"{config.model.model_version}" ) - log.info(f"Saving PVNet outputs to {save_dir}") - os.makedirs(f"{save_dir}/train") - os.makedirs(f"{save_dir}/val") - for dataloader_func, split in [ - (datamodule.train_dataloader, "train"), - (datamodule.val_dataloader, "val") - ]: - log.info(f"Saving {split} outputs") - dataloader = dataloader_func(shuffle=False, add_filename=True) + if os.path.isdir(save_dir): + log.info( + f"PVNet output directory already exists: {save_dir}\n" + "Skipping saving new outputs. The existing saved outputs will be loaded." + ) + + else: + log.info(f"Saving PVNet outputs to {save_dir}") + + os.makedirs(f"{save_dir}/train") + os.makedirs(f"{save_dir}/val") - for concurrent_sample_dict in tqdm(dataloader): - # Run though model and remove - pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0] - del concurrent_sample_dict["pvnet_inputs"] - concurrent_sample_dict["pvnet_outputs"] = pvnet_out - - # Save pvnet prediction sample - filepath = concurrent_sample_dict.pop("filepath") - sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir) - sample_path = f"{save_dir}{sample_rel_path}" - torch.save(concurrent_sample_dict, sample_path) + # Set batch size to None so batching is skipped + datamodule.batch_size = None + + for dataloader_func, split in [ + (datamodule.train_dataloader, "train"), + (datamodule.val_dataloader, "val") + ]: + log.info(f"Saving {split} outputs") + dataloader = dataloader_func(shuffle=False, add_filename=True) + + for concurrent_sample_dict in tqdm(dataloader): + # Run though model and remove + pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0] + del concurrent_sample_dict["pvnet_inputs"] + concurrent_sample_dict["pvnet_outputs"] = pvnet_out + + # Save pvnet prediction sample + filepath = concurrent_sample_dict.pop("filepath") + sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir) + sample_path = f"{save_dir}{sample_rel_path}" + torch.save(concurrent_sample_dict, sample_path)