diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index a375216e..b4a23d3d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -79,6 +79,68 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"): yaml.dump(config, outfile, default_flow_style=False) +def minimize_data_config(input_path, output_path, model): + """Strip out parts of the data config which aren't used by the model + + Args: + input_path: Path to input datapipes configuration file + output_path: Location to save the output configuration file + model: The PVNet model object + """ + with open(input_path) as cfg: + config = yaml.load(cfg, Loader=yaml.FullLoader) + + if "nwp" in config["input_data"]: + if not model.include_nwp: + del config["input_data"]["nwp"] + else: + for nwp_source in config["input_data"]["nwp"].keys(): + nwp_config = config["input_data"]["nwp"][nwp_source] + + if nwp_source not in model.nwp_encoders_dict: + # If not used, delete this source from the config + del config["input_data"]["nwp"][nwp_source] + else: + # Replace the image size + nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels + nwp_config["nwp_image_size_pixels_height"] = nwp_pixel_size + nwp_config["nwp_image_size_pixels_width"] = nwp_pixel_size + + # Replace the forecast minutes + nwp_config["forecast_minutes"] = ( + model.nwp_encoders_dict[nwp_source].sequence_length + - nwp_config["history_minutes"] / nwp_config["time_resolution_minutes"] + - 1 + ) * nwp_config["time_resolution_minutes"] + + if "satellite" in config["input_data"]: + if not model.include_sat: + del config["input_data"]["satellite"] + else: + sat_config = config["input_data"]["satellite"] + + # Replace the image size + sat_pixel_size = model.sat_encoder.image_size_pixels + sat_config["satellite_image_size_pixels_height"] = sat_pixel_size + sat_config["satellite_image_size_pixels_width"] = sat_pixel_size + + # Replace the satellite delay + sat_config["live_delay_minutes"] = model.min_sat_delay_minutes + + if "pv" in config["input_data"]: + if not model.include_pv: + del config["input_data"]["pv"] + + if "gsp" in config["input_data"]: + gsp_config = config["input_data"]["gsp"] + + # Replace the forecast minutes + gsp_config["forecast_minutes"] = model.forecast_minutes + + with open(output_path, "w") as outfile: + yaml.dump(config, outfile, default_flow_style=False) + + class PVNetModelHubMixin(PyTorchModelHubMixin): """ Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities. @@ -207,7 +269,13 @@ def save_pretrained( # Save cleaned datapipes configuration file if data_config is not None: - make_clean_data_config(data_config, save_directory / DATA_CONFIG_NAME) + new_data_config_path = save_directory / DATA_CONFIG_NAME + + # Replace the input filenames with place holders + make_clean_data_config(data_config, new_data_config_path) + + # Taylor the data config to the model being saved + minimize_data_config(new_data_config_path, new_data_config_path, self) # Creating and saving model card. card_data = ModelCardData(language="en", license="mit", library_name="pytorch") diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 806ce6d9..36b1c1a1 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -146,6 +146,7 @@ def __init__( self.embedding_dim = embedding_dim self.add_image_embedding_channel = add_image_embedding_channel self.interval_minutes = interval_minutes + self.min_sat_delay_minutes = min_sat_delay_minutes self.adapt_batches = adapt_batches super().__init__( diff --git a/scripts/checkpoint_to_huggingface.py b/scripts/checkpoint_to_huggingface.py index 6ad2ec81..ee876a5f 100644 --- a/scripts/checkpoint_to_huggingface.py +++ b/scripts/checkpoint_to_huggingface.py @@ -15,6 +15,7 @@ from pvnet.load_model import get_model_from_checkpoints wandb_repo = "openclimatefix/pvnet2.1" +huggingface_repo = "openclimatefix/pvnet_uk_region" def push_to_huggingface( @@ -66,7 +67,7 @@ def push_to_huggingface( data_config=data_config, wandb_ids=wandb_ids, push_to_hub=push_to_hub, - repo_id=wandb_repo if push_to_hub else None, + repo_id=huggingface_repo if push_to_hub else None, ) if local_path is None: