Skip to content

Commit

Permalink
Update datamodule config parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukhil Patel authored and Sukhil Patel committed Jul 1, 2024
1 parent 8541d34 commit c9fed4b
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
Constructs batches where each batch includes all GSPs and only a single timestamp.
Currently a slightly hacky implementation due to the way the configs are done. This script will use
the same config file currently set to train the model.
the same config file currently set to train the model. In the datamodule config file it is possible
to set the batch_output_dir and number of train/val batches, they can also be overriden in the command as
shown in the example below.
use:
```
python save_concurrent_batches.py \
+batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \
+num_train_batches=20_000 \
+num_val_batches=4_000
datamodule.batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \
datamodule.num_train_batches=20_000 \
datamodule.num_val_batches=4_000
```
"""
Expand Down Expand Up @@ -157,12 +159,12 @@ def main(config: DictConfig):
config_dm = config.datamodule

# Set up directory
os.makedirs(config.batch_output_dir, exist_ok=False)
os.makedirs(config_dm.batch_output_dir, exist_ok=False)

with open(f"{config.batch_output_dir}/datamodule.yaml", "w") as f:
with open(f"{config_dm.batch_output_dir}/datamodule.yaml", "w") as f:
f.write(OmegaConf.to_yaml(config.datamodule))

shutil.copyfile(config_dm.configuration, f"{config.batch_output_dir}/data_configuration.yaml")
shutil.copyfile(config_dm.configuration, f"{config_dm.batch_output_dir}/data_configuration.yaml")

dataloader_kwargs = dict(
shuffle=False,
Expand All @@ -179,39 +181,39 @@ def main(config: DictConfig):
persistent_workers=False,
)

if config.num_val_batches > 0:
if config_dm.num_val_batches > 0:
print("----- Saving val batches -----")

os.mkdir(f"{config.batch_output_dir}/val")
os.mkdir(f"{config_dm.batch_output_dir}/val")

val_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.val_period,
config.num_val_batches,
config_dm.num_val_batches,
)

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
batch_dir=f"{config_dm.batch_output_dir}/val",
num_batches=config_dm.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
)

if config.num_train_batches > 0:
if config_dm.num_train_batches > 0:
print("----- Saving train batches -----")

os.mkdir(f"{config.batch_output_dir}/train")
os.mkdir(f"{config_dm.batch_output_dir}/train")

train_batch_pipe = _get_datapipe(
config_dm.configuration,
*config_dm.train_period,
config.num_train_batches,
config_dm.num_train_batches,
)

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
batch_dir=f"{config_dm.batch_output_dir}/train",
num_batches=config_dm.num_train_batches,
dataloader_kwargs=dataloader_kwargs,
)

Expand Down

0 comments on commit c9fed4b

Please sign in to comment.