From 05e05ce546ce2d9f9b94863f89ca19c279f00042 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 13 Jun 2024 10:25:37 +0000 Subject: [PATCH 1/3] add filter to config for when trained on overcomplete batches --- pvnet/models/base_model.py | 75 +++++++++++++++++++++++++-- pvnet/models/multimodal/multimodal.py | 1 + scripts/checkpoint_to_huggingface.py | 3 +- 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 46706c00..2dd4be5f 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -42,7 +42,7 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"): """Resave the data config and replace the filepaths with a placeholder. - + Args: input_path: Path to input datapipes configuration file output_path: Location to save the output configuration file @@ -77,6 +77,69 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"): with open(output_path, "w") as outfile: yaml.dump(config, outfile, default_flow_style=False) + + + +def minimize_data_config(input_path, output_path, model): + """Strip out parts of the data config which aren't used by the model + + Args: + input_path: Path to input datapipes configuration file + output_path: Location to save the output configuration file + """ + with open(input_path) as cfg: + config = yaml.load(cfg, Loader=yaml.FullLoader) + + if "nwp" in config["input_data"]: + if not model.include_nwp: + del config["input_data"]["nwp"] + else: + for nwp_source in config["input_data"]["nwp"].keys(): + nwp_config = config["input_data"]["nwp"][nwp_source] + + if nwp_source not in model.nwp_encoders_dict: + # If not used, delete this source from the config + del config["input_data"]["nwp"][nwp_source] + else: + # Replace the image size + nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels + nwp_config["nwp_image_size_pixels_height"] = nwp_pixel_size + nwp_config["nwp_image_size_pixels_width"] = nwp_pixel_size + + # Replace the forecast minutes + nwp_config["forecast_minutes"] = ( + model.nwp_encoders_dict[nwp_source].sequence_length - + nwp_config["history_minutes"]/nwp_config["time_resolution_minutes"] - 1 + )*nwp_config["time_resolution_minutes"] + + + if "satellite" in config["input_data"]: + if not model.include_sat: + del config["input_data"]["satellite"] + else: + sat_config = config["input_data"]["satellite"] + + # Replace the image size + sat_pixel_size = model.sat_encoder.image_size_pixels + sat_config["satellite_image_size_pixels_height"] = sat_pixel_size + sat_config["satellite_image_size_pixels_width"] = sat_pixel_size + + # Replace the satellite delay + sat_config["live_delay_minutes"] = model.min_sat_delay_minutes + + if "pv" in config["input_data"]: + if not model.include_pv: + del config["input_data"]["pv"] + + if "gsp" in config["input_data"]: + gsp_config = config["input_data"]["gsp"] + + # Replace the forecast minutes + gsp_config["forecast_minutes"] = model.forecast_minutes + + with open(output_path, "w") as outfile: + yaml.dump(config, outfile, default_flow_style=False) + class PVNetModelHubMixin(PyTorchModelHubMixin): @@ -207,7 +270,13 @@ def save_pretrained( # Save cleaned datapipes configuration file if data_config is not None: - make_clean_data_config(data_config, save_directory / DATA_CONFIG_NAME) + new_data_config_path = save_directory / DATA_CONFIG_NAME + + # Replace the input filenames with place holders + make_clean_data_config(data_config, new_data_config_path) + + # Taylor the data config to the model being saved + minimize_data_config(new_data_config_path, new_data_config_path, self) # Creating and saving model card. card_data = ModelCardData(language="en", license="mit", library_name="pytorch") @@ -580,7 +649,7 @@ def on_validation_epoch_end(self): """Run on epoch end""" horizon_maes_dict = self._horizon_maes.flush() - + # Create the horizon accuracy curve if isinstance(self.logger, pl.loggers.WandbLogger): per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)] diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index ad4c11d6..2501b61e 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -143,6 +143,7 @@ def __init__( self.embedding_dim = embedding_dim self.add_image_embedding_channel = add_image_embedding_channel self.interval_minutes = interval_minutes + self.min_sat_delay_minutes = min_sat_delay_minutes self.adapt_batches = adapt_batches super().__init__( diff --git a/scripts/checkpoint_to_huggingface.py b/scripts/checkpoint_to_huggingface.py index 6ad2ec81..ee876a5f 100644 --- a/scripts/checkpoint_to_huggingface.py +++ b/scripts/checkpoint_to_huggingface.py @@ -15,6 +15,7 @@ from pvnet.load_model import get_model_from_checkpoints wandb_repo = "openclimatefix/pvnet2.1" +huggingface_repo = "openclimatefix/pvnet_uk_region" def push_to_huggingface( @@ -66,7 +67,7 @@ def push_to_huggingface( data_config=data_config, wandb_ids=wandb_ids, push_to_hub=push_to_hub, - repo_id=wandb_repo if push_to_hub else None, + repo_id=huggingface_repo if push_to_hub else None, ) if local_path is None: From 2f74109072742ad3abfd481be6a1694c3dd3325d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:34:23 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 2dd4be5f..eb5b0243 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -42,7 +42,7 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"): """Resave the data config and replace the filepaths with a placeholder. - + Args: input_path: Path to input datapipes configuration file output_path: Location to save the output configuration file @@ -77,26 +77,25 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"): with open(output_path, "w") as outfile: yaml.dump(config, outfile, default_flow_style=False) - - + def minimize_data_config(input_path, output_path, model): """Strip out parts of the data config which aren't used by the model - + Args: input_path: Path to input datapipes configuration file output_path: Location to save the output configuration file """ with open(input_path) as cfg: config = yaml.load(cfg, Loader=yaml.FullLoader) - + if "nwp" in config["input_data"]: if not model.include_nwp: del config["input_data"]["nwp"] else: for nwp_source in config["input_data"]["nwp"].keys(): nwp_config = config["input_data"]["nwp"][nwp_source] - + if nwp_source not in model.nwp_encoders_dict: # If not used, delete this source from the config del config["input_data"]["nwp"][nwp_source] @@ -105,35 +104,35 @@ def minimize_data_config(input_path, output_path, model): nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels nwp_config["nwp_image_size_pixels_height"] = nwp_pixel_size nwp_config["nwp_image_size_pixels_width"] = nwp_pixel_size - + # Replace the forecast minutes nwp_config["forecast_minutes"] = ( - model.nwp_encoders_dict[nwp_source].sequence_length - - nwp_config["history_minutes"]/nwp_config["time_resolution_minutes"] - 1 - )*nwp_config["time_resolution_minutes"] - - + model.nwp_encoders_dict[nwp_source].sequence_length + - nwp_config["history_minutes"] / nwp_config["time_resolution_minutes"] + - 1 + ) * nwp_config["time_resolution_minutes"] + if "satellite" in config["input_data"]: if not model.include_sat: del config["input_data"]["satellite"] else: sat_config = config["input_data"]["satellite"] - + # Replace the image size sat_pixel_size = model.sat_encoder.image_size_pixels sat_config["satellite_image_size_pixels_height"] = sat_pixel_size sat_config["satellite_image_size_pixels_width"] = sat_pixel_size - + # Replace the satellite delay sat_config["live_delay_minutes"] = model.min_sat_delay_minutes if "pv" in config["input_data"]: if not model.include_pv: del config["input_data"]["pv"] - + if "gsp" in config["input_data"]: gsp_config = config["input_data"]["gsp"] - + # Replace the forecast minutes gsp_config["forecast_minutes"] = model.forecast_minutes @@ -141,7 +140,6 @@ def minimize_data_config(input_path, output_path, model): yaml.dump(config, outfile, default_flow_style=False) - class PVNetModelHubMixin(PyTorchModelHubMixin): """ Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities. @@ -271,10 +269,10 @@ def save_pretrained( # Save cleaned datapipes configuration file if data_config is not None: new_data_config_path = save_directory / DATA_CONFIG_NAME - + # Replace the input filenames with place holders make_clean_data_config(data_config, new_data_config_path) - + # Taylor the data config to the model being saved minimize_data_config(new_data_config_path, new_data_config_path, self) @@ -649,7 +647,7 @@ def on_validation_epoch_end(self): """Run on epoch end""" horizon_maes_dict = self._horizon_maes.flush() - + # Create the horizon accuracy curve if isinstance(self.logger, pl.loggers.WandbLogger): per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)] From ee352d7897064af22a1fd13301efc018de91c12e Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 13 Jun 2024 11:21:55 +0000 Subject: [PATCH 3/3] linting --- pvnet/models/base_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index eb5b0243..10939bf5 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -85,6 +85,7 @@ def minimize_data_config(input_path, output_path, model): Args: input_path: Path to input datapipes configuration file output_path: Location to save the output configuration file + model: The PVNet model object """ with open(input_path) as cfg: config = yaml.load(cfg, Loader=yaml.FullLoader)