From 944787f49734ca7d87d93513b0582f7521377029 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Tue, 18 Jun 2024 09:26:41 +0000 Subject: [PATCH] minimal updates --- pvnet_summation/models/base_model.py | 7 +++++++ pvnet_summation/training.py | 2 ++ requirements.txt | 4 ++-- scripts/checkpoint_to_huggingface.py | 9 +++++---- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index 50e23a1..caec670 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -91,6 +91,8 @@ def __init__( ) else: self.pvnet_output_shape = (317, self.pvnet_model.forecast_len) + + self.use_weighted_loss = False def predict_pvnet_batch(self, batch): """Use PVNet model to create predictions for batch""" @@ -184,6 +186,11 @@ def validation_step(self, batch: dict, batch_idx): losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) + + # Store these to make horizon accuracy plot + self._horizon_maes.append( + {i: losses[f"MAE_horizon/step_{i:03}"].cpu().numpy() for i in range(self.forecast_len)} + ) logged_losses = {f"{k}/val": v for k, v in losses.items()} diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index 73bcd2a..5c723fa 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -149,6 +149,8 @@ def train(config: DictConfig) -> Optional[float]: for callback in callbacks: log.info(f"{callback}") if isinstance(callback, ModelCheckpoint): + # Need to call the .experiment property to initialise the logger + wandb_logger.experiment callback.dirpath = "/".join( callback.dirpath.split("/")[:-1] + [wandb_logger.version] ) diff --git a/requirements.txt b/requirements.txt index c31f661..6436837 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -ocf_datapipes>=3.3.19 -pvnet>=3.0.25 +ocf_datapipes>=3.3.33 +pvnet>=3.0.45 numpy pandas matplotlib diff --git a/scripts/checkpoint_to_huggingface.py b/scripts/checkpoint_to_huggingface.py index 0a5c09c..9004762 100644 --- a/scripts/checkpoint_to_huggingface.py +++ b/scripts/checkpoint_to_huggingface.py @@ -1,5 +1,5 @@ """Command line tool to push locally save model checkpoints to huggingface - + use: python checkpoint_to_huggingface.py "path/to/model/checkpoints" \ --local-path="~/tmp/this_model" \ @@ -56,9 +56,9 @@ def push_to_huggingface( # Only one epoch (best) saved per model files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt") assert len(files) == 1 - checkpoint = torch.load(files[0]) + checkpoint = torch.load(files[0], map_location="cpu") else: - checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt") + checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu") model.load_state_dict(state_dict=checkpoint["state_dict"]) @@ -72,7 +72,8 @@ def push_to_huggingface( model.save_pretrained( model_output_dir, config=model_config, - wandb_model_code=wandb_id, + data_config=None, + wandb_ids=wandb_id, push_to_hub=push_to_hub, repo_id="openclimatefix/pvnet_v2_summation" if push_to_hub else None, card_template_path=(