From 413ea67bd6fbf0d7e7702985fa1d36642ddd483a Mon Sep 17 00:00:00 2001 From: James Fulton <41546094+dfulu@users.noreply.github.com> Date: Thu, 19 Sep 2024 17:34:01 +0100 Subject: [PATCH] Refactor (#133) #minor * refactor * add no-satellite test --- pvnet_app/app.py | 187 +++------- pvnet_app/config.py | 139 +++++++ pvnet_app/data/satellite.py | 284 ++++++++------ pvnet_app/dataloader.py | 2 +- pvnet_app/forecast_compiler.py | 346 +++++++++++++++--- pvnet_app/utils.py | 258 ------------- tests/conftest.py | 1 - .../{test_data.py => data/test_satellite.py} | 43 ++- tests/test_app.py | 151 +++++--- 9 files changed, 792 insertions(+), 619 deletions(-) create mode 100644 pvnet_app/config.py delete mode 100644 pvnet_app/utils.py rename tests/{test_data.py => data/test_satellite.py} (82%) diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 9ae4844..95656e3 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -1,13 +1,4 @@ -"""App to run inference - -This app expects these evironmental variables to be available: - - DB_URL - - NWP_UKV_ZARR_PATH - - NWP_ECMWF_ZARR_PATH - - SATELLITE_ZARR_PATH - - RUN_EXTRA_MODELS - - USE_ADJUSTER - - SAVE_GSP_SUM +"""App to run inference for PVNet models """ import logging @@ -24,27 +15,21 @@ from nowcasting_datamodel.connection import DatabaseConnection from nowcasting_datamodel.models.base import Base_Forecast from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities -from nowcasting_datamodel.save.save import save as save_sql_forecasts from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device from pvnet.models.base_model import BaseModel as PVNetBaseModel import sentry_sdk import pvnet_app +from pvnet_app.config import get_union_of_configs, load_yaml_config, save_yaml_config from pvnet_app.data.nwp import download_all_nwp_data, preprocess_nwp_data from pvnet_app.data.satellite import ( download_all_sat_data, preprocess_sat_data, - check_model_inputs_available, + check_model_satellite_inputs_available, ) -from pvnet_app.forecast_compiler import ForecastCompiler -from pvnet_app.utils import ( - convert_dataarray_to_forecasts, - find_min_satellite_delay_config, - save_yaml_config, -) - from pvnet_app.dataloader import get_legacy_dataloader, get_dataloader +from pvnet_app.forecast_compiler import ForecastCompiler # sentry @@ -87,14 +72,16 @@ # If summation_model_name is set to None, a simple sum is computed instead "summation": { "name": "openclimatefix/pvnet_v2_summation", - "version": os.getenv('PVNET_V2_SUMMATION_VERSION', - "ffac655f9650b81865d96023baa15839f3ce26ec"), + "version": os.getenv( + 'PVNET_V2_SUMMATION_VERSION', + "ffac655f9650b81865d96023baa15839f3ce26ec" + ), }, # Whether to use the adjuster for this model - for pvnet_v2 is set by environmental variable "use_adjuster": os.getenv("USE_ADJUSTER", "true").lower() == "true", # Whether to save the GSP sum for this model - for pvnet_v2 is set by environmental variable "save_gsp_sum": os.getenv("SAVE_GSP_SUM", "false").lower() == "true", - # Where to log information through prediction steps for this model + # Whether to log information through prediction steps for this model "verbose": True, "save_gsp_to_forecast_value_last_seven_days": True, }, @@ -228,9 +215,8 @@ def app( The following are options - PVNET_V2_VERSION, pvnet version, default is a version above - PVNET_V2_SUMMATION_VERSION, the pvnet version, default is above - - USE_SATELLITE, option to get satellite data. defaults to true - USE_ADJUSTER, option to use adjuster, defaults to true - - SAVE_GSP_SUM, option to save gsp sum, defaults to false + - SAVE_GSP_SUM, option to save gsp sum for pvnet_v2, defaults to false - RUN_EXTRA_MODELS, option to run extra models, defaults to false - DAY_AHEAD_MODEL, option to use day ahead model, defaults to false - SENTRY_DSN, optional link to sentry @@ -253,26 +239,27 @@ def app( dask.config.set(scheduler="single-threaded") use_day_ahead_model = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true" - use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true" - logger.info(f"Using satellite data: {use_satellite}") - logger.info(f"Using day ahead model: {use_day_ahead_model}") - - if use_day_ahead_model: - logger.info(f"Using day ahead PVNet model") - + logger.info(f"Using `pvnet` library version: {pvnet.__version__}") logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}") logger.info(f"Using {num_workers} workers") + logger.info(f"Using day ahead model: {use_day_ahead_model}") + # Filter the models to be run if use_day_ahead_model: - logger.info(f"Using adjduster: {day_ahead_model_dict['pvnet_day_ahead']['use_adjuster']}") - logger.info(f"Saving GSP sum: {day_ahead_model_dict['pvnet_day_ahead']['save_gsp_sum']}") - + model_to_run_dict = day_ahead_model_dict + main_model_key = "pvnet_day_ahead" else: - logger.info(f"Using adjduster: {models_dict['pvnet_v2']['use_adjuster']}") - logger.info(f"Saving GSP sum: {models_dict['pvnet_v2']['save_gsp_sum']}") - # Used for temporarily storing things + if os.getenv("RUN_EXTRA_MODELS", "false").lower() == "false": + model_to_run_dict = {"pvnet_v2": models_dict["pvnet_v2"]} + else: + model_to_run_dict = models_dict + main_model_key = "pvnet_v2" + + logger.info(f"Using adjduster: {model_to_run_dict[main_model_key]['use_adjuster']}") + logger.info(f"Saving GSP sum: {model_to_run_dict[main_model_key]['save_gsp_sum']}") + temp_dir = tempfile.TemporaryDirectory() # --------------------------------------------------------------------------- @@ -291,9 +278,9 @@ def app( # --------------------------------------------------------------------------- # 1. Prepare data sources - logger.info("Loading GSP metadata") - # Get capacities from the database + logger.info("Loading capacities from the database") + db_connection = DatabaseConnection(url=os.getenv("DB_URL"), base=Base_Forecast, echo=False) with db_connection.get_session() as session: #  Pandas series of most recent GSP capacities @@ -305,18 +292,14 @@ def app( national_capacity = get_latest_gsp_capacities(session, [0])[0] # Download satellite data - if use_satellite: - logger.info("Downloading satellite data") - download_all_sat_data() - - # Preprocess the satellite data and record the delay of the most recent non-nan timestep - all_satellite_datetimes, data_freq_minutes = preprocess_sat_data( - t0, - use_legacy=use_day_ahead_model - ) + logger.info("Downloading satellite data") + sat_available = download_all_sat_data() + + # Preprocess the satellite data if available and store available timesteps + if not sat_available: + sat_datetimes = pd.DatetimeIndex([]) else: - all_satellite_datetimes = [] - data_freq_minutes = None + sat_datetimes = preprocess_sat_data(t0, use_legacy=use_day_ahead_model) # Download NWP data logger.info("Downloading NWP data") @@ -328,32 +311,23 @@ def app( # --------------------------------------------------------------------------- # 2. Set up models - if use_day_ahead_model: - model_to_run_dict = {"pvnet_day_ahead": day_ahead_model_dict["pvnet_day_ahead"]} - # Remove extra models if not configured to run them - elif os.getenv("RUN_EXTRA_MODELS", "false").lower() == "false": - model_to_run_dict = {"pvnet_v2": models_dict["pvnet_v2"]} - else: - model_to_run_dict = models_dict - # Prepare all the models which can be run forecast_compilers = {} - data_config_filenames = [] - for model_name, model_config in model_to_run_dict.items(): + data_config_paths = [] + for model_key, model_config in model_to_run_dict.items(): # First load the data config - data_config_filename = PVNetBaseModel.get_data_config( + data_config_path = PVNetBaseModel.get_data_config( model_config["pvnet"]["name"], revision=model_config["pvnet"]["version"], ) # Check if the data available will allow the model to run - model_can_run = check_model_inputs_available( - data_config_filename, all_satellite_datetimes, t0, data_freq_minutes - ) + model_can_run = check_model_satellite_inputs_available(data_config_path, t0, sat_datetimes) if model_can_run: # Set up a forecast compiler for the model - forecast_compilers[model_name] = ForecastCompiler( + forecast_compilers[model_key] = ForecastCompiler( + model_tag=model_key, model_name=model_config["pvnet"]["name"], model_version=model_config["pvnet"]["version"], summation_name=model_config["summation"]["name"], @@ -362,22 +336,23 @@ def app( t0=t0, gsp_capacities=gsp_capacities, national_capacity=national_capacity, + apply_adjuster=model_config["use_adjuster"], + save_gsp_sum=model_config["save_gsp_sum"], + save_gsp_to_recent=model_config["save_gsp_to_forecast_value_last_seven_days"], verbose=model_config["verbose"], + use_legacy=use_day_ahead_model, ) # Store the config filename so we can create batches suitable for all models - data_config_filenames.append(data_config_filename) + data_config_paths.append(data_config_path) else: - warnings.warn(f"The model {model_name} cannot be run with input data available") + warnings.warn(f"The model {model_key} cannot be run with input data available") if len(forecast_compilers) == 0: raise Exception(f"No models were compatible with the available input data.") - # Find the config with satellite delay suitable for all models running - common_config = find_min_satellite_delay_config( - data_config_filenames, - use_satellite=use_satellite - ) + # Find the config with values suitable for running all models + common_config = get_union_of_configs(data_config_paths) # Save the commmon config common_config_path = f"{temp_dir.name}/common_config_path.yaml" @@ -405,7 +380,6 @@ def app( batch_size=batch_size, num_workers=num_workers, ) - # --------------------------------------------------------------------------- # Make predictions @@ -440,70 +414,11 @@ def app( logger.info("Writing to database") with db_connection.get_session() as session: - for model_name, forecast_compiler in forecast_compilers.items(): - sql_forecasts = convert_dataarray_to_forecasts( - forecast_compiler.da_abs_all, - session, - model_name=model_name, - version=pvnet_app.__version__, - ) - if model_to_run_dict[model_name]["save_gsp_to_forecast_value_last_seven_days"]: - - save_sql_forecasts( - forecasts=sql_forecasts, - session=session, - update_national=True, - update_gsp=True, - apply_adjuster=model_to_run_dict[model_name]["use_adjuster"], - ) - else: - # national - save_sql_forecasts( - forecasts=sql_forecasts[0:1], - session=session, - update_national=True, - update_gsp=False, - apply_adjuster=model_to_run_dict[model_name]["use_adjuster"], - ) - save_sql_forecasts( - forecasts=sql_forecasts[1:], - session=session, - update_national=False, - update_gsp=True, - apply_adjuster=model_to_run_dict[model_name]["use_adjuster"], - save_to_last_seven_days=False, - ) - - if model_to_run_dict[model_name]["save_gsp_sum"]: - # Compute the sum if we are logging the sume of GSPs independently - da_abs_sum_gsps = ( - forecast_compiler.da_abs_all.sel(gsp_id=slice(1, 317)) - .sum(dim="gsp_id") - # Only select the central forecast for the GSP sum. The sums of different p-levels - # are not a meaningful qauntities - .sel(output_label=["forecast_mw"]) - .expand_dims(dim="gsp_id", axis=0) - .assign_coords(gsp_id=[0]) - ) - - # Save the sum of GSPs independently - mainly for summation model monitoring - sql_forecasts = convert_dataarray_to_forecasts( - da_abs_sum_gsps, - session, - model_name=f"{model_name}_gsp_sum", - version=pvnet_app.__version__, - ) - - save_sql_forecasts( - forecasts=sql_forecasts, - session=session, - update_national=True, - update_gsp=False, - apply_adjuster=False, - ) + for forecast_compiler in forecast_compilers.values(): + forecast_compiler.log_forecast_to_database(session=session) - temp_dir.cleanup() - logger.info("Finished forecast") + temp_dir.cleanup() + logger.info("Finished forecast") if __name__ == "__main__": diff --git a/pvnet_app/config.py b/pvnet_app/config.py new file mode 100644 index 0000000..b66aaa1 --- /dev/null +++ b/pvnet_app/config.py @@ -0,0 +1,139 @@ +import yaml + +from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path + + +def load_yaml_config(path: str) -> dict: + """Load config file from path""" + with open(path) as file: + config = yaml.load(file, Loader=yaml.FullLoader) + return config + + +def save_yaml_config(config: dict, path: str) -> None: + """Save config file to path""" + with open(path, 'w') as file: + yaml.dump(config, file, default_flow_style=False) + + +def populate_config_with_data_data_filepaths(config: dict, gsp_path: str = "") -> dict: + """Populate the data source filepaths in the config + + Args: + config: The data config + gsp_path: For lagacy usage only + """ + + production_paths = { + "gsp": gsp_path, + "nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path}, + "satellite": sat_path, + } + + # Replace data sources + for source in ["gsp", "satellite"]: + if source in config["input_data"] : + if config["input_data"][source][f"{source}_zarr_path"]!="": + assert source in production_paths, f"Missing production path: {source}" + config["input_data"][source][f"{source}_zarr_path"] = production_paths[source] + + # NWP is nested so much be treated separately + if "nwp" in config["input_data"]: + nwp_config = config["input_data"]["nwp"] + for nwp_source in nwp_config.keys(): + if nwp_config[nwp_source]["nwp_zarr_path"]!="": + assert "nwp" in production_paths, "Missing production path: nwp" + assert nwp_source in production_paths["nwp"], f"Missing NWP path: {nwp_source}" + nwp_config[nwp_source]["nwp_zarr_path"] = production_paths["nwp"][nwp_source] + + return config + + +def overwrite_config_dropouts(config: dict) -> dict: + """Overwrite the config drouput parameters for production + + Args: + config: The data config + """ + + # Replace data sources + for source in ["satellite"]: + if source in config["input_data"] : + if config["input_data"][source][f"{source}_zarr_path"]!="": + config["input_data"][source][f"dropout_timedeltas_minutes"] = None + + # NWP is nested so much be treated separately + if "nwp" in config["input_data"]: + nwp_config = config["input_data"]["nwp"] + for nwp_source in nwp_config.keys(): + if nwp_config[nwp_source]["nwp_zarr_path"]!="": + nwp_config[nwp_source]["dropout_timedeltas_minutes"] = None + + return config + + +def modify_data_config_for_production( + input_path: str, + output_path: str, + gsp_path: str = "" +) -> None: + """Resave the data config with the data source filepaths and dropouts overwritten + + Args: + input_path: Path to input datapipes configuration file + output_path: Location to save the output configuration file + gsp_path: For lagacy usage only + """ + config = load_yaml_config(input_path) + + config = populate_config_with_data_data_filepaths(config, gsp_path=gsp_path) + config = overwrite_config_dropouts(config) + + save_yaml_config(config, output_path) + + +def get_union_of_configs(config_paths: list[str]) -> dict: + """Find the config which is able to run all models from a list of config paths + + Note that this implementation is very limited and will not work in general unless all models + have been trained on the same batches. We do not chck example if the satellite and NWP channels + are the same in the different configs, or whether the NWP time slices are the same. Many more + limitations not mentioned apply + """ + + # Load all the configs + configs = [load_yaml_config(config_path) for config_path in config_paths] + + # We will ammend this config according to the entries in the other configs + common_config = configs[0] + + for config in configs[1:]: + + if "satellite" in config["input_data"]: + + if "satellite" in common_config["input_data"]: + + # Find the minimum satellite delay across configs + common_config["input_data"]["satellite"]["live_delay_minutes"] = min( + common_config["input_data"]["satellite"]["live_delay_minutes"], + config["input_data"]["satellite"]["live_delay_minutes"] + ) + + + else: + # Add satellite to common config if not there already + common_config["input_data"]["satellite"] = config["input_data"]["satellite"] + + if "nwp" in config["input_data"]: + + # Add NWP to common config if not there already + if "nwp" not in common_config["input_data"]: + common_config["input_data"]["nwp"] = config["input_data"]["nwp"] + + else: + for nwp_key, nwp_conf in config["input_data"]["nwp"].items(): + # Add different NWP sources to common config if not there already + if nwp_key not in common_config["input_data"]["nwp"]: + common_config["input_data"]["nwp"][nwp_key] = nwp_conf + + return common_config \ No newline at end of file diff --git a/pvnet_app/data/satellite.py b/pvnet_app/data/satellite.py index b3ff264..fb5d8b9 100644 --- a/pvnet_app/data/satellite.py +++ b/pvnet_app/data/satellite.py @@ -1,11 +1,9 @@ import numpy as np import pandas as pd import xarray as xr -import xesmf as xe import logging import os import fsspec -from datetime import timedelta, datetime import ocf_blosc2 from ocf_datapipes.config.load import load_yaml_configuration @@ -13,71 +11,64 @@ logger = logging.getLogger(__name__) -this_dir = os.path.dirname(os.path.abspath(__file__)) - sat_5_path = "sat_5_min.zarr" sat_15_path = "sat_15_min.zarr" -def download_all_sat_data() -> None: - """Download the sat data""" +def download_all_sat_data() -> bool: + """Download the sat data and return whether it was successful + + Returns: + bool: Whether the download was successful + """ # Clean out old files os.system(f"rm -r {sat_path} {sat_5_path} {sat_15_path}") + # Set variable to track whether the satellite download is successful + sat_available = False + # download 5 minute satellite data - sat_download_path = os.environ["SATELLITE_ZARR_PATH"] - fs = fsspec.open(sat_download_path).fs - if fs.exists(sat_download_path): - fs.get(sat_download_path, "sat_5_min.zarr.zip") + sat_5_dl_path = os.environ["SATELLITE_ZARR_PATH"] + fs = fsspec.open(sat_5_dl_path).fs + if fs.exists(sat_5_dl_path): + sat_available = True + logger.info(f"Downloading 5-minute satellite data") + fs.get(sat_5_dl_path, "sat_5_min.zarr.zip") os.system(f"unzip -qq sat_5_min.zarr.zip -d {sat_5_path}") os.system(f"rm sat_5_min.zarr.zip") + else: + logger.info(f"No 5-minute data available") # Also download 15-minute satellite if it exists - sat_15_dl_path = ( - os.environ["SATELLITE_ZARR_PATH"] - .replace("sat.zarr", "sat_15.zarr") - .replace("latest.zarr", "latest_15.zarr") - ) + sat_15_dl_path = os.environ["SATELLITE_ZARR_PATH"].replace(".zarr", "_15.zarr") if fs.exists(sat_15_dl_path): - logger.info(f"Downloading 15-minute satellite data {sat_15_dl_path}") + sat_available = True + logger.info(f"Downloading 15-minute satellite data") fs.get(sat_15_dl_path, "sat_15_min.zarr.zip") - os.system(f"unzip sat_15_min.zarr.zip -d {sat_15_path}") + os.system(f"unzip -qq sat_15_min.zarr.zip -d {sat_15_path}") os.system(f"rm sat_15_min.zarr.zip") + else: + logger.info(f"No 15-minute data available") + + return sat_available -def _get_latest_time_and_mins_delay( - sat_zarr_path: str, - t0: pd.Timestamp, -) -> tuple[pd.Timestamp, int, pd.DatetimeIndex]: - """Get datetime info about the available satellite data +def get_satellite_timestamps(sat_zarr_path: str) -> pd.DatetimeIndex: + """Get the datetimes of the satellite data Args: sat_zarr_path: The path to the satellite zarr - t0: The init-time of the forecast Returns: - pd.Timestamp: The most recent available satellite timestamp - int: The delay in minutes of the most recent timestamp pd.DatetimeIndex: All available satellite timestamps """ ds_sat = xr.open_zarr(sat_zarr_path) - latest_time = pd.to_datetime(ds_sat.time.max().item()) - delay = t0 - latest_time - delay_mins = int(delay.total_seconds() / 60) - all_datetimes = pd.to_datetime(ds_sat.time.values) - return latest_time, delay_mins, all_datetimes + return pd.to_datetime(ds_sat.time.values) -def combine_5_and_15_sat_data(t0) -> tuple[int, pd.DatetimeIndex]: +def combine_5_and_15_sat_data() -> None: """Select and/or combine the 5 and 15-minutely satellite data and move it to the expected path - - Args: - t0: The init-time of the forecast - - Returns: - int: The spacing between data samples in minutes - pd.DatetimeIndex: The available satellite timestamps """ # Check which satellite data exists @@ -89,88 +80,179 @@ def combine_5_and_15_sat_data(t0) -> tuple[int, pd.DatetimeIndex]: # Find the delay in the 5- and 15-minutely data if exists_5_minute: - latest_time_5, _, all_datetimes_5 = _get_latest_time_and_mins_delay( - sat_5_path, t0 - ) + datetimes_5min = get_satellite_timestamps(sat_5_path) logger.info( - f"Latest 5-minute timestamp is {latest_time_5} for t0 time {t0}. " - f"All the datetimes are {all_datetimes_5}" + f"Latest 5-minute timestamp is {datetimes_5min.max()}. " + f"All the datetimes are: \n{datetimes_5min}" ) else: - latest_time_5, all_datetimes_5 = datetime.min, [] logger.info("No 5-minute data was found.") if exists_15_minute: - latest_time_15, _, all_datetimes_15 = _get_latest_time_and_mins_delay( - sat_15_path, t0 - ) + datetimes_15min = get_satellite_timestamps(sat_15_path) logger.info( - f"Latest 5-minute timestamp is {latest_time_15} for t0 time {t0}. " - f"All the datetimes are {all_datetimes_15}" + f"Latest 5-minute timestamp is {datetimes_15min.max()}. " + f"All the datetimes are: \n{datetimes_15min}" ) else: - latest_time_15 = datetime.min logger.info("No 15-minute data was found.") + + # If both 5- and 15-minute data exists, use the most recent + if exists_5_minute and exists_15_minute: + use_5_minute = datetimes_5min.max() > datetimes_15min.max() + else: + # If only one exists, use that + use_5_minute = exists_5_minute - # Move the data with the most recent timestamp to the expected path - if latest_time_5 >= latest_time_15: + # Move the selected data to the expected path + if use_5_minute: logger.info(f"Using 5-minutely data.") os.system(f"mv {sat_5_path} {sat_path}") - data_freq_minutes = 5 - all_datetimes = all_datetimes_5 else: logger.info(f"Using 15-minutely data.") os.system(f"mv {sat_15_path} {sat_path}") - data_freq_minutes = 15 - all_datetimes = all_datetimes_15 - return data_freq_minutes, all_datetimes +def fill_1d_bool_gaps(x, max_gap): + """In a boolean array, fill consecutive False elements if their number is less than the gap_size + + Args: + x: A 1-dimensional boolean array + max_gap: integer of the maximum gap size which will be filled with True + + Returns: + A 1-dimensional boolean array + + Examples: + >>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0]) + >>> fill_1d_bool_gaps(x, max_gap=2).astype(int) + array([0, 1, 1, 1, 1, 1, 1, 0]) + + >>> x = np.array([1, 0, 0, 0, 1, 0, 1, 0]) + >>> fill_1d_bool_gaps(x, max_gap=2).astype(int) + array([1, 0, 0, 0, 1, 1, 1, 0]) + """ + + should_fill = np.zeros(len(x), dtype=bool) + + i_start = None + + last_b = False + for i, b in enumerate(x): + if last_b and not b: + i_start = i + elif b and not last_b and i_start is not None: + if i - i_start <= max_gap: + should_fill[i_start:i] = True + i_start = None + last_b = b + + return np.logical_or(should_fill, x) + + +def interpolate_missing_satellite_timestamps(max_gap: pd.Timedelta) -> None: + """Interpolate missing satellite timestamps""" + + ds_sat = xr.open_zarr(sat_path) + + # If any of these times are missing, we will try to interpolate them + dense_times = pd.date_range( + ds_sat.time.values.min(), + ds_sat.time.values.max(), + freq="5min", + ) + + # Create mask array of which timestamps are available + timestamp_available = np.isin(dense_times, ds_sat.time) + + # If all the requested times are present we avoid running interpolation + if timestamp_available.all(): + logger.warning("No gaps in the available satllite sequence - no interpolation run") + return + + # If less than 2 of the buffer requested times are present we cannot infill + elif timestamp_available.sum() < 2: + logger.warning("Cannot run interpolate infilling with less than 2 time steps available") + return + + else: + logger.info("Some requested times are missing - running interpolation") + + # Compute before interpolation for efficiency + ds_sat = ds_sat.compute() + + # Run the interpolation to all 5-minute timestamps between the first and last + ds_interp = ds_sat.interp(time=dense_times, method="linear", assume_sorted=True) + + # Find the timestamps which are within max gap size + max_gap_steps = int(max_gap / pd.Timedelta("5min")) - 1 + valid_fill_times = fill_1d_bool_gaps(timestamp_available, max_gap_steps) + + # Mask the timestamps outside the max gap size + valid_fill_times_xr = xr.zeros_like(ds_interp.time, dtype=bool) + valid_fill_times_xr.values[:] = valid_fill_times + ds_sat = ds_interp.where(valid_fill_times_xr) + + time_was_filled = np.logical_and(valid_fill_times_xr, ~timestamp_available) + + if time_was_filled.any(): + infilled_times = time_was_filled.where(time_was_filled, drop=True) + logger.info( + "The following times were filled by interpolation:\n" + f"{infilled_times.time.values}" + ) + + if not valid_fill_times_xr.all(): + not_infilled_times = valid_fill_times_xr.where(~valid_fill_times_xr, drop=True) + logger.info( + "After interpolation the following times are still missing:\n" + f"{not_infilled_times.time.values}" + ) + + # Save the interpolated data + os.system(f"rm -rf {sat_path}") + ds_sat.to_zarr(sat_path) -def extend_satellite_data_with_nans(t0: pd.Timestamp) -> int: + +def extend_satellite_data_with_nans(t0: pd.Timestamp) -> None: """Fill the satellite data with NaNs out to time t0 Args: t0: The init-time of the forecast - - Returns: - int: The delay in minutes of the most recent timestamp """ # Find how delayed the satellite data is - _, delay_mins, _ = _get_latest_time_and_mins_delay(sat_path, t0) + ds_sat = xr.open_zarr(sat_path) + delay = t0 - pd.to_datetime(ds_sat.time).max() - if delay_mins > 0: - logger.info(f"Filling most recent {delay_mins} mins with NaNs") + if delay > pd.Timedelta(0): + logger.info(f"Filling most recent {delay} with NaNs") # Load into memory so we can delete it on disk - ds_sat = xr.open_zarr(sat_path).compute() + ds_sat = ds_sat.compute() - # Pad with zeros - fill_times = pd.date_range(t0 + timedelta(minutes=(-delay_mins + 5)), t0, freq="5min") + # We will fill the data with NaNs for these timestamps + fill_times = pd.date_range(t0 - delay + pd.Timedelta("5min"), t0, freq="5min") + # Extend the data with NaNs ds_sat = ds_sat.reindex(time=np.concatenate([ds_sat.time, fill_times]), fill_value=np.nan) # Re-save inplace os.system(f"rm -rf {sat_path}") ds_sat.to_zarr(sat_path) - return delay_mins - -def check_model_inputs_available( +def check_model_satellite_inputs_available( data_config_filename: str, - all_satellite_datetimes: pd.DatetimeIndex, - t0: pd.Timestamp, - data_freq_minutes: int, + t0: pd.Timestamp, + sat_datetimes: pd.DatetimeIndex, ) -> bool: """Checks whether the model can be run given the current satellite delay Args: data_config_filename: Path to the data configuration file - all_satellite_datetimes: All the satellite datetimes available t0: The init-time of the forecast - data_freq_minutes: The frequency of the satellite data. This can be 5 or 15 minutes. + available_sat_datetimes: The available satellite timestamps Returns: bool: Whether the satellite data satisfies that specified in the config @@ -182,7 +264,7 @@ def check_model_inputs_available( # check satellite if using if hasattr(data_config.input_data, "satellite"): - if data_config.input_data.satellite is not None: + if data_config.input_data.satellite: # Take into account how recently the model tries to slice satellite data from max_sat_delay_allowed_mins = data_config.input_data.satellite.live_delay_minutes @@ -194,37 +276,27 @@ def check_model_inputs_available( np.abs(data_config.input_data.satellite.dropout_timedeltas_minutes).max(), ) - # get start and end satellite times + # Get all expected datetimes history_minutes = data_config.input_data.satellite.history_minutes - # we only check every 15 minutes, as ocf_datapipes resample from 15 to 5 if necessary. - freq = f"{data_freq_minutes}min" - logger.info( - f"Checking satellite data for {t0=} with history {history_minutes=} " - f"and freq {freq=}, for {max_sat_delay_allowed_mins=}" - ) expected_datetimes = pd.date_range( - t0 - timedelta(minutes=int(history_minutes)), - t0 - timedelta(minutes=int(max_sat_delay_allowed_mins)), - freq=freq, + t0 - pd.Timedelta(f"{int(history_minutes)}min"), + t0 - pd.Timedelta(f"{int(max_sat_delay_allowed_mins)}min"), + freq="5min", ) - # Check if all expected datetimes are in the available satellite data - all_satellite_data_present = all( - [t in all_satellite_datetimes for t in expected_datetimes] - ) - if not all_satellite_data_present: - # log something. e,g x,y timestamps are missing - logger.info( - f"Missing satellite data for {expected_datetimes} in {all_satellite_datetimes}" - ) + # Check if any of the expected datetimes are missing + missing_time_steps = np.setdiff1d(expected_datetimes, sat_datetimes, assume_unique=True) - available = all_satellite_data_present + available = len(missing_time_steps) == 0 + + if len(missing_time_steps)>0: + logger.info(f"Some satellite timesteps for {t0=} missing: \n{missing_time_steps}") return available -def preprocess_sat_data(t0: pd.Timestamp, use_legacy: bool = False) -> tuple[pd.DatetimeIndex, int]: +def preprocess_sat_data(t0: pd.Timestamp, use_legacy: bool = False) -> pd.DatetimeIndex: """Combine and 5- and 15-minutely satellite data and extend to t0 if required Args: @@ -237,18 +309,24 @@ def preprocess_sat_data(t0: pd.Timestamp, use_legacy: bool = False) -> tuple[pd. """ # Deal with switching between the 5 and 15 minutely satellite data - data_freq_minutes, all_datetimes = combine_5_and_15_sat_data(t0) + combine_5_and_15_sat_data() + + # Interpolate missing satellite timestamps + interpolate_missing_satellite_timestamps(pd.Timedelta("15min")) - # Extend the satellite data with NaNs if needed by the model and record the delay of most recent - # non-nan timestamp - extend_satellite_data_with_nans(t0) - if not use_legacy: # scale the satellite data if not legacy. The legacy dataloader does production data scaling # inside it. The new dataloader does not scale_satellite_data() - return all_datetimes, data_freq_minutes + # Get the available satellite timestamps before we extend with NaNs + sat_timestamps = get_satellite_timestamps(sat_path) + + # Extend the satellite data with NaNs if needed by the model and record the delay of most recent + # non-nan timestamp + extend_satellite_data_with_nans(t0) + + return sat_timestamps def scale_satellite_data() -> None: diff --git a/pvnet_app/dataloader.py b/pvnet_app/dataloader.py index f20c473..ba380ee 100644 --- a/pvnet_app/dataloader.py +++ b/pvnet_app/dataloader.py @@ -5,7 +5,7 @@ from ocf_datapipes.batch import stack_np_examples_into_batch from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset -from pvnet_app.utils import modify_data_config_for_production +from pvnet_app.config import modify_data_config_for_production # Legacy imports - only used for legacy dataloader import os diff --git a/pvnet_app/forecast_compiler.py b/pvnet_app/forecast_compiler.py index 3a3a8a2..6d97d9d 100644 --- a/pvnet_app/forecast_compiler.py +++ b/pvnet_app/forecast_compiler.py @@ -1,28 +1,35 @@ +from datetime import timezone, datetime import warnings import logging -from datetime import timedelta import torch import numpy as np import pandas as pd import xarray as xr -from ocf_datapipes.batch import BatchKey +from ocf_datapipes.batch import BatchKey, NumpyBatch from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD import pvnet from pvnet.models.base_model import BaseModel as PVNetBaseModel from pvnet_summation.models.base_model import BaseModel as SummationBaseModel -from pvnet_app.utils import preds_to_dataarray +from sqlalchemy.orm import Session + +from nowcasting_datamodel.models import ForecastSQL, ForecastValue +from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location +from nowcasting_datamodel.read.read_models import get_model +from nowcasting_datamodel.save.save import save as save_sql_forecasts + +import pvnet_app logger = logging.getLogger(__name__) -# If the solar elevation is less than this the predictions are set to zero +# If the solar elevation (in degrees) is less than this the predictions are set to zero MIN_DAY_ELEVATION = 0 -_summation_mismatch_msg = ( +_model_mismatch_msg = ( "The PVNet version running in this app is {}/{}. The summation model running in this app was " "trained on outputs from PVNet version {}/{}. Combining these models may lead to an error if " "the shape of PVNet output doesn't match the expected shape of the summation model. Combining " @@ -34,19 +41,25 @@ class ForecastCompiler: """Class for making and compiling solar forecasts from for all GB GSPsn and national total""" def __init__( self, + model_tag: str, model_name: str, model_version: str, summation_name: str | None, - summation_version: str | None, + summation_version: str | None, device: torch.device, t0: pd.Timestamp, gsp_capacities: xr.DataArray, national_capacity: float, - verbose: bool = False + apply_adjuster: bool, + save_gsp_sum: bool, + save_gsp_to_recent: bool, + verbose: bool = False, + use_legacy: bool = False, ): """Class for making and compiling solar forecasts from for all GB GSPsn and national total Args: + model_tag: The name the model results will be saved to the database under model_name: Name of the huggingface repo where the PVNet model is stored model_version: Version of the PVNet model to run within the huggingface repo summation_name: Name of the huggingface repo where the summation model is stored @@ -55,56 +68,91 @@ def __init__( t0: The t0 time used to compile the results to numpy array gsp_capacities: DataArray of the solar capacities for all regional GSPs at t0 national_capacity: The national solar capacity at t0 + apply_adjuster: Whether to apply the adjuster when saving to database + save_gsp_sum: Whether to save the GSP sum + save_gsp_to_recent: Whether to save the GSP results to the + forecast_value_last_seven_days table verbose: Whether to log all messages throughout prediction and compilation + legacy: Whether to run legacy dataloader """ + + logger.info(f"Loading model: {model_name} - {model_version}") + + + # Store settings + self.model_tag = model_tag self.model_name = model_name self.model_version = model_version self.device = device - self.t0 = t0 self.gsp_capacities = gsp_capacities self.national_capacity = national_capacity + self.apply_adjuster = apply_adjuster + self.save_gsp_sum = save_gsp_sum + self.save_gsp_to_recent = save_gsp_to_recent self.verbose = verbose + self.use_legacy = use_legacy + + # Create stores for the predictions self.normed_preds = [] self.gsp_ids_each_batch = [] self.sun_down_masks = [] + # Load the GSP and summation models + self.model, self.summation_model = self.load_model( + model_name, + model_version, + summation_name, + summation_version, + device, + ) + # These are the valid times this forecast will predict for + self.valid_times = ( + t0 + pd.timedelta_range(start='30min', freq='30min', periods=self.model.forecast_len) + ) - logger.info(f"Loading model: {model_name} - {model_version}") - - self.model = PVNetBaseModel.from_pretrained( + @staticmethod + def load_model( + model_name: str, + model_version: str, + summation_name: str | None, + summation_version: str | None, + device: torch.device, + ): + """Load the GSP and summation models""" + + # Load the GSP level model + model = PVNetBaseModel.from_pretrained( model_id=model_name, revision=model_version, ).to(device) - + + # Load the summation model if summation_name is None: - self.summation_model = None + sum_model = None else: - self.summation_model = SummationBaseModel.from_pretrained( + sum_model = SummationBaseModel.from_pretrained( model_id=summation_name, revision=summation_version, ).to(device) - - if ( - (self.summation_model.pvnet_model_name, self.summation_model.pvnet_model_version) != - (model_name, model_version) - ): - warnings.warn( - _summation_mismatch_msg.format( - model_name, - model_version, - self.summation_model.pvnet_model_name, - self.summation_model.pvnet_model_version, - ) - ) - + + # Compare the current GSP model with the one the summation model was trained on + this_gsp_model = (model_name, model_version) + sum_expected_gsp_model = (sum_model.pvnet_model_name, sum_model.pvnet_model_version) + + if sum_expected_gsp_model!=this_gsp_model: + warnings.warn(_model_mismatch_msg.format(*this_gsp_model, *sum_expected_gsp_model)) + + return model, sum_model + - def log_info(self, message): + def log_info(self, message: str) -> None: """Maybe log message depending on verbosity""" if self.verbose: logger.info(message) - def predict_batch(self, batch): + + def predict_batch(self, batch: NumpyBatch) -> None: """Make predictions for a batch and store results internally""" self.log_info(f"Predicting for model: {self.model_name}-{self.model_version}") @@ -119,10 +167,16 @@ def predict_batch(self, batch): preds = self.model(batch).detach().cpu().numpy() # Calculate unnormalised elevation and sun-dowm mask - self.log_info("Zeroing predictions after sundown") - elevation = ( - batch[BatchKey.gsp_solar_elevation].cpu().numpy() * ELEVATION_STD + ELEVATION_MEAN - ) + self.log_info("Computing sundown mask") + if self.use_legacy: + # The old dataloader standardises the data + elevation = ( + batch[BatchKey.gsp_solar_elevation].cpu().numpy() * ELEVATION_STD + ELEVATION_MEAN + ) + else: + # The new dataloader normalises the data to [0, 1] + elevation = (batch[BatchKey.gsp_solar_elevation].cpu().numpy() - 0.5) * 180 + # We only need elevation mask for forecasted values, not history elevation = elevation[:, -preds.shape[1] :] sun_down_mask = elevation < MIN_DAY_ELEVATION @@ -136,22 +190,20 @@ def predict_batch(self, batch): self.log_info(f"Max prediction: {np.max(preds, axis=1)}") - def compile_forecasts(self): - """Compile all forecasts internally + def compile_forecasts(self) -> None: + """Compile all forecasts internally in a single DataArray - Compiles all the regional GSP-level forecasts, makes national forecast, and compiles all - into a Dataset + Steps: + - Compile all the GSP level forecasts + - Make national forecast + - Compile all forecasts into a DataArray stored inside the object as `da_abs_all` """ # Complie results from all batches normed_preds = np.concatenate(self.normed_preds) sun_down_masks = np.concatenate(self.sun_down_masks) gsp_ids_all_batches = np.concatenate(self.gsp_ids_each_batch).squeeze() - - n_times = normed_preds.shape[1] - - valid_times = pd.to_datetime([self.t0 + timedelta(minutes=30 * (i + 1)) for i in range(n_times)]) - + # Reorder GSPs which can end up shuffled if multiprocessing is used inds = gsp_ids_all_batches.argsort() @@ -160,14 +212,18 @@ def compile_forecasts(self): gsp_ids_all_batches = gsp_ids_all_batches[inds] # Merge batch results to xarray DataArray - da_normed = preds_to_dataarray(normed_preds, self.model, valid_times, gsp_ids_all_batches) + da_normed = self.preds_to_dataarray( + normed_preds, + self.model.output_quantiles, + gsp_ids_all_batches + ) da_sundown_mask = xr.DataArray( data=sun_down_masks, dims=["gsp_id", "target_datetime_utc"], coords=dict( gsp_id=gsp_ids_all_batches, - target_datetime_utc=valid_times, + target_datetime_utc=self.valid_times, ), ) @@ -201,11 +257,10 @@ def compile_forecasts(self): normed_national = self.summation_model(inputs).detach().squeeze().cpu().numpy() # Convert national predictions to DataArray - da_normed_national = preds_to_dataarray( + da_normed_national = self.preds_to_dataarray( normed_national[np.newaxis], - self.summation_model, - valid_times, - gsp_ids=[0] + self.summation_model.output_quantiles, + gsp_ids=[0], ) # Multiply normalised forecasts by capacities and clip negatives @@ -214,9 +269,196 @@ def compile_forecasts(self): # Apply sundown mask - All GSPs must be masked to mask national da_abs_national = da_abs_national.where(~da_sundown_mask.all(dim="gsp_id")).fillna(0.0) + self.log_info( + f"National forecast is {da_abs_national.sel(output_label='forecast_mw').values}" + ) + # Store the compiled predictions internally self.da_abs_all = xr.concat([da_abs_national, da_abs], dim="gsp_id") - self.log_info( - f"National forecast is {self.da_abs_all.sel(gsp_id=0, output_label='forecast_mw').values}" - ) \ No newline at end of file + + def preds_to_dataarray( + self, + preds: np.ndarray, + output_quantiles: list[float] | None, + gsp_ids: list[int], + ) -> xr.DataArray: + """Put numpy array of predictions into a dataarray""" + + if output_quantiles is not None: + output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in output_quantiles] + output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" + else: + output_labels = ["forecast_mw"] + preds = preds[..., np.newaxis] + + da = xr.DataArray( + data=preds, + dims=["gsp_id", "target_datetime_utc", "output_label"], + coords=dict( + gsp_id=gsp_ids, + target_datetime_utc=self.valid_times, + output_label=output_labels, + ), + ) + return da + + + def log_forecast_to_database(self, session: Session) -> None: + """Log the compiled forecast to the database""" + + self.log_info("Converting DataArray to list of ForecastSQL") + + sql_forecasts = self.convert_dataarray_to_forecasts( + self.da_abs_all, + session, + model_tag=self.model_tag, + version=pvnet_app.__version__, + ) + + self.log_info("Saving ForecastSQL to database") + + if self.save_gsp_to_recent: + + # Save all forecasts and save to last_seven_days table + save_sql_forecasts( + forecasts=sql_forecasts, + session=session, + update_national=True, + update_gsp=True, + apply_adjuster=self.apply_adjuster, + save_to_last_seven_days=True, + ) + else: + # Save national and save to last_seven_days table + save_sql_forecasts( + forecasts=sql_forecasts[0:1], + session=session, + update_national=True, + update_gsp=False, + apply_adjuster=self.apply_adjuster, + save_to_last_seven_days=True, + ) + + # Save GSP results but not to last_seven_dats table + save_sql_forecasts( + forecasts=sql_forecasts[1:], + session=session, + update_national=False, + update_gsp=True, + apply_adjuster=self.apply_adjuster, + save_to_last_seven_days=False, + ) + + if self.save_gsp_sum: + # Compute the sum if we are logging the sum of GSPs independently + da_abs_sum_gsps = ( + self.da_abs_all.sel(gsp_id=slice(1, 317)) + .sum(dim="gsp_id") + # Only select the central forecast for the GSP sum. The sums of different p-levels + # are not a meaningful qauntities + .sel(output_label=["forecast_mw"]) + .expand_dims(dim="gsp_id", axis=0) + .assign_coords(gsp_id=[0]) + ) + + # Save the sum of GSPs independently - mainly for summation model monitoring + gsp_sum_sql_forecasts = self.convert_dataarray_to_forecasts( + da_abs_sum_gsps, + session, + model_tag=f"{self.model_tag}_gsp_sum", + version=pvnet_app.__version__, + ) + + save_sql_forecasts( + forecasts=gsp_sum_sql_forecasts, + session=session, + update_national=True, + update_gsp=False, + apply_adjuster=False, + save_to_last_seven_days=True, + ) + + + + @staticmethod + def convert_dataarray_to_forecasts( + da_preds: xr.DataArray, session: Session, model_tag: str, version: str + ) -> list[ForecastSQL]: + """ + Make a ForecastSQL object from a DataArray. + + Args: + da_preds: DataArray of forecasted values + session: Database session + model_key: the name of the model to saved to the database + version: The version of the model + Return: + List of ForecastSQL objects + """ + + assert "target_datetime_utc" in da_preds.coords + assert "gsp_id" in da_preds.coords + assert "forecast_mw" in da_preds.output_label + + # get last input data + input_data_last_updated = get_latest_input_data_last_updated(session=session) + + # get model name + model = get_model(name=model_tag, version=version, session=session) + + forecasts = [] + + for gsp_id in da_preds.gsp_id.values: + + # make forecast values + forecast_values = [] + + location = get_location(session=session, gsp_id=int(gsp_id)) + + da_gsp = da_preds.sel(gsp_id=gsp_id) + + for target_time in pd.to_datetime(da_gsp.target_datetime_utc.values): + + da_gsp_time = da_gsp.sel(target_datetime_utc=target_time) + + forecast_value_sql = ForecastValue( + target_time=target_time.replace(tzinfo=timezone.utc), + expected_power_generation_megawatts=( + da_gsp_time.sel(output_label="forecast_mw").item() + ), + ).to_orm() + + properties = {} + + if "forecast_mw_plevel_10" in da_gsp_time.output_label: + p10 = da_gsp_time.sel(output_label="forecast_mw_plevel_10").item() + # `p10` can be NaN if PVNet has probabilistic outputs and PVNet_summation + # doesn't, or vice versa. Do not log the value if NaN + if not np.isnan(p10): + properties["10"] = p10 + + if "forecast_mw_plevel_90" in da_gsp_time.output_label: + p90 = da_gsp_time.sel(output_label="forecast_mw_plevel_90").item() + + if not np.isnan(p90): + properties["90"] = p90 + + if len(properties)>0: + forecast_value_sql.properties = properties + + forecast_values.append(forecast_value_sql) + + # make forecast object + forecast = ForecastSQL( + model=model, + forecast_creation_time=datetime.now(tz=timezone.utc), + location=location, + input_data_last_updated=input_data_last_updated, + forecast_values=forecast_values, + historic=False, + ) + + forecasts.append(forecast) + + return forecasts \ No newline at end of file diff --git a/pvnet_app/utils.py b/pvnet_app/utils.py deleted file mode 100644 index 3e3b2c6..0000000 --- a/pvnet_app/utils.py +++ /dev/null @@ -1,258 +0,0 @@ -from datetime import timezone, datetime -import fsspec.asyn -import yaml -import os -import copy -import xarray as xr -import numpy as np -import pandas as pd -from sqlalchemy.orm import Session -import logging - -from nowcasting_datamodel.models import ( - ForecastSQL, - ForecastValue, -) -from nowcasting_datamodel.read.read import ( - get_latest_input_data_last_updated, - get_location, -) -from nowcasting_datamodel.read.read_models import get_model - -from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path - - -logger = logging.getLogger(__name__) - - - -def load_yaml_config(path: str) -> dict: - """Load config file from path""" - with open(path) as file: - config = yaml.load(file, Loader=yaml.FullLoader) - return config - - -def save_yaml_config(config: dict, path: str) -> None: - """Save config file to path""" - with open(path, 'w') as file: - yaml.dump(config, file, default_flow_style=False) - - -def populate_config_with_data_data_filepaths(config: dict, gsp_path: str = "") -> dict: - """Populate the data source filepaths in the config - - Args: - config: The data config - gsp_path: For lagacy usage only - """ - - production_paths = { - "gsp": gsp_path, - "nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path}, - "satellite": sat_path, - } - - # Replace data sources - for source in ["gsp", "satellite"]: - if source in config["input_data"] : - if config["input_data"][source][f"{source}_zarr_path"]!="": - assert source in production_paths, f"Missing production path: {source}" - config["input_data"][source][f"{source}_zarr_path"] = production_paths[source] - - # NWP is nested so much be treated separately - if "nwp" in config["input_data"]: - nwp_config = config["input_data"]["nwp"] - for nwp_source in nwp_config.keys(): - if nwp_config[nwp_source]["nwp_zarr_path"]!="": - assert "nwp" in production_paths, "Missing production path: nwp" - assert nwp_source in production_paths["nwp"], f"Missing NWP path: {nwp_source}" - nwp_config[nwp_source]["nwp_zarr_path"] = production_paths["nwp"][nwp_source] - - return config - - -def overwrite_config_dropouts(config: dict) -> dict: - """Overwrite the config drouput parameters for production - - Args: - config: The data config - """ - - # Replace data sources - for source in ["satellite"]: - if source in config["input_data"] : - if config["input_data"][source][f"{source}_zarr_path"]!="": - config["input_data"][source][f"dropout_timedeltas_minutes"] = None - - # NWP is nested so much be treated separately - if "nwp" in config["input_data"]: - nwp_config = config["input_data"]["nwp"] - for nwp_source in nwp_config.keys(): - if nwp_config[nwp_source]["nwp_zarr_path"]!="": - nwp_config[nwp_source]["dropout_timedeltas_minutes"] = None - - return config - - -def modify_data_config_for_production( - input_path: str, - output_path: str, - gsp_path: str = "" -) -> None: - """Resave the data config with the data source filepaths and dropouts overwritten - - Args: - input_path: Path to input datapipes configuration file - output_path: Location to save the output configuration file - gsp_path: For lagacy usage only - """ - config = load_yaml_config(input_path) - - config = populate_config_with_data_data_filepaths(config, gsp_path=gsp_path) - config = overwrite_config_dropouts(config) - - save_yaml_config(config, output_path) - - -def find_min_satellite_delay_config(config_paths: list[str], use_satellite: bool = False) -> dict: - """Find the config with the minimum satallite delay across from list of config paths""" - - logger.info(f"Finding minimum satellite delay config from {config_paths}") - - # Load all the configs - configs = [load_yaml_config(config_path) for config_path in config_paths] - if not use_satellite: - logger.info("Not using satellite data, so returning first config") - return configs[0] - - min_sat_delay = np.inf - - for config in configs: - - if "satellite" in config["input_data"]: - min_sat_delay = min( - min_sat_delay, - config["input_data"]["satellite"]["live_delay_minutes"] - ) - - config = configs[0] - config["input_data"]["satellite"]["live_delay_minutes"] = min_sat_delay - return config - - -def preds_to_dataarray(preds, model, valid_times, gsp_ids): - """Put numpy array of predictions into a dataarray""" - - if model.use_quantile_regression: - output_labels = model.output_quantiles - output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] - output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" - else: - output_labels = ["forecast_mw"] - preds = preds[..., np.newaxis] - - da = xr.DataArray( - data=preds, - dims=["gsp_id", "target_datetime_utc", "output_label"], - coords=dict( - gsp_id=gsp_ids, - target_datetime_utc=valid_times, - output_label=output_labels, - ), - ) - return da - - -def convert_dataarray_to_forecasts( - forecast_values_dataarray: xr.DataArray, session: Session, model_name: str, version: str -) -> list[ForecastSQL]: - """ - Make a ForecastSQL object from a DataArray. - - Args: - forecast_values_dataarray: Dataarray of forecasted values. Must have `target_datetime_utc` - `gsp_id`, and `output_label` coords. The `output_label` coords must have `"forecast_mw"` - as an element. - session: database session - model_name: the name of the model - version: the version of the model - Return: - List of ForecastSQL objects - """ - logger.debug("Converting DataArray to list of ForecastSQL") - - assert "target_datetime_utc" in forecast_values_dataarray.coords - assert "gsp_id" in forecast_values_dataarray.coords - assert "forecast_mw" in forecast_values_dataarray.output_label - - # get last input data - input_data_last_updated = get_latest_input_data_last_updated(session=session) - - # get model name - model = get_model(name=model_name, version=version, session=session) - - forecasts = [] - - for gsp_id in forecast_values_dataarray.gsp_id.values: - gsp_id = int(gsp_id) - # make forecast values - forecast_values = [] - - # get location - location = get_location(session=session, gsp_id=gsp_id) - - gsp_forecast_values_da = forecast_values_dataarray.sel(gsp_id=gsp_id) - - for target_time in pd.to_datetime(gsp_forecast_values_da.target_datetime_utc.values): - # add timezone - target_time_utc = target_time.replace(tzinfo=timezone.utc) - this_da = gsp_forecast_values_da.sel(target_datetime_utc=target_time) - - forecast_value_sql = ForecastValue( - target_time=target_time_utc, - expected_power_generation_megawatts=( - this_da.sel(output_label="forecast_mw").item() - ), - ).to_orm() - - forecast_value_sql.adjust_mw = 0.0 - - properties = {} - - if "forecast_mw_plevel_10" in gsp_forecast_values_da.output_label: - val = this_da.sel(output_label="forecast_mw_plevel_10").item() - # `val` can be NaN if PVNet has probabilistic outputs and PVNet_summation doesn't, - # or if PVNet_summation has probabilistic outputs and PVNet doesn't. - # Do not log the value if NaN - if not np.isnan(val): - properties["10"] = val - - if "forecast_mw_plevel_90" in gsp_forecast_values_da.output_label: - val = this_da.sel(output_label="forecast_mw_plevel_90").item() - - if not np.isnan(val): - properties["90"] = val - - if len(properties)>0: - forecast_value_sql.properties = properties - - forecast_values.append(forecast_value_sql) - - # make forecast object - forecast = ForecastSQL( - model=model, - forecast_creation_time=datetime.now(tz=timezone.utc), - location=location, - input_data_last_updated=input_data_last_updated, - forecast_values=forecast_values, - historic=False, - ) - - forecasts.append(forecast) - - return forecasts - - - - \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index a475669..0194e15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,7 +203,6 @@ def gsp_yields_and_systems(db_session, test_t0): installed_capacity_mw=123.0, ) - gsp_yield_sqls = [] # From 3 hours ago to 8.5 hours into future for minute in range(-3 * 60, 9 * 60, 30): gsp_yield_sql = GSPYield( diff --git a/tests/test_data.py b/tests/data/test_satellite.py similarity index 82% rename from tests/test_data.py rename to tests/data/test_satellite.py index 4801227..5a367bc 100644 --- a/tests/test_data.py +++ b/tests/data/test_satellite.py @@ -10,25 +10,23 @@ Note that I'm not sure these tests will work in parallel, due to files being saved in the same places """ +from datetime import datetime, timedelta + import os import tempfile -import pytest import zarr import numpy as np import pandas as pd import xarray as xr -from datetime import datetime, timedelta -from pvnet.models.base_model import BaseModel as PVNetBaseModel from pvnet_app.data.satellite import ( - check_model_inputs_available, download_all_sat_data, preprocess_sat_data, + check_model_satellite_inputs_available, sat_path, sat_5_path, sat_15_path, ) -from pvnet_app.app import models_dict def save_to_zarr_zip(ds, filename): @@ -37,14 +35,14 @@ def save_to_zarr_zip(ds, filename): ds.to_zarr(store, compute=True, mode="w", encoding=encoding, consolidated=True) -def check_timesteps(sat_path, expected_mins, skip_nans=False): +def check_timesteps(sat_path, expected_freq_mins): ds_sat = xr.open_zarr(sat_path) - if not isinstance(expected_mins, list): - expected_mins = [expected_mins] + if not isinstance(expected_freq_mins, list): + expected_freq_mins = [expected_freq_mins] dts = pd.to_datetime(ds_sat.time).diff()[1:] - assert (np.isin(dts, [np.timedelta64(m, "m") for m in expected_mins])).all(), dts + assert (np.isin(dts, [np.timedelta64(m, "m") for m in expected_freq_mins])).all(), dts def test_download_sat_5_data(sat_5_data): @@ -67,7 +65,7 @@ def test_download_sat_5_data(sat_5_data): assert not os.path.exists(sat_15_path) # Check the satellite data is 5-minutely - check_timesteps(sat_5_path, expected_mins=5) + check_timesteps(sat_5_path, expected_freq_mins=5) def test_download_sat_15_data(sat_15_data): @@ -91,7 +89,7 @@ def test_download_sat_15_data(sat_15_data): assert os.path.exists(sat_15_path) # Check the satellite data is 15-minutely - check_timesteps(sat_15_path, expected_mins=15) + check_timesteps(sat_15_path, expected_freq_mins=15) def test_download_sat_both_data(sat_5_data, sat_15_data): @@ -114,10 +112,10 @@ def test_download_sat_both_data(sat_5_data, sat_15_data): assert os.path.exists(sat_15_path) # Check this satellite data is 5-minutely - check_timesteps(sat_5_path, expected_mins=5) + check_timesteps(sat_5_path, expected_freq_mins=5) # Check this satellite data is 15-minutely - check_timesteps(sat_15_path, expected_mins=15) + check_timesteps(sat_15_path, expected_freq_mins=15) def test_preprocess_sat_data(sat_5_data, test_t0): @@ -138,7 +136,7 @@ def test_preprocess_sat_data(sat_5_data, test_t0): preprocess_sat_data(test_t0) # Check the satellite data is 5-minutely - check_timesteps(sat_path, expected_mins=5) + check_timesteps(sat_path, expected_freq_mins=5) def test_preprocess_sat_15_data(sat_15_data, test_t0): @@ -158,8 +156,8 @@ def test_preprocess_sat_15_data(sat_15_data, test_t0): preprocess_sat_data(test_t0) - # Check the satellite data being used is 15-minutely - check_timesteps(sat_path, expected_mins=15) + # We infill the satellite data to 5 minutes in the process step + check_timesteps(sat_path, expected_freq_mins=5) def test_preprocess_old_sat_5_data(sat_5_data_delayed, sat_15_data, test_t0): @@ -181,17 +179,18 @@ def test_preprocess_old_sat_5_data(sat_5_data_delayed, sat_15_data, test_t0): preprocess_sat_data(test_t0) - # Check the satellite data being used is 15-minutely - check_timesteps(sat_path, expected_mins=15) + # We infill the satellite data to 5 minutes in the process step + check_timesteps(sat_path, expected_freq_mins=5) + -def test_check_model_inputs_available(config_filename): +def test_check_model_satellite_inputs_available(config_filename): t0 = datetime(2023,1,1) sat_datetime_1 = pd.date_range(t0 - timedelta(minutes=120), t0- timedelta(minutes=5), freq="5min") sat_datetime_2 = pd.date_range(t0 - timedelta(minutes=120), t0 - timedelta(minutes=15), freq="5min") sat_datetime_3 = pd.date_range(t0 - timedelta(minutes=120), t0 - timedelta(minutes=35), freq="5min") - assert check_model_inputs_available(config_filename, sat_datetime_1, t0, 5 ) - assert check_model_inputs_available(config_filename, sat_datetime_2, t0, 5 ) - assert not check_model_inputs_available(config_filename, sat_datetime_3, t0,5 ) + assert check_model_satellite_inputs_available(config_filename, t0, sat_datetime_1) + assert check_model_satellite_inputs_available(config_filename, t0, sat_datetime_2) + assert not check_model_satellite_inputs_available(config_filename, t0, sat_datetime_3) diff --git a/tests/test_app.py b/tests/test_app.py index 8e0f81a..b61e8fb 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,7 +1,6 @@ import tempfile import zarr import os -import logging from nowcasting_datamodel.models.forecast import ( ForecastSQL, @@ -12,51 +11,60 @@ from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path from pvnet_app.data.satellite import sat_5_path, sat_15_path +from pvnet.models.base_model import BaseModel as PVNetBaseModel +from ocf_datapipes.config.load import load_yaml_configuration + + +def model_uses_satellite(model_config): + """Function to check model if model entry in model dictionary uses satellite data""" + + data_config_path = PVNetBaseModel.get_data_config( + model_config["pvnet"]["name"], + revision=model_config["pvnet"]["version"], + ) + data_config = load_yaml_configuration(data_config_path) + + return hasattr(data_config.input_data, "satellite") and data_config.input_data.satellite + def test_app( db_session, nwp_ukv_data, nwp_ecmwf_data, sat_5_data, gsp_yields_and_systems, me_latest ): - # Environment variable DB_URL is set in engine_url, which is called by db_session - # set NWP_ZARR_PATH - # save nwp_data to temporary file, and set NWP_ZARR_PATH - # SATELLITE_ZARR_PATH - # save sat_data to temporary file, and set SATELLITE_ZARR_PATH - # GSP data + """Test the app running the intraday models""" + with tempfile.TemporaryDirectory() as tmpdirname: + + os.chdir(tmpdirname) + # The app loads sat and NWP data from environment variable # Save out data, and set paths as environmental variables - temp_nwp_path = f"{tmpdirname}/nwp_ukv.zarr" + temp_nwp_path = "temp_nwp_ukv.zarr" os.environ["NWP_UKV_ZARR_PATH"] = temp_nwp_path nwp_ukv_data.to_zarr(temp_nwp_path) - temp_nwp_path = f"{tmpdirname}/nwp_ecmwf.zarr" + temp_nwp_path = "temp_nwp_ecmwf.zarr" os.environ["NWP_ECMWF_ZARR_PATH"] = temp_nwp_path nwp_ecmwf_data.to_zarr(temp_nwp_path) # In production sat zarr is zipped - temp_sat_path = f"{tmpdirname}/sat.zarr.zip" + temp_sat_path = "temp_sat.zarr.zip" os.environ["SATELLITE_ZARR_PATH"] = temp_sat_path - store = zarr.storage.ZipStore(temp_sat_path, mode="x") - sat_5_data.to_zarr(store) - store.close() + with zarr.storage.ZipStore(temp_sat_path, mode="x") as store: + sat_5_data.to_zarr(store) # Set environmental variables - os.environ["SAVE_GSP_SUM"] = "True" os.environ["RUN_EXTRA_MODELS"] = "True" + os.environ["SAVE_GSP_SUM"] = "True" + os.environ["DAY_AHEAD_MODEL"] = "False" # Run prediction - # Thes import needs to come after the environ vars have been set + # These imports need to come after the environ vars have been set from pvnet_app.app import app, models_dict app(gsp_ids=list(range(1, 318)), num_workers=2) - os.system(f"rm {sat_5_path}") - os.system(f"rm {sat_15_path}") - os.system(f"rm -r {sat_path}") - os.system(f"rm -r {nwp_ukv_path}") - os.system(f"rm -r {nwp_ecmwf_path}") # Check correct number of forecasts have been made # (317 GSPs + 1 National + maybe GSP-sum) = 318 or 319 forecasts # Forecast made with multiple models @@ -78,9 +86,11 @@ def test_app( expected_forecast_results = 0 for model_config in models_dict.values(): - expected_forecast_results += 1 # national + # National + expected_forecast_results += 1 + # GSP expected_forecast_results += 317 * model_config["save_gsp_to_forecast_value_last_seven_days"] - expected_forecast_results += model_config["save_gsp_sum"] # gsp sum national + expected_forecast_results += model_config["save_gsp_sum"] # optional Sum of GSPs assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == expected_forecast_results * 16 @@ -88,35 +98,26 @@ def test_app( def test_app_day_ahead_model( db_session, nwp_ukv_data, nwp_ecmwf_data, sat_5_data, gsp_yields_and_systems, me_latest ): - # Test app with day ahead model config - # Environment variable DB_URL is set in engine_url, which is called by db_session - # set NWP_ZARR_PATH - # save nwp_data to temporary file, and set NWP_ZARR_PATH - # SATELLITE_ZARR_PATH - # save sat_data to temporary file, and set SATELLITE_ZARR_PATH - # GSP data + """Test the app running the day ahead model""" with tempfile.TemporaryDirectory() as tmpdirname: - # The app loads sat and NWP data from environment variable - # Save out data, and set paths as environmental variables - temp_nwp_path = f"{tmpdirname}/nwp_ukv.zarr" + + os.chdir(tmpdirname) + + temp_nwp_path = "temp_nwp_ukv.zarr" os.environ["NWP_UKV_ZARR_PATH"] = temp_nwp_path nwp_ukv_data.to_zarr(temp_nwp_path) - temp_nwp_path = f"{tmpdirname}/nwp_ecmwf.zarr" + temp_nwp_path = "temp_nwp_ecmwf.zarr" os.environ["NWP_ECMWF_ZARR_PATH"] = temp_nwp_path nwp_ecmwf_data.to_zarr(temp_nwp_path) - # In production sat zarr is zipped - temp_sat_path = f"{tmpdirname}/sat.zarr.zip" + temp_sat_path = "temp_sat.zarr.zip" os.environ["SATELLITE_ZARR_PATH"] = temp_sat_path - store = zarr.storage.ZipStore(temp_sat_path, mode="x") - sat_5_data.to_zarr(store) - store.close() + with zarr.storage.ZipStore(temp_sat_path, mode="x") as store: + sat_5_data.to_zarr(store) - # Set environmental variables os.environ["DAY_AHEAD_MODEL"] = "True" - os.environ["SAVE_GSP_SUM"] = "True" os.environ["RUN_EXTRA_MODELS"] = "False" # Run prediction @@ -125,11 +126,6 @@ def test_app_day_ahead_model( app(gsp_ids=list(range(1, 318)), num_workers=2) - os.system(f"rm {sat_5_path}") - os.system(f"rm {sat_15_path}") - os.system(f"rm -r {sat_path}") - os.system(f"rm -r {nwp_ukv_path}") - os.system(f"rm -r {nwp_ecmwf_path}") # Check correct number of forecasts have been made # (317 GSPs + 1 National + maybe GSP-sum) = 318 or 319 forecasts # Forecast made with multiple models @@ -145,7 +141,7 @@ def test_app_day_ahead_model( assert "90" in forecasts[0].forecast_values[0].properties assert "10" in forecasts[0].forecast_values[0].properties - # 318 GSPs * 72 time steps in forecast + # 72 time steps in forecast expected_forecast_timesteps = 72 assert ( @@ -160,3 +156,66 @@ def test_app_day_ahead_model( len(db_session.query(ForecastValueSevenDaysSQL).all()) == expected_forecast_results * expected_forecast_timesteps ) + +def test_app_no_sat( + db_session, nwp_ukv_data, nwp_ecmwf_data, sat_5_data, gsp_yields_and_systems, me_latest +): + """Test the app for the case when no satellite data is available""" + + with tempfile.TemporaryDirectory() as tmpdirname: + + os.chdir(tmpdirname) + + temp_nwp_path = "temp_nwp_ukv.zarr" + os.environ["NWP_UKV_ZARR_PATH"] = temp_nwp_path + nwp_ukv_data.to_zarr(temp_nwp_path) + + temp_nwp_path = "temp_nwp_ecmwf.zarr" + os.environ["NWP_ECMWF_ZARR_PATH"] = temp_nwp_path + nwp_ecmwf_data.to_zarr(temp_nwp_path) + + # There is no satellite data available at the environ path + os.environ["SATELLITE_ZARR_PATH"] = "nonexistent_sat.zarr.zip" + + os.environ["RUN_EXTRA_MODELS"] = "True" + os.environ["SAVE_GSP_SUM"] = "True" + os.environ["DAY_AHEAD_MODEL"] = "False" + + # Run prediction + # Thes import needs to come after the environ vars have been set + from pvnet_app.app import app, models_dict + + app(gsp_ids=list(range(1, 318)), num_workers=2) + + # Only the models which don't use satellite will be run in this case + # The models below are the only ones which should have been run + no_sat_models_dict = {k: v for k, v in models_dict.items() if not model_uses_satellite(v)} + + # Check correct number of forecasts have been made + # (317 GSPs + 1 National + maybe GSP-sum) = 318 or 319 forecasts + # Forecast made with multiple models + expected_forecast_results = 0 + for model_config in no_sat_models_dict.values(): + expected_forecast_results += 318 + model_config["save_gsp_sum"] + + forecasts = db_session.query(ForecastSQL).all() + # Doubled for historic and forecast + assert len(forecasts) == expected_forecast_results * 2 + + # Check probabilistic added + assert "90" in forecasts[0].forecast_values[0].properties + assert "10" in forecasts[0].forecast_values[0].properties + + # 318 GSPs * 16 time steps in forecast + assert len(db_session.query(ForecastValueSQL).all()) == expected_forecast_results * 16 + assert len(db_session.query(ForecastValueLatestSQL).all()) == expected_forecast_results * 16 + + expected_forecast_results = 0 + for model_config in no_sat_models_dict.values(): + # National + expected_forecast_results += 1 + # GSP + expected_forecast_results += 317 * model_config["save_gsp_to_forecast_value_last_seven_days"] + expected_forecast_results += model_config["save_gsp_sum"] # optional Sum of GSPs + + assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == expected_forecast_results * 16 \ No newline at end of file