diff --git a/pvnet_app/app.py b/pvnet_app/app.py index bf872f0..3f48c49 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -29,6 +29,7 @@ ) from pvnet_app.dataloader import get_legacy_dataloader, get_dataloader from pvnet_app.forecast_compiler import ForecastCompiler +from pvnet_app.model_configs.pydantic_models import get_all_models # sentry @@ -53,122 +54,6 @@ # Batch size used to make forecasts for all GSPs batch_size = 10 -# Dictionary of all models to run -# - The dictionary key will be used as the model name when saving to the database -# - The key "pvnet_v2" must be included -# - Batches are prepared only once, so the extra models must be able to run on the batches created -# to run the pvnet_v2 model -models_dict = { - - "pvnet_v2": { - # Huggingfacehub model repo and commit for PVNet (GSP-level model) - "pvnet": { - "name": "openclimatefix/pvnet_uk_region", - "version": os.getenv('PVNET_V2_VERSION', "ae0b8006841ac6227db873a1fc7f7331dc7dadb5"), - # We should only set PVNET_V2_VERSION in a short term solution, - # as its difficult to track which model is being used - }, - # Huggingfacehub model repo and commit for PVNet summation (GSP sum to national model) - # If summation_model_name is set to None, a simple sum is computed instead - "summation": { - "name": "openclimatefix/pvnet_v2_summation", - "version": os.getenv( - 'PVNET_V2_SUMMATION_VERSION', - "ffac655f9650b81865d96023baa15839f3ce26ec" - ), - }, - # Whether to use the adjuster for this model - for pvnet_v2 is set by environmental variable - "use_adjuster": os.getenv("USE_ADJUSTER", "true").lower() == "true", - # Whether to save the GSP sum for this model - for pvnet_v2 is set by environmental variable - "save_gsp_sum": os.getenv("SAVE_GSP_SUM", "false").lower() == "true", - # Whether to log information through prediction steps for this model - "verbose": True, - "save_gsp_to_forecast_value_last_seven_days": True, - }, - - # Extra models which will be run on dev only - "pvnet_v2-sat0-samples-v1": { - "pvnet": { - "name": "openclimatefix/pvnet_uk_region", - "version": "8a7cc21b64d25ce1add7a8547674be3143b2e650", - }, - "summation": { - "name": "openclimatefix/pvnet_v2_summation", - "version": "dcfdc17fda8e48c387122614bec8b284eaa868b9", - }, - "use_adjuster": False, - "save_gsp_sum": False, - "verbose": False, - "save_gsp_to_forecast_value_last_seven_days": False, - }, - - # single source models - "pvnet_v2-sat0-only-samples-v1": { - "pvnet": { - "name": "openclimatefix/pvnet_uk_region", - "version": "d7ab648942c85b6788adcdbed44c91c4e1c5604a", - }, - "summation": { - "name": "openclimatefix/pvnet_v2_summation", - "version": "adbf9e7797fee9a5050beb8c13841696e72f99ef", - }, - "use_adjuster": False, - "save_gsp_sum": False, - "verbose": False, - "save_gsp_to_forecast_value_last_seven_days": False, - }, - - "pvnet_v2-ukv-only-samples-v1": { - "pvnet": { - "name": "openclimatefix/pvnet_uk_region", - "version": "eb73bf9a176a108f2e33b809f1f6993f893a4df9", - }, - "summation": { - "name": "openclimatefix/pvnet_v2_summation", - "version": "9002baf1e9dc1ec141f3c4a1fa8447b6316a4558", - }, - "use_adjuster": False, - "save_gsp_sum": False, - "verbose": False, - "save_gsp_to_forecast_value_last_seven_days": False, - }, - - "pvnet_v2-ecmwf-only-samples-v1": { - "pvnet": { - "name": "openclimatefix/pvnet_uk_region", - "version": "0bc344fafb2232fb0b6bb0bf419f0449fe11c643", - }, - "summation": { - "name": "openclimatefix/pvnet_v2_summation", - "version": "4fe6b1441b6dd549292c201ed85eee156ecc220c", - }, - "use_adjuster": False, - "save_gsp_sum": False, - "verbose": False, - "save_gsp_to_forecast_value_last_seven_days": False, - }, -} - -# The day ahead model has not yet been re-trained with data-sampler. -# It will be run with the legacy dataloader using ocf_datapipes -day_ahead_model_dict = { - "pvnet_day_ahead": { - # Huggingfacehub model repo and commit for PVNet day ahead models - "pvnet": { - "name": "openclimatefix/pvnet_uk_region_day_ahead", - "version": "d87565731692a6003e43caac4feaed0f69e79272", - }, - "summation": { - "name": "openclimatefix/pvnet_summation_uk_national_day_ahead", - "version": "ed60c5d32a020242ca4739dcc6dbc8864f783a08", - }, - "use_adjuster": True, - "save_gsp_sum": True, - "verbose": True, - "save_gsp_to_forecast_value_last_seven_days": True, - }, -} - # --------------------------------------------------------------------------- # LOGGER @@ -221,6 +106,7 @@ def app( - DAY_AHEAD_MODEL, option to use day ahead model, defaults to false - SENTRY_DSN, optional link to sentry - ENVIRONMENT, the environment this is running in, defaults to local + - USE_ECMWF_ONLY, option to use ecmwf only model, defaults to false Args: t0 (datetime): Datetime at which forecast is made @@ -239,26 +125,23 @@ def app( dask.config.set(scheduler="single-threaded") use_day_ahead_model = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true" + use_ecmwf_only = os.getenv("USE_ECMWF_ONLY", "false").lower() == "true" + run_extra_models = os.getenv("RUN_EXTRA_MODELS", "false").lower() == "true" logger.info(f"Using `pvnet` library version: {pvnet.__version__}") logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}") logger.info(f"Using {num_workers} workers") logger.info(f"Using day ahead model: {use_day_ahead_model}") + logger.info(f"Using ecwmwf only: {use_ecmwf_only}") + logger.info(f"Running extra models: {run_extra_models}") - # Filter the models to be run - if use_day_ahead_model: - model_to_run_dict = day_ahead_model_dict - main_model_key = "pvnet_day_ahead" - else: - - if os.getenv("RUN_EXTRA_MODELS", "false").lower() == "false": - model_to_run_dict = {"pvnet_v2": models_dict["pvnet_v2"]} - else: - model_to_run_dict = models_dict - main_model_key = "pvnet_v2" + # load models + model_configs = get_all_models(get_ecmwf_only=use_ecmwf_only, + get_day_ahead_only=use_day_ahead_model, + run_extra_models=run_extra_models) - logger.info(f"Using adjduster: {model_to_run_dict[main_model_key]['use_adjuster']}") - logger.info(f"Saving GSP sum: {model_to_run_dict[main_model_key]['save_gsp_sum']}") + logger.info(f"Using adjuster: {model_configs[0].use_adjuster}") + logger.info(f"Saving GSP sum: {model_configs[0].save_gsp_sum}") temp_dir = tempfile.TemporaryDirectory() @@ -314,11 +197,11 @@ def app( # Prepare all the models which can be run forecast_compilers = {} data_config_paths = [] - for model_key, model_config in model_to_run_dict.items(): + for model_config in model_configs: # First load the data config data_config_path = PVNetBaseModel.get_data_config( - model_config["pvnet"]["name"], - revision=model_config["pvnet"]["version"], + model_config.pvnet.repo, + revision=model_config.pvnet.version, ) # Check if the data available will allow the model to run @@ -326,27 +209,19 @@ def app( if model_can_run: # Set up a forecast compiler for the model - forecast_compilers[model_key] = ForecastCompiler( - model_tag=model_key, - model_name=model_config["pvnet"]["name"], - model_version=model_config["pvnet"]["version"], - summation_name=model_config["summation"]["name"], - summation_version=model_config["summation"]["version"], + forecast_compilers[model_config.name] = ForecastCompiler( + model_config=model_config, device=device, t0=t0, gsp_capacities=gsp_capacities, national_capacity=national_capacity, - apply_adjuster=model_config["use_adjuster"], - save_gsp_sum=model_config["save_gsp_sum"], - save_gsp_to_recent=model_config["save_gsp_to_forecast_value_last_seven_days"], - verbose=model_config["verbose"], use_legacy=use_day_ahead_model, ) # Store the config filename so we can create batches suitable for all models data_config_paths.append(data_config_path) else: - warnings.warn(f"The model {model_key} cannot be run with input data available") + warnings.warn(f"The model {model_config.name} cannot be run with input data available") if len(forecast_compilers) == 0: raise Exception(f"No models were compatible with the available input data.") diff --git a/pvnet_app/forecast_compiler.py b/pvnet_app/forecast_compiler.py index 6d97d9d..330be30 100644 --- a/pvnet_app/forecast_compiler.py +++ b/pvnet_app/forecast_compiler.py @@ -1,27 +1,23 @@ -from datetime import timezone, datetime -import warnings import logging -import torch +import warnings +from datetime import timezone, datetime import numpy as np import pandas as pd +import torch import xarray as xr - +from nowcasting_datamodel.models import ForecastSQL, ForecastValue +from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location +from nowcasting_datamodel.read.read_models import get_model +from nowcasting_datamodel.save.save import save as save_sql_forecasts from ocf_datapipes.batch import BatchKey, NumpyBatch from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD - -import pvnet from pvnet.models.base_model import BaseModel as PVNetBaseModel from pvnet_summation.models.base_model import BaseModel as SummationBaseModel - from sqlalchemy.orm import Session -from nowcasting_datamodel.models import ForecastSQL, ForecastValue -from nowcasting_datamodel.read.read import get_latest_input_data_last_updated, get_location -from nowcasting_datamodel.read.read_models import get_model -from nowcasting_datamodel.save.save import save as save_sql_forecasts - import pvnet_app +from pvnet_app.model_configs.pydantic_models import Model logger = logging.getLogger(__name__) @@ -39,94 +35,80 @@ class ForecastCompiler: """Class for making and compiling solar forecasts from for all GB GSPsn and national total""" + def __init__( - self, - model_tag: str, - model_name: str, - model_version: str, - summation_name: str | None, - summation_version: str | None, - device: torch.device, - t0: pd.Timestamp, - gsp_capacities: xr.DataArray, - national_capacity: float, - apply_adjuster: bool, - save_gsp_sum: bool, - save_gsp_to_recent: bool, - verbose: bool = False, + self, + model_config: Model, + device: torch.device, + t0: pd.Timestamp, + gsp_capacities: xr.DataArray, + national_capacity: float, use_legacy: bool = False, ): """Class for making and compiling solar forecasts from for all GB GSPsn and national total - + Args: - model_tag: The name the model results will be saved to the database under - model_name: Name of the huggingface repo where the PVNet model is stored - model_version: Version of the PVNet model to run within the huggingface repo - summation_name: Name of the huggingface repo where the summation model is stored - summation_version: Version of the summation model to run within the huggingface repo + model_config: The configuration for the model device: Device to run the model on t0: The t0 time used to compile the results to numpy array gsp_capacities: DataArray of the solar capacities for all regional GSPs at t0 national_capacity: The national solar capacity at t0 - apply_adjuster: Whether to apply the adjuster when saving to database - save_gsp_sum: Whether to save the GSP sum - save_gsp_to_recent: Whether to save the GSP results to the - forecast_value_last_seven_days table - verbose: Whether to log all messages throughout prediction and compilation - legacy: Whether to run legacy dataloader + use_legacy: Whether to run legacy dataloader """ - + + model_name = model_config.pvnet.repo + model_version = model_config.pvnet.version + logger.info(f"Loading model: {model_name} - {model_version}") - - + # Store settings - self.model_tag = model_tag + self.model_tag = model_config.name self.model_name = model_name self.model_version = model_version self.device = device self.gsp_capacities = gsp_capacities self.national_capacity = national_capacity - self.apply_adjuster = apply_adjuster - self.save_gsp_sum = save_gsp_sum - self.save_gsp_to_recent = save_gsp_to_recent - self.verbose = verbose + self.apply_adjuster = model_config.use_adjuster + self.save_gsp_sum = model_config.save_gsp_sum + self.save_gsp_to_recent = model_config.save_gsp_to_recent + self.verbose = model_config.verbose self.use_legacy = use_legacy - + # Create stores for the predictions self.normed_preds = [] self.gsp_ids_each_batch = [] self.sun_down_masks = [] - + # Load the GSP and summation models self.model, self.summation_model = self.load_model( - model_name, - model_version, - summation_name, - summation_version, + model_name, + model_version, + model_config.summation.repo, + model_config.summation.version, device, ) - + # These are the valid times this forecast will predict for - self.valid_times = ( - t0 + pd.timedelta_range(start='30min', freq='30min', periods=self.model.forecast_len) + self.valid_times = t0 + pd.timedelta_range( + start="30min", freq="30min", periods=self.model.forecast_len ) - + @staticmethod def load_model( - model_name: str, - model_version: str, - summation_name: str | None, + model_name: str, + model_version: str, + summation_name: str | None, summation_version: str | None, - device: torch.device, + device: torch.device, ): """Load the GSP and summation models""" - # Load the GSP level model + # Load the GSP level model model = PVNetBaseModel.from_pretrained( model_id=model_name, revision=model_version, ).to(device) - + # Load the summation model if summation_name is None: sum_model = None @@ -135,31 +117,29 @@ def load_model( model_id=summation_name, revision=summation_version, ).to(device) - + # Compare the current GSP model with the one the summation model was trained on this_gsp_model = (model_name, model_version) sum_expected_gsp_model = (sum_model.pvnet_model_name, sum_model.pvnet_model_version) - - if sum_expected_gsp_model!=this_gsp_model: + + if sum_expected_gsp_model != this_gsp_model: warnings.warn(_model_mismatch_msg.format(*this_gsp_model, *sum_expected_gsp_model)) - + return model, sum_model - - + def log_info(self, message: str) -> None: """Maybe log message depending on verbosity""" if self.verbose: logger.info(message) - - + def predict_batch(self, batch: NumpyBatch) -> None: """Make predictions for a batch and store results internally""" - + self.log_info(f"Predicting for model: {self.model_name}-{self.model_version}") # Store GSP IDs for this batch for reordering later these_gsp_ids = batch[BatchKey.gsp_id].cpu().numpy() self.gsp_ids_each_batch += [these_gsp_ids] - + # TODO: This change should be moved inside PVNet batch[BatchKey.gsp_id] = batch[BatchKey.gsp_id].unsqueeze(1) @@ -176,7 +156,7 @@ def predict_batch(self, batch: NumpyBatch) -> None: else: # The new dataloader normalises the data to [0, 1] elevation = (batch[BatchKey.gsp_solar_elevation].cpu().numpy() - 0.5) * 180 - + # We only need elevation mask for forecasted values, not history elevation = elevation[:, -preds.shape[1] :] sun_down_mask = elevation < MIN_DAY_ELEVATION @@ -188,36 +168,33 @@ def predict_batch(self, batch: NumpyBatch) -> None: # Log max prediction self.log_info(f"GSP IDs: {these_gsp_ids}") self.log_info(f"Max prediction: {np.max(preds, axis=1)}") - - + def compile_forecasts(self) -> None: """Compile all forecasts internally in a single DataArray - + Steps: - Compile all the GSP level forecasts - Make national forecast - Compile all forecasts into a DataArray stored inside the object as `da_abs_all` """ - + # Complie results from all batches normed_preds = np.concatenate(self.normed_preds) sun_down_masks = np.concatenate(self.sun_down_masks) gsp_ids_all_batches = np.concatenate(self.gsp_ids_each_batch).squeeze() - + # Reorder GSPs which can end up shuffled if multiprocessing is used inds = gsp_ids_all_batches.argsort() normed_preds = normed_preds[inds] sun_down_masks = sun_down_masks[inds] gsp_ids_all_batches = gsp_ids_all_batches[inds] - + # Merge batch results to xarray DataArray da_normed = self.preds_to_dataarray( - normed_preds, - self.model.output_quantiles, - gsp_ids_all_batches + normed_preds, self.model.output_quantiles, gsp_ids_all_batches ) - + da_sundown_mask = xr.DataArray( data=sun_down_masks, dims=["gsp_id", "target_datetime_utc"], @@ -235,7 +212,7 @@ def compile_forecasts(self) -> None: # Apply sundown mask da_abs = da_abs.where(~da_sundown_mask).fillna(0.0) - + if self.summation_model is None: self.log_info("Summing across GSPs to produce national forecast") da_abs_national = ( @@ -258,8 +235,8 @@ def compile_forecasts(self) -> None: # Convert national predictions to DataArray da_normed_national = self.preds_to_dataarray( - normed_national[np.newaxis], - self.summation_model.output_quantiles, + normed_national[np.newaxis], + self.summation_model.output_quantiles, gsp_ids=[0], ) @@ -272,15 +249,14 @@ def compile_forecasts(self) -> None: self.log_info( f"National forecast is {da_abs_national.sel(output_label='forecast_mw').values}" ) - + # Store the compiled predictions internally self.da_abs_all = xr.concat([da_abs_national, da_abs], dim="gsp_id") - def preds_to_dataarray( - self, - preds: np.ndarray, - output_quantiles: list[float] | None, + self, + preds: np.ndarray, + output_quantiles: list[float] | None, gsp_ids: list[int], ) -> xr.DataArray: """Put numpy array of predictions into a dataarray""" @@ -302,11 +278,10 @@ def preds_to_dataarray( ), ) return da - - + def log_forecast_to_database(self, session: Session) -> None: """Log the compiled forecast to the database""" - + self.log_info("Converting DataArray to list of ForecastSQL") sql_forecasts = self.convert_dataarray_to_forecasts( @@ -315,11 +290,11 @@ def log_forecast_to_database(self, session: Session) -> None: model_tag=self.model_tag, version=pvnet_app.__version__, ) - + self.log_info("Saving ForecastSQL to database") if self.save_gsp_to_recent: - + # Save all forecasts and save to last_seven_days table save_sql_forecasts( forecasts=sql_forecasts, @@ -339,7 +314,7 @@ def log_forecast_to_database(self, session: Session) -> None: apply_adjuster=self.apply_adjuster, save_to_last_seven_days=True, ) - + # Save GSP results but not to last_seven_dats table save_sql_forecasts( forecasts=sql_forecasts[1:], @@ -378,8 +353,6 @@ def log_forecast_to_database(self, session: Session) -> None: apply_adjuster=False, save_to_last_seven_days=True, ) - - @staticmethod def convert_dataarray_to_forecasts( @@ -433,7 +406,7 @@ def convert_dataarray_to_forecasts( if "forecast_mw_plevel_10" in da_gsp_time.output_label: p10 = da_gsp_time.sel(output_label="forecast_mw_plevel_10").item() - # `p10` can be NaN if PVNet has probabilistic outputs and PVNet_summation + # `p10` can be NaN if PVNet has probabilistic outputs and PVNet_summation # doesn't, or vice versa. Do not log the value if NaN if not np.isnan(p10): properties["10"] = p10 @@ -444,7 +417,7 @@ def convert_dataarray_to_forecasts( if not np.isnan(p90): properties["90"] = p90 - if len(properties)>0: + if len(properties) > 0: forecast_value_sql.properties = properties forecast_values.append(forecast_value_sql) @@ -461,4 +434,4 @@ def convert_dataarray_to_forecasts( forecasts.append(forecast) - return forecasts \ No newline at end of file + return forecasts diff --git a/pvnet_app/model_configs/all_models.yaml b/pvnet_app/model_configs/all_models.yaml new file mode 100644 index 0000000..e1a31c2 --- /dev/null +++ b/pvnet_app/model_configs/all_models.yaml @@ -0,0 +1,64 @@ +# Here we define all the models that are available in the app +# Batches are prepared only once, so the extra models must be able to run on the batches created +# to run the pvnet_v2 model +models: + - name: pvnet_v2 + pvnet: + repo: openclimatefix/pvnet_uk_region + version: ae0b8006841ac6227db873a1fc7f7331dc7dadb5 + summation: + repo: openclimatefix/pvnet_v2_summation + version: ffac655f9650b81865d96023baa15839f3ce26ec + use_adjuster: True + save_gsp_sum: False + verbose: True + save_gsp_to_recent: true + - name: pvnet_v2-sat0-samples-v1 + pvnet: + repo: openclimatefix/pvnet_uk_region + version: 8a7cc21b64d25ce1add7a8547674be3143b2e650 + summation: + repo: openclimatefix/pvnet_v2_summation + version: dcfdc17fda8e48c387122614bec8b284eaa868b9 + # single source models + - name: pvnet_v2-sat0-only-samples-v1" + pvnet: + repo: openclimatefix/pvnet_uk_region + version: d7ab648942c85b6788adcdbed44c91c4e1c5604a + summation: + repo: openclimatefix/pvnet_v2_summation + version: adbf9e7797fee9a5050beb8c13841696e72f99ef + + - name: pvnet_v2-ukv-only-samples-v1 + pvnet: + repo: openclimatefix/pvnet_uk_region + version: eb73bf9a176a108f2e33b809f1f6993f893a4df9 + summation: + repo: openclimatefix/pvnet_v2_summation + version: 9002baf1e9dc1ec141f3c4a1fa8447b6316a4558 + uses_satellite_data: False + + - name: pvnet_v2-ecmwf-only-samples-v1 + pvnet: + repo: openclimatefix/pvnet_uk_region + version: 0bc344fafb2232fb0b6bb0bf419f0449fe11c643 + summation: + repo: openclimatefix/pvnet_v2_summation + version: 4fe6b1441b6dd549292c201ed85eee156ecc220c + ecmwf_only: True + uses_satellite_data: False +# The day ahead model has not yet been re-trained with data-sampler. +# It will be run with the legacy dataloader using ocf_datapipes + - name: pvnet_day_ahead + pvnet: + repo: openclimatefix/pvnet_uk_region_day_ahead + version: d87565731692a6003e43caac4feaed0f69e79272 + summation: + repo: openclimatefix/pvnet_summation_uk_national_day_ahead + version: ed60c5d32a020242ca4739dcc6dbc8864f783a08 + use_adjuster: True + save_gsp_sum: True + verbose: True + save_gsp_to_recent: True + day_ahead: True + diff --git a/pvnet_app/model_configs/pydantic_models.py b/pvnet_app/model_configs/pydantic_models.py new file mode 100644 index 0000000..3f3635b --- /dev/null +++ b/pvnet_app/model_configs/pydantic_models.py @@ -0,0 +1,123 @@ +""" A pydantic model for the ML models""" +import os +import logging + +from typing import List, Optional + +import fsspec +from pyaml_env import parse_config +from pydantic import BaseModel, Field, field_validator + +log = logging.getLogger(__name__) + + +class ModelHF(BaseModel): + repo: str = Field(..., title="Repo name", description="The HF Repo") + version: str = Field(..., title="Repo version", description="The HF version") + + +class Model(BaseModel): + """One ML Model""" + + name: str = Field(..., title="Model Name", description="The name of the model") + pvnet: ModelHF = Field(..., title="PVNet", description="The PVNet model") + summation: ModelHF = Field(..., title="Summation", description="The Summation model") + + use_adjuster: Optional[bool] = Field( + False, title="Use Adjuster", description="Whether to use the adjuster model" + ) + save_gsp_sum: Optional[bool] = Field( + False, title="Save GSP Sum", description="Whether to save the GSP sum" + ) + verbose: Optional[bool] = Field( + False, title="Verbose", description="Whether to print verbose output" + ) + save_gsp_to_recent: Optional[bool] = Field( + False, + title="Save GSP to Forecast Value Last Seven Days", + description="Whether to save the GSP to Forecast Value Last Seven Days", + ) + day_ahead: Optional[bool] = Field( + False, title="Day Ahead", description="If this model is day ahead or not" + ) + + ecmwf_only: Optional[bool] = Field( + False, title="ECMWF ONly", description="If this model is only using ecmwf data" + ) + + uses_satellite_data: Optional[bool] = Field( + True, title="Uses Satellite Data", description="If this model uses satellite data" + ) + + +class Models(BaseModel): + """A group of ml models""" + + models: List[Model] = Field( + ..., title="Models", description="A list of models to use for the forecast" + ) + + @field_validator("models") + @classmethod + def name_must_be_unique(cls, v: List[Model]) -> List[Model]: + """Ensure that all model names are unique""" + names = [model.name for model in v] + unique_names = set(names) + + if len(names) != len(unique_names): + raise Exception(f"Model names must be unique, names are {names}") + return v + + +def get_all_models( + get_ecmwf_only: Optional[bool] = False, + get_day_ahead_only: Optional[bool] = False, + run_extra_models: Optional[bool] = False, +) -> List[Model]: + """ + Returns all the models for a given client + + Args: + get_ecmwf_only: If only the ECMWF model should be returned + get_day_ahead_only: If only the day ahead model should be returned + run_extra_models: If extra models should be run + """ + + # load models from yaml file + filename = os.path.dirname(os.path.abspath(__file__)) + "/all_models.yaml" + + with fsspec.open(filename, mode="r") as stream: + models = parse_config(data=stream) + models = Models(**models) + + models = config_pvnet_v2_model(models) + + if get_ecmwf_only: + log.info("Using ECMWF model only") + models.models = [model for model in models.models if model.ecmwf_only] + + if get_day_ahead_only: + log.info("Using Day Ahead model only") + models.models = [model for model in models.models if model.day_ahead] + else: + log.info("Not using Day Ahead model") + models.models = [model for model in models.models if not model.day_ahead] + + if not run_extra_models and not get_day_ahead_only and not get_ecmwf_only: + log.info("Not running extra models") + models.models = [model for model in models.models if model.name == "pvnet_v2"] + + return models.models + + +def config_pvnet_v2_model(models): + """Function to adjust pvnet model""" + # special case for environment variables + use_adjuster = os.getenv("USE_ADJUSTER", "true").lower() == "true" + save_gsp_sum = os.getenv("SAVE_GSP_SUM", "false").lower() == "true" + # find index where name=pvnet_v2 + pvnet_v2_index = 0 + models.models[pvnet_v2_index].use_adjuster = use_adjuster + models.models[pvnet_v2_index].save_gsp_sum = save_gsp_sum + + return models diff --git a/tests/model_configs/test_pydantic_models.py b/tests/model_configs/test_pydantic_models.py new file mode 100644 index 0000000..585a55e --- /dev/null +++ b/tests/model_configs/test_pydantic_models.py @@ -0,0 +1,30 @@ +""" Test for getting all ml models""" +from pvnet_app.model_configs.pydantic_models import get_all_models + + +def test_get_all_models(): + """Test for getting all models""" + models = get_all_models() + assert len(models) == 1 + assert models[0].name == "pvnet_v2" + + +def test_get_all_models_get_ecmwf_only(): + """Test for getting all models with ecmwf_only""" + models = get_all_models(get_ecmwf_only=True) + assert len(models) == 1 + assert models[0].ecmwf_only + + +def test_get_all_models_get_day_ahead_only(): + """Test for getting all models with ecmwf_only""" + models = get_all_models(get_day_ahead_only=True) + assert len(models) == 1 + assert models[0].day_ahead + + +def test_get_all_models_run_extra_models(): + """Test for getting all models with ecmwf_only""" + models = get_all_models(run_extra_models=True) + assert len(models) == 5 + diff --git a/tests/test_app.py b/tests/test_app.py index b61e8fb..797a3cd 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -14,17 +14,7 @@ from pvnet.models.base_model import BaseModel as PVNetBaseModel from ocf_datapipes.config.load import load_yaml_configuration - -def model_uses_satellite(model_config): - """Function to check model if model entry in model dictionary uses satellite data""" - - data_config_path = PVNetBaseModel.get_data_config( - model_config["pvnet"]["name"], - revision=model_config["pvnet"]["version"], - ) - data_config = load_yaml_configuration(data_config_path) - - return hasattr(data_config.input_data, "satellite") and data_config.input_data.satellite +from pvnet_app.model_configs.pydantic_models import get_all_models @@ -61,16 +51,18 @@ def test_app( # Run prediction # These imports need to come after the environ vars have been set - from pvnet_app.app import app, models_dict + from pvnet_app.app import app app(gsp_ids=list(range(1, 318)), num_workers=2) + all_models = get_all_models(run_extra_models=True) + # Check correct number of forecasts have been made # (317 GSPs + 1 National + maybe GSP-sum) = 318 or 319 forecasts # Forecast made with multiple models expected_forecast_results = 0 - for model_config in models_dict.values(): - expected_forecast_results += 318 + model_config["save_gsp_sum"] + for model_config in all_models: + expected_forecast_results += 318 + model_config.save_gsp_sum forecasts = db_session.query(ForecastSQL).all() # Doubled for historic and forecast @@ -85,12 +77,12 @@ def test_app( assert len(db_session.query(ForecastValueLatestSQL).all()) == expected_forecast_results * 16 expected_forecast_results = 0 - for model_config in models_dict.values(): + for model_config in all_models: # National expected_forecast_results += 1 # GSP - expected_forecast_results += 317 * model_config["save_gsp_to_forecast_value_last_seven_days"] - expected_forecast_results += model_config["save_gsp_sum"] # optional Sum of GSPs + expected_forecast_results += 317 * model_config.save_gsp_to_recent + expected_forecast_results += model_config.save_gsp_sum # optional Sum of GSPs assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == expected_forecast_results * 16 @@ -122,16 +114,18 @@ def test_app_day_ahead_model( # Run prediction # Thes import needs to come after the environ vars have been set - from pvnet_app.app import app, day_ahead_model_dict + from pvnet_app.app import app app(gsp_ids=list(range(1, 318)), num_workers=2) + all_models = get_all_models(get_day_ahead_only=True) + # Check correct number of forecasts have been made # (317 GSPs + 1 National + maybe GSP-sum) = 318 or 319 forecasts # Forecast made with multiple models expected_forecast_results = 0 - for model_config in day_ahead_model_dict.values(): - expected_forecast_results += 318 + model_config["save_gsp_sum"] + for model_config in all_models: + expected_forecast_results += 318 + model_config.save_gsp_sum forecasts = db_session.query(ForecastSQL).all() # Doubled for historic and forecast @@ -183,20 +177,21 @@ def test_app_no_sat( # Run prediction # Thes import needs to come after the environ vars have been set - from pvnet_app.app import app, models_dict + from pvnet_app.app import app app(gsp_ids=list(range(1, 318)), num_workers=2) # Only the models which don't use satellite will be run in this case # The models below are the only ones which should have been run - no_sat_models_dict = {k: v for k, v in models_dict.items() if not model_uses_satellite(v)} + all_models = get_all_models(run_extra_models=True) + all_models = [model for model in all_models if not model.uses_satellite_data] # Check correct number of forecasts have been made # (317 GSPs + 1 National + maybe GSP-sum) = 318 or 319 forecasts # Forecast made with multiple models expected_forecast_results = 0 - for model_config in no_sat_models_dict.values(): - expected_forecast_results += 318 + model_config["save_gsp_sum"] + for model_config in all_models: + expected_forecast_results += 318 + model_config.save_gsp_sum forecasts = db_session.query(ForecastSQL).all() # Doubled for historic and forecast @@ -211,11 +206,11 @@ def test_app_no_sat( assert len(db_session.query(ForecastValueLatestSQL).all()) == expected_forecast_results * 16 expected_forecast_results = 0 - for model_config in no_sat_models_dict.values(): + for model_config in all_models: # National expected_forecast_results += 1 # GSP - expected_forecast_results += 317 * model_config["save_gsp_to_forecast_value_last_seven_days"] - expected_forecast_results += model_config["save_gsp_sum"] # optional Sum of GSPs + expected_forecast_results += 317 * model_config.save_gsp_to_recent + expected_forecast_results += model_config.save_gsp_sum # optional Sum of GSPs assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == expected_forecast_results * 16 \ No newline at end of file