Skip to content

Commit

Permalink
fix dataloader bug and make saved outputs reusable
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jul 21, 2023
1 parent 1f877d0 commit cc4707d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 26 deletions.
8 changes: 4 additions & 4 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,15 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):

if self.batch_size is not None:

data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size))
data_pipeline = DictApply(
data_pipeline,
batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size))
batch_pipeline = DictApply(
batch_pipeline,
pvnet_outputs=torch.stack,
national_targets=torch.stack,
times=torch.stack,
)

return data_pipeline
return batch_pipeline

def train_dataloader(self, shuffle=True):
"""Construct train dataloader"""
Expand Down
55 changes: 33 additions & 22 deletions pvnet_summation/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,38 +71,49 @@ def train(config: DictConfig) -> Optional[float]:
# Presave batches
if config.get("presave_pvnet_outputs", False):

# Set batch size to None so batching is skipped
datamodule.batch_size = None


save_dir = (
f"{config.datamodule.batch_dir}/"
f"{config.model.model_name}/"
f"{config.model.model_version}"
)

log.info(f"Saving PVNet outputs to {save_dir}")

os.makedirs(f"{save_dir}/train")
os.makedirs(f"{save_dir}/val")

for dataloader_func, split in [
(datamodule.train_dataloader, "train"),
(datamodule.val_dataloader, "val")
]:
log.info(f"Saving {split} outputs")
dataloader = dataloader_func(shuffle=False, add_filename=True)
if os.path.isdir(save_dir):
log.info(
f"PVNet output directory already exists: {save_dir}\n"
"Skipping saving new outputs. The existing saved outputs will be loaded."
)

else:
log.info(f"Saving PVNet outputs to {save_dir}")

os.makedirs(f"{save_dir}/train")
os.makedirs(f"{save_dir}/val")

for concurrent_sample_dict in tqdm(dataloader):
# Run though model and remove
pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0]
del concurrent_sample_dict["pvnet_inputs"]
concurrent_sample_dict["pvnet_outputs"] = pvnet_out

# Save pvnet prediction sample
filepath = concurrent_sample_dict.pop("filepath")
sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir)
sample_path = f"{save_dir}{sample_rel_path}"
torch.save(concurrent_sample_dict, sample_path)
# Set batch size to None so batching is skipped
datamodule.batch_size = None

for dataloader_func, split in [
(datamodule.train_dataloader, "train"),
(datamodule.val_dataloader, "val")
]:
log.info(f"Saving {split} outputs")
dataloader = dataloader_func(shuffle=False, add_filename=True)

for concurrent_sample_dict in tqdm(dataloader):
# Run though model and remove
pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0]
del concurrent_sample_dict["pvnet_inputs"]
concurrent_sample_dict["pvnet_outputs"] = pvnet_out

# Save pvnet prediction sample
filepath = concurrent_sample_dict.pop("filepath")
sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir)
sample_path = f"{save_dir}{sample_rel_path}"
torch.save(concurrent_sample_dict, sample_path)



Expand Down

0 comments on commit cc4707d

Please sign in to comment.