Skip to content

Commit

Permalink
add filter to config for when trained on overcomplete batches
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jun 13, 2024
1 parent 1c8a7b4 commit 05e05ce
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
75 changes: 72 additions & 3 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):
"""Resave the data config and replace the filepaths with a placeholder.
Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
Expand Down Expand Up @@ -77,6 +77,69 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):

with open(output_path, "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)



def minimize_data_config(input_path, output_path, model):
"""Strip out parts of the data config which aren't used by the model
Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
"""
with open(input_path) as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)

if "nwp" in config["input_data"]:
if not model.include_nwp:
del config["input_data"]["nwp"]
else:
for nwp_source in config["input_data"]["nwp"].keys():
nwp_config = config["input_data"]["nwp"][nwp_source]

if nwp_source not in model.nwp_encoders_dict:
# If not used, delete this source from the config
del config["input_data"]["nwp"][nwp_source]
else:
# Replace the image size
nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels
nwp_config["nwp_image_size_pixels_height"] = nwp_pixel_size
nwp_config["nwp_image_size_pixels_width"] = nwp_pixel_size

# Replace the forecast minutes
nwp_config["forecast_minutes"] = (
model.nwp_encoders_dict[nwp_source].sequence_length -
nwp_config["history_minutes"]/nwp_config["time_resolution_minutes"] - 1
)*nwp_config["time_resolution_minutes"]


if "satellite" in config["input_data"]:
if not model.include_sat:
del config["input_data"]["satellite"]
else:
sat_config = config["input_data"]["satellite"]

# Replace the image size
sat_pixel_size = model.sat_encoder.image_size_pixels
sat_config["satellite_image_size_pixels_height"] = sat_pixel_size
sat_config["satellite_image_size_pixels_width"] = sat_pixel_size

# Replace the satellite delay
sat_config["live_delay_minutes"] = model.min_sat_delay_minutes

if "pv" in config["input_data"]:
if not model.include_pv:
del config["input_data"]["pv"]

if "gsp" in config["input_data"]:
gsp_config = config["input_data"]["gsp"]

# Replace the forecast minutes
gsp_config["forecast_minutes"] = model.forecast_minutes

with open(output_path, "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)



class PVNetModelHubMixin(PyTorchModelHubMixin):
Expand Down Expand Up @@ -207,7 +270,13 @@ def save_pretrained(

# Save cleaned datapipes configuration file
if data_config is not None:
make_clean_data_config(data_config, save_directory / DATA_CONFIG_NAME)
new_data_config_path = save_directory / DATA_CONFIG_NAME

# Replace the input filenames with place holders
make_clean_data_config(data_config, new_data_config_path)

# Taylor the data config to the model being saved
minimize_data_config(new_data_config_path, new_data_config_path, self)

# Creating and saving model card.
card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
Expand Down Expand Up @@ -580,7 +649,7 @@ def on_validation_epoch_end(self):
"""Run on epoch end"""

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
if isinstance(self.logger, pl.loggers.WandbLogger):
per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)]
Expand Down
1 change: 1 addition & 0 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
self.embedding_dim = embedding_dim
self.add_image_embedding_channel = add_image_embedding_channel
self.interval_minutes = interval_minutes
self.min_sat_delay_minutes = min_sat_delay_minutes
self.adapt_batches = adapt_batches

super().__init__(
Expand Down
3 changes: 2 additions & 1 deletion scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pvnet.load_model import get_model_from_checkpoints

wandb_repo = "openclimatefix/pvnet2.1"
huggingface_repo = "openclimatefix/pvnet_uk_region"


def push_to_huggingface(
Expand Down Expand Up @@ -66,7 +67,7 @@ def push_to_huggingface(
data_config=data_config,
wandb_ids=wandb_ids,
push_to_hub=push_to_hub,
repo_id=wandb_repo if push_to_hub else None,
repo_id=huggingface_repo if push_to_hub else None,
)

if local_path is None:
Expand Down

0 comments on commit 05e05ce

Please sign in to comment.