diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index a0252258..55bee9ee 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -2,14 +2,16 @@ Constructs batches where each batch includes all GSPs and only a single timestamp. Currently a slightly hacky implementation due to the way the configs are done. This script will use -the same config file currently set to train the model. +the same config file currently set to train the model. In the datamodule config file it is possible +to set the batch_output_dir and number of train/val batches, they can also be overriden in the command as +shown in the example below. use: ``` python save_concurrent_batches.py \ - +batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \ - +num_train_batches=20_000 \ - +num_val_batches=4_000 + datamodule.batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \ + datamodule.num_train_batches=20_000 \ + datamodule.num_val_batches=4_000 ``` """ @@ -157,12 +159,12 @@ def main(config: DictConfig): config_dm = config.datamodule # Set up directory - os.makedirs(config.batch_output_dir, exist_ok=False) + os.makedirs(config_dm.batch_output_dir, exist_ok=False) - with open(f"{config.batch_output_dir}/datamodule.yaml", "w") as f: + with open(f"{config_dm.batch_output_dir}/datamodule.yaml", "w") as f: f.write(OmegaConf.to_yaml(config.datamodule)) - shutil.copyfile(config_dm.configuration, f"{config.batch_output_dir}/data_configuration.yaml") + shutil.copyfile(config_dm.configuration, f"{config_dm.batch_output_dir}/data_configuration.yaml") dataloader_kwargs = dict( shuffle=False, @@ -179,39 +181,39 @@ def main(config: DictConfig): persistent_workers=False, ) - if config.num_val_batches > 0: + if config_dm.num_val_batches > 0: print("----- Saving val batches -----") - os.mkdir(f"{config.batch_output_dir}/val") + os.mkdir(f"{config_dm.batch_output_dir}/val") val_batch_pipe = _get_datapipe( config_dm.configuration, *config_dm.val_period, - config.num_val_batches, + config_dm.num_val_batches, ) _save_batches_with_dataloader( batch_pipe=val_batch_pipe, - batch_dir=f"{config.batch_output_dir}/val", - num_batches=config.num_val_batches, + batch_dir=f"{config_dm.batch_output_dir}/val", + num_batches=config_dm.num_val_batches, dataloader_kwargs=dataloader_kwargs, ) - if config.num_train_batches > 0: + if config_dm.num_train_batches > 0: print("----- Saving train batches -----") - os.mkdir(f"{config.batch_output_dir}/train") + os.mkdir(f"{config_dm.batch_output_dir}/train") train_batch_pipe = _get_datapipe( config_dm.configuration, *config_dm.train_period, - config.num_train_batches, + config_dm.num_train_batches, ) _save_batches_with_dataloader( batch_pipe=train_batch_pipe, - batch_dir=f"{config.batch_output_dir}/train", - num_batches=config.num_train_batches, + batch_dir=f"{config_dm.batch_output_dir}/train", + num_batches=config_dm.num_train_batches, dataloader_kwargs=dataloader_kwargs, )