Skip to content

Commit

Permalink
Merge pull request #219 from openclimatefix/upload_config_fix
Browse files Browse the repository at this point in the history
Add filter to config for when trained on overcomplete batches
  • Loading branch information
dfulu authored Jun 13, 2024
2 parents c17f0e4 + ee352d7 commit c67d36f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
70 changes: 69 additions & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,68 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):
yaml.dump(config, outfile, default_flow_style=False)


def minimize_data_config(input_path, output_path, model):
"""Strip out parts of the data config which aren't used by the model
Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
model: The PVNet model object
"""
with open(input_path) as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)

if "nwp" in config["input_data"]:
if not model.include_nwp:
del config["input_data"]["nwp"]
else:
for nwp_source in config["input_data"]["nwp"].keys():
nwp_config = config["input_data"]["nwp"][nwp_source]

if nwp_source not in model.nwp_encoders_dict:
# If not used, delete this source from the config
del config["input_data"]["nwp"][nwp_source]
else:
# Replace the image size
nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels
nwp_config["nwp_image_size_pixels_height"] = nwp_pixel_size
nwp_config["nwp_image_size_pixels_width"] = nwp_pixel_size

# Replace the forecast minutes
nwp_config["forecast_minutes"] = (
model.nwp_encoders_dict[nwp_source].sequence_length
- nwp_config["history_minutes"] / nwp_config["time_resolution_minutes"]
- 1
) * nwp_config["time_resolution_minutes"]

if "satellite" in config["input_data"]:
if not model.include_sat:
del config["input_data"]["satellite"]
else:
sat_config = config["input_data"]["satellite"]

# Replace the image size
sat_pixel_size = model.sat_encoder.image_size_pixels
sat_config["satellite_image_size_pixels_height"] = sat_pixel_size
sat_config["satellite_image_size_pixels_width"] = sat_pixel_size

# Replace the satellite delay
sat_config["live_delay_minutes"] = model.min_sat_delay_minutes

if "pv" in config["input_data"]:
if not model.include_pv:
del config["input_data"]["pv"]

if "gsp" in config["input_data"]:
gsp_config = config["input_data"]["gsp"]

# Replace the forecast minutes
gsp_config["forecast_minutes"] = model.forecast_minutes

with open(output_path, "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)


class PVNetModelHubMixin(PyTorchModelHubMixin):
"""
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
Expand Down Expand Up @@ -207,7 +269,13 @@ def save_pretrained(

# Save cleaned datapipes configuration file
if data_config is not None:
make_clean_data_config(data_config, save_directory / DATA_CONFIG_NAME)
new_data_config_path = save_directory / DATA_CONFIG_NAME

# Replace the input filenames with place holders
make_clean_data_config(data_config, new_data_config_path)

# Taylor the data config to the model being saved
minimize_data_config(new_data_config_path, new_data_config_path, self)

# Creating and saving model card.
card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
Expand Down
1 change: 1 addition & 0 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
self.embedding_dim = embedding_dim
self.add_image_embedding_channel = add_image_embedding_channel
self.interval_minutes = interval_minutes
self.min_sat_delay_minutes = min_sat_delay_minutes
self.adapt_batches = adapt_batches

super().__init__(
Expand Down
3 changes: 2 additions & 1 deletion scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pvnet.load_model import get_model_from_checkpoints

wandb_repo = "openclimatefix/pvnet2.1"
huggingface_repo = "openclimatefix/pvnet_uk_region"


def push_to_huggingface(
Expand Down Expand Up @@ -66,7 +67,7 @@ def push_to_huggingface(
data_config=data_config,
wandb_ids=wandb_ids,
push_to_hub=push_to_hub,
repo_id=wandb_repo if push_to_hub else None,
repo_id=huggingface_repo if push_to_hub else None,
)

if local_path is None:
Expand Down

0 comments on commit c67d36f

Please sign in to comment.