Skip to content

Commit

Permalink
Refactor (#133) #minor
Browse files Browse the repository at this point in the history
* refactor

* add no-satellite test
  • Loading branch information
dfulu authored Sep 19, 2024
1 parent 264e040 commit 413ea67
Show file tree
Hide file tree
Showing 9 changed files with 792 additions and 619 deletions.
187 changes: 51 additions & 136 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
"""App to run inference
This app expects these evironmental variables to be available:
- DB_URL
- NWP_UKV_ZARR_PATH
- NWP_ECMWF_ZARR_PATH
- SATELLITE_ZARR_PATH
- RUN_EXTRA_MODELS
- USE_ADJUSTER
- SAVE_GSP_SUM
"""App to run inference for PVNet models
"""

import logging
Expand All @@ -24,27 +15,21 @@
from nowcasting_datamodel.connection import DatabaseConnection
from nowcasting_datamodel.models.base import Base_Forecast
from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities
from nowcasting_datamodel.save.save import save as save_sql_forecasts
from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device
from pvnet.models.base_model import BaseModel as PVNetBaseModel
import sentry_sdk


import pvnet_app
from pvnet_app.config import get_union_of_configs, load_yaml_config, save_yaml_config
from pvnet_app.data.nwp import download_all_nwp_data, preprocess_nwp_data
from pvnet_app.data.satellite import (
download_all_sat_data,
preprocess_sat_data,
check_model_inputs_available,
check_model_satellite_inputs_available,
)
from pvnet_app.forecast_compiler import ForecastCompiler
from pvnet_app.utils import (
convert_dataarray_to_forecasts,
find_min_satellite_delay_config,
save_yaml_config,
)

from pvnet_app.dataloader import get_legacy_dataloader, get_dataloader
from pvnet_app.forecast_compiler import ForecastCompiler


# sentry
Expand Down Expand Up @@ -87,14 +72,16 @@
# If summation_model_name is set to None, a simple sum is computed instead
"summation": {
"name": "openclimatefix/pvnet_v2_summation",
"version": os.getenv('PVNET_V2_SUMMATION_VERSION',
"ffac655f9650b81865d96023baa15839f3ce26ec"),
"version": os.getenv(
'PVNET_V2_SUMMATION_VERSION',
"ffac655f9650b81865d96023baa15839f3ce26ec"
),
},
# Whether to use the adjuster for this model - for pvnet_v2 is set by environmental variable
"use_adjuster": os.getenv("USE_ADJUSTER", "true").lower() == "true",
# Whether to save the GSP sum for this model - for pvnet_v2 is set by environmental variable
"save_gsp_sum": os.getenv("SAVE_GSP_SUM", "false").lower() == "true",
# Where to log information through prediction steps for this model
# Whether to log information through prediction steps for this model
"verbose": True,
"save_gsp_to_forecast_value_last_seven_days": True,
},
Expand Down Expand Up @@ -228,9 +215,8 @@ def app(
The following are options
- PVNET_V2_VERSION, pvnet version, default is a version above
- PVNET_V2_SUMMATION_VERSION, the pvnet version, default is above
- USE_SATELLITE, option to get satellite data. defaults to true
- USE_ADJUSTER, option to use adjuster, defaults to true
- SAVE_GSP_SUM, option to save gsp sum, defaults to false
- SAVE_GSP_SUM, option to save gsp sum for pvnet_v2, defaults to false
- RUN_EXTRA_MODELS, option to run extra models, defaults to false
- DAY_AHEAD_MODEL, option to use day ahead model, defaults to false
- SENTRY_DSN, optional link to sentry
Expand All @@ -253,26 +239,27 @@ def app(
dask.config.set(scheduler="single-threaded")

use_day_ahead_model = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true"
use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true"
logger.info(f"Using satellite data: {use_satellite}")
logger.info(f"Using day ahead model: {use_day_ahead_model}")

if use_day_ahead_model:
logger.info(f"Using day ahead PVNet model")


logger.info(f"Using `pvnet` library version: {pvnet.__version__}")
logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}")
logger.info(f"Using {num_workers} workers")
logger.info(f"Using day ahead model: {use_day_ahead_model}")

# Filter the models to be run
if use_day_ahead_model:
logger.info(f"Using adjduster: {day_ahead_model_dict['pvnet_day_ahead']['use_adjuster']}")
logger.info(f"Saving GSP sum: {day_ahead_model_dict['pvnet_day_ahead']['save_gsp_sum']}")

model_to_run_dict = day_ahead_model_dict
main_model_key = "pvnet_day_ahead"
else:
logger.info(f"Using adjduster: {models_dict['pvnet_v2']['use_adjuster']}")
logger.info(f"Saving GSP sum: {models_dict['pvnet_v2']['save_gsp_sum']}")

# Used for temporarily storing things
if os.getenv("RUN_EXTRA_MODELS", "false").lower() == "false":
model_to_run_dict = {"pvnet_v2": models_dict["pvnet_v2"]}
else:
model_to_run_dict = models_dict
main_model_key = "pvnet_v2"

logger.info(f"Using adjduster: {model_to_run_dict[main_model_key]['use_adjuster']}")
logger.info(f"Saving GSP sum: {model_to_run_dict[main_model_key]['save_gsp_sum']}")

temp_dir = tempfile.TemporaryDirectory()

# ---------------------------------------------------------------------------
Expand All @@ -291,9 +278,9 @@ def app(
# ---------------------------------------------------------------------------
# 1. Prepare data sources

logger.info("Loading GSP metadata")

# Get capacities from the database
logger.info("Loading capacities from the database")

db_connection = DatabaseConnection(url=os.getenv("DB_URL"), base=Base_Forecast, echo=False)
with db_connection.get_session() as session:
#  Pandas series of most recent GSP capacities
Expand All @@ -305,18 +292,14 @@ def app(
national_capacity = get_latest_gsp_capacities(session, [0])[0]

# Download satellite data
if use_satellite:
logger.info("Downloading satellite data")
download_all_sat_data()

# Preprocess the satellite data and record the delay of the most recent non-nan timestep
all_satellite_datetimes, data_freq_minutes = preprocess_sat_data(
t0,
use_legacy=use_day_ahead_model
)
logger.info("Downloading satellite data")
sat_available = download_all_sat_data()

# Preprocess the satellite data if available and store available timesteps
if not sat_available:
sat_datetimes = pd.DatetimeIndex([])
else:
all_satellite_datetimes = []
data_freq_minutes = None
sat_datetimes = preprocess_sat_data(t0, use_legacy=use_day_ahead_model)

# Download NWP data
logger.info("Downloading NWP data")
Expand All @@ -328,32 +311,23 @@ def app(
# ---------------------------------------------------------------------------
# 2. Set up models

if use_day_ahead_model:
model_to_run_dict = {"pvnet_day_ahead": day_ahead_model_dict["pvnet_day_ahead"]}
# Remove extra models if not configured to run them
elif os.getenv("RUN_EXTRA_MODELS", "false").lower() == "false":
model_to_run_dict = {"pvnet_v2": models_dict["pvnet_v2"]}
else:
model_to_run_dict = models_dict

# Prepare all the models which can be run
forecast_compilers = {}
data_config_filenames = []
for model_name, model_config in model_to_run_dict.items():
data_config_paths = []
for model_key, model_config in model_to_run_dict.items():
# First load the data config
data_config_filename = PVNetBaseModel.get_data_config(
data_config_path = PVNetBaseModel.get_data_config(
model_config["pvnet"]["name"],
revision=model_config["pvnet"]["version"],
)

# Check if the data available will allow the model to run
model_can_run = check_model_inputs_available(
data_config_filename, all_satellite_datetimes, t0, data_freq_minutes
)
model_can_run = check_model_satellite_inputs_available(data_config_path, t0, sat_datetimes)

if model_can_run:
# Set up a forecast compiler for the model
forecast_compilers[model_name] = ForecastCompiler(
forecast_compilers[model_key] = ForecastCompiler(
model_tag=model_key,
model_name=model_config["pvnet"]["name"],
model_version=model_config["pvnet"]["version"],
summation_name=model_config["summation"]["name"],
Expand All @@ -362,22 +336,23 @@ def app(
t0=t0,
gsp_capacities=gsp_capacities,
national_capacity=national_capacity,
apply_adjuster=model_config["use_adjuster"],
save_gsp_sum=model_config["save_gsp_sum"],
save_gsp_to_recent=model_config["save_gsp_to_forecast_value_last_seven_days"],
verbose=model_config["verbose"],
use_legacy=use_day_ahead_model,
)

# Store the config filename so we can create batches suitable for all models
data_config_filenames.append(data_config_filename)
data_config_paths.append(data_config_path)
else:
warnings.warn(f"The model {model_name} cannot be run with input data available")
warnings.warn(f"The model {model_key} cannot be run with input data available")

if len(forecast_compilers) == 0:
raise Exception(f"No models were compatible with the available input data.")

# Find the config with satellite delay suitable for all models running
common_config = find_min_satellite_delay_config(
data_config_filenames,
use_satellite=use_satellite
)
# Find the config with values suitable for running all models
common_config = get_union_of_configs(data_config_paths)

# Save the commmon config
common_config_path = f"{temp_dir.name}/common_config_path.yaml"
Expand Down Expand Up @@ -405,7 +380,6 @@ def app(
batch_size=batch_size,
num_workers=num_workers,
)


# ---------------------------------------------------------------------------
# Make predictions
Expand Down Expand Up @@ -440,70 +414,11 @@ def app(
logger.info("Writing to database")

with db_connection.get_session() as session:
for model_name, forecast_compiler in forecast_compilers.items():
sql_forecasts = convert_dataarray_to_forecasts(
forecast_compiler.da_abs_all,
session,
model_name=model_name,
version=pvnet_app.__version__,
)
if model_to_run_dict[model_name]["save_gsp_to_forecast_value_last_seven_days"]:

save_sql_forecasts(
forecasts=sql_forecasts,
session=session,
update_national=True,
update_gsp=True,
apply_adjuster=model_to_run_dict[model_name]["use_adjuster"],
)
else:
# national
save_sql_forecasts(
forecasts=sql_forecasts[0:1],
session=session,
update_national=True,
update_gsp=False,
apply_adjuster=model_to_run_dict[model_name]["use_adjuster"],
)
save_sql_forecasts(
forecasts=sql_forecasts[1:],
session=session,
update_national=False,
update_gsp=True,
apply_adjuster=model_to_run_dict[model_name]["use_adjuster"],
save_to_last_seven_days=False,
)

if model_to_run_dict[model_name]["save_gsp_sum"]:
# Compute the sum if we are logging the sume of GSPs independently
da_abs_sum_gsps = (
forecast_compiler.da_abs_all.sel(gsp_id=slice(1, 317))
.sum(dim="gsp_id")
# Only select the central forecast for the GSP sum. The sums of different p-levels
# are not a meaningful qauntities
.sel(output_label=["forecast_mw"])
.expand_dims(dim="gsp_id", axis=0)
.assign_coords(gsp_id=[0])
)

# Save the sum of GSPs independently - mainly for summation model monitoring
sql_forecasts = convert_dataarray_to_forecasts(
da_abs_sum_gsps,
session,
model_name=f"{model_name}_gsp_sum",
version=pvnet_app.__version__,
)

save_sql_forecasts(
forecasts=sql_forecasts,
session=session,
update_national=True,
update_gsp=False,
apply_adjuster=False,
)
for forecast_compiler in forecast_compilers.values():
forecast_compiler.log_forecast_to_database(session=session)

temp_dir.cleanup()
logger.info("Finished forecast")
temp_dir.cleanup()
logger.info("Finished forecast")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 413ea67

Please sign in to comment.