Skip to content

Commit

Permalink
Update config example parameters (#198)
Browse files Browse the repository at this point in the history
Update config example for batch creation
  • Loading branch information
Sukh-P authored May 17, 2024
1 parent f9831cd commit 130846c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,15 @@ Where `FULL-PATH-TO-REPO` represent the whole path to the PVNet repo on your loc

### Running the batch creation script

Run the save_batches.py script to create batches with the following example arguments as:
Run the save_batches.py script to create batches if setting parameters in the datamodule config (`streamed_batches.yaml` in this example):

```
python scripts/save_batches.py datamodule=streamed_batches +batch_output_dir="./output" +num_train_batches=10 +num_val_batches=5
python scripts/save_batches.py
```
or with the following example arguments to override config:

```
python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5
```

In this function the datamodule argument looks for a config under `PVNet/configs/datamodule`. The examples here are either to use "premade_batches" or "streamed_batches".
Expand Down
5 changes: 5 additions & 0 deletions configs.example/datamodule/streamed_batches.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ configuration: "PLACEHOLDER.yaml"
num_workers: 20
prefetch_factor: 2
batch_size: 8
batch_output_dir: "PLACEHOLDER"
num_train_batches: 2
num_val_batches: 1


train_period:
- null
- "2022-05-07"
Expand Down
39 changes: 23 additions & 16 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
the same config file currently set to train the model.
use:
```
python save_batches.py
```
if setting all values in the datamodule config file, or
```
python save_batches.py \
+batch_output_dir="/mnt/disks/bigbatches/batches_v0" \
datamodule.batch_output_dir="/mnt/disks/bigbatches/batches_v0" \
datamodule.batch_size=2 \
datamodule.num_workers=2 \
+num_train_batches=0 \
+num_val_batches=2
datamodule.num_train_batches=0 \
datamodule.num_val_batches=2
```
if wanting to override these values for example
"""
# This is needed to get multiprocessing/multiple workers to behave
try:
Expand Down Expand Up @@ -110,12 +115,14 @@ def main(config: DictConfig):
print_config(config, resolve=False)

# Set up directory
os.makedirs(config.batch_output_dir, exist_ok=False)
os.makedirs(config_dm.batch_output_dir, exist_ok=False)

with open(f"{config.batch_output_dir}/datamodule.yaml", "w") as f:
f.write(OmegaConf.to_yaml(config.datamodule))
with open(f"{config_dm.batch_output_dir}/datamodule.yaml", "w") as f:
f.write(OmegaConf.to_yaml(config_dm))

shutil.copyfile(config_dm.configuration, f"{config.batch_output_dir}/data_configuration.yaml")
shutil.copyfile(
config_dm.configuration, f"{config_dm.batch_output_dir}/data_configuration.yaml"
)

dataloader_kwargs = dict(
shuffle=False,
Expand All @@ -132,8 +139,8 @@ def main(config: DictConfig):
persistent_workers=False,
)

if config.num_val_batches > 0:
os.mkdir(f"{config.batch_output_dir}/val")
if config_dm.num_val_batches > 0:
os.mkdir(f"{config_dm.batch_output_dir}/val")
print("----- Saving val batches -----")

val_batch_pipe = _get_datapipe(
Expand All @@ -145,14 +152,14 @@ def main(config: DictConfig):

_save_batches_with_dataloader(
batch_pipe=val_batch_pipe,
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
batch_dir=f"{config_dm.batch_output_dir}/val",
num_batches=config_dm.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
output_format="torch" if config.renewable == "pv" else "netcdf",
)

if config.num_train_batches > 0:
os.mkdir(f"{config.batch_output_dir}/train")
if config_dm.num_train_batches > 0:
os.mkdir(f"{config_dm.batch_output_dir}/train")
print("----- Saving train batches -----")

train_batch_pipe = _get_datapipe(
Expand All @@ -164,8 +171,8 @@ def main(config: DictConfig):

_save_batches_with_dataloader(
batch_pipe=train_batch_pipe,
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
batch_dir=f"{config_dm.batch_output_dir}/train",
num_batches=config_dm.num_train_batches,
dataloader_kwargs=dataloader_kwargs,
output_format="torch" if config.renewable == "pv" else "netcdf",
)
Expand Down

0 comments on commit 130846c

Please sign in to comment.