From 8bd2039d0d13fd1cde830f39540ba1c44c60a1f4 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Tue, 26 Nov 2024 16:28:47 +0000 Subject: [PATCH] add option to save bathces --- README.md | 3 +++ india_forecast_app/models/pvnet/model.py | 4 ++++ india_forecast_app/models/pvnet/utils.py | 28 ++++++++++++++++++++++++ tests/models/pvnet/test_utils.py | 20 ++++++++++++++++- 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 570908d..38a1f9f 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,9 @@ This example runs the application and writes the results to stdout DB_URL={DB_URL} NWP_ZARR_PATH={NWP_ZARR_PATH} poetry run app ``` +To save batches, you need to set the `SAVE_BATCHES_DIR` environment variable to directory. +``` + ### Starting a local database using docker ```bash diff --git a/india_forecast_app/models/pvnet/model.py b/india_forecast_app/models/pvnet/model.py index 086ecb0..8a32f98 100644 --- a/india_forecast_app/models/pvnet/model.py +++ b/india_forecast_app/models/pvnet/model.py @@ -38,6 +38,7 @@ download_satellite_data, populate_data_config_sources, process_and_cache_nwp, + save_batch, set_night_time_zeros, worker_init_fn, ) @@ -92,6 +93,9 @@ def predict(self, site_id: str, timestamp: dt.datetime): with torch.no_grad(): for i, batch in enumerate(self.dataloader): log.info(f"Predicting for batch: {i}") + + # save batch + save_batch(batch=batch, i=i, model_name=self.name) # Run batch through model device_batch = copy_batch_to_device(batch_to_tensor(batch), DEVICE) diff --git a/india_forecast_app/models/pvnet/utils.py b/india_forecast_app/models/pvnet/utils.py index 980d3a5..47f1b87 100644 --- a/india_forecast_app/models/pvnet/utils.py +++ b/india_forecast_app/models/pvnet/utils.py @@ -1,9 +1,11 @@ """Useful functions for setting up PVNet model""" import logging import os +from typing import Optional import fsspec import numpy as np +import torch import xarray as xr import yaml from ocf_datapipes.batch import BatchKey @@ -191,3 +193,29 @@ def set_night_time_zeros(batch, preds, sun_elevation_limit=0.0): preds[sun_elevation < sun_elevation_limit] = 0 return preds + + +def save_batch(batch, i: int, model_name, save_batches_dir: Optional[str] = None): + """ + Save batch to SAVE_BATCHES_DIR if set + + Args: + batch: The batch to save + i: The index of the batch + model_name: The name of the + save_batches_dir: The directory to save the batch to, + defaults to environment variable SAVE_BATCHES_DIR + """ + + if save_batches_dir is None: + save_batches_dir = os.getenv("SAVE_BATCHES_DIR", None) + + if save_batches_dir: + log.info(f"Saving batch {i} to {save_batches_dir}") + + local_filename = f'batch_{i}_{model_name}.pt' + torch.save(batch, local_filename) + + fs = fsspec.open(save_batches_dir).fs + fs.put(local_filename, f"{save_batches_dir}/{local_filename}") + diff --git a/tests/models/pvnet/test_utils.py b/tests/models/pvnet/test_utils.py index 762a7f0..95d6d8d 100644 --- a/tests/models/pvnet/test_utils.py +++ b/tests/models/pvnet/test_utils.py @@ -1,8 +1,10 @@ """ Tests for utils for pvnet""" import numpy as np +import os from ocf_datapipes.batch import BatchKey +import tempfile -from india_forecast_app.models.pvnet.utils import set_night_time_zeros +from india_forecast_app.models.pvnet.utils import set_night_time_zeros, save_batch def test_set_night_time_zeros(): @@ -26,3 +28,19 @@ def test_set_night_time_zeros(): assert np.all(preds[:, 2:, :] == 0) # check that all values are positive assert np.all(preds[:, :2, :] > 0) + + +def test_save_batch(): + + # set up batch + batch = {"key": "value"} + i = 1 + model_name = "test_model_name" + + # create temp folder + with tempfile.TemporaryDirectory() as temp_dir: + save_batch(batch, i, model_name, save_batches_dir=temp_dir) + + # check that batch is saved + assert os.path.exists(f"{temp_dir}/batch_{i}_{model_name}.pt") +