Skip to content

Commit

Permalink
add presaving routine
Browse files Browse the repository at this point in the history
dfulu committed Jul 20, 2023
1 parent 8a42d1b commit 1493fc4
Showing 2 changed files with 53 additions and 4 deletions.
4 changes: 4 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -9,6 +9,10 @@ defaults:
- logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`)
- hydra: default.yaml

# Whether to loop through the PVNet outputs and save them out before training
presave_pvnet_outputs: True


# enable color logging
# - override hydra/hydra_logging: colorlog
# - override hydra/job_logging: colorlog
53 changes: 49 additions & 4 deletions pvnet_summation/training.py
Original file line number Diff line number Diff line change
@@ -15,9 +15,12 @@
from lightning.pytorch.loggers import Logger
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from pvnet import utils

from pvnet_summation.data.datamodule import PVNetPresavedDataModule

log = utils.get_logger(__name__)

torch.set_default_dtype(torch.float32)
@@ -64,6 +67,51 @@ def train(config: DictConfig) -> Optional[float]:
# Init lightning model
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(config.model)

# Presave batches
if config.get("presave_pvnet_outputs", False):

# Set batch size to None so batching is skipped
datamodule.batch_size = None

save_dir = (
f"{config.datamodule.batch_dir}/"
f"{config.model.model_name}/"
f"{config.model.model_version}"
)

log.info(f"Saving PVNet outputs to {save_dir}")

os.makedirs(f"{save_dir}/train")
os.makedirs(f"{save_dir}/val")

for dataloader_func, split in [
(datamodule.train_dataloader, "train"),
(datamodule.val_dataloader, "val")
]:
log.info(f"Saving {split} outputs")
dataloader = dataloader_func(shuffle=False, add_filename=True)

for concurrent_sample_dict in tqdm(dataloader):
# Run though model and remove
pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0]
del concurrent_sample_dict["pvnet_inputs"]
concurrent_sample_dict["pvnet_outputs"] = pvnet_out

# Save pvnet prediction sample
filepath = concurrent_sample_dict.pop("filepath")
sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir)
sample_path = f"{save_dir}{sample_rel_path}"
torch.save(concurrent_sample_dict, sample_path)



datamodule = PVNetPresavedDataModule(
batch_dir=save_dir,
batch_size=config.datamodule.batch_size,
num_workers=config.datamodule.num_workers,
prefetch_factor=config.datamodule.prefetch_factor
)

# Init lightning loggers
loggers: list[Logger] = []
@@ -104,6 +152,7 @@ def train(config: DictConfig) -> Optional[float]:
OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml")
break


trainer: Trainer = hydra.utils.instantiate(
config.trainer,
logger=loggers,
@@ -114,10 +163,6 @@ def train(config: DictConfig) -> Optional[float]:
# Train the model completely
trainer.fit(model=model, datamodule=datamodule)

if config.test_after_training:
# Evaluate model on test set, using the best model achieved during training
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")

# Make sure everything closed properly
log.info("Finalizing!")

0 comments on commit 1493fc4

Please sign in to comment.