Skip to content

Commit

Permalink
Fix ci (#15)
Browse files Browse the repository at this point in the history
Fixes CI
  • Loading branch information
colon3ltocard authored Sep 11, 2024
1 parent 4c58184 commit 5510c3b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements_dev.txt
- name: Lint with ruff
pip install -r requirements_lint.txt
pip install --timeout 1000 pyg-lib==0.4.0 torch-scatter==2.1.2 torch-sparse==0.6.18 torch-cluster==1.6.2 torch-geometric==2.3.1 -f https://data.pyg.org/whl/torch-2.1.2+cpu.html
- name: Lint
run: |
ruff check py4cast
./lint.sh .
- name: Integration Test with pytest
run: |
export PY4CAST_ROOTDIR=`pwd`
coverage run -p -m pytest tests/
coverage run -p bin/train.py --model halfunet --model_conf config/models/halfunet32.json --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1
coverage run -p bin/train.py --model hilam --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1
coverage run -p bin/train.py --model halfunet --model_conf config/models/halfunet32.json --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 --num_workers 1
coverage run -p bin/train.py --model hilam --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 --num_workers 1
coverage combine
coverage report --fail-under=60
coverage report --ignore-errors --fail-under=60
28 changes: 14 additions & 14 deletions py4cast/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from py4cast.datasets.base import DatasetInfo, ItemBatch, NamedTensor
from py4cast.losses import ScaledLoss, WeightedLoss
from py4cast.metrics import MetricACC, MetricPSDK, MetricPSDVar, MetricRMSE
from py4cast.metrics import MetricPSDK, MetricPSDVar
from py4cast.models import build_model_from_settings, get_model_kls_and_settings
from py4cast.models.base import expand_to_batch
from py4cast.observer import (
Expand Down Expand Up @@ -455,9 +455,9 @@ def on_validation_start(self):
l1_loss.prepare(self, self.interior_mask, self.hparams["hparams"].dataset_info)
metrics = {"mae": l1_loss}
save_path = self.hparams["hparams"].save_path
#self.rmse_metric = MetricRMSE()
#self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)

# self.rmse_metric = MetricRMSE()
# self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)
self.valid_plotters = [
StateErrorPlot(metrics, prefix="Validation"),
PredictionTimestepPlot(
Expand Down Expand Up @@ -513,8 +513,8 @@ def validation_step(self, batch: ItemBatch, batch_idx):
plotter.update(self, prediction=prediction, target=target)
self.psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_psd_plot_metric.update(prediction, target, self.original_shape)
#self.rmse_metric.update(prediction, target)
#self.acc_metric.update(prediction, target)
# self.rmse_metric.update(prediction, target)
# self.acc_metric.update(prediction, target)

def on_validation_epoch_end(self):
"""
Expand All @@ -525,8 +525,8 @@ def on_validation_epoch_end(self):
dict_metrics = dict()
dict_metrics.update(self.psd_plot_metric.compute())
dict_metrics.update(self.rmse_psd_plot_metric.compute())
#dict_metrics.update(self.rmse_metric.compute())
#dict_metrics.update(self.acc_metric.compute())
# dict_metrics.update(self.rmse_metric.compute())
# dict_metrics.update(self.acc_metric.compute())
for name, elmnt in dict_metrics.items():
if isinstance(elmnt, matplotlib.figure.Figure):
self.logger.experiment.add_figure(f"{name}", elmnt, self.current_epoch)
Expand Down Expand Up @@ -564,8 +564,8 @@ def on_test_start(self):
max_pred_step = self.hparams["hparams"].num_pred_steps_val_test - 1
self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step)
self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step)
#self.rmse_metric = MetricRMSE()
#self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)
# self.rmse_metric = MetricRMSE()
# self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)
self.test_plotters = [
StateErrorPlot(metrics, save_path=save_path),
SpatialErrorPlot(),
Expand All @@ -589,8 +589,8 @@ def test_step(self, batch: ItemBatch, batch_idx):
plotter.update(self, prediction=prediction, target=target)
self.psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_psd_plot_metric.update(prediction, target, self.original_shape)
#self.rmse_metric.update(prediction, target)
#self.acc_metric.update(prediction, target)
# self.rmse_metric.update(prediction, target)
# self.acc_metric.update(prediction, target)

@cached_property
def interior_2d(self) -> torch.Tensor:
Expand All @@ -614,8 +614,8 @@ def on_test_epoch_end(self):
# and: https://github.com/Lightning-AI/pytorch-lightning/issues/18803
self.psd_plot_metric.compute()
self.rmse_psd_plot_metric.compute()
#self.rmse_metric.compute()
#self.acc_metric.compute()
# self.rmse_metric.compute()
# self.acc_metric.compute()
# Notify plotters that the test epoch end
for plotter in self.test_plotters:
plotter.on_step_end(self)
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ wandb>=0.13.10
matplotlib>=3.7.0
scipy>=1.10.0
pytorch-lightning>=2.1.2
lightning
lightning==2.2.2
shapely>=2.0.1
networkx>=3.0
Cartopy>=0.22.0
Expand All @@ -12,14 +12,14 @@ tueplots>=0.0.8
plotly>=5.15.0
tornado>=6.3.3
cython>=3
cfgrib
cfgrib==0.9.14.0
dataclasses-json==0.6.4
xarray
xarray==2024.7.0
argparse-dataclass==2.0.0
tensorboard==2.13.0
typer==0.9.0
netCDF4==1.6.5
tensorboard-plugin-profile
tensorboard-plugin-profile==2.17.0
torch-tb-profiler==0.4.1
einops==0.7.0
torchinfo==1.8.0
Expand All @@ -29,7 +29,7 @@ coverage==7.6.1
onnx==1.16.1
onnxruntime==1.18.1
onnxruntime-gpu==1.18.1
onnxscript
onnxscript==0.1.0.dev20240905
monai==1.3.1
gif==23.3.0
scikit-image==0.24.0
Expand Down

0 comments on commit 5510c3b

Please sign in to comment.