From 26426850073df3d1ab48a96158bcec2c119e06cb Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 16 Oct 2024 11:54:22 +0100 Subject: [PATCH 1/2] fix for us legacy --- pvnet_app/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 306ef0b..2d26998 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -218,7 +218,7 @@ def app( t0=t0, gsp_capacities=gsp_capacities, national_capacity=national_capacity, - use_legacy=use_day_ahead_model, + use_legacy=not use_ocf_data_sampler, ) # Store the config filename so we can create batches suitable for all models From 54616d58dbddaca8266196bf23757c87a19b7a4f Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 16 Oct 2024 12:08:23 +0100 Subject: [PATCH 2/2] run blacks --- pvnet_app/app.py | 34 ++++----- pvnet_app/config.py | 81 +++++++++++----------- pvnet_app/consts.py | 2 +- pvnet_app/data/nwp.py | 5 +- pvnet_app/data/satellite.py | 54 ++++++++------- pvnet_app/dataloader.py | 37 +++++----- pvnet_app/model_configs/pydantic_models.py | 14 ++-- 7 files changed, 113 insertions(+), 114 deletions(-) diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 2d26998..8c7219c 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -34,9 +34,7 @@ # sentry sentry_sdk.init( - dsn=os.getenv("SENTRY_DSN"), - environment=os.getenv("ENVIRONMENT", "local"), - traces_sample_rate=1 + dsn=os.getenv("SENTRY_DSN"), environment=os.getenv("ENVIRONMENT", "local"), traces_sample_rate=1 ) sentry_sdk.set_tag("app_name", "pvnet_app") @@ -129,7 +127,7 @@ def app( use_ecmwf_only = os.getenv("USE_ECMWF_ONLY", "false").lower() == "true" run_extra_models = os.getenv("RUN_EXTRA_MODELS", "false").lower() == "true" use_ocf_data_sampler = os.getenv("USE_OCF_DATA_SAMPLER", "true").lower() == "true" - + logger.info(f"Using `pvnet` library version: {pvnet.__version__}") logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}") logger.info(f"Using {num_workers} workers") @@ -138,10 +136,12 @@ def app( logger.info(f"Running extra models: {run_extra_models}") # load models - model_configs = get_all_models(get_ecmwf_only=use_ecmwf_only, - get_day_ahead_only=use_day_ahead_model, - run_extra_models=run_extra_models, - use_ocf_data_sampler=use_ocf_data_sampler) + model_configs = get_all_models( + get_ecmwf_only=use_ecmwf_only, + get_day_ahead_only=use_day_ahead_model, + run_extra_models=run_extra_models, + use_ocf_data_sampler=use_ocf_data_sampler, + ) logger.info(f"Using adjuster: {model_configs[0].use_adjuster}") logger.info(f"Saving GSP sum: {model_configs[0].save_gsp_sum}") @@ -166,12 +166,12 @@ def app( # Get capacities from the database logger.info("Loading capacities from the database") - + db_connection = DatabaseConnection(url=os.getenv("DB_URL"), base=Base_Forecast, echo=False) with db_connection.get_session() as session: #  Pandas series of most recent GSP capacities gsp_capacities = get_latest_gsp_capacities( - session=session, gsp_ids=gsp_ids, datetime_utc=t0-timedelta(days=2) + session=session, gsp_ids=gsp_ids, datetime_utc=t0 - timedelta(days=2) ) # National capacity is needed if using summation model @@ -241,21 +241,21 @@ def app( logger.info("Creating DataLoader") if not use_ocf_data_sampler: - logger.info('Making OCF datapipes dataloader') + logger.info("Making OCF datapipes dataloader") # The current day ahead model uses the legacy dataloader dataloader = get_legacy_dataloader( - config_filename=common_config_path, - t0=t0, + config_filename=common_config_path, + t0=t0, gsp_ids=gsp_ids, batch_size=batch_size, num_workers=num_workers, ) - + else: - logger.info('Making OCF Data Sampler dataloader') + logger.info("Making OCF Data Sampler dataloader") dataloader = get_dataloader( - config_filename=common_config_path, - t0=t0, + config_filename=common_config_path, + t0=t0, gsp_ids=gsp_ids, batch_size=batch_size, num_workers=num_workers, diff --git a/pvnet_app/config.py b/pvnet_app/config.py index b66aaa1..bad2ce2 100644 --- a/pvnet_app/config.py +++ b/pvnet_app/config.py @@ -12,7 +12,7 @@ def load_yaml_config(path: str) -> dict: def save_yaml_config(config: dict, path: str) -> None: """Save config file to path""" - with open(path, 'w') as file: + with open(path, "w") as file: yaml.dump(config, file, default_flow_style=False) @@ -23,29 +23,29 @@ def populate_config_with_data_data_filepaths(config: dict, gsp_path: str = "") - config: The data config gsp_path: For lagacy usage only """ - + production_paths = { "gsp": gsp_path, "nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path}, "satellite": sat_path, - } - + } + # Replace data sources for source in ["gsp", "satellite"]: - if source in config["input_data"] : - if config["input_data"][source][f"{source}_zarr_path"]!="": + if source in config["input_data"]: + if config["input_data"][source][f"{source}_zarr_path"] != "": assert source in production_paths, f"Missing production path: {source}" config["input_data"][source][f"{source}_zarr_path"] = production_paths[source] - + # NWP is nested so much be treated separately if "nwp" in config["input_data"]: nwp_config = config["input_data"]["nwp"] for nwp_source in nwp_config.keys(): - if nwp_config[nwp_source]["nwp_zarr_path"]!="": + if nwp_config[nwp_source]["nwp_zarr_path"] != "": assert "nwp" in production_paths, "Missing production path: nwp" assert nwp_source in production_paths["nwp"], f"Missing NWP path: {nwp_source}" nwp_config[nwp_source]["nwp_zarr_path"] = production_paths["nwp"][nwp_source] - + return config @@ -55,27 +55,25 @@ def overwrite_config_dropouts(config: dict) -> dict: Args: config: The data config """ - + # Replace data sources for source in ["satellite"]: - if source in config["input_data"] : - if config["input_data"][source][f"{source}_zarr_path"]!="": + if source in config["input_data"]: + if config["input_data"][source][f"{source}_zarr_path"] != "": config["input_data"][source][f"dropout_timedeltas_minutes"] = None - + # NWP is nested so much be treated separately if "nwp" in config["input_data"]: nwp_config = config["input_data"]["nwp"] for nwp_source in nwp_config.keys(): - if nwp_config[nwp_source]["nwp_zarr_path"]!="": + if nwp_config[nwp_source]["nwp_zarr_path"] != "": nwp_config[nwp_source]["dropout_timedeltas_minutes"] = None - + return config - - + + def modify_data_config_for_production( - input_path: str, - output_path: str, - gsp_path: str = "" + input_path: str, output_path: str, gsp_path: str = "" ) -> None: """Resave the data config with the data source filepaths and dropouts overwritten @@ -85,16 +83,16 @@ def modify_data_config_for_production( gsp_path: For lagacy usage only """ config = load_yaml_config(input_path) - + config = populate_config_with_data_data_filepaths(config, gsp_path=gsp_path) config = overwrite_config_dropouts(config) - + save_yaml_config(config, output_path) - - + + def get_union_of_configs(config_paths: list[str]) -> dict: """Find the config which is able to run all models from a list of config paths - + Note that this implementation is very limited and will not work in general unless all models have been trained on the same batches. We do not chck example if the satellite and NWP channels are the same in the different configs, or whether the NWP time slices are the same. Many more @@ -103,37 +101,36 @@ def get_union_of_configs(config_paths: list[str]) -> dict: # Load all the configs configs = [load_yaml_config(config_path) for config_path in config_paths] - + # We will ammend this config according to the entries in the other configs - common_config = configs[0] - + common_config = configs[0] + for config in configs[1:]: - + if "satellite" in config["input_data"]: - + if "satellite" in common_config["input_data"]: - - # Find the minimum satellite delay across configs + + # Find the minimum satellite delay across configs common_config["input_data"]["satellite"]["live_delay_minutes"] = min( common_config["input_data"]["satellite"]["live_delay_minutes"], - config["input_data"]["satellite"]["live_delay_minutes"] + config["input_data"]["satellite"]["live_delay_minutes"], ) - else: - # Add satellite to common config if not there already + # Add satellite to common config if not there already common_config["input_data"]["satellite"] = config["input_data"]["satellite"] - + if "nwp" in config["input_data"]: - - # Add NWP to common config if not there already + + # Add NWP to common config if not there already if "nwp" not in common_config["input_data"]: common_config["input_data"]["nwp"] = config["input_data"]["nwp"] - + else: for nwp_key, nwp_conf in config["input_data"]["nwp"].items(): - # Add different NWP sources to common config if not there already + # Add different NWP sources to common config if not there already if nwp_key not in common_config["input_data"]["nwp"]: common_config["input_data"]["nwp"][nwp_key] = nwp_conf - - return common_config \ No newline at end of file + + return common_config diff --git a/pvnet_app/consts.py b/pvnet_app/consts.py index ff54f0a..0a32205 100644 --- a/pvnet_app/consts.py +++ b/pvnet_app/consts.py @@ -1,3 +1,3 @@ sat_path = "sat.zarr" nwp_ukv_path = "nwp_ukv.zarr" -nwp_ecmwf_path = "nwp_ecmwf.zarr" \ No newline at end of file +nwp_ecmwf_path = "nwp_ecmwf.zarr" diff --git a/pvnet_app/data/nwp.py b/pvnet_app/data/nwp.py index 11d3ac2..31c49bb 100644 --- a/pvnet_app/data/nwp.py +++ b/pvnet_app/data/nwp.py @@ -21,7 +21,9 @@ def _download_nwp_data(source, destination): fs.get(source, destination, recursive=True) -def download_all_nwp_data(download_ukv: Optional[bool] = True, download_ecmwf: Optional[bool] = True): +def download_all_nwp_data( + download_ukv: Optional[bool] = True, download_ecmwf: Optional[bool] = True +): """Download the NWP data""" if download_ukv: _download_nwp_data(os.environ["NWP_UKV_ZARR_PATH"], nwp_ukv_path) @@ -163,4 +165,3 @@ def preprocess_nwp_data(use_ukv: Optional[bool] = True, use_ecmwf: Optional[bool fix_ecmwf_data() else: logger.info(f"Skipping ECMWF data preprocessing") - diff --git a/pvnet_app/data/satellite.py b/pvnet_app/data/satellite.py index 711fadb..ed56d0d 100644 --- a/pvnet_app/data/satellite.py +++ b/pvnet_app/data/satellite.py @@ -18,7 +18,7 @@ def download_all_sat_data() -> bool: """Download the sat data and return whether it was successful - + Returns: bool: Whether the download was successful """ @@ -29,8 +29,7 @@ def download_all_sat_data() -> bool: # Set variable to track whether the satellite download is successful sat_available = False if "SATELLITE_ZARR_PATH" not in os.environ: - logger.info("SATELLITE_ZARR_PATH has not be set. " - "No satellite data will be downloaded.") + logger.info("SATELLITE_ZARR_PATH has not be set. " "No satellite data will be downloaded.") return False # download 5 minute satellite data @@ -55,16 +54,16 @@ def download_all_sat_data() -> bool: os.system(f"rm sat_15_min.zarr.zip") else: logger.info(f"No 15-minute data available") - + return sat_available def get_satellite_timestamps(sat_zarr_path: str) -> pd.DatetimeIndex: """Get the datetimes of the satellite data - + Args: sat_zarr_path: The path to the satellite zarr - + Returns: pd.DatetimeIndex: All available satellite timestamps """ @@ -73,8 +72,7 @@ def get_satellite_timestamps(sat_zarr_path: str) -> pd.DatetimeIndex: def combine_5_and_15_sat_data() -> None: - """Select and/or combine the 5 and 15-minutely satellite data and move it to the expected path - """ + """Select and/or combine the 5 and 15-minutely satellite data and move it to the expected path""" # Check which satellite data exists exists_5_minute = os.path.exists(sat_5_path) @@ -101,7 +99,7 @@ def combine_5_and_15_sat_data() -> None: ) else: logger.info("No 15-minute data was found.") - + # If both 5- and 15-minute data exists, use the most recent if exists_5_minute and exists_15_minute: use_5_minute = datetimes_5min.max() > datetimes_15min.max() @@ -132,7 +130,7 @@ def fill_1d_bool_gaps(x, max_gap): >>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0]) >>> fill_1d_bool_gaps(x, max_gap=2).astype(int) array([0, 1, 1, 1, 1, 1, 1, 0]) - + >>> x = np.array([1, 0, 0, 0, 1, 0, 1, 0]) >>> fill_1d_bool_gaps(x, max_gap=2).astype(int) array([1, 0, 0, 0, 1, 1, 1, 0]) @@ -157,7 +155,7 @@ def fill_1d_bool_gaps(x, max_gap): def interpolate_missing_satellite_timestamps(max_gap: pd.Timedelta) -> None: """Interpolate missing satellite timestamps""" - + ds_sat = xr.open_zarr(sat_path) # If any of these times are missing, we will try to interpolate them @@ -182,13 +180,13 @@ def interpolate_missing_satellite_timestamps(max_gap: pd.Timedelta) -> None: else: logger.info("Some requested times are missing - running interpolation") - + # Compute before interpolation for efficiency ds_sat = ds_sat.compute() - + # Run the interpolation to all 5-minute timestamps between the first and last ds_interp = ds_sat.interp(time=dense_times, method="linear", assume_sorted=True) - + # Find the timestamps which are within max gap size max_gap_steps = int(max_gap / pd.Timedelta("5min")) - 1 valid_fill_times = fill_1d_bool_gaps(timestamp_available, max_gap_steps) @@ -197,9 +195,9 @@ def interpolate_missing_satellite_timestamps(max_gap: pd.Timedelta) -> None: valid_fill_times_xr = xr.zeros_like(ds_interp.time, dtype=bool) valid_fill_times_xr.values[:] = valid_fill_times ds_sat = ds_interp.where(valid_fill_times_xr) - + time_was_filled = np.logical_and(valid_fill_times_xr, ~timestamp_available) - + if time_was_filled.any(): infilled_times = time_was_filled.where(time_was_filled, drop=True) logger.info( @@ -218,10 +216,12 @@ def interpolate_missing_satellite_timestamps(max_gap: pd.Timedelta) -> None: os.system(f"rm -rf {sat_path}") ds_sat.to_zarr(sat_path) - -def extend_satellite_data_with_nans(t0: pd.Timestamp, satellite_data_path: Optional[str] = sat_path) -> None: + +def extend_satellite_data_with_nans( + t0: pd.Timestamp, satellite_data_path: Optional[str] = sat_path +) -> None: """Fill the satellite data with NaNs out to time t0 - + Args: t0: The init-time of the forecast """ @@ -235,8 +235,10 @@ def extend_satellite_data_with_nans(t0: pd.Timestamp, satellite_data_path: Optio logger.info(f"Filling most recent {delay} with NaNs") if delay > pd.Timedelta("3h"): - logger.warning("The satellite data is delayed by more than 3 hours. " - "Will only infill last 3 hours.") + logger.warning( + "The satellite data is delayed by more than 3 hours. " + "Will only infill last 3 hours." + ) delay = pd.Timedelta("3h") # Load into memory so we can delete it on disk @@ -254,7 +256,7 @@ def extend_satellite_data_with_nans(t0: pd.Timestamp, satellite_data_path: Optio def check_model_satellite_inputs_available( - data_config_filename: str, + data_config_filename: str, t0: pd.Timestamp, sat_datetimes: pd.DatetimeIndex, ) -> bool: @@ -264,7 +266,7 @@ def check_model_satellite_inputs_available( data_config_filename: Path to the data configuration file t0: The init-time of the forecast available_sat_datetimes: The available satellite timestamps - + Returns: bool: Whether the satellite data satisfies that specified in the config """ @@ -301,7 +303,7 @@ def check_model_satellite_inputs_available( available = len(missing_time_steps) == 0 - if len(missing_time_steps)>0: + if len(missing_time_steps) > 0: logger.info(f"Some satellite timesteps for {t0=} missing: \n{missing_time_steps}") return available @@ -309,11 +311,11 @@ def check_model_satellite_inputs_available( def preprocess_sat_data(t0: pd.Timestamp, use_legacy: bool = False) -> pd.DatetimeIndex: """Combine and 5- and 15-minutely satellite data and extend to t0 if required - + Args: t0: The init-time of the forecast use_legacy: Whether to prepare the data as required for the legacy dataloader - + Returns: pd.DatetimeIndex: The available satellite timestamps int: The spacing between data samples in minutes diff --git a/pvnet_app/dataloader.py b/pvnet_app/dataloader.py index ba380ee..ce1bc98 100644 --- a/pvnet_app/dataloader.py +++ b/pvnet_app/dataloader.py @@ -16,23 +16,22 @@ from pvnet.utils import GSPLocationLookup - def get_dataloader( - config_filename: str, - t0: pd.Timestamp, - gsp_ids: list[int], + config_filename: str, + t0: pd.Timestamp, + gsp_ids: list[int], batch_size: int, num_workers: int, ): - - # Populate the data config with production data paths + + # Populate the data config with production data paths modified_data_config_filename = Path(config_filename).parent / "data_config.yaml" - + modify_data_config_for_production(config_filename, modified_data_config_filename) - + dataset = PVNetUKRegionalDataset( - config_filename=modified_data_config_filename, - start_time=t0, + config_filename=modified_data_config_filename, + start_time=t0, end_time=t0, gsp_ids=gsp_ids, ) @@ -61,27 +60,26 @@ def legacy_squeeze(batch): def get_legacy_dataloader( - config_filename: str, - t0: pd.Timestamp, - gsp_ids: list[int], + config_filename: str, + t0: pd.Timestamp, + gsp_ids: list[int], batch_size: int, num_workers: int, ): - + # Populate the data config with production data paths populated_data_config_filename = Path(config_filename).parent / "data_config.yaml" - + modify_data_config_for_production( - config_filename, + config_filename, populated_data_config_filename, gsp_path=os.environ["DB_URL"], - ) - + # Set up ID location query object ds_gsp = next(iter(OpenGSPFromDatabase())) gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - + # Location and time datapipes location_pipe = IterableWrapper([gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids]) t0_datapipe = IterableWrapper([t0]).repeat(len(location_pipe)) @@ -119,4 +117,3 @@ def get_legacy_dataloader( ) return DataLoader(batch_datapipe, **dataloader_kwargs) - diff --git a/pvnet_app/model_configs/pydantic_models.py b/pvnet_app/model_configs/pydantic_models.py index c0780b2..b70129e 100644 --- a/pvnet_app/model_configs/pydantic_models.py +++ b/pvnet_app/model_configs/pydantic_models.py @@ -45,17 +45,17 @@ class Model(BaseModel): False, title="ECMWF ONly", description="If this model is only using ecmwf data" ) - uses_satellite_data: Optional[bool] = Field( + uses_satellite_data: Optional[bool] = Field( True, title="Uses Satellite Data", description="If this model uses satellite data" ) uses_ocf_data_sampler: Optional[bool] = Field( - True, title="Uses OCF Data Sampler", description="If this model uses data sampler, old one uses ocf_datapipes" + True, + title="Uses OCF Data Sampler", + description="If this model uses data sampler, old one uses ocf_datapipes", ) - - class Models(BaseModel): """A group of ml models""" @@ -67,7 +67,7 @@ class Models(BaseModel): @classmethod def name_must_be_unique(cls, v: List[Model]) -> List[Model]: """Ensure that all model names are unique, respect to using ocf_data_sampler or not""" - names = [(model.name,model.uses_ocf_data_sampler) for model in v] + names = [(model.name, model.uses_ocf_data_sampler) for model in v] unique_names = set(names) if len(names) != len(unique_names): @@ -122,7 +122,9 @@ def get_all_models( log.info("Not using OCF Data Sampler, using ocf_datapipes") models.models = [model for model in models.models if not model.uses_ocf_data_sampler] - log.info(f"Got the following models: {[(model.name, model.uses_ocf_data_sampler) for model in models.models]}") + log.info( + f"Got the following models: {[(model.name, f'uses_ocf_data_sampler={model.uses_ocf_data_sampler}') for model in models.models]}" + ) return models.models