diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 78cec83d..faa2c942 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,14 +28,17 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install -r requirements_dev.txt - - name: Lint with ruff + pip install -r requirements_lint.txt + pip install --timeout 1000 pyg-lib==0.4.0 torch-scatter==2.1.2 torch-sparse==0.6.18 torch-cluster==1.6.2 torch-geometric==2.3.1 -f https://data.pyg.org/whl/torch-2.1.2+cpu.html + + - name: Lint run: | - ruff check py4cast + ./lint.sh . - name: Integration Test with pytest run: | + export PY4CAST_ROOTDIR=`pwd` coverage run -p -m pytest tests/ - coverage run -p bin/train.py --model halfunet --model_conf config/models/halfunet32.json --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 - coverage run -p bin/train.py --model hilam --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 + coverage run -p bin/train.py --model halfunet --model_conf config/models/halfunet32.json --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 --num_workers 1 + coverage run -p bin/train.py --model hilam --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 --num_workers 1 coverage combine - coverage report --fail-under=60 + coverage report --ignore-errors --fail-under=60 diff --git a/py4cast/lightning.py b/py4cast/lightning.py index 08284a0d..a324337f 100644 --- a/py4cast/lightning.py +++ b/py4cast/lightning.py @@ -17,7 +17,7 @@ from py4cast.datasets.base import DatasetInfo, ItemBatch, NamedTensor from py4cast.losses import ScaledLoss, WeightedLoss -from py4cast.metrics import MetricACC, MetricPSDK, MetricPSDVar, MetricRMSE +from py4cast.metrics import MetricPSDK, MetricPSDVar from py4cast.models import build_model_from_settings, get_model_kls_and_settings from py4cast.models.base import expand_to_batch from py4cast.observer import ( @@ -455,9 +455,9 @@ def on_validation_start(self): l1_loss.prepare(self, self.interior_mask, self.hparams["hparams"].dataset_info) metrics = {"mae": l1_loss} save_path = self.hparams["hparams"].save_path - - #self.rmse_metric = MetricRMSE() - #self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info) + + # self.rmse_metric = MetricRMSE() + # self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info) self.valid_plotters = [ StateErrorPlot(metrics, prefix="Validation"), PredictionTimestepPlot( @@ -513,8 +513,8 @@ def validation_step(self, batch: ItemBatch, batch_idx): plotter.update(self, prediction=prediction, target=target) self.psd_plot_metric.update(prediction, target, self.original_shape) self.rmse_psd_plot_metric.update(prediction, target, self.original_shape) - #self.rmse_metric.update(prediction, target) - #self.acc_metric.update(prediction, target) + # self.rmse_metric.update(prediction, target) + # self.acc_metric.update(prediction, target) def on_validation_epoch_end(self): """ @@ -525,8 +525,8 @@ def on_validation_epoch_end(self): dict_metrics = dict() dict_metrics.update(self.psd_plot_metric.compute()) dict_metrics.update(self.rmse_psd_plot_metric.compute()) - #dict_metrics.update(self.rmse_metric.compute()) - #dict_metrics.update(self.acc_metric.compute()) + # dict_metrics.update(self.rmse_metric.compute()) + # dict_metrics.update(self.acc_metric.compute()) for name, elmnt in dict_metrics.items(): if isinstance(elmnt, matplotlib.figure.Figure): self.logger.experiment.add_figure(f"{name}", elmnt, self.current_epoch) @@ -564,8 +564,8 @@ def on_test_start(self): max_pred_step = self.hparams["hparams"].num_pred_steps_val_test - 1 self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step) self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step) - #self.rmse_metric = MetricRMSE() - #self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info) + # self.rmse_metric = MetricRMSE() + # self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info) self.test_plotters = [ StateErrorPlot(metrics, save_path=save_path), SpatialErrorPlot(), @@ -589,8 +589,8 @@ def test_step(self, batch: ItemBatch, batch_idx): plotter.update(self, prediction=prediction, target=target) self.psd_plot_metric.update(prediction, target, self.original_shape) self.rmse_psd_plot_metric.update(prediction, target, self.original_shape) - #self.rmse_metric.update(prediction, target) - #self.acc_metric.update(prediction, target) + # self.rmse_metric.update(prediction, target) + # self.acc_metric.update(prediction, target) @cached_property def interior_2d(self) -> torch.Tensor: @@ -614,8 +614,8 @@ def on_test_epoch_end(self): # and: https://github.com/Lightning-AI/pytorch-lightning/issues/18803 self.psd_plot_metric.compute() self.rmse_psd_plot_metric.compute() - #self.rmse_metric.compute() - #self.acc_metric.compute() + # self.rmse_metric.compute() + # self.acc_metric.compute() # Notify plotters that the test epoch end for plotter in self.test_plotters: plotter.on_step_end(self) diff --git a/requirements.txt b/requirements.txt index 39abc173..0e07e858 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ wandb>=0.13.10 matplotlib>=3.7.0 scipy>=1.10.0 pytorch-lightning>=2.1.2 -lightning +lightning==2.2.2 shapely>=2.0.1 networkx>=3.0 Cartopy>=0.22.0 @@ -12,14 +12,14 @@ tueplots>=0.0.8 plotly>=5.15.0 tornado>=6.3.3 cython>=3 -cfgrib +cfgrib==0.9.14.0 dataclasses-json==0.6.4 -xarray +xarray==2024.7.0 argparse-dataclass==2.0.0 tensorboard==2.13.0 typer==0.9.0 netCDF4==1.6.5 -tensorboard-plugin-profile +tensorboard-plugin-profile==2.17.0 torch-tb-profiler==0.4.1 einops==0.7.0 torchinfo==1.8.0 @@ -29,7 +29,7 @@ coverage==7.6.1 onnx==1.16.1 onnxruntime==1.18.1 onnxruntime-gpu==1.18.1 -onnxscript +onnxscript==0.1.0.dev20240905 monai==1.3.1 gif==23.3.0 scikit-image==0.24.0