diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 768fdbc..169bb62 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -11,51 +11,42 @@ """ import logging -import warnings import os import tempfile +import warnings from datetime import timedelta - -import numpy as np +import dask import pandas as pd +import pvnet import torch import typer -import dask from nowcasting_datamodel.connection import DatabaseConnection -from nowcasting_datamodel.save.save import save as save_sql_forecasts -from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities from nowcasting_datamodel.models.base import Base_Forecast -from ocf_datapipes.load import OpenGSPFromDatabase -from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline -from ocf_datapipes.batch import stack_np_examples_into_batch, batch_to_tensor, copy_batch_to_device - -from torch.utils.data import DataLoader -from torch.utils.data.datapipes.iter import IterableWrapper - -import pvnet +from nowcasting_datamodel.read.read_gsp import get_latest_gsp_capacities +from nowcasting_datamodel.save.save import save as save_sql_forecasts +from ocf_datapipes.batch import batch_to_tensor, copy_batch_to_device from pvnet.models.base_model import BaseModel as PVNetBaseModel -from pvnet.utils import GSPLocationLookup import sentry_sdk + import pvnet_app -from pvnet_app.utils import ( - worker_init_fn, - populate_data_config_sources, - convert_dataarray_to_forecasts, - find_min_satellite_delay_config, - save_yaml_config, -) +from pvnet_app.data.nwp import download_all_nwp_data, preprocess_nwp_data from pvnet_app.data.satellite import ( download_all_sat_data, preprocess_sat_data, check_model_inputs_available, ) -from pvnet_app.data.nwp import ( - download_all_nwp_data, - preprocess_nwp_data, -) from pvnet_app.forecast_compiler import ForecastCompiler +from pvnet_app.utils import ( + populate_data_config_sources, + convert_dataarray_to_forecasts, + find_min_satellite_delay_config, + save_yaml_config, +) + +from pvnet_app.dataloader import get_legacy_dataloader, get_dataloader + # sentry sentry_sdk.init( @@ -84,11 +75,12 @@ # - Batches are prepared only once, so the extra models must be able to run on the batches created # to run the pvnet_v2 model models_dict = { + "pvnet_v2": { # Huggingfacehub model repo and commit for PVNet (GSP-level model) "pvnet": { "name": "openclimatefix/pvnet_uk_region", - "version": os.getenv('PVNET_V2_VERSION', "62e5e20ab793cee7cf94eadac870d2199501a730"), + "version": os.getenv('PVNET_V2_VERSION', "ae0b8006841ac6227db873a1fc7f7331dc7dadb5"), # We should only set PVNET_V2_VERSION in a short term solution, # as its difficult to track which model is being used }, @@ -107,11 +99,12 @@ "verbose": True, "save_gsp_to_forecast_value_last_seven_days": True, }, + # Extra models which will be run on dev only - "pvnet_v2-sat0min-v12-batches": { + "pvnet_v2-sat0-samples-v1": { "pvnet": { "name": "openclimatefix/pvnet_uk_region", - "version": "dce387462ee08401355f33f53e86461dd59663e2", + "version": "8a7cc21b64d25ce1add7a8547674be3143b2e650", }, "summation": { "name": "openclimatefix/pvnet_v2_summation", @@ -122,11 +115,12 @@ "verbose": False, "save_gsp_to_forecast_value_last_seven_days": False, }, + # single source models - "pvnet_v2-sat_delay0_only-v12-batches": { + "pvnet_v2-sat0-only-samples-v1": { "pvnet": { "name": "openclimatefix/pvnet_uk_region", - "version": "ea6ad2cf84152969c768788586df227976890f31", + "version": "d7ab648942c85b6788adcdbed44c91c4e1c5604a", }, "summation": { "name": "openclimatefix/pvnet_v2_summation", @@ -138,10 +132,10 @@ "save_gsp_to_forecast_value_last_seven_days": False, }, - "pvnet_v2-ukv_only-v12-batches": { + "pvnet_v2-ukv-only-samples-v1": { "pvnet": { "name": "openclimatefix/pvnet_uk_region", - "version": "35d55181a82440bdd087f380d650bfd0b64bd322", + "version": "eb73bf9a176a108f2e33b809f1f6993f893a4df9", }, "summation": { "name": "openclimatefix/pvnet_v2_summation", @@ -153,10 +147,10 @@ "save_gsp_to_forecast_value_last_seven_days": False, }, - "pvnet_v2-ecmwf_only-v12-batches": { + "pvnet_v2-ecmwf-only-samples-v1": { "pvnet": { "name": "openclimatefix/pvnet_uk_region", - "version": "c14f7427d9854d63430aa936ce45f55d3818d033", + "version": "0bc344fafb2232fb0b6bb0bf419f0449fe11c643", }, "summation": { "name": "openclimatefix/pvnet_v2_summation", @@ -169,6 +163,8 @@ }, } +# The day ahead model has not yet been re-trained with data-sampler. +# It will be run with the legacy dataloader using ocf_datapipes day_ahead_model_dict = { "pvnet_day_ahead": { # Huggingfacehub model repo and commit for PVNet day ahead models @@ -257,19 +253,19 @@ def app( # Without this line the dataloader will hang if multiple workers are used dask.config.set(scheduler="single-threaded") - day_ahead_model_used = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true" + use_day_ahead_model = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true" use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true" logger.info(f"Using satellite data: {use_satellite}") - logger.info(f"Using day ahead model: {day_ahead_model_used}") + logger.info(f"Using day ahead model: {use_day_ahead_model}") - if day_ahead_model_used: + if use_day_ahead_model: logger.info(f"Using day ahead PVNet model") logger.info(f"Using `pvnet` library version: {pvnet.__version__}") logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}") logger.info(f"Using {num_workers} workers") - if day_ahead_model_used: + if use_day_ahead_model: logger.info(f"Using adjduster: {day_ahead_model_dict['pvnet_day_ahead']['use_adjuster']}") logger.info(f"Saving GSP sum: {day_ahead_model_dict['pvnet_day_ahead']['save_gsp_sum']}") @@ -296,26 +292,19 @@ def app( # --------------------------------------------------------------------------- # 1. Prepare data sources - # Make pands Series of most recent GSP effective capacities logger.info("Loading GSP metadata") - ds_gsp = next(iter(OpenGSPFromDatabase())) - # Get capacities from the database db_connection = DatabaseConnection(url=os.getenv("DB_URL"), base=Base_Forecast, echo=False) with db_connection.get_session() as session: #  Pandas series of most recent GSP capacities - now_minis_two_days = pd.Timestamp.now(tz="UTC") - timedelta(days=2) gsp_capacities = get_latest_gsp_capacities( - session=session, gsp_ids=gsp_ids, datetime_utc=now_minis_two_days + session=session, gsp_ids=gsp_ids, datetime_utc=t0-timedelta(days=2) ) # National capacity is needed if using summation model national_capacity = get_latest_gsp_capacities(session, [0])[0] - # Set up ID location query object - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - # Download satellite data if use_satellite: logger.info("Downloading satellite data") @@ -325,7 +314,7 @@ def app( all_satellite_datetimes, data_freq_minutes = preprocess_sat_data(t0) else: all_satellite_datetimes = [] - data_freq_minutes = 5 + data_freq_minutes = None # Download NWP data logger.info("Downloading NWP data") @@ -337,7 +326,7 @@ def app( # --------------------------------------------------------------------------- # 2. Set up models - if day_ahead_model_used: + if use_day_ahead_model: model_to_run_dict = {"pvnet_day_ahead": day_ahead_model_dict["pvnet_day_ahead"]} # Remove extra models if not configured to run them elif os.getenv("RUN_EXTRA_MODELS", "false").lower() == "false": @@ -393,47 +382,25 @@ def app( # Set up data loader logger.info("Creating DataLoader") - # Populate the data config with production data paths - populated_data_config_filename = f"{temp_dir.name}/data_config.yaml" - - populate_data_config_sources(common_config_path, populated_data_config_filename) - - # Location and time datapipes - location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids]) - t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe)) - - location_pipe = location_pipe.sharding_filter() - t0_datapipe = t0_datapipe.sharding_filter() - - # Batch datapipe - batch_datapipe = ( - construct_sliced_data_pipeline( - config_filename=populated_data_config_filename, - location_pipe=location_pipe, - t0_datapipe=t0_datapipe, - production=True, + if use_day_ahead_model: + # The current day ahead model uses the legacy dataloader + dataloader = get_legacy_dataloader( + config_filename=common_config_path, + t0=t0, + gsp_ids=gsp_ids, + batch_size=batch_size, + num_workers=num_workers, ) - .batch(batch_size) - .map(stack_np_examples_into_batch) - ) - - # Set up dataloader for parallel loading - dataloader_kwargs = dict( - shuffle=False, - batch_size=None, # batched in datapipe step - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=worker_init_fn, - prefetch_factor=None if num_workers == 0 else 2, - persistent_workers=False, - ) - - dataloader = DataLoader(batch_datapipe, **dataloader_kwargs) + + else: + dataloader = get_dataloader( + config_filename=common_config_path, + t0=t0, + gsp_ids=gsp_ids, + batch_size=batch_size, + num_workers=num_workers, + ) + # --------------------------------------------------------------------------- # Make predictions @@ -459,7 +426,7 @@ def app( # Escape clause for making predictions locally if not write_predictions: temp_dir.cleanup() - if not day_ahead_model_used: + if not use_day_ahead_model: return forecast_compilers["pvnet_v2"].da_abs_all return forecast_compilers["pvnet_day_ahead"].da_abs_all diff --git a/pvnet_app/dataloader.py b/pvnet_app/dataloader.py new file mode 100644 index 0000000..da4615e --- /dev/null +++ b/pvnet_app/dataloader.py @@ -0,0 +1,123 @@ +from pathlib import Path +import os + +import pandas as pd + +from torch.utils.data import DataLoader +from ocf_datapipes.batch import stack_np_examples_into_batch +from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset + +from pvnet_app.utils import populate_data_config_sources + +# Legacy imports +from ocf_datapipes.load import OpenGSPFromDatabase +from torch.utils.data.datapipes.iter import IterableWrapper +from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline +from ocf_datapipes.batch import BatchKey +from pvnet.utils import GSPLocationLookup + + + +def get_dataloader( + config_filename: str, + t0: pd.Timestamp, + gsp_ids: list[int], + batch_size: int, + num_workers: int, +): + + # Populate the data config with production data paths + populated_data_config_filename = Path(config_filename).parent / "data_config.yaml" + + populate_data_config_sources(config_filename, populated_data_config_filename) + + dataset = PVNetUKRegionalDataset( + config_filename=populated_data_config_filename, + start_time=t0, + end_time=t0, + gsp_ids=gsp_ids, + ) + + # Set up dataloader for parallel loading + dataloader_kwargs = dict( + shuffle=False, + batch_size=batch_size, + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=stack_np_examples_into_batch, + pin_memory=False, + drop_last=False, + timeout=0, + prefetch_factor=None if num_workers == 0 else 2, + persistent_workers=False, + ) + + return DataLoader(dataset, **dataloader_kwargs) + + +def legacy_squeeze(batch): + batch[BatchKey.gsp_id] = batch[BatchKey.gsp_id].squeeze(1) + return batch + + +def get_legacy_dataloader( + config_filename: str, + t0: pd.Timestamp, + gsp_ids: list[int], + batch_size: int, + num_workers: int, +): + + # Populate the data config with production data paths + populated_data_config_filename = Path(config_filename).parent / "data_config.yaml" + + populate_data_config_sources( + config_filename, + populated_data_config_filename, + gsp_path=os.environ["DB_URL"], + + ) + + # Set up ID location query object + ds_gsp = next(iter(OpenGSPFromDatabase())) + gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) + + # Location and time datapipes + location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids]) + t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe)) + + location_pipe = location_pipe.sharding_filter() + t0_datapipe = t0_datapipe.sharding_filter() + + # Batch datapipe + batch_datapipe = ( + construct_sliced_data_pipeline( + config_filename=populated_data_config_filename, + location_pipe=location_pipe, + t0_datapipe=t0_datapipe, + production=True, + ) + .batch(batch_size) + .map(stack_np_examples_into_batch) + .map(legacy_squeeze) + ) + + # Set up dataloader for parallel loading + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=None if num_workers == 0 else 2, + persistent_workers=False, + ) + + return DataLoader(batch_datapipe, **dataloader_kwargs) + diff --git a/pvnet_app/forecast_compiler.py b/pvnet_app/forecast_compiler.py index 9142578..3a3a8a2 100644 --- a/pvnet_app/forecast_compiler.py +++ b/pvnet_app/forecast_compiler.py @@ -69,6 +69,7 @@ def __init__( self.sun_down_masks = [] + logger.info(f"Loading model: {model_name} - {model_version}") self.model = PVNetBaseModel.from_pretrained( @@ -103,7 +104,6 @@ def log_info(self, message): if self.verbose: logger.info(message) - def predict_batch(self, batch): """Make predictions for a batch and store results internally""" @@ -111,6 +111,9 @@ def predict_batch(self, batch): # Store GSP IDs for this batch for reordering later these_gsp_ids = batch[BatchKey.gsp_id].cpu().numpy() self.gsp_ids_each_batch += [these_gsp_ids] + + # TODO: This change should be moved inside PVNet + batch[BatchKey.gsp_id] = batch[BatchKey.gsp_id].unsqueeze(1) # Run batch through model preds = self.model(batch).detach().cpu().numpy() diff --git a/pvnet_app/utils.py b/pvnet_app/utils.py index ba99b26..e78b5fa 100644 --- a/pvnet_app/utils.py +++ b/pvnet_app/utils.py @@ -25,21 +25,6 @@ logger = logging.getLogger(__name__) - -def worker_init_fn(worker_id): - """ - Clear reference to the loop and thread. - This is a nasty hack that was suggested but NOT recommended by the lead fsspec developer! - This appears necessary otherwise gcsfs hangs when used after forking multiple worker processes. - Only required for fsspec >= 0.9.0 - See: - - https://github.com/fsspec/gcsfs/issues/379#issuecomment-839929801 - - https://github.com/fsspec/filesystem_spec/pull/963#issuecomment-1131709948 - TODO: Try deleting this two lines to make sure this is still relevant. - """ - fsspec.asyn.iothread[0] = None - fsspec.asyn.loop[0] = None - def load_yaml_config(path): """Load config file from path""" @@ -54,24 +39,24 @@ def save_yaml_config(config, path): yaml.dump(config, file, default_flow_style=False) -def populate_data_config_sources(input_path, output_path): +def populate_data_config_sources(input_path, output_path, gsp_path=""): """Resave the data config and replace the source filepaths Args: input_path: Path to input datapipes configuration file output_path: Location to save the output configuration file + gsp_path: For lagacy usage only """ config = load_yaml_config(input_path) production_paths = { - "gsp": os.environ["DB_URL"], + "gsp": gsp_path, "nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path}, "satellite": sat_path, - # TODO: include hrvsatellite } # Replace data sources - for source in ["gsp", "satellite", "hrvsatellite"]: + for source in ["gsp", "satellite"]: if source in config["input_data"] : if config["input_data"][source][f"{source}_zarr_path"]!="": assert source in production_paths, f"Missing production path: {source}" diff --git a/requirements.txt b/requirements.txt index 29f8b52..00183b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ torch[cpu]==2.3.1 PVNet-summation==0.3.2 pvnet==3.0.50 ocf_datapipes==3.3.39 +ocf_data_sampler==0.0.19 nowcasting_datamodel>=1.5.45 fsspec[s3] xarray diff --git a/tests/conftest.py b/tests/conftest.py index 44e6250..a475669 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,7 @@ def test_t0(): @pytest.fixture(scope="session") def engine_url(): """Database engine, this includes the table creation.""" - with PostgresContainer("postgres:14.5") as postgres: + with PostgresContainer("postgres:16.1") as postgres: url = postgres.get_connection_url() os.environ["DB_URL"] = url @@ -85,8 +85,8 @@ def make_nwp_data(shell_path, varname, test_t0): # Load dataset which only contains coordinates, but no data ds = xr.open_zarr(shell_path) - # Last init time was at least 2 hours ago and floor to 3-hour interval - t0_datetime_utc = (test_t0 - timedelta(hours=2)).floor(timedelta(hours=3)) + # Last init time was at least 8 hours ago and floor to 3-hour interval + t0_datetime_utc = (test_t0 - timedelta(hours=8)).floor(timedelta(hours=3)) ds.init_time.values[:] = pd.date_range( t0_datetime_utc - timedelta(hours=3 * (len(ds.init_time) - 1)), t0_datetime_utc, diff --git a/tests/test_data/test.yaml b/tests/test_data/test.yaml index 2ac5c64..c18c91e 100644 --- a/tests/test_data/test.yaml +++ b/tests/test_data/test.yaml @@ -27,7 +27,7 @@ input_data: ecmwf: dropout_fraction: 1.0 dropout_timedeltas_minutes: - - -360 + - -60 forecast_minutes: 480 history_minutes: 120 index_by_id: false @@ -62,7 +62,7 @@ input_data: ukv: dropout_fraction: 1.0 dropout_timedeltas_minutes: - - -180 + - -60 forecast_minutes: 480 history_minutes: 120 index_by_id: false