Skip to content

Commit

Permalink
Begun integrating windnet
Browse files Browse the repository at this point in the history
  • Loading branch information
confusedmatrix committed Feb 7, 2024
1 parent ad23fd9 commit 7de8328
Show file tree
Hide file tree
Showing 28 changed files with 5,192 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ format:

.PHONY: test
test:
poetry run pytest tests
poetry run pytest tests -W ignore::DeprecationWarning

.PHONY: docker.build
docker.build:
Expand Down
12 changes: 9 additions & 3 deletions india_forecast_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pvsite_datamodel.write import insert_forecast_values
from sqlalchemy.orm import Session

from .model import DummyModel
from .models import DummyModel, PVNetModel

log = logging.getLogger(__name__)

Expand All @@ -35,18 +35,24 @@ def get_sites(db_session: Session) -> list[SiteSQL]:
return sites


def get_model(asset_type: str):
def get_model(asset_type: str, timestamp: dt.datetime) -> PVNetModel:
"""
Instantiates and returns the forecast model ready for running inference
Args:
asset_type: One or "pv" or "wind"
timestamp: Datetime at which the forecast will be made
Returns:
A forecasting model
"""

model = DummyModel(asset_type)
# Only windnet is ready, so if asset_type is PV, continue using dummy model
if asset_type == "wind":
model = PVNetModel(asset_type, timestamp)
else:
model = DummyModel(asset_type, timestamp)

return model


Expand Down
6 changes: 6 additions & 0 deletions india_forecast_app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Available models for India forecast"""

from .dummy import DummyModel
from .pvnet.model import PVNetModel

__all__ = ['DummyModel', 'PVNetModel']
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Model classes (currently just allows for loading a dummy model)
Dummy Model class (generate a dummy forecast)
"""

import datetime as dt
Expand All @@ -17,9 +17,10 @@ def version(self):
"""Version number"""
return "0.0.0"

def __init__(self, asset_type: str):
def __init__(self, asset_type: str, timestamp: dt.datetime):
"""Initializer for the model"""
self.asset_type = asset_type
self.to = timestamp

def predict(self, site_id: str, timestamp: dt.datetime):
"""Make a prediction for the model"""
Expand Down
153 changes: 153 additions & 0 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
PVNet model class
"""

import datetime as dt
import logging
import os
import tempfile

import fsspec
import torch
from ocf_datapipes.batch import stack_np_examples_into_batch
from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline as pv_base_pipeline
from ocf_datapipes.training.windnet import construct_sliced_data_pipeline as wind_base_pipeline
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper

from .utils import populate_data_config_sources, worker_init_fn

# Global settings for running the model

# Model will use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WIND_MODEL_NAME = os.getenv("WIND_MODEL_NAME", default="openclimatefix/windnet_india")
WIND_MODEL_VERSION = os.getenv("WIND_MODEL_VERSION",
default="c6af802823edc5e87b22df680b41b0dcdb4869e1")

PV_MODEL_NAME = os.getenv("WIND_MODEL_NAME", default="openclimatefix/pvnet_india")
PV_MODEL_VERSION = os.getenv("WIND_MODEL_VERSION",
default="d194488203375e766253f0d2961010356de52eb9")

BATCH_SIZE = 10

log = logging.getLogger(__name__)


class PVNetModel:
"""
Instantiates a PVNet model for inference
"""

@property
def name(self):
"""Model name"""

return WIND_MODEL_NAME if self.asset_type == "wind" else PV_MODEL_NAME

@property
def version(self):
"""Model version"""

return WIND_MODEL_VERSION if self.asset_type == "wind" else PV_MODEL_VERSION

def __init__(self, asset_type: str, timestamp: dt.datetime):
"""Initializer for the model"""

self.asset_type = asset_type
self.t0 = timestamp
self.setup()

def setup(self):
"""Sets up the model ready for inference"""

self._prepare_data_sources()
self._create_dataloader()
self._load_model()

def predict(self, site_id: str, timestamp: dt.datetime):
"""Make a prediction for the model"""

return []

def _prepare_data_sources(self):
"""Pull and prepare data sources required for inference"""

log.info("Preparing data sources")

nwp_source_file_path = os.getenv("NWP_ZARR_PATH", default="")
fs = fsspec.open(nwp_source_file_path).fs
fs.get(nwp_source_file_path, "nwp.zarr", recursive=True)

# TODO load historic wind data

def _create_dataloader(self):
"""Setup dataloader with prepared data sources"""

log.info("Creating dataloader")

# Pull the data config from huggingface
data_config_filename = PVNetBaseModel.get_data_config(
self.name,
revision=self.version,
)

# Populate the data config with production data paths
temp_dir = tempfile.TemporaryDirectory()
populated_data_config_filename = f"{temp_dir.name}/data_config.yaml"

populate_data_config_sources(data_config_filename, populated_data_config_filename)

# Location and time datapipes
# TODO not sure what to use here for the location pipe - site uuid/location?
location_pipe = IterableWrapper([1])
t0_datapipe = IterableWrapper([self.t0])
# t0_datapipe = IterableWrapper([self.t0]).repeat(len(location_pipe))

location_pipe = location_pipe.sharding_filter()
t0_datapipe = t0_datapipe.sharding_filter()

# Batch datapipe
base_pipeline = wind_base_pipeline if self.asset_type == "wind" else pv_base_pipeline
batch_datapipe = (
# TODO wind return dict, whereas PV returns IterDataPipe - need to resolve this
# Perhaps see https://github.com/openclimatefix/ocf_datapipes/blob/main/ocf_datapipes/training/windnet.py#L328
base_pipeline(
config_filename=populated_data_config_filename,
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=False # TODO was True, but threw error as expecting GSP key to be defined
)
.batch(BATCH_SIZE)
.map(stack_np_examples_into_batch)
)

n_workers = os.cpu_count() - 1

# Set up dataloader for parallel loading
dataloader_kwargs = dict(
shuffle=False,
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=n_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=worker_init_fn,
prefetch_factor=None if n_workers == 0 else 2,
persistent_workers=False,
)

self.dataloader = DataLoader(batch_datapipe, **dataloader_kwargs)

def _load_model(self):
"""Load model"""

log.info(f"Loading model: {self.name} - {self.version}")
self.model = PVNetBaseModel.from_pretrained(
self.name,
revision=self.version
).to(DEVICE)
50 changes: 50 additions & 0 deletions india_forecast_app/models/pvnet/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Useful functions for setting up PVNet model"""

import fsspec
import yaml


def worker_init_fn(worker_id):
"""
Clear reference to the loop and thread.
This is a nasty hack that was suggested but NOT recommended by the lead fsspec developer!
This appears necessary otherwise gcsfs hangs when used after forking multiple worker processes.
Only required for fsspec >= 0.9.0
See:
- https://github.com/fsspec/gcsfs/issues/379#issuecomment-839929801
- https://github.com/fsspec/filesystem_spec/pull/963#issuecomment-1131709948
TODO: Try deleting this two lines to make sure this is still relevant.
"""
fsspec.asyn.iothread[0] = None
fsspec.asyn.loop[0] = None


def populate_data_config_sources(input_path, output_path):
"""Re-save the data config and replace the source filepaths
Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
"""
with open(input_path) as infile:
config = yaml.load(infile, Loader=yaml.FullLoader)

production_paths = {
# "wind": os.environ["DB_URL"],
"nwp": {"ecmwf": "nwp.zarr"}
}

if "nwp" in config["input_data"]:
nwp_config = config["input_data"]["nwp"]
for nwp_source in nwp_config.keys():
if nwp_config[nwp_source]["nwp_zarr_path"] != "":
assert "nwp" in production_paths, "Missing production path: nwp"
assert nwp_source in production_paths["nwp"], f"Missing NWP path: {nwp_source}"
nwp_config[nwp_source]["nwp_zarr_path"] = production_paths["nwp"][nwp_source]

# We do not need to set wind/PV path right now. This currently done through datapipes
# TODO - Move the wind/PV path to here?

with open(output_path, 'w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)
Loading

0 comments on commit 7de8328

Please sign in to comment.