Skip to content

Commit

Permalink
multiprocessing fix and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Nov 24, 2023
1 parent 325b402 commit f8bd3f9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 32 deletions.
36 changes: 7 additions & 29 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import typer
import xarray as xr
import dask
from nowcasting_datamodel.connection import DatabaseConnection
from nowcasting_datamodel.save.save import save as save_sql_forecasts
from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities
Expand All @@ -38,7 +39,7 @@

import pvnet_app
from pvnet_app.utils import (
worker_init_fn, populate_data_config_sources, convert_dataarray_to_forecasts
worker_init_fn, populate_data_config_sources, convert_dataarray_to_forecasts, preds_to_dataarray
)
from pvnet_app.data import regrid_nwp_data, download_sat_data, download_nwp_data

Expand Down Expand Up @@ -92,31 +93,7 @@
sql_logger.addHandler(logging.NullHandler())

# ---------------------------------------------------------------------------
# HELPER FUNCTIONS


def preds_to_dataarray(preds, model, valid_times, gsp_ids):

if model.use_quantile_regression:
output_labels = model.output_quantiles
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
else:
output_labels = ["forecast_mw"]
normed_preds = normed_preds[..., np.newaxis]

da = xr.DataArray(
data=preds,
dims=["gsp_id", "target_datetime_utc", "output_label"],
coords=dict(
gsp_id=gsp_ids,
target_datetime_utc=valid_times,
output_label=output_labels,
),
)

return da

# APP MAIN

def app(
t0=None,
Expand Down Expand Up @@ -144,6 +121,9 @@ def app(

if num_workers == -1:
num_workers = os.cpu_count() - 1
if num_workers>0:
# Without this line the dataloader will hang if multiple workers are used
dask.config.set(scheduler='single-threaded')

logger.info(f"Using `pvnet` library version: {pvnet.__version__}")
logger.info(f"Using {num_workers} workers")
Expand Down Expand Up @@ -233,9 +213,7 @@ def app(
location_pipe=location_pipe,
t0_datapipe=t0_datapipe,
production=True,
check_satellite_no_zeros=False,
block_sat=True,
block_nwp=False,
check_satellite_no_zeros=True,
)
.batch(batch_size)
.map(stack_np_examples_into_batch)
Expand Down
27 changes: 25 additions & 2 deletions pvnet_app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,31 @@ def populate_data_config_sources(input_path, output_path):

with open(output_path, 'w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)


def preds_to_dataarray(preds, model, valid_times, gsp_ids):
"""Put numpy array of predictions into a dataarray"""

if model.use_quantile_regression:
output_labels = model.output_quantiles
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
else:
output_labels = ["forecast_mw"]
normed_preds = normed_preds[..., np.newaxis]

da = xr.DataArray(
data=preds,
dims=["gsp_id", "target_datetime_utc", "output_label"],
coords=dict(
gsp_id=gsp_ids,
target_datetime_utc=valid_times,
output_label=output_labels,
),
)
return da



def convert_dataarray_to_forecasts(
forecast_values_dataarray: xr.DataArray, session: Session, model_name: str, version: str
) -> list[ForecastSQL]:
Expand Down Expand Up @@ -159,4 +182,4 @@ def convert_dataarray_to_forecasts(

forecasts.append(forecast)

return forecasts
return forecasts
3 changes: 2 additions & 1 deletion tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tempfile
import zarr
import os
import logging

from nowcasting_datamodel.models.forecast import (
ForecastSQL,
Expand Down Expand Up @@ -38,7 +39,7 @@ def test_app(db_session, nwp_data, sat_data, gsp_yields_and_systems, me_latest):
# Run prediction
# This import needs to come after the environ vars have been set
from pvnet_app.app import app
app(gsp_ids=list(range(1, 318)))
app(gsp_ids=list(range(1, 318)), num_workers=2)

# Check forecasts have been made
# (317 GSPs + 1 National + GSP-sum) = 319 forecasts
Expand Down

0 comments on commit f8bd3f9

Please sign in to comment.