Skip to content

Commit

Permalink
Merge pull request #114 from openclimatefix/bug_fixes
Browse files Browse the repository at this point in the history
Two small bugs
  • Loading branch information
dfulu authored Dec 22, 2023
2 parents a2a4f7b + 298e10c commit 8dd0f33
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ def _calculate_val_losses(self, y, y_hat):
# Take median value for remaining metric calculations
y_hat = self._quantiles_to_prediction(y_hat)

common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy())
common_metrics_each_step = common_metrics(
predictions=y_hat.cpu().numpy(), target=y.cpu().numpy()
)
mse_each_step = common_metrics_each_step["rmse"] ** 2
mae_each_step = common_metrics_each_step["mae"]

Expand Down
2 changes: 2 additions & 0 deletions pvnet/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def train(config: DictConfig) -> Optional[float]:
for callback in callbacks:
log.info(f"{callback}")
if isinstance(callback, ModelCheckpoint):
# Need to call the .experiment property to initialise the logger
wandb_logger.experiment
callback.dirpath = "/".join(
callback.dirpath.split("/")[:-1] + [wandb_logger.version]
)
Expand Down

0 comments on commit 8dd0f33

Please sign in to comment.