Skip to content

Commit

Permalink
Merge pull request #16 from openclimatefix/log_gsp_sum
Browse files Browse the repository at this point in the history
add optional sum-of-GSP saving
  • Loading branch information
dfulu authored Oct 20, 2023
2 parents e19702c + bf00d72 commit be3e4bb
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
38 changes: 38 additions & 0 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
model_name_ocf_db = "pvnet_v2"
use_adjuster = os.getenv("USE_ADJUSTER", "True").lower() == "true"

# If environmental variable is true, the sum-of-GSPs will be computed and saved under a different
# model name. This can be useful to compare against the summation model and therefore monitor its
# performance in production
save_gsp_sum = os.getenv("SAVE_GSP_SUM", "False").lower() == "true"
gsp_sum_model_name_ocf_db = "pvnet_gsp_sum"

# ---------------------------------------------------------------------------
# LOGGER
formatter = logging.Formatter(
Expand Down Expand Up @@ -213,6 +219,8 @@ def app(

logger.info(f"Using `pvnet` library version: {pvnet.__version__}")
logger.info(f"Using {num_workers} workers")
logger.info(f"Using adjduster: {use_adjuster}")
logger.info(f"Saving GSP sum: {save_gsp_sum}")

# Allow environment overwrite of model
model_name = os.getenv("APP_MODEL", default=default_model_name)
Expand Down Expand Up @@ -491,6 +499,18 @@ def app(
logger.info(
f"National forecast is {da_abs.sel(gsp_id=0, output_label='forecast_mw').values}"
)

if save_gsp_sum:
# Compute the sum if we are logging the sume of GSPs independently
logger.info("Summing across GSPs to for independent sum-of-GSP saving")
da_abs_sum_gsps = (
da_abs.sum(dim="gsp_id")
# Only select the central forecast for the GSP sum. The sums of different p-levels
# are not a meaningful qauntities
.sel(output_label=["forecast_mw"])
.expand_dims(dim="gsp_id", axis=0)
.assign_coords(gsp_id=[0])
)

# ---------------------------------------------------------------------------
# Escape clause for making predictions locally
Expand All @@ -514,6 +534,24 @@ def app(
update_gsp=True,
apply_adjuster=apply_adjuster,
)

if save_gsp_sum:
# Save the sum of GSPs independently - mainly for summation model monitoring
sql_forecasts = convert_dataarray_to_forecasts(
da_abs_sum_gsps,
session,
model_name=gsp_sum_model_name_ocf_db,
version=pvnet_app.__version__
)

save_sql_forecasts(
forecasts=sql_forecasts,
session=session,
update_national=True,
update_gsp=False,
apply_adjuster=False,
)


logger.info("Finished forecast")

Expand Down
14 changes: 8 additions & 6 deletions tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pvnet_app.app import app
import tempfile
import zarr
import os
Expand Down Expand Up @@ -36,21 +35,24 @@ def test_app(db_session, nwp_data, sat_data, gsp_yields_and_systems, me_latest):
# Set model version
os.environ["APP_MODEL_VERSION"] = "96ac8c67fa8663844ddcfa82aece51ef94f34453"
os.environ["APP_SUMMATION_MODEL_VERSION"] = "4a145d74c725ffc72f482025d3418659a6869c94"
os.environ["SAVE_GSP_SUM"] = "True"

# Run prediction
# This import needs to come after the environ vars have been set
from pvnet_app.app import app
app(gsp_ids=list(range(1, 318)))

# Check forecasts have been made
# (317 GSPs + 1 National) = 318 forecasts
# (317 GSPs + 1 National + GSP-sum) = 319 forecasts
# Doubled for historic and forecast
forecasts = db_session.query(ForecastSQL).all()
assert len(forecasts) == 318 * 2
assert len(forecasts) == 319 * 2

# Check probabilistic added
assert "90" in forecasts[0].forecast_values[0].properties
assert "10" in forecasts[0].forecast_values[0].properties

# 318 GSPs * 16 time steps in forecast
assert len(db_session.query(ForecastValueSQL).all()) == 318 * 16
assert len(db_session.query(ForecastValueLatestSQL).all()) == 318 * 16
assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 318 * 16
assert len(db_session.query(ForecastValueSQL).all()) == 319 * 16
assert len(db_session.query(ForecastValueLatestSQL).all()) == 319 * 16
assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 319 * 16

0 comments on commit be3e4bb

Please sign in to comment.