Skip to content

Commit

Permalink
Merge branch 'inference_subclass' into inference_subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
CorentinSeznec authored Sep 17, 2024
2 parents 0266ead + 0f92f63 commit 09d61e2
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 132 deletions.
39 changes: 0 additions & 39 deletions .gitlab-ci.yml

This file was deleted.

2 changes: 2 additions & 0 deletions bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
else:
config_override = {"num_inference_pred_steps": args.infer_steps}


# Get dataset for inference
_, _, infer_ds = get_datasets(
args.dataset,
Expand All @@ -41,6 +42,7 @@
)

# Transform in dataloader

dl_settings = TorchDataloaderSettings(batch_size=1)
infer_loader = infer_ds.torch_dataloader(dl_settings)
trainer = Trainer(devices="auto")
Expand Down
7 changes: 7 additions & 0 deletions config/models/unetrpp8512_linear_up_d2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"num_heads": 8,
"hidden_size": 512,
"linear_upsampling": true,
"downsampling_rate": 2
}

2 changes: 1 addition & 1 deletion doc/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Currently we support the following neural network architectures:
| [segformer](../py4cast/models/vision/transformers.py#L216) | [arxiv link](https://arxiv.org/abs/2105.15203) | (Batch, Height, Width, features) | On par with u-net like on Deepsyg (MF internal), added an upsampling stage. Adapted from [Lucidrains' github](https://github.com/lucidrains/segformer-pytorch) | Frank Guibert |
| [swinunetr](../py4cast/models/vision/transformers.py#L335) | [arxiv link](https://arxiv.org/abs/2201.01266) | (Batch, Height, Width, features) | 2D Swin Unet transformer (Pangu and archweather uses customised 3D versions of Swin Transformers). Plugged in from [MONAI](https://github.com/Project-MONAI/MONAI/). The decoders have been modified to use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Frank Guibert |
| [hilam](../py4cast/models/nlam/models.py#L754), graphlam | [arxiv link](https://arxiv.org/abs/2309.17370) | (Batch, graph_node_id, features) | Imported and adapted from [Joel's github](https://github.com/joeloskarsson/neural-lam) | Vincent Chabot/Frank Guibert |
| [unetrpp](../py4cast/models/vision/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs | Frank Guibert |
| [unetrpp](../py4cast/models/vision/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs. Changed Upsampling to use linear upsampling. Made stem layer downsampling rate a parameter. | Frank Guibert |

## Available datasets

Expand Down
6 changes: 3 additions & 3 deletions py4cast/datasets/poesy.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ class InferSample(Sample):
def __post_init__(self):
self.terms = self.input_terms


class PoesyDataset(DatasetABC, Dataset):
def __init__(
self, grid: Grid, period: Period, params: List[Param], settings: PoesySettings
Expand Down Expand Up @@ -839,7 +838,7 @@ class InferPoesyDataset(PoesyDataset):
Inherite from the PoesyDataset class.
This class is used for inference, the class overrides methods sample_list and from_json.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -903,6 +902,7 @@ def from_json(
num_pred_steps_train: int,
num_pred_steps_val_tests: int,
config_override: Union[Dict, None] = None,

) -> Tuple[None, None, "InferPoesyDataset"]:
"""
Return 1 InferPoesyDataset.
Expand Down Expand Up @@ -950,13 +950,13 @@ def from_json(
term=term,
num_input_steps=num_input_steps,
num_output_steps=0,

num_inference_pred_steps=config_override["num_inference_pred_steps"],
),
)
return None, None, ds



if __name__ == "__main__":

path_config = "config/datasets/poesy.json"
Expand Down
7 changes: 6 additions & 1 deletion py4cast/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(self, hparams: ArLightningHyperParam, *args, **kwargs):
max_pred_step = self.hparams["hparams"].num_pred_steps_val_test - 1
self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step)
self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step)
self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)

@rank_zero_only
def log_hparams_tb(self):
Expand Down Expand Up @@ -453,9 +454,9 @@ def on_validation_start(self):
l1_loss.prepare(self, self.interior_mask, self.hparams["hparams"].dataset_info)
metrics = {"mae": l1_loss}
save_path = self.hparams["hparams"].save_path

self.rmse_metric = MetricRMSE()
self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)

self.valid_plotters = [
StateErrorPlot(metrics, prefix="Validation"),
PredictionTimestepPlot(
Expand Down Expand Up @@ -564,6 +565,7 @@ def on_test_start(self):
self.psd_plot_metric = MetricPSDK(save_path, pred_step=max_pred_step)
self.rmse_metric = MetricRMSE()
self.acc_metric = MetricACC(self.hparams["hparams"].dataset_info)

self.test_plotters = [
StateErrorPlot(metrics, save_path=save_path),
SpatialErrorPlot(),
Expand All @@ -584,11 +586,13 @@ def test_step(self, batch: ItemBatch, batch_idx):
# Notify plotters & metrics
for plotter in self.test_plotters:
plotter.update(self, prediction=prediction, target=target)

self.psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_psd_plot_metric.update(prediction, target, self.original_shape)
self.rmse_metric.update(prediction, target)
self.acc_metric.update(prediction, target)


@cached_property
def interior_2d(self) -> torch.Tensor:
"""
Expand All @@ -613,6 +617,7 @@ def on_test_epoch_end(self):
self.rmse_psd_plot_metric.compute()
self.rmse_metric.compute()
self.acc_metric.compute()

# Notify plotters that the test epoch end
for plotter in self.test_plotters:
plotter.on_step_end(self, label="Test")
78 changes: 4 additions & 74 deletions py4cast/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,79 +342,6 @@ def power_spectral_density(x: np.ndarray) -> np.ndarray:
return out


class MetricRMSE(Metric):
"""
Compute the spatially averaged, per pred_step RMSE between the target and the prediction for each channels/features.
"""

def __init__(self):
super().__init__()

# Declaration of state, states are reset when self.reset() is called
# Sum MSE at each epoch
self.add_state(
"sum_rmse",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)

# Step counter, needed to compute RMSE at each epoch.
self.add_state("step_count", default=torch.tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: NamedTensor, target: NamedTensor):
"""
compute the RMSE between target and pred.
prediction/target: (B, pred_steps, N_grid, d_f) or (B, pred_steps, W, H, d_f)
called at each end of step
"""
if self.step_count == 0:
self.feature_names = preds.feature_names
self.pred_steps = preds.tensor.shape[1]

features = preds.tensor.shape[-1]
# a priori unknown number of spatial dims
# but they are all after pred_steps and before features
spatial_dims = tuple(preds.spatial_dim_idx)

if preds.tensor.shape != target.tensor.shape:
raise ValueError("preds and target must have the same shape")

res = torch.sqrt((preds.tensor - target.tensor) ** 2).mean(
dim=(0, *spatial_dims)
)

# Initialize sum_rmse as a tensor of dim (nb_features)
if not self.sum_rmse.ndim: # self.sum_rmse not yet initalized as a tensor
self.sum_rmse = torch.zeros(self.pred_steps, features, device=self.device)

# Add RMSE for this channel
self.sum_rmse += res.to(device=self.device)

# Increment step_count
self.step_count += 1

def compute(self) -> dict:
"""
Compute RMSE mean for each channels/features, return a dict.
Should be called at each epoch's end
"""
# Compute mean RMSE over an epoch
mean_rmse = self.sum_rmse / self.step_count
feature_names = self.feature_names

# dict {"feature_name" : RMSE mean}
metric_log_dict = {
f"val_rmse/{name}_step_{j}": mean_rmse[j, i]
for i, name in enumerate(feature_names)
for j in range(self.pred_steps)
}

# Reset metric's state
self.reset()

return metric_log_dict


class MetricACC(Metric):
"""
Compute the spatially averaged, per pred step Anomaly Correlation Coefficent between the target and the prediction
Expand All @@ -436,7 +363,7 @@ def __init__(self, dataset_info: DatasetInfo):
dataset_info.shortnames["input_output"]
+ dataset_info.shortnames["diagnostic"]
)
self.climate_means = dataset_info.stats.to_list("mean", names)
self.climate_means = dataset_info.stats.to_list("mean", names).to(self.device)

# adding sum of acc coefficient (to compute mean on each epoch)
self.add_state(
Expand All @@ -454,6 +381,9 @@ def update(self, preds: NamedTensor, target: NamedTensor):
prediction/target: (B, pred_steps, N_grid, d_f) or (B, pred_steps, W, H, d_f)
called at each end of step
"""

if self.step_count == 0 and self.climate_means.device != preds.tensor.device:
self.climate_means = self.climate_means.to(preds.tensor.device)
if self.step_count == 0:
self.feature_names = preds.feature_names
self.pred_steps = preds.tensor.shape[1]
Expand Down
41 changes: 29 additions & 12 deletions py4cast/models/vision/unetrpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __init__(
in_channels=4,
dropout=0.0,
transformer_dropout_rate=0.1,
**kwargs,
downsampling_rate: int = 4,
):
super().__init__()

Expand All @@ -316,8 +316,8 @@ def __init__(
spatial_dims,
in_channels,
dims[0],
kernel_size=4,
stride=4,
kernel_size=downsampling_rate,
stride=downsampling_rate,
dropout=dropout,
conv_only=True,
),
Expand Down Expand Up @@ -530,6 +530,7 @@ class UNETRPPSettings:
do_ds = False
spatial_dims = 2
linear_upsampling: bool = False
downsampling_rate: int = 4


class UNETRPP(ModelABC, nn.Module):
Expand Down Expand Up @@ -578,33 +579,49 @@ def __init__(
raise KeyError(
f"Position embedding layer of type {settings.pos_embed} is not supported."
)

# we have first a stem layer with stride=subsampling_rate and k_size=subsampling_rate
# followed by 3 successive downsampling layer (k=2, stride=2)
dim_divider = (2**3) * settings.downsampling_rate
if settings.spatial_dims == 2:
self.feat_size = (input_shape[0] // 32, input_shape[1] // 32)
self.feat_size = (
input_shape[0] // dim_divider,
input_shape[1] // dim_divider,
)
else:
self.feat_size = (
input_shape[0] // 32,
input_shape[1] // 32,
input_shape[2] // 32,
input_shape[0] // dim_divider,
input_shape[1] // dim_divider,
input_shape[2] // dim_divider,
)

self.hidden_size = settings.hidden_size
self.spatial_dims = settings.spatial_dims
no_pixels = (input_shape[0] * input_shape[1]) // 16
# Number of pixels after stem layer
no_pixels = (input_shape[0] * input_shape[1]) // (
settings.downsampling_rate**2
)
encoder_input_size = [
no_pixels,
no_pixels // 4,
no_pixels // 16,
no_pixels // 64,
]
h_size = settings.hidden_size

self.unetr_pp_encoder = UnetrPPEncoder(
input_size=encoder_input_size,
dims=(h_size // 8, h_size // 4, h_size // 2, h_size),
dims=(
h_size // 8,
h_size // 4,
h_size // 2,
h_size,
),
proj_size=[64, 64, 64, 32],
depths=settings.depths,
num_heads=settings.num_heads,
spatial_dims=settings.spatial_dims,
in_channels=num_input_features,
downsampling_rate=settings.downsampling_rate,
)

self.encoder1 = UnetResBlock(
Expand Down Expand Up @@ -650,9 +667,9 @@ def __init__(
in_channels=settings.hidden_size // 8,
out_channels=settings.hidden_size // 16,
kernel_size=3,
upsample_kernel_size=4,
upsample_kernel_size=settings.downsampling_rate,
norm_name=settings.norm_name,
out_size=no_pixels * 16,
out_size=no_pixels * (settings.downsampling_rate**2),
conv_decoder=True,
linear_upsampling=settings.linear_upsampling,
)
Expand Down
6 changes: 4 additions & 2 deletions py4cast/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def gather(
Be careful if you are doing something else than plotting results.
"""
if isinstance(obj.trainer.strategy, SingleDeviceStrategy):
loss_tensor = obj.trainer.strategy.all_gather(tensor_to_gather)
loss_tensor = obj.trainer.strategy.all_gather(tensor_to_gather).cpu()
elif isinstance(obj.trainer.strategy, ParallelStrategy):
loss_tensor = obj.trainer.strategy.all_gather(tensor_to_gather).flatten(0, 1)
loss_tensor = (
obj.trainer.strategy.all_gather(tensor_to_gather).flatten(0, 1).cpu()
)
else:
raise TypeError(
f"Unknwon type {obj.trainer.strategy}. Know {SingleDeviceStrategy} and {ParallelStrategy}"
Expand Down

0 comments on commit 09d61e2

Please sign in to comment.