diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index 6a32045e..00000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,39 +0,0 @@ -stages: - - lint - - test - -linting: - stage: lint - image: python:3.8.16 - script: - - pip install --trusted-host pypi.org --trusted-host files.pythonhosted.org -r requirements_lint.txt - - ./lint.sh . - -complexity: - stage: lint - image: python:3.9.16 - script: - - pip install --trusted-host pypi.org --trusted-host files.pythonhosted.org sourcery-analytics - - sourcery-analytics assess . - -unit_test: - stage: test - image: pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime - script: - - pip install --timeout 1000 --trusted-host pypi.org --trusted-host files.pythonhosted.org -r requirements.txt - - pip install --timeout 1000 --trusted-host data.pyg.org --trusted-host pypi.org --trusted-host files.pythonhosted.org pyg-lib==0.4.0 torch-scatter==2.1.2 torch-sparse==0.6.18 torch-cluster==1.6.2 torch-geometric==2.3.1 -f https://data.pyg.org/whl/torch-2.1.2+cpu.html - - python -m pytest - -integration_test: - image: pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime - tags: - - dind - variables: - PYTHONPATH: "${PYTHONPATH}:${CI_PROJECT_DIR}" - stage: test - script: - - pip install --timeout 1000 --trusted-host pypi.org --trusted-host files.pythonhosted.org -r requirements.txt - - pip install --timeout 1000 --trusted-host data.pyg.org --trusted-host pypi.org --trusted-host files.pythonhosted.org pyg-lib==0.4.0 torch-scatter==2.1.2 torch-sparse==0.6.18 torch-cluster==1.6.2 torch-geometric==2.3.1 -f https://data.pyg.org/whl/torch-2.1.2+cpu.html - - python bin/train.py --model halfunet --model_conf config/models/halfunet32.json --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 - - python bin/train.py --model hilam --dataset dummy --epochs 1 --batch_size 1 --num_pred_steps_train 1 --limit_train_batches 1 - diff --git a/README.md b/README.md index dbaa5bd1..b68b0b38 100644 --- a/README.md +++ b/README.md @@ -304,23 +304,28 @@ You can find more details about all the `num_X_steps` options [here](doc/num_ste ### Inference -Inference is done by running the `bin/inference.py` script. This script will load a model and run it on a dataset using the training parameters (dataset name, dataset config, timestep options, ...). +Inference is done by running the `bin/inference.py` script. This script will load a model and run it on a dataset using the training parameters (dataset config, timestep options, ...). ```bash -usage: py4cast Inference script [-h] [--model_path MODEL_PATH] [--dataset DATASET] [--ds_config_file DS_CONFIG_FILE] [--date DATE] +usage: py4cast Inference script [-h] [--model_path MODEL_PATH] [--dataset DATASET] [--infer_steps INFER_STEPS] [--date DATE] options: -h, --help show this help message and exit --model_path MODEL_PATH Path to the model checkpoint --date DATE + Date of the sample to infer on. Format:YYYYMMDDHH + --dataset DATASET + Name of the dataset config file to use + --infer_steps INFER_STEPS + Number of auto-regressive steps/prediction steps during the inference ``` A simple example of inference is shown below: ```bash - runai exec_gpu python bin/inference.py --model_path /scratch/shared/py4cast/logs/comparison/titan/swinunetr/bert_swinunetr_0/last.ckpt --date 2023123123 + runai exec_gpu python bin/inference.py --model_path /scratch/shared/py4cast/logs/camp0/poesy/halfunet/sezn_run_dev_9/last.ckpt --date 2021061621 --dataset poesy_infer --infer_steps 2 ``` diff --git a/bin/inference.py b/bin/inference.py index d15f77a2..7e36657b 100644 --- a/bin/inference.py +++ b/bin/inference.py @@ -12,22 +12,30 @@ parser = argparse.ArgumentParser("py4cast Inference script") parser.add_argument("--model_path", type=str, help="Path to the model checkpoint") parser.add_argument("--date", type=str, help="Date for inference", default=None) + parser.add_argument( + "--dataset", type=str, help="Dataset used in inference", default="poesy_infer" + ) + parser.add_argument( + "--infer_steps", type=int, help="Number of inference steps", default=1 + ) args = parser.parse_args() # Load checkpoint lightning_module = AutoRegressiveLightning.load_from_checkpoint(args.model_path) + hparams = lightning_module.hparams["hparams"] if args.date is not None: - config_override = {"periods": {"test": {"start": args.date, "end": args.date}}} + config_override = { + "periods": {"test": {"start": args.date, "end": args.date}}, + "num_inference_pred_steps": args.infer_steps, + } else: - config_override = None - - hparams = lightning_module.hparams["hparams"] + config_override = {"num_inference_pred_steps": args.infer_steps} # Get dataset for inference - _, _, test_ds = get_datasets( - hparams.dataset_name, + _, _, infer_ds = get_datasets( + args.dataset, hparams.num_input_steps, hparams.num_pred_steps_train, hparams.num_pred_steps_val_test, @@ -36,7 +44,8 @@ ) # Transform in dataloader - dl_settings = TorchDataloaderSettings(batch_size=2) - infer_loader = test_ds.torch_dataloader(dl_settings) + + dl_settings = TorchDataloaderSettings(batch_size=1) + infer_loader = infer_ds.torch_dataloader(dl_settings) trainer = Trainer(devices="auto") preds = trainer.predict(lightning_module, infer_loader) diff --git a/config/datasets/poesy_infer.json b/config/datasets/poesy_infer.json new file mode 100644 index 00000000..96e56120 --- /dev/null +++ b/config/datasets/poesy_infer.json @@ -0,0 +1,65 @@ +{ + "periods": { + "train": { + "start": 2020061521, + "end": 2021061521, + "step": 24 + }, + "test": { + "start": 2021061621, + "end": 2021071621, + "step": 24 + }, + "valid": { + "start": 2021071721, + "end": 2021111521, + "step": 24 + } + }, + "grid":{ + "geometry":"EURW1S40", + "border_size":10, + "domain":"france", + "model":"arome", + "subgrid":[50,178,50,178] + }, + "settings":{ + "step_duration": 1, + "standardize": true, + "file_format": "npy" + }, + "dataset": { + "arome": { + "members": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16 + ], + "term": { + "start": 0, + "end": 44, + "timestep": 1 + }, + "var": { + "TEMPERATURE_2m": {"shortname":"t2m", "level": [1]}, + "PRECIPITATIONS_DECUM":{"shortname":"rrdecum","level": [1]}, + "WIND.U.PHYS":{"shortname":"u","level": [1]}, + "WIND.V.PHYS":{"shortname":"v","level": [1]} + } + } + } +} diff --git a/config/models/unetrpp8512_linear_up_d2.json b/config/models/unetrpp8512_linear_up_d2.json new file mode 100644 index 00000000..95c25dcc --- /dev/null +++ b/config/models/unetrpp8512_linear_up_d2.json @@ -0,0 +1,7 @@ +{ +"num_heads": 8, +"hidden_size": 512, +"linear_upsampling": true, +"downsampling_rate": 2 +} + diff --git a/doc/features.md b/doc/features.md index 1563ccdf..8a20278f 100644 --- a/doc/features.md +++ b/doc/features.md @@ -20,7 +20,7 @@ Currently we support the following neural network architectures: | [segformer](../py4cast/models/vision/transformers.py#L216) | [arxiv link](https://arxiv.org/abs/2105.15203) | (Batch, Height, Width, features) | On par with u-net like on Deepsyg (MF internal), added an upsampling stage. Adapted from [Lucidrains' github](https://github.com/lucidrains/segformer-pytorch) | Frank Guibert | | [swinunetr](../py4cast/models/vision/transformers.py#L335) | [arxiv link](https://arxiv.org/abs/2201.01266) | (Batch, Height, Width, features) | 2D Swin Unet transformer (Pangu and archweather uses customised 3D versions of Swin Transformers). Plugged in from [MONAI](https://github.com/Project-MONAI/MONAI/). The decoders have been modified to use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Frank Guibert | | [hilam](../py4cast/models/nlam/models.py#L754), graphlam | [arxiv link](https://arxiv.org/abs/2309.17370) | (Batch, graph_node_id, features) | Imported and adapted from [Joel's github](https://github.com/joeloskarsson/neural-lam) | Vincent Chabot/Frank Guibert | -| [unetrpp](../py4cast/models/vision/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs | Frank Guibert | +| [unetrpp](../py4cast/models/vision/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs. Changed Upsampling to use linear upsampling. Made stem layer downsampling rate a parameter. | Frank Guibert | ## Available datasets diff --git a/py4cast/datasets/__init__.py b/py4cast/datasets/__init__.py index eff9ea1e..219b3b39 100644 --- a/py4cast/datasets/__init__.py +++ b/py4cast/datasets/__init__.py @@ -40,6 +40,16 @@ except ImportError: warnings.warn(f"Could not import PoesyDataset. {traceback.format_exc()}") +try: + from .poesy import InferPoesyDataset + + registry["poesy_infer"] = ( + InferPoesyDataset, + default_config_root / "poesy_infer.json", + ) +except ImportError: + warnings.warn(f"Could not import InferPoesyDataset. {traceback.format_exc()}") + try: from .dummy import DummyDataset diff --git a/py4cast/datasets/base.py b/py4cast/datasets/base.py index d171c29f..fcdc5319 100644 --- a/py4cast/datasets/base.py +++ b/py4cast/datasets/base.py @@ -305,7 +305,7 @@ class Item: inputs: NamedTensor forcing: NamedTensor - outputs: Union[None, NamedTensor] = None + outputs: NamedTensor def __post_init__(self): """ diff --git a/py4cast/datasets/poesy.py b/py4cast/datasets/poesy.py index 9eeea7a1..7c549266 100644 --- a/py4cast/datasets/poesy.py +++ b/py4cast/datasets/poesy.py @@ -26,6 +26,7 @@ ) from py4cast.plots import DomainInfo from py4cast.settings import CACHE_DIR +from py4cast.utils import merge_dicts SCRATCH_PATH = Path("/scratch/shared/poesy/poesy_crop") OROGRAPHY_FNAME = "PEARO_EURW1S40_Orography_crop.npy" @@ -175,7 +176,7 @@ class Param: # It is not necessarly the same as the model grid. # Function which can return the filenames. # It should accept member and date as argument (as well as term). - fnamer: Callable[[], [str]] + fnamer: Callable[[], [str]] # VSCode doesn't like this, is it ok ? level_type: str = "hPa" # To be read in nc file ? kind: Literal["input", "output", "input_output"] = "input_output" unit: str = "FakeUnit" # To be read in nc FIle ? @@ -240,7 +241,8 @@ def exist(self, date: dt.datetime) -> bool: class PoesySettings: term: dict num_input_steps: int # = 2 # Number of input timesteps - num_pred_steps: int # = 1 # Number of output timesteps + num_output_steps: int # = 1 # Number of output timesteps (= 0 for inference) + num_inference_pred_steps: int = 0 # 0 in training config ; else used to provide future information about forcings standardize: bool = False members: Tuple[int] = (0,) @@ -251,7 +253,7 @@ def num_total_steps(self) -> int: for one sample. """ # Nb of step in one sample - return self.num_input_steps + self.num_pred_steps + return self.num_input_steps + self.num_output_steps @dataclass(slots=True) @@ -312,8 +314,17 @@ def is_valid(self, param_list: List) -> bool: if not param.exist(self.date): return False - else: - return True + + return True + + +class InferSample(Sample): + """ + Sample dedicated to inference. No outputs terms, only inputs. + """ + + def __post_init__(self): + self.terms = self.input_terms class PoesyDataset(DatasetABC, Dataset): @@ -403,10 +414,10 @@ def sample_list(self): + self.settings.num_input_steps : sample * self.settings.num_total_steps + self.settings.num_input_steps - + self.settings.num_pred_steps + + self.settings.num_output_steps ] samp = Sample( - settings=PoesySettings, + settings=self.settings, date=date, member=member, input_terms=input_terms, @@ -495,7 +506,12 @@ def diagnostic_dim(self): return res def get_param_tensor( - self, param: Param, date: dt.datetime, terms: List, member: int = 1 + self, + param: Param, + date: dt.datetime, + terms: List, + member: int = 1, + inference_steps: int = 0, ) -> torch.tensor: if self.settings.standardize: @@ -514,6 +530,10 @@ def get_param_tensor( # Define which value is considered invalid tensor_data = torch.from_numpy(array) + + if inference_steps: + empty_data = torch.empty((inference_steps, *array.shape[1:])) + tensor_data = torch.cat((tensor_data, empty_data), dim=0) return tensor_data def __getitem__(self, index): @@ -553,12 +573,15 @@ def __getitem__(self, index): if param.kind == "input_output": # Search data for date sample.date and terms sample.terms tensor = self.get_param_tensor( - param, sample.date, terms=sample.terms + param, + sample.date, + terms=sample.terms, + inference_steps=self.settings.num_inference_pred_steps, ) state_kwargs["names"][0] = "timestep" # Save outputs tmp_state = NamedTensor( - tensor=tensor[-self.settings.num_pred_steps :], + tensor=tensor[self.settings.num_input_steps :], **deepcopy(state_kwargs), ) loutputs.append(tmp_state) @@ -627,7 +650,7 @@ def from_json( PoesySettings( members=members, term=term, - num_pred_steps=num_pred_steps_train, + num_output_steps=num_pred_steps_train, num_input_steps=num_input_steps, ), ) @@ -638,7 +661,7 @@ def from_json( PoesySettings( members=members, term=term, - num_pred_steps=num_pred_steps_val_test, + num_output_steps=num_pred_steps_val_test, num_input_steps=num_input_steps, ), ) @@ -649,7 +672,7 @@ def from_json( PoesySettings( members=members, term=term, - num_pred_steps=num_pred_steps_val_test, + num_output_steps=num_pred_steps_val_test, num_input_steps=num_input_steps, ), ) @@ -811,6 +834,128 @@ def prepare(cls, path_config: Path): return train_ds +class InferPoesyDataset(PoesyDataset): + """ + Inherite from the PoesyDataset class. + This class is used for inference, the class overrides methods sample_list and from_json. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @cached_property + def sample_list(self): + """ + Create a list of sample from information. + Outputs terms are computed from the number of prediction steps wanted by the user. + """ + print("Start forming samples") + terms = list( + np.arange( + self.settings.term["start"], + self.settings.term["end"], + self.settings.term["timestep"], + ) + ) + + sample_by_date = len(terms) // self.settings.num_total_steps + + samples = [] + number = 0 + + for date in self.period.date_list: + for member in self.settings.members: + for sample in range(0, sample_by_date): + + input_terms = terms[ + sample + * self.settings.num_total_steps : sample + * self.settings.num_total_steps + + self.settings.num_input_steps + ] + + output_terms = [ + input_terms[-1] + self.settings.term["timestep"] * (step + 1) + for step in range(self.settings.num_inference_pred_steps) + ] + + samp = InferSample( + settings=self.settings, + date=date, + member=member, + input_terms=input_terms, + output_terms=output_terms, + ) + + if samp.is_valid(self.params): + + samples.append(samp) + number += 1 + print("All samples are now defined") + + return samples + + @classmethod + def from_json( + cls, + fname: Path, + num_input_steps: int, + num_pred_steps_train: int, + num_pred_steps_val_tests: int, + config_override: Union[Dict, None] = None, + ) -> Tuple[None, None, "InferPoesyDataset"]: + """ + Return 1 InferPoesyDataset. + Override configuration file if needed. + """ + with open(fname, "r") as fp: + conf = json.load(fp) + if config_override is not None: + conf = merge_dicts(conf, config_override) + + grid = Grid(**conf["grid"]) + param_list = [] + for data_source in conf["dataset"]: + data = conf["dataset"][data_source] + members = conf["dataset"][data_source].get("members", [0]) + term = conf["dataset"][data_source]["term"] + for var in data["var"]: + vard = data["var"][var] + # Change grid definition + if "level" in vard: + level_type = "hPa" + else: + level_type = "m" + param = Param( + name=var, + shortname=vard.pop("shortname", "t2m"), + levels=vard.pop("level", [0]), + grid=grid, + level_type=level_type, + fnamer=poesy_forecast_namer, + **vard, + ) + param_list.append(param) + inference_period = ( + Period(**conf["periods"]["test"], name="infer") + if not config_override + else Period(**config_override["periods"]["test"], name="infer") + ) + ds = InferPoesyDataset( + grid, + inference_period, + param_list, + PoesySettings( + members=members, + term=term, + num_input_steps=num_input_steps, + num_output_steps=0, + num_inference_pred_steps=config_override["num_inference_pred_steps"], + ), + ) + return None, None, ds + + if __name__ == "__main__": path_config = "config/datasets/poesy.json" diff --git a/py4cast/lightning.py b/py4cast/lightning.py index 92486ca0..0ed55037 100644 --- a/py4cast/lightning.py +++ b/py4cast/lightning.py @@ -454,6 +454,8 @@ def on_validation_start(self): l1_loss.prepare(self, self.interior_mask, self.hparams["hparams"].dataset_info) metrics = {"mae": l1_loss} save_path = self.hparams["hparams"].save_path + self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info) + self.valid_plotters = [ StateErrorPlot(metrics, prefix="Validation"), PredictionTimestepPlot( @@ -559,6 +561,7 @@ def on_test_start(self): self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step) self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step) self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info) + self.test_plotters = [ StateErrorPlot(metrics, save_path=save_path), SpatialErrorPlot(), @@ -579,9 +582,10 @@ def test_step(self, batch: ItemBatch, batch_idx): # Notify plotters & metrics for plotter in self.test_plotters: plotter.update(self, prediction=prediction, target=target) - self.acc_metric.update(prediction, target) + self.psd_plot_metric.update(prediction, target, self.original_shape) self.rmse_psd_plot_metric.update(prediction, target, self.original_shape) + self.acc_metric.update(prediction, target) @cached_property def interior_2d(self) -> torch.Tensor: @@ -606,6 +610,7 @@ def on_test_epoch_end(self): self.psd_plot_metric.compute() self.rmse_psd_plot_metric.compute() self.acc_metric.compute() + # Notify plotters that the test epoch end for plotter in self.test_plotters: plotter.on_step_end(self, label="Test") diff --git a/py4cast/models/vision/unetrpp.py b/py4cast/models/vision/unetrpp.py index 7328e067..4492af62 100644 --- a/py4cast/models/vision/unetrpp.py +++ b/py4cast/models/vision/unetrpp.py @@ -304,7 +304,7 @@ def __init__( in_channels=4, dropout=0.0, transformer_dropout_rate=0.1, - **kwargs, + downsampling_rate: int = 4, ): super().__init__() @@ -316,8 +316,8 @@ def __init__( spatial_dims, in_channels, dims[0], - kernel_size=4, - stride=4, + kernel_size=downsampling_rate, + stride=downsampling_rate, dropout=dropout, conv_only=True, ), @@ -530,6 +530,7 @@ class UNETRPPSettings: do_ds = False spatial_dims = 2 linear_upsampling: bool = False + downsampling_rate: int = 4 class UNETRPP(ModelABC, nn.Module): @@ -578,19 +579,27 @@ def __init__( raise KeyError( f"Position embedding layer of type {settings.pos_embed} is not supported." ) - + # we have first a stem layer with stride=subsampling_rate and k_size=subsampling_rate + # followed by 3 successive downsampling layer (k=2, stride=2) + dim_divider = (2**3) * settings.downsampling_rate if settings.spatial_dims == 2: - self.feat_size = (input_shape[0] // 32, input_shape[1] // 32) + self.feat_size = ( + input_shape[0] // dim_divider, + input_shape[1] // dim_divider, + ) else: self.feat_size = ( - input_shape[0] // 32, - input_shape[1] // 32, - input_shape[2] // 32, + input_shape[0] // dim_divider, + input_shape[1] // dim_divider, + input_shape[2] // dim_divider, ) self.hidden_size = settings.hidden_size self.spatial_dims = settings.spatial_dims - no_pixels = (input_shape[0] * input_shape[1]) // 16 + # Number of pixels after stem layer + no_pixels = (input_shape[0] * input_shape[1]) // ( + settings.downsampling_rate**2 + ) encoder_input_size = [ no_pixels, no_pixels // 4, @@ -598,13 +607,21 @@ def __init__( no_pixels // 64, ] h_size = settings.hidden_size + self.unetr_pp_encoder = UnetrPPEncoder( input_size=encoder_input_size, - dims=(h_size // 8, h_size // 4, h_size // 2, h_size), + dims=( + h_size // 8, + h_size // 4, + h_size // 2, + h_size, + ), + proj_size=[64, 64, 64, 32], depths=settings.depths, num_heads=settings.num_heads, spatial_dims=settings.spatial_dims, in_channels=num_input_features, + downsampling_rate=settings.downsampling_rate, ) self.encoder1 = UnetResBlock( @@ -650,9 +667,9 @@ def __init__( in_channels=settings.hidden_size // 8, out_channels=settings.hidden_size // 16, kernel_size=3, - upsample_kernel_size=4, + upsample_kernel_size=settings.downsampling_rate, norm_name=settings.norm_name, - out_size=no_pixels * 16, + out_size=no_pixels * (settings.downsampling_rate**2), conv_decoder=True, linear_upsampling=settings.linear_upsampling, ) diff --git a/py4cast/observer.py b/py4cast/observer.py index eb4c6e5f..32d0d1c8 100644 --- a/py4cast/observer.py +++ b/py4cast/observer.py @@ -29,14 +29,39 @@ def gather( Be careful if you are doing something else than plotting results. """ if isinstance(obj.trainer.strategy, SingleDeviceStrategy): - loss_tensor = obj.trainer.strategy.all_gather(tensor_to_gather) + loss_tensor = obj.trainer.strategy.all_gather(tensor_to_gather).cpu() elif isinstance(obj.trainer.strategy, ParallelStrategy): - loss_tensor = obj.trainer.strategy.all_gather(tensor_to_gather).flatten(0, 1) + loss_tensor = ( + obj.trainer.strategy.all_gather(tensor_to_gather).flatten(0, 1).cpu() + ) else: raise TypeError( f"Unknwon type {obj.trainer.strategy}. Know {SingleDeviceStrategy} and {ParallelStrategy}" ) + return loss_tensor + +def reduce( + obj: "AutoRegressiveLightning", tensor_to_reduce: torch.Tensor +) -> torch.Tensor: + """ + Send a tensor with the same dimension whether we are in Paralll or SingleDevice strategy. + Be careful if you are doing something else than plotting results. + """ + if isinstance(obj.trainer.strategy, SingleDeviceStrategy): + loss_tensor = obj.trainer.strategy.reduce( + tensor_to_reduce, reduce_op="mean" + ).cpu() + elif isinstance(obj.trainer.strategy, ParallelStrategy): + loss_tensor = ( + obj.trainer.strategy.reduce(tensor_to_reduce, reduce_op="mean") + .flatten(0, 1) + .cpu() + ) + else: + raise TypeError( + f"Unknwon type {obj.trainer.strategy}. Know {SingleDeviceStrategy} and {ParallelStrategy}" + ) return loss_tensor @@ -327,7 +352,9 @@ def update( Compute the metric. Append to a dictionnary """ for name in self.metrics: - self.losses[name].append(self.metrics[name](prediction, target).cpu()) + self.losses[name].append( + reduce(obj, self.metrics[name](prediction, target).unsqueeze(0)) + ) if not self.initialized: self.shortnames = prediction.feature_names self.units = [ @@ -341,9 +368,9 @@ def on_step_end(self, obj: "AutoRegressiveLightning", label: str = "") -> None: Make the summary figure """ tensorboard = obj.logger.experiment - for name in self.metrics: - loss_tensor = gather(obj, torch.cat(self.losses[name], dim=0)) - if obj.trainer.is_global_zero: + if obj.trainer.is_global_zero: + for name in self.metrics: + loss_tensor = torch.cat(self.losses[name], dim=0) loss = torch.mean(loss_tensor, dim=0) # Log metrics in tensorboard, with x axis as forecast timestep @@ -369,8 +396,8 @@ def on_step_end(self, obj: "AutoRegressiveLightning", label: str = "") -> None: dest_file.parent.mkdir(exist_ok=True) fig.savefig(dest_file) plt.close(fig) - # Free memory - [self.losses[name].clear() for name in self.metrics] + # Free memory + [self.losses[name].clear() for name in self.metrics] class SpatialErrorPlot(ErrorObserver): @@ -388,23 +415,23 @@ def update( prediction: NamedTensor, target: NamedTensor, ) -> None: - spatial_loss = obj.loss(prediction, target, reduce_spatial_dim=False).cpu() + spatial_loss = obj.loss(prediction, target, reduce_spatial_dim=False) # Getting only spatial loss for the required val_step_errors if prediction.num_spatial_dims == 1: spatial_loss = einops.rearrange( spatial_loss, "b t (x y) -> b t x y ", x=obj.grid_shape[0] ) - self.spatial_loss_maps.append(spatial_loss) # (B, N_log, N_lat, N_lon) + self.spatial_loss_maps.append( + reduce(obj, spatial_loss).unsqueeze(0) + ) # (B, N_log, N_lat, N_lon) def on_step_end(self, obj: "AutoRegressiveLightning", label: str = "") -> None: """ Make the summary figure """ - spatial_loss_tensor = gather( - obj, torch.cat(self.spatial_loss_maps, dim=0) - ) # (N_test, N_log, N_lat, N_lon) - + # (N_test, N_log, N_lat, N_lon) if obj.trainer.is_global_zero: + spatial_loss_tensor = torch.cat(self.spatial_loss_maps, dim=0) mean_spatial_loss = torch.mean( spatial_loss_tensor, dim=0 ) # (N_log, N_lat, N_lon)