Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 13, 2024
1 parent 05e05ce commit 2f74109
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):
"""Resave the data config and replace the filepaths with a placeholder.
Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
Expand Down Expand Up @@ -77,26 +77,25 @@ def make_clean_data_config(input_path, output_path, placeholder="PLACEHOLDER"):

with open(output_path, "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)




def minimize_data_config(input_path, output_path, model):
"""Strip out parts of the data config which aren't used by the model
Args:
input_path: Path to input datapipes configuration file
output_path: Location to save the output configuration file
"""
with open(input_path) as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)

if "nwp" in config["input_data"]:
if not model.include_nwp:
del config["input_data"]["nwp"]
else:
for nwp_source in config["input_data"]["nwp"].keys():
nwp_config = config["input_data"]["nwp"][nwp_source]

if nwp_source not in model.nwp_encoders_dict:
# If not used, delete this source from the config
del config["input_data"]["nwp"][nwp_source]
Expand All @@ -105,43 +104,42 @@ def minimize_data_config(input_path, output_path, model):
nwp_pixel_size = model.nwp_encoders_dict[nwp_source].image_size_pixels
nwp_config["nwp_image_size_pixels_height"] = nwp_pixel_size
nwp_config["nwp_image_size_pixels_width"] = nwp_pixel_size

# Replace the forecast minutes
nwp_config["forecast_minutes"] = (
model.nwp_encoders_dict[nwp_source].sequence_length -
nwp_config["history_minutes"]/nwp_config["time_resolution_minutes"] - 1
)*nwp_config["time_resolution_minutes"]

model.nwp_encoders_dict[nwp_source].sequence_length
- nwp_config["history_minutes"] / nwp_config["time_resolution_minutes"]
- 1
) * nwp_config["time_resolution_minutes"]

if "satellite" in config["input_data"]:
if not model.include_sat:
del config["input_data"]["satellite"]
else:
sat_config = config["input_data"]["satellite"]

# Replace the image size
sat_pixel_size = model.sat_encoder.image_size_pixels
sat_config["satellite_image_size_pixels_height"] = sat_pixel_size
sat_config["satellite_image_size_pixels_width"] = sat_pixel_size

# Replace the satellite delay
sat_config["live_delay_minutes"] = model.min_sat_delay_minutes

if "pv" in config["input_data"]:
if not model.include_pv:
del config["input_data"]["pv"]

if "gsp" in config["input_data"]:
gsp_config = config["input_data"]["gsp"]

# Replace the forecast minutes
gsp_config["forecast_minutes"] = model.forecast_minutes

with open(output_path, "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)



class PVNetModelHubMixin(PyTorchModelHubMixin):
"""
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
Expand Down Expand Up @@ -271,10 +269,10 @@ def save_pretrained(
# Save cleaned datapipes configuration file
if data_config is not None:
new_data_config_path = save_directory / DATA_CONFIG_NAME

# Replace the input filenames with place holders
make_clean_data_config(data_config, new_data_config_path)

# Taylor the data config to the model being saved
minimize_data_config(new_data_config_path, new_data_config_path, self)

Expand Down Expand Up @@ -649,7 +647,7 @@ def on_validation_epoch_end(self):
"""Run on epoch end"""

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
if isinstance(self.logger, pl.loggers.WandbLogger):
per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)]
Expand Down

0 comments on commit 2f74109

Please sign in to comment.