Skip to content

Commit

Permalink
Temporary hotfix to be able to run xps
Browse files Browse the repository at this point in the history
  • Loading branch information
colon3ltocard committed Sep 10, 2024
1 parent 78f971b commit b3302a7
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions py4cast/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def __init__(self, hparams: ArLightningHyperParam, *args, **kwargs):

self.loss.prepare(self, statics.interior_mask, hparams.dataset_info)

save_path = self.hparams["hparams"].save_path
max_pred_step = self.hparams["hparams"].num_pred_steps_val_test - 1
self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step)
self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step)

@rank_zero_only
def log_hparams_tb(self):
if self.logger:
Expand Down Expand Up @@ -450,11 +455,9 @@ def on_validation_start(self):
l1_loss.prepare(self, self.interior_mask, self.hparams["hparams"].dataset_info)
metrics = {"mae": l1_loss}
save_path = self.hparams["hparams"].save_path
max_pred_step = self.hparams["hparams"].num_pred_steps_val_test - 1
self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step)
self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step)
self.rmse_metric = MetricRMSE()
self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)

#self.rmse_metric = MetricRMSE()
#self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)
self.valid_plotters = [
StateErrorPlot(metrics, prefix="Validation"),
PredictionTimestepPlot(
Expand Down Expand Up @@ -510,8 +513,8 @@ def validation_step(self, batch: ItemBatch, batch_idx):
plotter.update(self, prediction=prediction, target=target)
self.psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_metric.update(prediction, target)
self.acc_metric.update(prediction, target)
#self.rmse_metric.update(prediction, target)
#self.acc_metric.update(prediction, target)

def on_validation_epoch_end(self):
"""
Expand All @@ -522,8 +525,8 @@ def on_validation_epoch_end(self):
dict_metrics = dict()
dict_metrics.update(self.psd_plot_metric.compute())
dict_metrics.update(self.rmse_psd_plot_metric.compute())
dict_metrics.update(self.rmse_metric.compute())
dict_metrics.update(self.acc_metric.compute())
#dict_metrics.update(self.rmse_metric.compute())
#dict_metrics.update(self.acc_metric.compute())
for name, elmnt in dict_metrics.items():
if isinstance(elmnt, matplotlib.figure.Figure):
self.logger.experiment.add_figure(f"{name}", elmnt, self.current_epoch)
Expand Down Expand Up @@ -561,8 +564,8 @@ def on_test_start(self):
max_pred_step = self.hparams["hparams"].num_pred_steps_val_test - 1
self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step)
self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step)
self.rmse_metric = MetricRMSE()
self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)
#self.rmse_metric = MetricRMSE()
#self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)
self.test_plotters = [
StateErrorPlot(metrics, save_path=save_path),
SpatialErrorPlot(),
Expand All @@ -586,8 +589,8 @@ def test_step(self, batch: ItemBatch, batch_idx):
plotter.update(self, prediction=prediction, target=target)
self.psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_metric.update(prediction, target)
self.acc_metric.update(prediction, target)
#self.rmse_metric.update(prediction, target)
#self.acc_metric.update(prediction, target)

@cached_property
def interior_2d(self) -> torch.Tensor:
Expand All @@ -611,8 +614,8 @@ def on_test_epoch_end(self):
# and: https://github.com/Lightning-AI/pytorch-lightning/issues/18803
self.psd_plot_metric.compute()
self.rmse_psd_plot_metric.compute()
self.rmse_metric.compute()
self.acc_metric.compute()
#self.rmse_metric.compute()
#self.acc_metric.compute()
# Notify plotters that the test epoch end
for plotter in self.test_plotters:
plotter.on_step_end(self)

0 comments on commit b3302a7

Please sign in to comment.