Skip to content

Commit

Permalink
add option to save bathces
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Nov 26, 2024
1 parent 90008b7 commit 8bd2039
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ This example runs the application and writes the results to stdout
DB_URL={DB_URL} NWP_ZARR_PATH={NWP_ZARR_PATH} poetry run app
```

To save batches, you need to set the `SAVE_BATCHES_DIR` environment variable to directory.
```
### Starting a local database using docker
```bash
Expand Down
4 changes: 4 additions & 0 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
download_satellite_data,
populate_data_config_sources,
process_and_cache_nwp,
save_batch,
set_night_time_zeros,
worker_init_fn,
)
Expand Down Expand Up @@ -92,6 +93,9 @@ def predict(self, site_id: str, timestamp: dt.datetime):
with torch.no_grad():
for i, batch in enumerate(self.dataloader):
log.info(f"Predicting for batch: {i}")

# save batch
save_batch(batch=batch, i=i, model_name=self.name)

# Run batch through model
device_batch = copy_batch_to_device(batch_to_tensor(batch), DEVICE)
Expand Down
28 changes: 28 additions & 0 deletions india_forecast_app/models/pvnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Useful functions for setting up PVNet model"""
import logging
import os
from typing import Optional

import fsspec
import numpy as np
import torch
import xarray as xr
import yaml
from ocf_datapipes.batch import BatchKey
Expand Down Expand Up @@ -191,3 +193,29 @@ def set_night_time_zeros(batch, preds, sun_elevation_limit=0.0):
preds[sun_elevation < sun_elevation_limit] = 0

return preds


def save_batch(batch, i: int, model_name, save_batches_dir: Optional[str] = None):
"""
Save batch to SAVE_BATCHES_DIR if set
Args:
batch: The batch to save
i: The index of the batch
model_name: The name of the
save_batches_dir: The directory to save the batch to,
defaults to environment variable SAVE_BATCHES_DIR
"""

if save_batches_dir is None:
save_batches_dir = os.getenv("SAVE_BATCHES_DIR", None)

if save_batches_dir:
log.info(f"Saving batch {i} to {save_batches_dir}")

local_filename = f'batch_{i}_{model_name}.pt'
torch.save(batch, local_filename)

fs = fsspec.open(save_batches_dir).fs
fs.put(local_filename, f"{save_batches_dir}/{local_filename}")

20 changes: 19 additions & 1 deletion tests/models/pvnet/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
""" Tests for utils for pvnet"""
import numpy as np
import os
from ocf_datapipes.batch import BatchKey
import tempfile

from india_forecast_app.models.pvnet.utils import set_night_time_zeros
from india_forecast_app.models.pvnet.utils import set_night_time_zeros, save_batch


def test_set_night_time_zeros():

Check failure on line 10 in tests/models/pvnet/test_utils.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (I001)

tests/models/pvnet/test_utils.py:2:1: I001 Import block is un-sorted or un-formatted
Expand All @@ -26,3 +28,19 @@ def test_set_night_time_zeros():
assert np.all(preds[:, 2:, :] == 0)
# check that all values are positive
assert np.all(preds[:, :2, :] > 0)


def test_save_batch():

Check failure on line 33 in tests/models/pvnet/test_utils.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (D103)

tests/models/pvnet/test_utils.py:33:5: D103 Missing docstring in public function

# set up batch
batch = {"key": "value"}
i = 1
model_name = "test_model_name"

# create temp folder
with tempfile.TemporaryDirectory() as temp_dir:
save_batch(batch, i, model_name, save_batches_dir=temp_dir)

# check that batch is saved
assert os.path.exists(f"{temp_dir}/batch_{i}_{model_name}.pt")

0 comments on commit 8bd2039

Please sign in to comment.