From e5c245f21d01ca7379ee1a47a3de448cc02e8a59 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 5 May 2024 20:21:36 +0200
Subject: [PATCH 001/273] yaml_config for cosmo data
---
neural_lam/data_config.yaml | 130 ++++++++++++++++++++++++++++++++++++
1 file changed, 130 insertions(+)
create mode 100644 neural_lam/data_config.yaml
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
new file mode 100644
index 00000000..0aedd7fe
--- /dev/null
+++ b/neural_lam/data_config.yaml
@@ -0,0 +1,130 @@
+zarrs: # List of zarrs containing fields related to state
+ state:
+ path: /scratch/sadamov/template.zarr # Path to zarr
+ dims: # Name of dimensions in zarr, to be used for indexing
+ time: time
+ level: z
+ x: x # Either give "grid" (flattened) dimension or "x" and "y"
+ y: y
+ static:
+ path: /scratch/sadamov/template.zarr
+ dims:
+ level: z
+ x: x
+ y: y
+ forcing:
+ path: /scratch/sadamov/template.zarr
+ dims:
+ time: time
+ level: z
+ x: x
+ y: y
+ boundary:
+ zarrs:
+ mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary.
+state: # Variables forecasted by the model
+ surface: # Single-field variables
+ - CLCT
+ - PMSL
+ - PS
+ - T_2M
+ - TOT_PREC
+ - U_10M
+ - V_10M
+ surface_units:
+ - "%"
+ - Pa
+ - Pa
+ - K
+ - kg/m^2
+ - m/s
+ - m/s
+ atmosphere: # Variables with vertical levels
+ - PP
+ - QV
+ - RELHUM
+ - T
+ - U
+ - V
+ - W
+ atmosphere_units:
+ - Pa
+ - kg/kg
+ - "%"
+ - K
+ - m/s
+ - m/s
+ - Pa/s
+ levels: # Levels to use for atmosphere variables
+ - 0
+ - 5
+ - 8
+ - 11
+ - 13
+ - 15
+ - 19
+ - 22
+ - 26
+ - 30
+ - 38
+ - 44
+ - 59
+static: # Static inputs
+ surface:
+ - HSURF
+ surface_units:
+ - m
+ atmosphere:
+ - FI
+ atmosphere_units:
+ - m^2/s^2
+ levels:
+ - 0
+ - 5
+ - 8
+ - 11
+ - 13
+ - 15
+ - 19
+ - 22
+ - 26
+ - 30
+ - 38
+ - 44
+ - 59
+forcing: # Forcing variables, dynamic inputs to the model
+ surface:
+ - ASOB_S
+ surface_units:
+ - W/m^2
+ atmosphere:
+ atmosphere_units:
+ levels:
+boundary: # Boundary conditions
+ surface:
+ surface_units:
+ atmosphere:
+ atmosphere_units:
+ levels:
+lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells
+ lat: lat
+ lon: lon
+grid_shape:
+ x: 582
+ y: 390
+splits:
+ train:
+ start: 2015-01-01T00
+ end: 2024-12-31T23
+ val:
+ start: 2015-01-01T00
+ end: 2024-12-31T23
+ test:
+ start: 2015-01-01T00
+ end: 2024-12-31T23
+projection:
+ class: RotatedPole # Name of class in cartopy.crs
+ kwargs: # Parsed and used directly as kwargs to projection-class above
+ pole_longitude: 10.0
+ pole_latitude: -43.0
+normalization_zarr: data/meps_example/norm.zarr
From 33e7ecf5fdf1949e5f055f465a337a40076ad91c Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 5 May 2024 20:21:55 +0200
Subject: [PATCH 002/273] initial version of single zarr dataset
---
neural_lam/weather_dataset.py | 437 ++++++++++++++++------------------
1 file changed, 205 insertions(+), 232 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index eeefc313..8c9e0072 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,260 +1,233 @@
# Standard library
-import datetime as dt
-import glob
import os
# Third-party
-import numpy as np
+import pytorch_lightning as pl
import torch
+import xarray as xr
+import yaml
-# First-party
-from neural_lam import constants, utils
+
+class ConfigLoader:
+ """
+ Class for loading configuration files.
+
+ This class loads a YAML configuration file and provides a way to access
+ its values as attributes.
+ """
+
+ def __init__(self, config_path, values=None):
+ self.config_path = config_path
+ if values is None:
+ self.values = self.load_config()
+ else:
+ self.values = values
+
+ def load_config(self):
+ with open(self.config_path, "r") as file:
+ return yaml.safe_load(file)
+
+ def __getattr__(self, name):
+ keys = name.split(".")
+ value = self.values
+ for key in keys:
+ if key in value:
+ value = value[key]
+ else:
+ None
+ if isinstance(value, dict):
+ return ConfigLoader(None, values=value)
+ return value
+
+ def __getitem__(self, key):
+ value = self.values[key]
+ if isinstance(value, dict):
+ return ConfigLoader(None, values=value)
+ return value
+
+ def __contains__(self, key):
+ return key in self.values
class WeatherDataset(torch.utils.data.Dataset):
"""
- For our dataset:
- N_t' = 65
- N_t = 65//subsample_step (= 21 for 3h steps)
- dim_x = 268
- dim_y = 238
- N_grid = 268x238 = 63784
- d_features = 17 (d_features' = 18)
- d_forcing = 5
+ Dataset class for weather data.
+
+ This class loads and processes weather data from zarr files based on the
+ provided configuration. It supports splitting the data into train,
+ validation, and test sets.
"""
+ def process_dataset(self, dataset_name):
+ """
+ Process a single dataset specified by the dataset name.
+
+ Args:
+ dataset_name (str): Name of the dataset to process.
+
+ Returns:
+ xarray.Dataset: Processed dataset.
+ """
+
+ dataset_path = os.path.join(self.config_loader.zarrs[dataset_name].path)
+ dataset = xr.open_zarr(dataset_path, consolidated=True)
+
+ start, end = self.config_loader.splits[self.split].start, self.config_loader.splits[self.split].end
+ dataset = dataset.sel(time=slice(start, end))
+ dataset = dataset.rename_dims(
+ {v: k for k, v in self.config_loader.zarrs[dataset_name].dims.values.items()
+ if k not in dataset.dims})
+ if 'grid' not in dataset.dims:
+ dataset = dataset.stack(grid=('x', 'y'))
+
+ vars_surface = []
+ if self.config_loader[dataset_name].surface:
+ vars_surface = dataset[self.config_loader[dataset_name].surface]
+
+ vars_atmosphere = []
+ if self.config_loader[dataset_name].atmosphere:
+ vars_atmosphere = xr.merge(
+ [dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}")
+ for var in self.config_loader[dataset_name].atmosphere
+ for level in self.config_loader[dataset_name].levels])
+
+ if vars_surface and vars_atmosphere:
+ dataset = xr.merge([vars_surface, vars_atmosphere])
+ elif vars_surface:
+ dataset = vars_surface
+ elif vars_atmosphere:
+ dataset = vars_atmosphere
+ else:
+ raise ValueError(f"No variables specified for dataset: {dataset_name}")
+
+ dataset = dataset.squeeze(drop=True).to_array()
+ if "time" in dataset.dims:
+ dataset = dataset.transpose("time", "grid", "variable")
+ else:
+ dataset = dataset.transpose("grid", "variable")
+ return dataset
+
def __init__(
self,
- dataset_name,
- pred_length=19,
split="train",
- subsample_step=3,
- standardize=True,
- subset=False,
+ batch_size=4,
+ ar_steps=3,
control_only=False,
+ yaml_path="neural_lam/data_config.yaml",
):
super().__init__()
- assert split in ("train", "val", "test"), "Unknown dataset split"
- self.sample_dir_path = os.path.join(
- "data", dataset_name, "samples", split
- )
+ assert split in (
+ "train",
+ "val",
+ "test",
+ ), "Unknown dataset split"
- member_file_regexp = (
- "nwp*mbr000.npy" if control_only else "nwp*mbr*.npy"
- )
- sample_paths = glob.glob(
- os.path.join(self.sample_dir_path, member_file_regexp)
- )
- self.sample_names = [path.split("/")[-1][4:-4] for path in sample_paths]
- # Now on form "yyymmddhh_mbrXXX"
-
- if subset:
- self.sample_names = self.sample_names[:50] # Limit to 50 samples
-
- self.sample_length = pred_length + 2 # 2 init states
- self.subsample_step = subsample_step
- self.original_sample_length = (
- 65 // self.subsample_step
- ) # 21 for 3h steps
- assert (
- self.sample_length <= self.original_sample_length
- ), "Requesting too long time series samples"
-
- # Set up for standardization
- self.standardize = standardize
- if standardize:
- ds_stats = utils.load_dataset_stats(dataset_name, "cpu")
- self.data_mean, self.data_std, self.flux_mean, self.flux_std = (
- ds_stats["data_mean"],
- ds_stats["data_std"],
- ds_stats["flux_mean"],
- ds_stats["flux_std"],
- )
+ self.split = split
+ self.batch_size = batch_size
+ self.ar_steps = ar_steps
+ self.control_only = control_only
+ self.config_loader = ConfigLoader(yaml_path)
- # If subsample index should be sampled (only duing training)
- self.random_subsample = split == "train"
+ self.state = self.process_dataset("state")
+ self.static = self.process_dataset("static")
+ self.forcings = self.process_dataset("forcing")
+ # self.boundary = self.process_dataset("boundary")
+
+ self.static = self.static.expand_dims({"time": self.state.time}, axis=0)
+ self.ds = xr.concat([self.state, self.static], dim="variable")
def __len__(self):
- return len(self.sample_names)
+ return len(self.ds.time) - self.ar_steps
def __getitem__(self, idx):
- # === Sample ===
- sample_name = self.sample_names[idx]
- sample_path = os.path.join(
- self.sample_dir_path, f"nwp_{sample_name}.npy"
- )
- try:
- full_sample = torch.tensor(
- np.load(sample_path), dtype=torch.float32
- ) # (N_t', dim_x, dim_y, d_features')
- except ValueError:
- print(f"Failed to load {sample_path}")
-
- # Only use every ss_step:th time step, sample which of ss_step
- # possible such time series
- if self.random_subsample:
- subsample_index = torch.randint(0, self.subsample_step, ()).item()
- else:
- subsample_index = 0
- subsample_end_index = self.original_sample_length * self.subsample_step
- sample = full_sample[
- subsample_index : subsample_end_index : self.subsample_step
- ]
- # (N_t, dim_x, dim_y, d_features')
-
- # Remove feature 15, "z_height_above_ground"
- sample = torch.cat(
- (sample[:, :, :, :15], sample[:, :, :, 16:]), dim=3
- ) # (N_t, dim_x, dim_y, d_features)
-
- # Accumulate solar radiation instead of just subsampling
- rad_features = full_sample[:, :, :, 2:4] # (N_t', dim_x, dim_y, 2)
- # Accumulate for first time step
- init_accum_rad = torch.sum(
- rad_features[: (subsample_index + 1)], dim=0, keepdim=True
- ) # (1, dim_x, dim_y, 2)
- # Accumulate for rest of subsampled sequence
- in_subsample_len = (
- subsample_end_index - self.subsample_step + subsample_index + 1
- )
- rad_features_in_subsample = rad_features[
- (subsample_index + 1) : in_subsample_len
- ] # (N_t*, dim_x, dim_y, 2), N_t* = (N_t-1)*ss_step
- _, dim_x, dim_y, _ = sample.shape
- rest_accum_rad = torch.sum(
- rad_features_in_subsample.view(
- self.original_sample_length - 1,
- self.subsample_step,
- dim_x,
- dim_y,
- 2,
- ),
- dim=1,
- ) # (N_t-1, dim_x, dim_y, 2)
- accum_rad = torch.cat(
- (init_accum_rad, rest_accum_rad), dim=0
- ) # (N_t, dim_x, dim_y, 2)
- # Replace in sample
- sample[:, :, :, 2:4] = accum_rad
-
- # Flatten spatial dim
- sample = sample.flatten(1, 2) # (N_t, N_grid, d_features)
-
- # Uniformly sample time id to start sample from
- init_id = torch.randint(
- 0, 1 + self.original_sample_length - self.sample_length, ()
+ sample = self.ds.isel(time=slice(idx, idx + self.ar_steps))
+ forcings = self.forcings.isel(time=slice(idx, idx + self.ar_steps))
+ sample = torch.tensor(sample.values, dtype=torch.float32)
+ forcings = torch.tensor(forcings.values, dtype=torch.float32)
+
+ init_states = sample[:2]
+ target_states = sample[2:]
+
+ batch_times = self.ds.isel(
+ time=slice(
+ idx,
+ idx +
+ self.ar_steps)).time.values.astype(str).tolist()
+
+ # init_states: (2, N_grid, d_features)
+ # target_states: (ar_steps-2, N_grid, d_features)
+ # forcings: (ar_steps, N_grid, d_windowed_forcings)
+ # batch_times: (ar_steps,)
+ return init_states, target_states, forcings, batch_times
+
+
+class WeatherDataModule(pl.LightningDataModule):
+ """DataModule for weather data."""
+
+ def __init__(
+ self,
+ batch_size=4,
+ num_workers=16,
+ ):
+ super().__init__()
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+
+ def setup(self, stage=None):
+ if stage == "fit" or stage is None:
+ self.train_dataset = WeatherDataset(
+ split="train",
+ batch_size=self.batch_size,
+ )
+ self.val_dataset = WeatherDataset(
+ split="val",
+ batch_size=self.batch_size,
+ )
+
+ if stage == "test" or stage is None:
+ self.test_dataset = WeatherDataset(
+ split="test",
+ batch_size=self.batch_size,
+ )
+
+ def train_dataloader(self):
+ """Load train dataset."""
+ return torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
)
- sample = sample[init_id : (init_id + self.sample_length)]
- # (sample_length, N_grid, d_features)
-
- if self.standardize:
- # Standardize sample
- sample = (sample - self.data_mean) / self.data_std
-
- # Split up sample in init. states and target states
- init_states = sample[:2] # (2, N_grid, d_features)
- target_states = sample[2:] # (sample_length-2, N_grid, d_features)
-
- # === Forcing features ===
- # Now batch-static features are just part of forcing,
- # repeated over temporal dimension
- # Load water coverage
- sample_datetime = sample_name[:10]
- water_path = os.path.join(
- self.sample_dir_path, f"wtr_{sample_datetime}.npy"
+
+ def val_dataloader(self):
+ """Load validation dataset."""
+ return torch.utils.data.DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
)
- water_cover_features = torch.tensor(
- np.load(water_path), dtype=torch.float32
- ).unsqueeze(
- -1
- ) # (dim_x, dim_y, 1)
- # Flatten
- water_cover_features = water_cover_features.flatten(0, 1) # (N_grid, 1)
- # Expand over temporal dimension
- water_cover_expanded = water_cover_features.unsqueeze(0).expand(
- self.sample_length - 2, -1, -1 # -2 as added on after windowing
- ) # (sample_len, N_grid, 1)
-
- # TOA flux
- flux_path = os.path.join(
- self.sample_dir_path,
- f"nwp_toa_downwelling_shortwave_flux_{sample_datetime}.npy",
+
+ def test_dataloader(self):
+ """Load test dataset."""
+ return torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
)
- flux = torch.tensor(np.load(flux_path), dtype=torch.float32).unsqueeze(
- -1
- ) # (N_t', dim_x, dim_y, 1)
-
- if self.standardize:
- flux = (flux - self.flux_mean) / self.flux_std
-
- # Flatten and subsample flux forcing
- flux = flux.flatten(1, 2) # (N_t, N_grid, 1)
- flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1)
- flux = flux[
- init_id : (init_id + self.sample_length)
- ] # (sample_len, N_grid, 1)
-
- # Time of day and year
- dt_obj = dt.datetime.strptime(sample_datetime, "%Y%m%d%H")
- dt_obj = dt_obj + dt.timedelta(
- hours=2 + subsample_index
- ) # Offset for first index
- # Extract for initial step
- init_hour_in_day = dt_obj.hour
- start_of_year = dt.datetime(dt_obj.year, 1, 1)
- init_seconds_into_year = (dt_obj - start_of_year).total_seconds()
-
- # Add increments for all steps
- hour_inc = (
- torch.arange(self.sample_length) * self.subsample_step
- ) # (sample_len,)
- hour_of_day = (
- init_hour_in_day + hour_inc
- ) # (sample_len,), Can be > 24 but ok
- second_into_year = (
- init_seconds_into_year + hour_inc * 3600
- ) # (sample_len,)
- # can roll over to next year, ok because periodicity
-
- # Encode as sin/cos
- hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,)
- year_angle = (
- (second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi
- ) # (sample_len,)
- datetime_forcing = torch.stack(
- (
- torch.sin(hour_angle),
- torch.cos(hour_angle),
- torch.sin(year_angle),
- torch.cos(year_angle),
- ),
- dim=1,
- ) # (N_t, 4)
- datetime_forcing = (datetime_forcing + 1) / 2 # Rescale to [0,1]
- datetime_forcing = datetime_forcing.unsqueeze(1).expand(
- -1, flux.shape[1], -1
- ) # (sample_len, N_grid, 4)
-
- # Put forcing features together
- forcing_features = torch.cat(
- (flux, datetime_forcing), dim=-1
- ) # (sample_len, N_grid, d_forcing)
-
- # Combine forcing over each window of 3 time steps
- forcing_windowed = torch.cat(
- (
- forcing_features[:-2],
- forcing_features[1:-1],
- forcing_features[2:],
- ),
- dim=2,
- ) # (sample_len-2, N_grid, 3*d_forcing)
- # Now index 0 of ^ corresponds to forcing at index 0-2 of sample
-
- # batch-static water cover is added after windowing,
- # as it is static over time
- forcing = torch.cat((water_cover_expanded, forcing_windowed), dim=2)
- # (sample_len-2, N_grid, forcing_dim)
-
- return init_states, target_states, forcing
+
+
+data_module = WeatherDataModule(batch_size=4, num_workers=0)
+data_module.setup()
+train_dataloader = data_module.train_dataloader()
+for batch in train_dataloader:
+ print(batch[0].shape)
+ print(batch[1].shape)
+ print(batch[2].shape)
+ print(batch[3])
+ break
From 9936e3bcb9b32a81908af3cdde196b89a5e3b5ee Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Mon, 6 May 2024 23:12:15 +0200
Subject: [PATCH 003/273] handling None zarrs
---
neural_lam/data_config.yaml | 4 +-
neural_lam/weather_dataset.py | 80 ++++++++++++++++++++++++-----------
2 files changed, 57 insertions(+), 27 deletions(-)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 0aedd7fe..af2672af 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -20,7 +20,7 @@ zarrs: # List of zarrs containing fields related to state
x: x
y: y
boundary:
- zarrs:
+ path:
mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary.
state: # Variables forecasted by the model
surface: # Single-field variables
@@ -127,4 +127,4 @@ projection:
kwargs: # Parsed and used directly as kwargs to projection-class above
pole_longitude: 10.0
pole_latitude: -43.0
-normalization_zarr: data/meps_example/norm.zarr
+normalization_zarr: /scratch/sadamov/norm.zarr
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 8c9e0072..968a985f 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,5 +1,6 @@
# Standard library
import os
+from functools import lru_cache
# Third-party
import pytorch_lightning as pl
@@ -27,6 +28,7 @@ def load_config(self):
with open(self.config_path, "r") as file:
return yaml.safe_load(file)
+ @lru_cache(maxsize=None)
def __getattr__(self, name):
keys = name.split(".")
value = self.values
@@ -69,16 +71,26 @@ def process_dataset(self, dataset_name):
xarray.Dataset: Processed dataset.
"""
- dataset_path = os.path.join(self.config_loader.zarrs[dataset_name].path)
+ dataset_path = self.config_loader.zarrs[dataset_name].path
+ if dataset_path is None or not os.path.exists(dataset_path):
+ print(f"Dataset '{dataset_name}' not found at path: {dataset_path}")
+ return None
dataset = xr.open_zarr(dataset_path, consolidated=True)
- start, end = self.config_loader.splits[self.split].start, self.config_loader.splits[self.split].end
+ start, end = (
+ self.config_loader.splits[self.split].start,
+ self.config_loader.splits[self.split].end,
+ )
dataset = dataset.sel(time=slice(start, end))
dataset = dataset.rename_dims(
- {v: k for k, v in self.config_loader.zarrs[dataset_name].dims.values.items()
- if k not in dataset.dims})
- if 'grid' not in dataset.dims:
- dataset = dataset.stack(grid=('x', 'y'))
+ {
+ v: k
+ for k, v in self.config_loader.zarrs[dataset_name].dims.values.items()
+ if k not in dataset.dims
+ }
+ )
+ if "grid" not in dataset.dims:
+ dataset = dataset.stack(grid=("x", "y"))
vars_surface = []
if self.config_loader[dataset_name].surface:
@@ -87,9 +99,12 @@ def process_dataset(self, dataset_name):
vars_atmosphere = []
if self.config_loader[dataset_name].atmosphere:
vars_atmosphere = xr.merge(
- [dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}")
- for var in self.config_loader[dataset_name].atmosphere
- for level in self.config_loader[dataset_name].levels])
+ [
+ dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}")
+ for var in self.config_loader[dataset_name].atmosphere
+ for level in self.config_loader[dataset_name].levels
+ ]
+ )
if vars_surface and vars_atmosphere:
dataset = xr.merge([vars_surface, vars_atmosphere])
@@ -98,7 +113,8 @@ def process_dataset(self, dataset_name):
elif vars_atmosphere:
dataset = vars_atmosphere
else:
- raise ValueError(f"No variables specified for dataset: {dataset_name}")
+ print("No variables found in dataset {dataset_name}")
+ return None
dataset = dataset.squeeze(drop=True).to_array()
if "time" in dataset.dims:
@@ -130,36 +146,49 @@ def __init__(
self.config_loader = ConfigLoader(yaml_path)
self.state = self.process_dataset("state")
+ assert self.state is not None, "State dataset not found"
self.static = self.process_dataset("static")
self.forcings = self.process_dataset("forcing")
- # self.boundary = self.process_dataset("boundary")
+ self.boundary = self.process_dataset("boundary")
- self.static = self.static.expand_dims({"time": self.state.time}, axis=0)
- self.ds = xr.concat([self.state, self.static], dim="variable")
+ if self.static is not None:
+ self.static = self.static.expand_dims({"time": self.state.time}, axis=0)
+ self.state = xr.concat([self.state, self.static], dim="variable")
def __len__(self):
- return len(self.ds.time) - self.ar_steps
+ return len(self.state.time) - self.ar_steps
def __getitem__(self, idx):
- sample = self.ds.isel(time=slice(idx, idx + self.ar_steps))
- forcings = self.forcings.isel(time=slice(idx, idx + self.ar_steps))
- sample = torch.tensor(sample.values, dtype=torch.float32)
- forcings = torch.tensor(forcings.values, dtype=torch.float32)
+ sample = torch.tensor(
+ self.state.isel(time=slice(idx, idx + self.ar_steps)).values,
+ dtype=torch.float32,
+ )
+
+ forcings = torch.tensor(
+ self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values,
+ dtype=torch.float32,
+ ) if self.forcings is not None else torch.tensor([])
+
+ boundary = torch.tensor(
+ self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values,
+ dtype=torch.float32,
+ ) if self.boundary is not None else torch.tensor([])
init_states = sample[:2]
target_states = sample[2:]
- batch_times = self.ds.isel(
- time=slice(
- idx,
- idx +
- self.ar_steps)).time.values.astype(str).tolist()
+ batch_times = (
+ self.state.isel(time=slice(idx, idx + self.ar_steps))
+ .time.values.astype(str)
+ .tolist()
+ )
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps-2, N_grid, d_features)
# forcings: (ar_steps, N_grid, d_windowed_forcings)
+ # boundary: (ar_steps, N_grid, d_windowed_boundary)
# batch_times: (ar_steps,)
- return init_states, target_states, forcings, batch_times
+ return init_states, target_states, forcings, boundary, batch_times
class WeatherDataModule(pl.LightningDataModule):
@@ -229,5 +258,6 @@ def test_dataloader(self):
print(batch[0].shape)
print(batch[1].shape)
print(batch[2].shape)
- print(batch[3])
+ print(batch[3].shape)
+ print(batch[4])
break
From 774d16a7685f912a74f5a19cc0faba4d4ab09eb7 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Wed, 8 May 2024 11:08:18 +0200
Subject: [PATCH 004/273] removed all dependencies on constants.py user configs
are retrieved either from data_config.yaml or they are set as flags to
train_model.py Capabilities of ConfigLoader class extended
---
create_parameter_weights.py | 13 +---
neural_lam/constants.py | 120 ----------------------------------
neural_lam/models/ar_model.py | 53 ++++++++-------
neural_lam/utils.py | 6 +-
neural_lam/vis.py | 24 ++++---
neural_lam/weather_dataset.py | 35 ++++++----
train_model.py | 37 ++++++++++-
7 files changed, 102 insertions(+), 186 deletions(-)
delete mode 100644 neural_lam/constants.py
diff --git a/create_parameter_weights.py b/create_parameter_weights.py
index 494a5e81..6956d4ca 100644
--- a/create_parameter_weights.py
+++ b/create_parameter_weights.py
@@ -8,7 +8,6 @@
from tqdm import tqdm
# First-party
-from neural_lam import constants
from neural_lam.weather_dataset import WeatherDataset
@@ -45,6 +44,7 @@ def main():
static_dir_path = os.path.join("data", args.dataset, "static")
+ ds = WeatherDataset()
# Create parameter weights based on height
# based on fig A.1 in graph cast paper
w_dict = {
@@ -56,7 +56,7 @@ def main():
"500": 0.03,
}
w_list = np.array(
- [w_dict[par.split("_")[-2]] for par in constants.PARAM_NAMES]
+ [w_dict[par.split("_")[-2]] for par in ds.config_loader.param_names()]
)
print("Saving parameter weights...")
np.save(
@@ -65,13 +65,6 @@ def main():
)
# Load dataset without any subsampling
- ds = WeatherDataset(
- args.dataset,
- split="train",
- subsample_step=1,
- pred_length=63,
- standardize=False,
- ) # Without standardization
loader = torch.utils.data.DataLoader(
ds, args.batch_size, shuffle=False, num_workers=args.n_workers
)
@@ -133,7 +126,7 @@ def main():
# Note: batch contains only 1h-steps
stepped_batch = torch.cat(
[
- batch[:, ss_i : used_subsample_len : args.step_length]
+ batch[:, ss_i: used_subsample_len: args.step_length]
for ss_i in range(args.step_length)
],
dim=0,
diff --git a/neural_lam/constants.py b/neural_lam/constants.py
deleted file mode 100644
index 527c31d8..00000000
--- a/neural_lam/constants.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Third-party
-import cartopy
-import numpy as np
-
-WANDB_PROJECT = "neural-lam"
-
-SECONDS_IN_YEAR = (
- 365 * 24 * 60 * 60
-) # Assuming no leap years in dataset (2024 is next)
-
-# Log prediction error for these lead times
-VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19])
-
-# Log these metrics to wandb as scalar values for
-# specific variables and lead times
-# List of metrics to watch, including any prefix (e.g. val_rmse)
-METRICS_WATCH = []
-# Dict with variables and lead times to log watched metrics for
-# Format is a dictionary that maps from a variable index to
-# a list of lead time steps
-VAR_LEADS_METRICS_WATCH = {
- 6: [2, 19], # t_2
- 14: [2, 19], # wvint_0
- 15: [2, 19], # z_1000
-}
-
-# Variable names
-PARAM_NAMES = [
- "pres_heightAboveGround_0_instant",
- "pres_heightAboveSea_0_instant",
- "nlwrs_heightAboveGround_0_accum",
- "nswrs_heightAboveGround_0_accum",
- "r_heightAboveGround_2_instant",
- "r_hybrid_65_instant",
- "t_heightAboveGround_2_instant",
- "t_hybrid_65_instant",
- "t_isobaricInhPa_500_instant",
- "t_isobaricInhPa_850_instant",
- "u_hybrid_65_instant",
- "u_isobaricInhPa_850_instant",
- "v_hybrid_65_instant",
- "v_isobaricInhPa_850_instant",
- "wvint_entireAtmosphere_0_instant",
- "z_isobaricInhPa_1000_instant",
- "z_isobaricInhPa_500_instant",
-]
-
-PARAM_NAMES_SHORT = [
- "pres_0g",
- "pres_0s",
- "nlwrs_0",
- "nswrs_0",
- "r_2",
- "r_65",
- "t_2",
- "t_65",
- "t_500",
- "t_850",
- "u_65",
- "u_850",
- "v_65",
- "v_850",
- "wvint_0",
- "z_1000",
- "z_500",
-]
-PARAM_UNITS = [
- "Pa",
- "Pa",
- "W/m\\textsuperscript{2}",
- "W/m\\textsuperscript{2}",
- "-", # unitless
- "-",
- "K",
- "K",
- "K",
- "K",
- "m/s",
- "m/s",
- "m/s",
- "m/s",
- "kg/m\\textsuperscript{2}",
- "m\\textsuperscript{2}/s\\textsuperscript{2}",
- "m\\textsuperscript{2}/s\\textsuperscript{2}",
-]
-
-# Projection and grid
-# Hard coded for now, but should eventually be part of dataset desc. files
-GRID_SHAPE = (268, 238) # (y, x)
-
-LAMBERT_PROJ_PARAMS = {
- "a": 6367470,
- "b": 6367470,
- "lat_0": 63.3,
- "lat_1": 63.3,
- "lat_2": 63.3,
- "lon_0": 15.0,
- "proj": "lcc",
-}
-
-GRID_LIMITS = [ # In projection
- -1059506.5523409774, # min x
- 1310493.4476590226, # max x
- -1331732.4471934352, # min y
- 1338267.5528065648, # max y
-]
-
-# Create projection
-LAMBERT_PROJ = cartopy.crs.LambertConformal(
- central_longitude=LAMBERT_PROJ_PARAMS["lon_0"],
- central_latitude=LAMBERT_PROJ_PARAMS["lat_0"],
- standard_parallels=(
- LAMBERT_PROJ_PARAMS["lat_1"],
- LAMBERT_PROJ_PARAMS["lat_2"],
- ),
-)
-
-# Data dimensions
-GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static
-GRID_STATE_DIM = 17
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 7d0a8320..902a89e4 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -6,10 +6,12 @@
import numpy as np
import pytorch_lightning as pl
import torch
+
import wandb
# First-party
-from neural_lam import constants, metrics, utils, vis
+from neural_lam import metrics, utils, vis
+from neural_lam.weather_dataset import ConfigLoader
class ARModel(pl.LightningModule):
@@ -25,6 +27,7 @@ def __init__(self, args):
super().__init__()
self.save_hyperparameters()
self.lr = args.lr
+ self.config_loader = ConfigLoader(args.data_config)
# Load static features for grid/data
static_data_dict = utils.load_static_data(args.dataset)
@@ -37,11 +40,11 @@ def __init__(self, args):
self.output_std = bool(args.output_std)
if self.output_std:
self.grid_output_dim = (
- 2 * constants.GRID_STATE_DIM
+ 2 * self.config_loader.num_data_vars("state")
) # Pred. dim. in grid cell
else:
self.grid_output_dim = (
- constants.GRID_STATE_DIM
+ self.config_loader.num_data_vars("state")
) # Pred. dim. in grid cell
# Store constant per-variable std.-dev. weighting
@@ -59,9 +62,9 @@ def __init__(self, args):
grid_static_dim,
) = self.grid_static_features.shape # 63784 = 268x238
self.grid_dim = (
- 2 * constants.GRID_STATE_DIM
+ 2 * self.config_loader.num_data_vars("state")
+ grid_static_dim
- + constants.GRID_FORCING_DIM
+ + self.config_loader.num_data_vars("forcing")
)
# Instantiate loss function
@@ -246,7 +249,7 @@ def validation_step(self, batch, batch_idx):
# Log loss per time step forward and mean
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
- for step in constants.VAL_STEP_LOG_ERRORS
+ for step in self.args.val_steps_log
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
@@ -294,7 +297,7 @@ def test_step(self, batch, batch_idx):
# Log loss per time step forward and mean
test_log_dict = {
f"test_loss_unroll{step}": time_step_loss[step - 1]
- for step in constants.VAL_STEP_LOG_ERRORS
+ for step in self.args.val_steps_log
}
test_log_dict["test_mean_loss"] = mean_loss
@@ -328,7 +331,7 @@ def test_step(self, batch, batch_idx):
spatial_loss = self.loss(
prediction, target, pred_std, average_grid=False
) # (B, pred_steps, num_grid_nodes)
- log_spatial_losses = spatial_loss[:, constants.VAL_STEP_LOG_ERRORS - 1]
+ log_spatial_losses = spatial_loss[:, self.args.val_steps_log - 1]
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)
@@ -399,14 +402,15 @@ def plot_examples(self, batch, n_examples, prediction=None):
pred_t[:, var_i],
target_t[:, var_i],
self.interior_mask[:, 0],
+ self.config_loader,
title=f"{var_name} ({var_unit}), "
- f"t={t_i} ({self.step_length*t_i} h)",
+ f"t={t_i} ({self.step_length * t_i} h)",
vrange=var_vrange,
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
- constants.PARAM_NAMES_SHORT,
- constants.PARAM_UNITS,
+ self.config_loader.param_names(),
+ self.config_loader.param_units(),
var_vranges,
)
)
@@ -417,7 +421,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
- constants.PARAM_NAMES_SHORT, var_figs
+ self.config_loader.param_names(), var_figs
)
}
)
@@ -453,7 +457,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
log_dict = {}
metric_fig = vis.plot_error_map(
- metric_tensor, step_length=self.step_length
+ metric_tensor, self.config_loader, step_length=self.step_length
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
@@ -471,14 +475,14 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
)
# Check if metrics are watched, log exact values for specific vars
- if full_log_name in constants.METRICS_WATCH:
- for var_i, timesteps in constants.VAR_LEADS_METRICS_WATCH.items():
- var = constants.PARAM_NAMES_SHORT[var_i]
+ if full_log_name in self.args.metrics_watch:
+ for var_i, timesteps in self.args.var_leads_metrics_watch.items():
+ var = self.config_loader.param_names()[var_i]
log_dict.update(
{
f"{full_log_name}_{var}_step_{step}": metric_tensor[
step - 1, var_i
- ] # 1-indexed in constants
+ ] # 1-indexed in data_config
for step in timesteps
}
)
@@ -542,10 +546,11 @@ def on_test_epoch_end(self):
vis.plot_spatial_error(
loss_map,
self.interior_mask[:, 0],
- title=f"Test loss, t={t_i} ({self.step_length*t_i} h)",
+ self.config_loader,
+ title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
for t_i, loss_map in zip(
- constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss
+ self.args.val_steps_log, mean_spatial_loss
)
]
@@ -554,14 +559,14 @@ def on_test_epoch_end(self):
wandb.log({"test_loss": wandb.Image(fig)})
# also make without title and save as pdf
- pdf_loss_map_figs = [
- vis.plot_spatial_error(loss_map, self.interior_mask[:, 0])
- for loss_map in mean_spatial_loss
- ]
+ pdf_loss_map_figs = [vis.plot_spatial_error(
+ loss_map, self.interior_mask[:, 0],
+ self.config_loader)
+ for loss_map in mean_spatial_loss]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(
- constants.VAL_STEP_LOG_ERRORS, pdf_loss_map_figs
+ self.args.val_steps_log, pdf_loss_map_figs
):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 31715502..8b9e250b 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -7,8 +7,6 @@
from torch import nn
from tueplots import bundles, figsizes
-# First-party
-from neural_lam import constants
def load_dataset_stats(dataset_name, device="cpu"):
@@ -263,11 +261,11 @@ def fractional_plot_bundle(fraction):
return bundle
-def init_wandb_metrics(wandb_logger):
+def init_wandb_metrics(wandb_logger, val_steps):
"""
Set up wandb metrics to track
"""
experiment = wandb_logger.experiment
experiment.define_metric("val_mean_loss", summary="min")
- for step in constants.VAL_STEP_LOG_ERRORS:
+ for step in val_steps:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index cef34a84..81adb935 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -4,11 +4,11 @@
import numpy as np
# First-party
-from neural_lam import constants, utils
+from neural_lam import utils
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_error_map(errors, title=None, step_length=3):
+def plot_error_map(errors, data_config, title=None, step_length=3):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
@@ -51,7 +51,7 @@ def plot_error_map(errors, title=None, step_length=3):
y_ticklabels = [
f"{name} ({unit})"
for name, unit in zip(
- constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS
+ data_config.param_names(), data_config.param_units()
)
]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)
@@ -63,7 +63,7 @@ def plot_error_map(errors, title=None, step_length=3):
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_prediction(pred, target, obs_mask, title=None, vrange=None):
+def plot_prediction(pred, target, obs_mask, data_config, title=None, vrange=None):
"""
Plot example prediction and grond truth.
Each has shape (N_grid,)
@@ -76,23 +76,22 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None):
vmin, vmax = vrange
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE)
+ mask_reshaped = obs_mask.reshape(*data_config.grid_shape)
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
fig, axes = plt.subplots(
- 1, 2, figsize=(13, 7), subplot_kw={"projection": constants.LAMBERT_PROJ}
+ 1, 2, figsize=(13, 7), subplot_kw={"projection": data_config.projection()}
)
# Plot pred and target
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
- data_grid = data.reshape(*constants.GRID_SHAPE).cpu().numpy()
+ data_grid = data.reshape(*data_config.grid_shape).cpu().numpy()
im = ax.imshow(
data_grid,
origin="lower",
- extent=constants.GRID_LIMITS,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
@@ -112,7 +111,7 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None):
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_spatial_error(error, obs_mask, title=None, vrange=None):
+def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
"""
Plot errors over spatial map
Error and obs_mask has shape (N_grid,)
@@ -125,22 +124,21 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None):
vmin, vmax = vrange
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE)
+ mask_reshaped = obs_mask.reshape(*data_config.grid_shape)
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
fig, ax = plt.subplots(
- figsize=(5, 4.8), subplot_kw={"projection": constants.LAMBERT_PROJ}
+ figsize=(5, 4.8), subplot_kw={"projection": data_config.projection()}
)
ax.coastlines() # Add coastline outlines
- error_grid = error.reshape(*constants.GRID_SHAPE).cpu().numpy()
+ error_grid = error.reshape(*data_config.grid_shape).cpu().numpy()
im = ax.imshow(
error_grid,
origin="lower",
- extent=constants.GRID_LIMITS,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 968a985f..b6181602 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -7,6 +7,7 @@
import torch
import xarray as xr
import yaml
+import cartopy.crs as ccrs
class ConfigLoader:
@@ -36,7 +37,7 @@ def __getattr__(self, name):
if key in value:
value = value[key]
else:
- None
+ return None
if isinstance(value, dict):
return ConfigLoader(None, values=value)
return value
@@ -50,6 +51,24 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.values
+ def param_names(self):
+ return self.values['state']['surface'] + self.values['state']['atmosphere']
+
+ def param_units(self):
+ return self.values['state']['surface_units'] + self.values['state']['atmosphere_units']
+
+ def num_data_vars(self, key):
+ surface_vars = len(self.values[key]['surface'])
+ atmosphere_vars = len(self.values[key]['atmosphere'])
+ levels = len(self.values[key]['levels'])
+ return surface_vars + atmosphere_vars * levels
+
+ def projection(self):
+ proj_config = self.values["projections"]["class"]
+ proj_class = getattr(ccrs, proj_config["proj_class"])
+ proj_params = proj_config["proj_params"]
+ return proj_class(**proj_params)
+
class WeatherDataset(torch.utils.data.Dataset):
"""
@@ -61,15 +80,7 @@ class WeatherDataset(torch.utils.data.Dataset):
"""
def process_dataset(self, dataset_name):
- """
- Process a single dataset specified by the dataset name.
-
- Args:
- dataset_name (str): Name of the dataset to process.
-
- Returns:
- xarray.Dataset: Processed dataset.
- """
+ """Process a single dataset specified by the dataset name."""
dataset_path = self.config_loader.zarrs[dataset_name].path
if dataset_path is None or not os.path.exists(dataset_path):
@@ -129,7 +140,7 @@ def __init__(
batch_size=4,
ar_steps=3,
control_only=False,
- yaml_path="neural_lam/data_config.yaml",
+ data_config="neural_lam/data_config.yaml",
):
super().__init__()
@@ -143,7 +154,7 @@ def __init__(
self.batch_size = batch_size
self.ar_steps = ar_steps
self.control_only = control_only
- self.config_loader = ConfigLoader(yaml_path)
+ self.config_loader = ConfigLoader(data_config)
self.state = self.process_dataset("state")
assert self.state is not None, "State dataset not found"
diff --git a/train_model.py b/train_model.py
index 96d21a3f..767d575a 100644
--- a/train_model.py
+++ b/train_model.py
@@ -9,7 +9,7 @@
from lightning_fabric.utilities import seed
# First-party
-from neural_lam import constants, utils
+from neural_lam import utils
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.models.hi_lam import HiLAM
from neural_lam.models.hi_lam_parallel import HiLAMParallel
@@ -44,6 +44,12 @@ def main():
default="graph_lam",
help="Model architecture to train/evaluate (default: graph_lam)",
)
+ parser. add_argument(
+ "--data_config",
+ type=str,
+ default="neural_lam/data_config.yaml",
+ help="Path to data configuration file (default: neural_lam/data_config.yaml)",
+ )
parser.add_argument(
"--subset_ds",
type=int,
@@ -183,6 +189,30 @@ def main():
help="Number of example predictions to plot during evaluation "
"(default: 1)",
)
+ parser.add_argument(
+ "--wandb_project",
+ type=str,
+ default="neural-lam",
+ help="Wandb project to log to (default: neural-lam)",
+ )
+ parser.add_argument(
+ "--val_steps_log",
+ type=list,
+ default=[1, 2, 3, 5, 10, 15, 19],
+ help="Steps to log validation loss for (default: [1, 2, 3, 5, 10, 15, 19])",
+ )
+ parser.add_argument(
+ "--metrics_watch",
+ type=list,
+ default=[],
+ help="List of metrics to watch, including any prefix (e.g. val_rmse)",
+ )
+ parser.add_argument(
+ "--var_leads_metrics_watch",
+ type=dict,
+ default={},
+ help="Dict with variables and lead times to log watched metrics for",
+ )
args = parser.parse_args()
# Asserts for arguments
@@ -264,7 +294,7 @@ def main():
save_last=True,
)
logger = pl.loggers.WandbLogger(
- project=constants.WANDB_PROJECT, name=run_name, config=args
+ project=args.wandb_project, name=run_name, config=args
)
trainer = pl.Trainer(
max_epochs=args.epochs,
@@ -280,7 +310,8 @@ def main():
# Only init once, on rank 0 only
if trainer.global_rank == 0:
- utils.init_wandb_metrics(logger) # Do after wandb.init
+ utils.init_wandb_metrics(
+ logger, val_steps=args.val_steps_log) # Do after wandb.init
if args.eval:
if args.eval == "val":
From 7bb139b2b54a64ca9e8ae5d73deea2b2579767c4 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Wed, 8 May 2024 11:22:37 +0200
Subject: [PATCH 005/273] fix linter
---
create_parameter_weights.py | 10 ++----
neural_lam/data_config.yaml | 8 ++---
neural_lam/models/ar_model.py | 23 ++++++-------
neural_lam/utils.py | 1 -
neural_lam/vis.py | 9 +++--
neural_lam/weather_dataset.py | 64 +++++++++++++++++++++++------------
train_model.py | 36 +++++---------------
7 files changed, 75 insertions(+), 76 deletions(-)
diff --git a/create_parameter_weights.py b/create_parameter_weights.py
index 6956d4ca..926d7741 100644
--- a/create_parameter_weights.py
+++ b/create_parameter_weights.py
@@ -105,13 +105,7 @@ def main():
# Compute mean and std.-dev. of one-step differences across the dataset
print("Computing mean and std.-dev. for one-step differences...")
- ds_standard = WeatherDataset(
- args.dataset,
- split="train",
- subsample_step=1,
- pred_length=63,
- standardize=True,
- ) # Re-load with standardization
+ ds_standard = WeatherDataset() # Re-load with standardization
loader_standard = torch.utils.data.DataLoader(
ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers
)
@@ -126,7 +120,7 @@ def main():
# Note: batch contains only 1h-steps
stepped_batch = torch.cat(
[
- batch[:, ss_i: used_subsample_len: args.step_length]
+ batch[:, ss_i : used_subsample_len : args.step_length]
for ss_i in range(args.step_length)
],
dim=0,
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index af2672af..8d936154 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -12,7 +12,7 @@ zarrs: # List of zarrs containing fields related to state
level: z
x: x
y: y
- forcing:
+ forcing:
path: /scratch/sadamov/template.zarr
dims:
time: time
@@ -55,7 +55,7 @@ state: # Variables forecasted by the model
- m/s
- m/s
- Pa/s
- levels: # Levels to use for atmosphere variables
+ levels: # Levels to use for atmosphere variables
- 0
- 5
- 8
@@ -71,7 +71,7 @@ state: # Variables forecasted by the model
- 59
static: # Static inputs
surface:
- - HSURF
+ - HSURF
surface_units:
- m
atmosphere:
@@ -122,7 +122,7 @@ splits:
test:
start: 2015-01-01T00
end: 2024-12-31T23
-projection:
+projection:
class: RotatedPole # Name of class in cartopy.crs
kwargs: # Parsed and used directly as kwargs to projection-class above
pole_longitude: 10.0
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 902a89e4..93f2b569 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -6,7 +6,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
-
import wandb
# First-party
@@ -39,12 +38,12 @@ def __init__(self, args):
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
if self.output_std:
- self.grid_output_dim = (
- 2 * self.config_loader.num_data_vars("state")
+ self.grid_output_dim = 2 * self.config_loader.num_data_vars(
+ "state"
) # Pred. dim. in grid cell
else:
- self.grid_output_dim = (
- self.config_loader.num_data_vars("state")
+ self.grid_output_dim = self.config_loader.num_data_vars(
+ "state"
) # Pred. dim. in grid cell
# Store constant per-variable std.-dev. weighting
@@ -559,15 +558,15 @@ def on_test_epoch_end(self):
wandb.log({"test_loss": wandb.Image(fig)})
# also make without title and save as pdf
- pdf_loss_map_figs = [vis.plot_spatial_error(
- loss_map, self.interior_mask[:, 0],
- self.config_loader)
- for loss_map in mean_spatial_loss]
+ pdf_loss_map_figs = [
+ vis.plot_spatial_error(
+ loss_map, self.interior_mask[:, 0], self.config_loader
+ )
+ for loss_map in mean_spatial_loss
+ ]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
- for t_i, fig in zip(
- self.args.val_steps_log, pdf_loss_map_figs
- ):
+ for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
torch.save(
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 8b9e250b..836b04ed 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -8,7 +8,6 @@
from tueplots import bundles, figsizes
-
def load_dataset_stats(dataset_name, device="cpu"):
"""
Load arrays with stored dataset statistics from pre-processing
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 81adb935..02b8dd35 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -63,7 +63,9 @@ def plot_error_map(errors, data_config, title=None, step_length=3):
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_prediction(pred, target, obs_mask, data_config, title=None, vrange=None):
+def plot_prediction(
+ pred, target, obs_mask, data_config, title=None, vrange=None
+):
"""
Plot example prediction and grond truth.
Each has shape (N_grid,)
@@ -82,7 +84,10 @@ def plot_prediction(pred, target, obs_mask, data_config, title=None, vrange=None
) # Faded border region
fig, axes = plt.subplots(
- 1, 2, figsize=(13, 7), subplot_kw={"projection": data_config.projection()}
+ 1,
+ 2,
+ figsize=(13, 7),
+ subplot_kw={"projection": data_config.projection()},
)
# Plot pred and target
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index b6181602..28c29db6 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,13 +1,12 @@
# Standard library
import os
-from functools import lru_cache
# Third-party
+import cartopy.crs as ccrs
import pytorch_lightning as pl
import torch
import xarray as xr
import yaml
-import cartopy.crs as ccrs
class ConfigLoader:
@@ -26,10 +25,10 @@ def __init__(self, config_path, values=None):
self.values = values
def load_config(self):
- with open(self.config_path, "r") as file:
+ """Load configuration file."""
+ with open(self.config_path, encoding="utf-8", mode="r") as file:
return yaml.safe_load(file)
- @lru_cache(maxsize=None)
def __getattr__(self, name):
keys = name.split(".")
value = self.values
@@ -52,18 +51,27 @@ def __contains__(self, key):
return key in self.values
def param_names(self):
- return self.values['state']['surface'] + self.values['state']['atmosphere']
+ """Return parameter names."""
+ return (
+ self.values["state"]["surface"] + self.values["state"]["atmosphere"]
+ )
def param_units(self):
- return self.values['state']['surface_units'] + self.values['state']['atmosphere_units']
+ """Return parameter units."""
+ return (
+ self.values["state"]["surface_units"]
+ + self.values["state"]["atmosphere_units"]
+ )
def num_data_vars(self, key):
- surface_vars = len(self.values[key]['surface'])
- atmosphere_vars = len(self.values[key]['atmosphere'])
- levels = len(self.values[key]['levels'])
+ """Return the number of data variables for a given key."""
+ surface_vars = len(self.values[key]["surface"])
+ atmosphere_vars = len(self.values[key]["atmosphere"])
+ levels = len(self.values[key]["levels"])
return surface_vars + atmosphere_vars * levels
-
+
def projection(self):
+ """Return the projection."""
proj_config = self.values["projections"]["class"]
proj_class = getattr(ccrs, proj_config["proj_class"])
proj_params = proj_config["proj_params"]
@@ -96,7 +104,9 @@ def process_dataset(self, dataset_name):
dataset = dataset.rename_dims(
{
v: k
- for k, v in self.config_loader.zarrs[dataset_name].dims.values.items()
+ for k, v in self.config_loader.zarrs[
+ dataset_name
+ ].dims.values.items()
if k not in dataset.dims
}
)
@@ -111,7 +121,9 @@ def process_dataset(self, dataset_name):
if self.config_loader[dataset_name].atmosphere:
vars_atmosphere = xr.merge(
[
- dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}")
+ dataset[var]
+ .sel(level=level, drop=True)
+ .rename(f"{var}_{level}")
for var in self.config_loader[dataset_name].atmosphere
for level in self.config_loader[dataset_name].levels
]
@@ -163,7 +175,9 @@ def __init__(
self.boundary = self.process_dataset("boundary")
if self.static is not None:
- self.static = self.static.expand_dims({"time": self.state.time}, axis=0)
+ self.static = self.static.expand_dims(
+ {"time": self.state.time}, axis=0
+ )
self.state = xr.concat([self.state, self.static], dim="variable")
def __len__(self):
@@ -175,15 +189,23 @@ def __getitem__(self, idx):
dtype=torch.float32,
)
- forcings = torch.tensor(
- self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values,
- dtype=torch.float32,
- ) if self.forcings is not None else torch.tensor([])
+ forcings = (
+ torch.tensor(
+ self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values,
+ dtype=torch.float32,
+ )
+ if self.forcings is not None
+ else torch.tensor([])
+ )
- boundary = torch.tensor(
- self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values,
- dtype=torch.float32,
- ) if self.boundary is not None else torch.tensor([])
+ boundary = (
+ torch.tensor(
+ self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values,
+ dtype=torch.float32,
+ )
+ if self.boundary is not None
+ else torch.tensor([])
+ )
init_states = sample[:2]
target_states = sample[2:]
diff --git a/train_model.py b/train_model.py
index 767d575a..23a0330c 100644
--- a/train_model.py
+++ b/train_model.py
@@ -44,11 +44,11 @@ def main():
default="graph_lam",
help="Model architecture to train/evaluate (default: graph_lam)",
)
- parser. add_argument(
+ parser.add_argument(
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
- help="Path to data configuration file (default: neural_lam/data_config.yaml)",
+ help="Path to data config file (default: neural_lam/data_config.yaml)",
)
parser.add_argument(
"--subset_ds",
@@ -199,7 +199,7 @@ def main():
"--val_steps_log",
type=list,
default=[1, 2, 3, 5, 10, 15, 19],
- help="Steps to log validation loss for (default: [1, 2, 3, 5, 10, 15, 19])",
+ help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])",
)
parser.add_argument(
"--metrics_watch",
@@ -232,28 +232,13 @@ def main():
# Load data
train_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- args.dataset,
- pred_length=args.ar_steps,
- split="train",
- subsample_step=args.step_length,
- subset=bool(args.subset_ds),
- control_only=args.control_only,
- ),
+ WeatherDataset(control_only=args.control_only),
args.batch_size,
shuffle=True,
num_workers=args.n_workers,
)
- max_pred_length = (65 // args.step_length) - 2 # 19
val_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- args.dataset,
- pred_length=max_pred_length,
- split="val",
- subsample_step=args.step_length,
- subset=bool(args.subset_ds),
- control_only=args.control_only,
- ),
+ WeatherDataset(control_only=args.control_only),
args.batch_size,
shuffle=False,
num_workers=args.n_workers,
@@ -311,20 +296,15 @@ def main():
# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics(
- logger, val_steps=args.val_steps_log) # Do after wandb.init
+ logger, val_steps=args.val_steps_log
+ ) # Do after wandb.init
if args.eval:
if args.eval == "val":
eval_loader = val_loader
else: # Test
eval_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- args.dataset,
- pred_length=max_pred_length,
- split="test",
- subsample_step=args.step_length,
- subset=bool(args.subset_ds),
- ),
+ WeatherDataset(),
args.batch_size,
shuffle=False,
num_workers=args.n_workers,
From af076feb613165c7596a79e47d987867015cfd4f Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Wed, 8 May 2024 14:08:46 +0200
Subject: [PATCH 006/273] Fixed calls to new WeatherDataModule Class
---
train_model.py | 39 ++++++++-------------------------------
1 file changed, 8 insertions(+), 31 deletions(-)
diff --git a/train_model.py b/train_model.py
index 23a0330c..a303132f 100644
--- a/train_model.py
+++ b/train_model.py
@@ -13,7 +13,7 @@
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.models.hi_lam import HiLAM
from neural_lam.models.hi_lam_parallel import HiLAMParallel
-from neural_lam.weather_dataset import WeatherDataset
+from neural_lam.weather_dataset import WeatherDataModule
MODELS = {
"graph_lam": GraphLAM,
@@ -189,6 +189,8 @@ def main():
help="Number of example predictions to plot during evaluation "
"(default: 1)",
)
+
+ # Logging Options
parser.add_argument(
"--wandb_project",
type=str,
@@ -229,18 +231,9 @@ def main():
# Set seed
seed.seed_everything(args.seed)
-
- # Load data
- train_loader = torch.utils.data.DataLoader(
- WeatherDataset(control_only=args.control_only),
- args.batch_size,
- shuffle=True,
- num_workers=args.n_workers,
- )
- val_loader = torch.utils.data.DataLoader(
- WeatherDataset(control_only=args.control_only),
- args.batch_size,
- shuffle=False,
+ # Create datamodule
+ data_module = WeatherDataModule(
+ batch_size=args.batch_size,
num_workers=args.n_workers,
)
@@ -300,25 +293,9 @@ def main():
) # Do after wandb.init
if args.eval:
- if args.eval == "val":
- eval_loader = val_loader
- else: # Test
- eval_loader = torch.utils.data.DataLoader(
- WeatherDataset(),
- args.batch_size,
- shuffle=False,
- num_workers=args.n_workers,
- )
-
- print(f"Running evaluation on {args.eval}")
- trainer.test(model=model, dataloaders=eval_loader)
+ trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
- # Train model
- trainer.fit(
- model=model,
- train_dataloaders=train_loader,
- val_dataloaders=val_loader,
- )
+ trainer.fit(model=model, datamodule=data_module)
if __name__ == "__main__":
From 147caec913d8d490bdb5c155f1c61ef060d3a3d9 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Wed, 8 May 2024 14:12:43 +0200
Subject: [PATCH 007/273] fix linter
---
requirements.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/requirements.txt b/requirements.txt
index 5a2111b2..0a921225 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,6 +10,7 @@ Cartopy>=0.22.0
pyproj>=3.4.1
tueplots>=0.0.8
plotly>=5.15.0
+xarray>=0.20.1
# for dev
codespell>=2.0.0
black>=21.9b0
From 2b65416336b29d6f3347bd4b89284b0b867659df Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Wed, 8 May 2024 14:48:27 +0200
Subject: [PATCH 008/273] upload data config to wandb for history logs
---
train_model.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/train_model.py b/train_model.py
index a303132f..1839474b 100644
--- a/train_model.py
+++ b/train_model.py
@@ -8,6 +8,8 @@
import torch
from lightning_fabric.utilities import seed
+import wandb
+
# First-party
from neural_lam import utils
from neural_lam.models.graph_lam import GraphLAM
@@ -291,7 +293,7 @@ def main():
utils.init_wandb_metrics(
logger, val_steps=args.val_steps_log
) # Do after wandb.init
-
+ wandb.save(args.data_config)
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
From ed9ed696a6256b63bd0283c5057d0d12424dc91a Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Wed, 8 May 2024 16:59:23 +0200
Subject: [PATCH 009/273] Improved handling of static data
---
neural_lam/data_config.yaml | 19 ++-
neural_lam/models/ar_model.py | 13 +--
neural_lam/utils.py | 211 ++++++++++++++++++++++------------
neural_lam/weather_dataset.py | 150 +-----------------------
4 files changed, 164 insertions(+), 229 deletions(-)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 8d936154..6c4536f5 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -72,8 +72,12 @@ state: # Variables forecasted by the model
static: # Static inputs
surface:
- HSURF
+ - lat
+ - lon
surface_units:
- m
+ - °N
+ - °E
atmosphere:
- FI
atmosphere_units:
@@ -127,4 +131,17 @@ projection:
kwargs: # Parsed and used directly as kwargs to projection-class above
pole_longitude: 10.0
pole_latitude: -43.0
-normalization_zarr: /scratch/sadamov/norm.zarr
+normalization:
+ zarr: /scratch/sadamov/norm.zarr
+ vars:
+ data_mean: data_mean
+ data_std: data_std
+ forcing_mean: forcing_mean
+ forcing_std: forcing_std
+ boundary_mean: boundary_mean
+ boundary_std: boundary_std
+ diff_mean: diff_mean
+ diff_std: diff_std
+ grid_static_features: grid_static_features
+ param_weights: param_weights
+
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 93f2b569..04679022 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -6,11 +6,11 @@
import numpy as np
import pytorch_lightning as pl
import torch
+
import wandb
# First-party
from neural_lam import metrics, utils, vis
-from neural_lam.weather_dataset import ConfigLoader
class ARModel(pl.LightningModule):
@@ -26,14 +26,11 @@ def __init__(self, args):
super().__init__()
self.save_hyperparameters()
self.lr = args.lr
- self.config_loader = ConfigLoader(args.data_config)
+ self.config_loader = utils.ConfigLoader(args.data_config)
# Load static features for grid/data
- static_data_dict = utils.load_static_data(args.dataset)
- for static_data_name, static_data_tensor in static_data_dict.items():
- self.register_buffer(
- static_data_name, static_data_tensor, persistent=False
- )
+ static = self.config_loader.process_dataset("static")
+ self.register_buffer("grid_static_features", torch.tensor(static.values))
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
@@ -59,7 +56,7 @@ def __init__(self, args):
(
self.num_grid_nodes,
grid_static_dim,
- ) = self.grid_static_features.shape # 63784 = 268x238
+ ) = self.grid_static_features.shape
self.grid_dim = (
2 * self.config_loader.num_data_vars("state")
+ grid_static_dim
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 836b04ed..f4c34141 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -1,87 +1,16 @@
# Standard library
import os
+import cartopy.crs as ccrs
+
# Third-party
-import numpy as np
import torch
+import xarray as xr
+import yaml
from torch import nn
from tueplots import bundles, figsizes
-def load_dataset_stats(dataset_name, device="cpu"):
- """
- Load arrays with stored dataset statistics from pre-processing
- """
- static_dir_path = os.path.join("data", dataset_name, "static")
-
- def loads_file(fn):
- return torch.load(
- os.path.join(static_dir_path, fn), map_location=device
- )
-
- data_mean = loads_file("parameter_mean.pt") # (d_features,)
- data_std = loads_file("parameter_std.pt") # (d_features,)
-
- flux_stats = loads_file("flux_stats.pt") # (2,)
- flux_mean, flux_std = flux_stats
-
- return {
- "data_mean": data_mean,
- "data_std": data_std,
- "flux_mean": flux_mean,
- "flux_std": flux_std,
- }
-
-
-def load_static_data(dataset_name, device="cpu"):
- """
- Load static files related to dataset
- """
- static_dir_path = os.path.join("data", dataset_name, "static")
-
- def loads_file(fn):
- return torch.load(
- os.path.join(static_dir_path, fn), map_location=device
- )
-
- # Load border mask, 1. if node is part of border, else 0.
- border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy"))
- border_mask = (
- torch.tensor(border_mask_np, dtype=torch.float32, device=device)
- .flatten(0, 1)
- .unsqueeze(1)
- ) # (N_grid, 1)
-
- grid_static_features = loads_file(
- "grid_features.pt"
- ) # (N_grid, d_grid_static)
-
- # Load step diff stats
- step_diff_mean = loads_file("diff_mean.pt") # (d_f,)
- step_diff_std = loads_file("diff_std.pt") # (d_f,)
-
- # Load parameter std for computing validation errors in original data scale
- data_mean = loads_file("parameter_mean.pt") # (d_features,)
- data_std = loads_file("parameter_std.pt") # (d_features,)
-
- # Load loss weighting vectors
- param_weights = torch.tensor(
- np.load(os.path.join(static_dir_path, "parameter_weights.npy")),
- dtype=torch.float32,
- device=device,
- ) # (d_f,)
-
- return {
- "border_mask": border_mask,
- "grid_static_features": grid_static_features,
- "step_diff_mean": step_diff_mean,
- "step_diff_std": step_diff_std,
- "data_mean": data_mean,
- "data_std": data_std,
- "param_weights": param_weights,
- }
-
-
class BufferList(nn.Module):
"""
A list of torch buffer tensors that sit together as a Module with no
@@ -268,3 +197,135 @@ def init_wandb_metrics(wandb_logger, val_steps):
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")
+
+
+class ConfigLoader:
+ """
+ Class for loading configuration files.
+
+ This class loads a YAML configuration file and provides a way to access
+ its values as attributes.
+ """
+
+ def __init__(self, config_path, values=None):
+ self.config_path = config_path
+ if values is None:
+ self.values = self.load_config()
+ else:
+ self.values = values
+
+ def load_config(self):
+ """Load configuration file."""
+ with open(self.config_path, encoding="utf-8", mode="r") as file:
+ return yaml.safe_load(file)
+
+ def __getattr__(self, name):
+ keys = name.split(".")
+ value = self.values
+ for key in keys:
+ if key in value:
+ value = value[key]
+ else:
+ return None
+ if isinstance(value, dict):
+ return ConfigLoader(None, values=value)
+ return value
+
+ def __getitem__(self, key):
+ value = self.values[key]
+ if isinstance(value, dict):
+ return ConfigLoader(None, values=value)
+ return value
+
+ def __contains__(self, key):
+ return key in self.values
+
+ def param_names(self):
+ """Return parameter names."""
+ return self.values["state"]["surface"] + self.values["state"]["atmosphere"]
+
+ def param_units(self):
+ """Return parameter units."""
+ return (
+ self.values["state"]["surface_units"]
+ + self.values["state"]["atmosphere_units"]
+ )
+
+ def num_data_vars(self, key):
+ """Return the number of data variables for a given key."""
+ surface_vars = len(self.values[key]["surface"])
+ atmosphere_vars = len(self.values[key]["atmosphere"])
+ levels = len(self.values[key]["levels"])
+ return surface_vars + atmosphere_vars * levels
+
+ def projection(self):
+ """Return the projection."""
+ proj_config = self.values["projections"]["class"]
+ proj_class = getattr(ccrs, proj_config["proj_class"])
+ proj_params = proj_config["proj_params"]
+ return proj_class(**proj_params)
+
+ def open_zarr(self, dataset_name, split):
+ """Open a dataset specified by the dataset name."""
+ dataset_path = self.zarrs[dataset_name].path
+ if dataset_path is None or not os.path.exists(dataset_path):
+ print(f"Dataset '{dataset_name}' not found at path: {dataset_path}")
+ return None
+ dataset = xr.open_zarr(dataset_path, consolidated=True)
+ return dataset
+
+ def process_dataset(self, dataset_name, split):
+ """Process a single dataset specified by the dataset name."""
+
+ dataset = self.open_zarr(dataset_name, split)
+
+ start, end = (
+ self.splits[split].start,
+ self.splits[split].end,
+ )
+ dataset = dataset.sel(time=slice(start, end))
+ dataset = dataset.rename_dims(
+ {
+ v: k
+ for k, v in self.zarrs[
+ dataset_name
+ ].dims.values.items()
+ if k not in dataset.dims
+ }
+ )
+ if "grid" not in dataset.dims:
+ dataset = dataset.stack(grid=("x", "y"))
+
+ vars_surface = []
+ if self[dataset_name].surface:
+ vars_surface = dataset[self[dataset_name].surface]
+
+ vars_atmosphere = []
+ if self[dataset_name].atmosphere:
+ vars_atmosphere = xr.merge(
+ [
+ dataset[var]
+ .sel(level=level, drop=True)
+ .rename(f"{var}_{level}")
+ for var in self[dataset_name].atmosphere
+ for level in self[dataset_name].levels
+ ]
+ )
+
+ if vars_surface and vars_atmosphere:
+ dataset = xr.merge([vars_surface, vars_atmosphere])
+ elif vars_surface:
+ dataset = vars_surface
+ elif vars_atmosphere:
+ dataset = vars_atmosphere
+ else:
+ print("No variables found in dataset {dataset_name}")
+ return None
+
+ if "time" in dataset.dims:
+ dataset = dataset.squeeze(
+ drop=True).to_array().transpose(
+ "time", "grid", "variable")
+ else:
+ dataset = dataset.to_array().transpose("grid", "variable")
+ return dataset
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 28c29db6..87aa3f56 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,81 +1,7 @@
-# Standard library
-import os
-
-# Third-party
-import cartopy.crs as ccrs
import pytorch_lightning as pl
import torch
-import xarray as xr
-import yaml
-
-
-class ConfigLoader:
- """
- Class for loading configuration files.
-
- This class loads a YAML configuration file and provides a way to access
- its values as attributes.
- """
-
- def __init__(self, config_path, values=None):
- self.config_path = config_path
- if values is None:
- self.values = self.load_config()
- else:
- self.values = values
-
- def load_config(self):
- """Load configuration file."""
- with open(self.config_path, encoding="utf-8", mode="r") as file:
- return yaml.safe_load(file)
-
- def __getattr__(self, name):
- keys = name.split(".")
- value = self.values
- for key in keys:
- if key in value:
- value = value[key]
- else:
- return None
- if isinstance(value, dict):
- return ConfigLoader(None, values=value)
- return value
-
- def __getitem__(self, key):
- value = self.values[key]
- if isinstance(value, dict):
- return ConfigLoader(None, values=value)
- return value
-
- def __contains__(self, key):
- return key in self.values
-
- def param_names(self):
- """Return parameter names."""
- return (
- self.values["state"]["surface"] + self.values["state"]["atmosphere"]
- )
-
- def param_units(self):
- """Return parameter units."""
- return (
- self.values["state"]["surface_units"]
- + self.values["state"]["atmosphere_units"]
- )
-
- def num_data_vars(self, key):
- """Return the number of data variables for a given key."""
- surface_vars = len(self.values[key]["surface"])
- atmosphere_vars = len(self.values[key]["atmosphere"])
- levels = len(self.values[key]["levels"])
- return surface_vars + atmosphere_vars * levels
- def projection(self):
- """Return the projection."""
- proj_config = self.values["projections"]["class"]
- proj_class = getattr(ccrs, proj_config["proj_class"])
- proj_params = proj_config["proj_params"]
- return proj_class(**proj_params)
+from neural_lam import utils
class WeatherDataset(torch.utils.data.Dataset):
@@ -87,65 +13,6 @@ class WeatherDataset(torch.utils.data.Dataset):
validation, and test sets.
"""
- def process_dataset(self, dataset_name):
- """Process a single dataset specified by the dataset name."""
-
- dataset_path = self.config_loader.zarrs[dataset_name].path
- if dataset_path is None or not os.path.exists(dataset_path):
- print(f"Dataset '{dataset_name}' not found at path: {dataset_path}")
- return None
- dataset = xr.open_zarr(dataset_path, consolidated=True)
-
- start, end = (
- self.config_loader.splits[self.split].start,
- self.config_loader.splits[self.split].end,
- )
- dataset = dataset.sel(time=slice(start, end))
- dataset = dataset.rename_dims(
- {
- v: k
- for k, v in self.config_loader.zarrs[
- dataset_name
- ].dims.values.items()
- if k not in dataset.dims
- }
- )
- if "grid" not in dataset.dims:
- dataset = dataset.stack(grid=("x", "y"))
-
- vars_surface = []
- if self.config_loader[dataset_name].surface:
- vars_surface = dataset[self.config_loader[dataset_name].surface]
-
- vars_atmosphere = []
- if self.config_loader[dataset_name].atmosphere:
- vars_atmosphere = xr.merge(
- [
- dataset[var]
- .sel(level=level, drop=True)
- .rename(f"{var}_{level}")
- for var in self.config_loader[dataset_name].atmosphere
- for level in self.config_loader[dataset_name].levels
- ]
- )
-
- if vars_surface and vars_atmosphere:
- dataset = xr.merge([vars_surface, vars_atmosphere])
- elif vars_surface:
- dataset = vars_surface
- elif vars_atmosphere:
- dataset = vars_atmosphere
- else:
- print("No variables found in dataset {dataset_name}")
- return None
-
- dataset = dataset.squeeze(drop=True).to_array()
- if "time" in dataset.dims:
- dataset = dataset.transpose("time", "grid", "variable")
- else:
- dataset = dataset.transpose("grid", "variable")
- return dataset
-
def __init__(
self,
split="train",
@@ -166,19 +33,12 @@ def __init__(
self.batch_size = batch_size
self.ar_steps = ar_steps
self.control_only = control_only
- self.config_loader = ConfigLoader(data_config)
+ self.config_loader = utils.ConfigLoader(data_config)
- self.state = self.process_dataset("state")
+ self.state = self.config_loader("state", self.split)
assert self.state is not None, "State dataset not found"
- self.static = self.process_dataset("static")
- self.forcings = self.process_dataset("forcing")
- self.boundary = self.process_dataset("boundary")
-
- if self.static is not None:
- self.static = self.static.expand_dims(
- {"time": self.state.time}, axis=0
- )
- self.state = xr.concat([self.state, self.static], dim="variable")
+ self.forcings = self.config_loader("forcing", self.split)
+ self.boundary = self.config_loader("boundary", self.split)
def __len__(self):
return len(self.state.time) - self.ar_steps
From 0b69f4e04466d52e9ee6e36df7ed5accc93e915e Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Wed, 8 May 2024 23:04:36 +0200
Subject: [PATCH 010/273] dask and zarr are required backends to xarray
---
requirements.txt | 2 ++
1 file changed, 2 insertions(+)
diff --git a/requirements.txt b/requirements.txt
index 0a921225..cb9bd425 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,6 +11,8 @@ pyproj>=3.4.1
tueplots>=0.0.8
plotly>=5.15.0
xarray>=0.20.1
+zarr>=2.10.0
+dask>=2022.0.0
# for dev
codespell>=2.0.0
black>=21.9b0
From b76d078a66ab5635e82abfcacdb0befd80336181 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 07:00:41 +0200
Subject: [PATCH 011/273] Implements windowed forcing and boundary
---
create_grid_features.py | 59 -----------------------------------
neural_lam/data_config.yaml | 13 +++++++-
neural_lam/models/ar_model.py | 17 +++++-----
neural_lam/utils.py | 29 +++++++++--------
neural_lam/weather_dataset.py | 57 ++++++++++++++++++++++++++-------
plot_graph.py | 18 +++++------
train_model.py | 3 +-
7 files changed, 88 insertions(+), 108 deletions(-)
delete mode 100644 create_grid_features.py
diff --git a/create_grid_features.py b/create_grid_features.py
deleted file mode 100644
index c9038103..00000000
--- a/create_grid_features.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Standard library
-import os
-from argparse import ArgumentParser
-
-# Third-party
-import numpy as np
-import torch
-
-
-def main():
- """
- Pre-compute all static features related to the grid nodes
- """
- parser = ArgumentParser(description="Training arguments")
- parser.add_argument(
- "--dataset",
- type=str,
- default="meps_example",
- help="Dataset to compute weights for (default: meps_example)",
- )
- args = parser.parse_args()
-
- static_dir_path = os.path.join("data", args.dataset, "static")
-
- # -- Static grid node features --
- grid_xy = torch.tensor(
- np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
- ) # (2, N_x, N_y)
- grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
- pos_max = torch.max(torch.abs(grid_xy))
- grid_xy = grid_xy / pos_max # Divide by maximum coordinate
-
- geopotential = torch.tensor(
- np.load(os.path.join(static_dir_path, "surface_geopotential.npy"))
- ) # (N_x, N_y)
- geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1)
- gp_min = torch.min(geopotential)
- gp_max = torch.max(geopotential)
- # Rescale geopotential to [0,1]
- geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1)
-
- grid_border_mask = torch.tensor(
- np.load(os.path.join(static_dir_path, "border_mask.npy")),
- dtype=torch.int64,
- ) # (N_x, N_y)
- grid_border_mask = (
- grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
- ) # (N_grid, 1)
-
- # Concatenate grid features
- grid_features = torch.cat(
- (grid_xy, geopotential, grid_border_mask), dim=1
- ) # (N_grid, 4)
-
- torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt"))
-
-
-if __name__ == "__main__":
- main()
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 6c4536f5..cdfb57dc 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -102,14 +102,26 @@ forcing: # Forcing variables, dynamic inputs to the model
surface_units:
- W/m^2
atmosphere:
+ - T
atmosphere_units:
+ - K
levels:
+ - 0
+ - 5
+ - 8
+ - 11
+ - 13
+ - 38
+ - 44
+ - 59
+ window: 3 # Number of time steps to use for forcing (odd)
boundary: # Boundary conditions
surface:
surface_units:
atmosphere:
atmosphere_units:
levels:
+ window: 3 # Number of time steps to use for boundary (odd)
lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells
lat: lat
lon: lon
@@ -144,4 +156,3 @@ normalization:
diff_std: diff_std
grid_static_features: grid_static_features
param_weights: param_weights
-
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 04679022..8353327d 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -6,7 +6,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
-
import wandb
# First-party
@@ -29,19 +28,19 @@ def __init__(self, args):
self.config_loader = utils.ConfigLoader(args.data_config)
# Load static features for grid/data
- static = self.config_loader.process_dataset("static")
- self.register_buffer("grid_static_features", torch.tensor(static.values))
+ static = self.config_loader.process_dataset("static", self.split)
+ self.register_buffer(
+ "grid_static_features", torch.tensor(static.values)
+ )
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
if self.output_std:
- self.grid_output_dim = 2 * self.config_loader.num_data_vars(
- "state"
- ) # Pred. dim. in grid cell
+ # Pred. dim. in grid cell
+ self.grid_output_dim = 2 * self.config_loader.num_data_vars("state")
else:
- self.grid_output_dim = self.config_loader.num_data_vars(
- "state"
- ) # Pred. dim. in grid cell
+ # Pred. dim. in grid cell
+ self.grid_output_dim = self.config_loader.num_data_vars("state")
# Store constant per-variable std.-dev. weighting
# Note that this is the inverse of the multiplicative weighting
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index f4c34141..3992bc6c 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -1,9 +1,8 @@
# Standard library
import os
-import cartopy.crs as ccrs
-
# Third-party
+import cartopy.crs as ccrs
import torch
import xarray as xr
import yaml
@@ -242,7 +241,9 @@ def __contains__(self, key):
def param_names(self):
"""Return parameter names."""
- return self.values["state"]["surface"] + self.values["state"]["atmosphere"]
+ return (
+ self.values["state"]["surface"] + self.values["state"]["atmosphere"]
+ )
def param_units(self):
"""Return parameter units."""
@@ -265,7 +266,7 @@ def projection(self):
proj_params = proj_config["proj_params"]
return proj_class(**proj_params)
- def open_zarr(self, dataset_name, split):
+ def open_zarr(self, dataset_name):
"""Open a dataset specified by the dataset name."""
dataset_path = self.zarrs[dataset_name].path
if dataset_path is None or not os.path.exists(dataset_path):
@@ -274,10 +275,12 @@ def open_zarr(self, dataset_name, split):
dataset = xr.open_zarr(dataset_path, consolidated=True)
return dataset
- def process_dataset(self, dataset_name, split):
+ def process_dataset(self, dataset_name, split="train"):
"""Process a single dataset specified by the dataset name."""
- dataset = self.open_zarr(dataset_name, split)
+ dataset = self.open_zarr(dataset_name)
+ if dataset is None:
+ return None
start, end = (
self.splits[split].start,
@@ -287,14 +290,10 @@ def process_dataset(self, dataset_name, split):
dataset = dataset.rename_dims(
{
v: k
- for k, v in self.zarrs[
- dataset_name
- ].dims.values.items()
+ for k, v in self.zarrs[dataset_name].dims.values.items()
if k not in dataset.dims
}
)
- if "grid" not in dataset.dims:
- dataset = dataset.stack(grid=("x", "y"))
vars_surface = []
if self[dataset_name].surface:
@@ -322,10 +321,10 @@ def process_dataset(self, dataset_name, split):
print("No variables found in dataset {dataset_name}")
return None
+ dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
+
if "time" in dataset.dims:
- dataset = dataset.squeeze(
- drop=True).to_array().transpose(
- "time", "grid", "variable")
+ dataset = dataset.transpose("time", "grid", "variable")
else:
- dataset = dataset.to_array().transpose("grid", "variable")
+ dataset = dataset.transpose("grid", "variable")
return dataset
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 87aa3f56..03c8114f 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,6 +1,8 @@
+# Third-party
import pytorch_lightning as pl
import torch
+# First-party
from neural_lam import utils
@@ -35,34 +37,64 @@ def __init__(
self.control_only = control_only
self.config_loader = utils.ConfigLoader(data_config)
- self.state = self.config_loader("state", self.split)
+ self.state = self.config_loader.process_dataset("state", self.split)
assert self.state is not None, "State dataset not found"
- self.forcings = self.config_loader("forcing", self.split)
- self.boundary = self.config_loader("boundary", self.split)
+ self.forcings = self.config_loader.process_dataset(
+ "forcing", self.split
+ )
+ self.boundary = self.config_loader.process_dataset(
+ "boundary", self.split
+ )
+
+ self.state_times = self.state.time.values
+ self.forcing_window = self.config_loader.forcing.window
+ self.boundary_window = self.config_loader.boundary.window
+ self.idx_max = max(
+ (self.boundary_window - 1), (self.forcing_window - 1)
+ )
+
+ if self.forcings is not None:
+ self.forcings_windowed = (
+ self.forcings.sel(
+ time=self.forcings.time.isin(self.state.time),
+ method="nearest",
+ )
+ .rolling(time=self.forcing_window, center=True)
+ .construct("window")
+ )
+ if self.boundary is not None:
+ self.boundary_windowed = (
+ self.boundary.sel(
+ time=self.forcings.time.isin(self.state.time),
+ method="nearest",
+ )
+ .rolling(time=self.boundary_window, center=True)
+ .construct("window")
+ )
def __len__(self):
- return len(self.state.time) - self.ar_steps
+ # Skip first and last time step
+ return len(self.state.time) - self.ar_steps - self.idx_max
def __getitem__(self, idx):
+ idx += self.idx_max / 2 # Skip first time step
sample = torch.tensor(
self.state.isel(time=slice(idx, idx + self.ar_steps)).values,
dtype=torch.float32,
)
forcings = (
- torch.tensor(
- self.forcings.isel(time=slice(idx, idx + self.ar_steps)).values,
- dtype=torch.float32,
- )
+ self.forcings_windowed.isel(time=slice(idx, idx + self.ar_steps))
+ .stack(variable_window=("variable", "window"))
+ .values
if self.forcings is not None
else torch.tensor([])
)
boundary = (
- torch.tensor(
- self.boundary.isel(time=slice(idx, idx + self.ar_steps)).values,
- dtype=torch.float32,
- )
+ self.boundary_windowed.isel(time=slice(idx, idx + self.ar_steps))
+ .stack(variable_window=("variable", "window"))
+ .values
if self.boundary is not None
else torch.tensor([])
)
@@ -153,4 +185,5 @@ def test_dataloader(self):
print(batch[2].shape)
print(batch[3].shape)
print(batch[4])
+ print(batch[2][0, 0, 0, :])
break
diff --git a/plot_graph.py b/plot_graph.py
index 48427d5c..c82b4e04 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -19,12 +19,6 @@ def main():
Plot graph structure in 3D using plotly
"""
parser = ArgumentParser(description="Plot graph")
- parser.add_argument(
- "--dataset",
- type=str,
- default="meps_example",
- help="Datast to load grid coordinates from (default: meps_example)",
- )
parser.add_argument(
"--graph",
type=str,
@@ -42,6 +36,12 @@ def main():
default=0,
help="If the axis should be displayed (default: 0 (No))",
)
+ parser.add_argument(
+ "--data_config",
+ type=str,
+ default="neural_lam/data_config.yaml",
+ help="Path to data config file (default: neural_lam/data_config.yaml)",
+ )
args = parser.parse_args()
@@ -62,10 +62,8 @@ def main():
)
mesh_static_features = graph_ldict["mesh_static_features"]
- grid_static_features = utils.load_static_data(args.dataset)[
- "grid_static_features"
- ]
-
+ config_loader = utils.ConfigLoader(args.data_config)
+ grid_static_features = config_loader.process_dataset("static")
# Extract values needed, turn to numpy
grid_pos = grid_static_features[:, :2].numpy()
# Add in z-dimension
diff --git a/train_model.py b/train_model.py
index 1839474b..4f57ca24 100644
--- a/train_model.py
+++ b/train_model.py
@@ -6,9 +6,8 @@
# Third-party
import pytorch_lightning as pl
import torch
-from lightning_fabric.utilities import seed
-
import wandb
+from lightning_fabric.utilities import seed
# First-party
from neural_lam import utils
From 5d27a4ce21c8894a69711f324bc14087f95fae8f Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 07:01:09 +0200
Subject: [PATCH 012/273] Some project related stuff (simple setup to pip
install -e .)
---
.gitignore | 1 +
pyproject.toml | 21 +++++++++++++--------
2 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/.gitignore b/.gitignore
index 7bb826a2..590c7e12 100644
--- a/.gitignore
+++ b/.gitignore
@@ -72,3 +72,4 @@ tags
# Coc configuration directory
.vim
+.vscode
diff --git a/pyproject.toml b/pyproject.toml
index b513a258..619f444f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,10 @@
+[project]
+name = "neural_lam"
+version = "0.1.0"
+
+[tool.setuptools]
+packages = ["neural_lam"]
+
[tool.black]
line-length = 80
@@ -42,12 +49,9 @@ ignore = [
"create_mesh.py", # Disable linting for now, as major rework is planned/expected
]
# Temporary fix for import neural_lam statements until set up as proper package
-init-hook='import sys; sys.path.append(".")'
+init-hook = 'import sys; sys.path.append(".")'
[tool.pylint.TYPECHECK]
-generated-members = [
- "numpy.*",
- "torch.*",
-]
+generated-members = ["numpy.*", "torch.*"]
[tool.pylint.'MESSAGES CONTROL']
disable = [
"C0114", # 'missing-module-docstring', Do not require module docstrings
@@ -56,10 +60,11 @@ disable = [
"R0913", # 'too-many-arguments', Allow many function arguments
"R0914", # 'too-many-locals', Allow many local variables
"W0223", # 'abstract-method', Subclasses do not have to override all abstract methods
+ "C0411", # 'wrong-import-order', Allow for isort to handle import order
]
[tool.pylint.DESIGN]
-max-statements=100 # Allow for some more involved functions
+max-statements = 100 # Allow for some more involved functions
[tool.pylint.IMPORTS]
-allow-any-import-level="neural_lam"
+allow-any-import-level = "neural_lam"
[tool.pylint.SIMILARITIES]
-min-similarity-lines=10
+min-similarity-lines = 10
From 4dadf2985591f7cacc536f906f8bad2eab98878a Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 08:22:37 +0200
Subject: [PATCH 013/273] introducing realistic boundaries
---
neural_lam/data_config.yaml | 75 ++++++++++++++++++++++++++++++++++-
neural_lam/utils.py | 2 +-
neural_lam/weather_dataset.py | 15 +++++--
3 files changed, 87 insertions(+), 5 deletions(-)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index cdfb57dc..e6a0d506 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -20,7 +20,12 @@ zarrs: # List of zarrs containing fields related to state
x: x
y: y
boundary:
- path:
+ path: /scratch/sadamov/era5.zarr
+ dims:
+ time: time
+ level: level
+ x: longitude
+ y: latitude
mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary.
state: # Variables forecasted by the model
surface: # Single-field variables
@@ -111,16 +116,84 @@ forcing: # Forcing variables, dynamic inputs to the model
- 8
- 11
- 13
+ - 15
+ - 19
+ - 22
+ - 26
+ - 30
- 38
- 44
- 59
window: 3 # Number of time steps to use for forcing (odd)
boundary: # Boundary conditions
surface:
+ - 10m_u_component_of_wind
+ # - 10m_v_component_of_wind
+ # - 2m_dewpoint_temperature
+ # - 2m_temperature
+ # - mean_sea_level_pressure
+ # - mean_surface_latent_heat_flux
+ # - mean_surface_net_long_wave_radiation_flux
+ # - mean_surface_net_short_wave_radiation_flux
+ # - mean_surface_sensible_heat_flux
+ # - surface_pressure
+ # - total_cloud_cover
+ # - total_column_water_vapour
+ # - total_precipitation_12hr
+ # - total_precipitation_24hr
+ # - total_precipitation_6hr
+ # - geopotential_at_surface
surface_units:
+ - m/s
+ # - m/s
+ # - K
+ # - K
+ # - Pa
+ # - W/m^2
+ # - W/m^2
+ # - W/m^2
+ # - W/m^2
+ # - Pa
+ # - "%"
+ # - kg/m^2
+ # - kg/m^2
+ # - kg/m^2
+ # - kg/m^2
+ # - m^2/s^2
atmosphere:
+ - divergence
+ # - geopotential
+ # - relative_humidity
+ # - specific_humidity
+ # - temperature
+ # - u_component_of_wind
+ # - v_component_of_wind
+ # - vertical_velocity
+ # - vorticity
atmosphere_units:
+ - 1/s
+ # - m^2/s^2
+ # - "%"
+ # - kg/kg
+ # - K
+ # - m/s
+ # - m/s
+ # - m/s
+ # - 1/s
levels:
+ - 50
+ - 100
+ - 150
+ - 200
+ - 250
+ - 300
+ - 400
+ - 500
+ - 600
+ - 700
+ - 850
+ - 925
+ - 1000
window: 3 # Number of time steps to use for boundary (odd)
lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells
lat: lat
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 3992bc6c..b4855eff 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -318,7 +318,7 @@ def process_dataset(self, dataset_name, split="train"):
elif vars_atmosphere:
dataset = vars_atmosphere
else:
- print("No variables found in dataset {dataset_name}")
+ print(f"No variables found in dataset {dataset_name}")
return None
dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 03c8114f..d6662cfb 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -56,18 +56,27 @@ def __init__(
if self.forcings is not None:
self.forcings_windowed = (
self.forcings.sel(
- time=self.forcings.time.isin(self.state.time),
+ time=self.state.time,
method="nearest",
)
+ .pad(
+ time=(self.forcing_window // 2, self.forcing_window // 2),
+ mode="edge",
+ )
.rolling(time=self.forcing_window, center=True)
.construct("window")
)
+
if self.boundary is not None:
self.boundary_windowed = (
self.boundary.sel(
- time=self.forcings.time.isin(self.state.time),
+ time=self.state.time,
method="nearest",
)
+ .pad(
+ time=(self.boundary_window // 2, self.boundary_window // 2),
+ mode="edge",
+ )
.rolling(time=self.boundary_window, center=True)
.construct("window")
)
@@ -77,7 +86,7 @@ def __len__(self):
return len(self.state.time) - self.ar_steps - self.idx_max
def __getitem__(self, idx):
- idx += self.idx_max / 2 # Skip first time step
+ idx += self.idx_max // 2 # Skip first time step
sample = torch.tensor(
self.state.isel(time=slice(idx, idx + self.ar_steps)).values,
dtype=torch.float32,
From 7524c4ddad149a16b4bdc8cfa8e4f41e967b0223 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 12:48:46 +0200
Subject: [PATCH 014/273] Adapted nwp_xy related code to new data loading
procedure
---
.gitignore | 1 +
create_mesh.py | 24 ++++++++++++++----------
neural_lam/utils.py | 35 ++++++++++++++++++++++++++++++++++-
neural_lam/vis.py | 8 ++++----
plot_graph.py | 8 +++++---
5 files changed, 58 insertions(+), 18 deletions(-)
diff --git a/.gitignore b/.gitignore
index 590c7e12..1ecd1dfe 100644
--- a/.gitignore
+++ b/.gitignore
@@ -73,3 +73,4 @@ tags
# Coc configuration directory
.vim
.vscode
+cosmo_hilam.html
diff --git a/create_mesh.py b/create_mesh.py
index cb524cd6..2b6af9fd 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -12,6 +12,11 @@
import torch_geometric as pyg
from torch_geometric.utils.convert import from_networkx
+# First-party
+from neural_lam import utils
+
+# matplotlib.use('TkAgg')
+
def plot_graph(graph, title=None):
fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H
@@ -152,13 +157,6 @@ def prepend_node_index(graph, new_index):
def main():
parser = ArgumentParser(description="Graph generation arguments")
- parser.add_argument(
- "--dataset",
- type=str,
- default="meps_example",
- help="Dataset to load grid point coordinates from "
- "(default: meps_example)",
- )
parser.add_argument(
"--graph",
type=str,
@@ -184,15 +182,21 @@ def main():
default=0,
help="Generate hierarchical mesh graph (default: 0, no)",
)
+ parser.add_argument(
+ "--data_config",
+ type=str,
+ default="neural_lam/data_config.yaml",
+ help="Path to data config file (default: neural_lam/data_config.yaml)",
+ )
+
args = parser.parse_args()
# Load grid positions
- static_dir_path = os.path.join("data", args.dataset, "static")
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)
- xy = np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
-
+ config_loader = utils.ConfigLoader(args.data_config)
+ xy = config_loader.get_nwp_xy()
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index b4855eff..172cef95 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -3,6 +3,7 @@
# Third-party
import cartopy.crs as ccrs
+import numpy as np
import torch
import xarray as xr
import yaml
@@ -275,7 +276,7 @@ def open_zarr(self, dataset_name):
dataset = xr.open_zarr(dataset_path, consolidated=True)
return dataset
- def process_dataset(self, dataset_name, split="train"):
+ def process_dataset(self, dataset_name, split="train", stack=True):
"""Process a single dataset specified by the dataset name."""
dataset = self.open_zarr(dataset_name)
@@ -321,6 +322,29 @@ def process_dataset(self, dataset_name, split="train"):
print(f"No variables found in dataset {dataset_name}")
return None
+ if not all(
+ lat_lon in self.zarrs[dataset_name].dims.values.values()
+ for lat_lon in self.zarrs[
+ dataset_name
+ ].lat_lon_names.values.values()
+ ):
+ lat_name = self.zarrs[dataset_name].lat_lon_names.lat
+ lon_name = self.zarrs[dataset_name].lat_lon_names.lon
+ if dataset[lat_name].ndim == 2:
+ dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
+ if dataset[lon_name].ndim == 2:
+ dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True)
+ dataset = dataset.assign_coords(
+ x=dataset[lon_name], y=dataset[lat_name]
+ )
+
+ if stack:
+ dataset = self.stack_grid(dataset)
+
+ return dataset
+
+ def stack_grid(self, dataset):
+ """Stack grid dimensions."""
dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
if "time" in dataset.dims:
@@ -328,3 +352,12 @@ def process_dataset(self, dataset_name, split="train"):
else:
dataset = dataset.transpose("grid", "variable")
return dataset
+
+ def get_nwp_xy(self):
+ """Get the x and y coordinates for the NWP grid."""
+ x = self.process_dataset("static", stack=False).x.values
+ y = self.process_dataset("static", stack=False).y.values
+ xx, yy = np.meshgrid(y, x)
+ xy = np.stack((xx, yy), axis=0)
+
+ return xy
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 02b8dd35..8c36a9a7 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -78,7 +78,7 @@ def plot_prediction(
vmin, vmax = vrange
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*data_config.grid_shape)
+ mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state)
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
@@ -93,7 +93,7 @@ def plot_prediction(
# Plot pred and target
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
- data_grid = data.reshape(*data_config.grid_shape).cpu().numpy()
+ data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy()
im = ax.imshow(
data_grid,
origin="lower",
@@ -129,7 +129,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
vmin, vmax = vrange
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*data_config.grid_shape)
+ mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state)
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
@@ -139,7 +139,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
)
ax.coastlines() # Add coastline outlines
- error_grid = error.reshape(*data_config.grid_shape).cpu().numpy()
+ error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy()
im = ax.imshow(
error_grid,
diff --git a/plot_graph.py b/plot_graph.py
index c82b4e04..e246200d 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -63,9 +63,11 @@ def main():
mesh_static_features = graph_ldict["mesh_static_features"]
config_loader = utils.ConfigLoader(args.data_config)
- grid_static_features = config_loader.process_dataset("static")
- # Extract values needed, turn to numpy
- grid_pos = grid_static_features[:, :2].numpy()
+ xy = config_loader.get_nwp_xy()
+ grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2) # (N_grid, 2)
+ pos_max = np.max(np.abs(grid_xy))
+ grid_pos = grid_xy / pos_max # Divide by maximum coordinate
+
# Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
grid_pos = np.concatenate(
From 45fd375b3ff3c2760906b866c3c26e631162dbbf Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 12:49:53 +0200
Subject: [PATCH 015/273] only state requires units for plotting lat lon
specifications make the code more flexible
---
neural_lam/data_config.yaml | 56 +++++++++----------------------------
1 file changed, 13 insertions(+), 43 deletions(-)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index e6a0d506..cce477ed 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -6,12 +6,18 @@ zarrs: # List of zarrs containing fields related to state
level: z
x: x # Either give "grid" (flattened) dimension or "x" and "y"
y: y
+ lat_lon_names:
+ lon: lon
+ lat: lat
static:
path: /scratch/sadamov/template.zarr
dims:
level: z
x: x
y: y
+ lat_lon_names:
+ lon: lon
+ lat: lat
forcing:
path: /scratch/sadamov/template.zarr
dims:
@@ -19,6 +25,9 @@ zarrs: # List of zarrs containing fields related to state
level: z
x: x
y: y
+ lat_lon_names:
+ lon: lon
+ lat: lat
boundary:
path: /scratch/sadamov/era5.zarr
dims:
@@ -26,6 +35,9 @@ zarrs: # List of zarrs containing fields related to state
level: level
x: longitude
y: latitude
+ lat_lon_names:
+ lon: longitude
+ lat: latitude
mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary.
state: # Variables forecasted by the model
surface: # Single-field variables
@@ -77,16 +89,8 @@ state: # Variables forecasted by the model
static: # Static inputs
surface:
- HSURF
- - lat
- - lon
- surface_units:
- - m
- - °N
- - °E
atmosphere:
- FI
- atmosphere_units:
- - m^2/s^2
levels:
- 0
- 5
@@ -104,12 +108,8 @@ static: # Static inputs
forcing: # Forcing variables, dynamic inputs to the model
surface:
- ASOB_S
- surface_units:
- - W/m^2
atmosphere:
- T
- atmosphere_units:
- - K
levels:
- 0
- 5
@@ -143,23 +143,6 @@ boundary: # Boundary conditions
# - total_precipitation_24hr
# - total_precipitation_6hr
# - geopotential_at_surface
- surface_units:
- - m/s
- # - m/s
- # - K
- # - K
- # - Pa
- # - W/m^2
- # - W/m^2
- # - W/m^2
- # - W/m^2
- # - Pa
- # - "%"
- # - kg/m^2
- # - kg/m^2
- # - kg/m^2
- # - kg/m^2
- # - m^2/s^2
atmosphere:
- divergence
# - geopotential
@@ -170,16 +153,6 @@ boundary: # Boundary conditions
# - v_component_of_wind
# - vertical_velocity
# - vorticity
- atmosphere_units:
- - 1/s
- # - m^2/s^2
- # - "%"
- # - kg/kg
- # - K
- # - m/s
- # - m/s
- # - m/s
- # - 1/s
levels:
- 50
- 100
@@ -195,10 +168,7 @@ boundary: # Boundary conditions
- 925
- 1000
window: 3 # Number of time steps to use for boundary (odd)
-lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells
- lat: lat
- lon: lon
-grid_shape:
+grid_shape_state:
x: 582
y: 390
splits:
From 812323ddce5a86cc1fef0f1305ad27d5dfce629f Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 19:23:04 +0200
Subject: [PATCH 016/273] small bugfixes and improvements
---
.gitignore | 3 ++-
neural_lam/data_config.yaml | 4 +---
neural_lam/weather_dataset.py | 33 +++++++++------------------------
train_model.py | 4 ++--
4 files changed, 14 insertions(+), 30 deletions(-)
diff --git a/.gitignore b/.gitignore
index 1ecd1dfe..08cc014e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,8 @@ graphs
*.sif
sweeps
test_*.sh
+cosmo_hilam.html
+normalization.zarr
### Python ###
# Byte-compiled / optimized / DLL files
@@ -73,4 +75,3 @@ tags
# Coc configuration directory
.vim
.vscode
-cosmo_hilam.html
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index cce477ed..faaabd32 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -29,7 +29,7 @@ zarrs: # List of zarrs containing fields related to state
lon: lon
lat: lat
boundary:
- path: /scratch/sadamov/era5.zarr
+ path: /scratch/sadamov/era5_template.zarr
dims:
time: time
level: level
@@ -197,5 +197,3 @@ normalization:
boundary_std: boundary_std
diff_mean: diff_mean
diff_std: diff_std
- grid_static_features: grid_static_features
- param_weights: param_weights
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index d6662cfb..d51fb896 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -39,9 +39,7 @@ def __init__(
self.state = self.config_loader.process_dataset("state", self.split)
assert self.state is not None, "State dataset not found"
- self.forcings = self.config_loader.process_dataset(
- "forcing", self.split
- )
+ self.forcing = self.config_loader.process_dataset("forcing", self.split)
self.boundary = self.config_loader.process_dataset(
"boundary", self.split
)
@@ -53,9 +51,9 @@ def __init__(
(self.boundary_window - 1), (self.forcing_window - 1)
)
- if self.forcings is not None:
- self.forcings_windowed = (
- self.forcings.sel(
+ if self.forcing is not None:
+ self.forcing_windowed = (
+ self.forcing.sel(
time=self.state.time,
method="nearest",
)
@@ -92,11 +90,11 @@ def __getitem__(self, idx):
dtype=torch.float32,
)
- forcings = (
- self.forcings_windowed.isel(time=slice(idx, idx + self.ar_steps))
+ forcing = (
+ self.forcing_windowed.isel(time=slice(idx, idx + self.ar_steps))
.stack(variable_window=("variable", "window"))
.values
- if self.forcings is not None
+ if self.forcing is not None
else torch.tensor([])
)
@@ -119,10 +117,10 @@ def __getitem__(self, idx):
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps-2, N_grid, d_features)
- # forcings: (ar_steps, N_grid, d_windowed_forcings)
+ # forcing: (ar_steps, N_grid, d_windowed_forcing)
# boundary: (ar_steps, N_grid, d_windowed_boundary)
# batch_times: (ar_steps,)
- return init_states, target_states, forcings, boundary, batch_times
+ return init_states, target_states, forcing, boundary, batch_times
class WeatherDataModule(pl.LightningDataModule):
@@ -183,16 +181,3 @@ def test_dataloader(self):
num_workers=self.num_workers,
shuffle=False,
)
-
-
-data_module = WeatherDataModule(batch_size=4, num_workers=0)
-data_module.setup()
-train_dataloader = data_module.train_dataloader()
-for batch in train_dataloader:
- print(batch[0].shape)
- print(batch[1].shape)
- print(batch[2].shape)
- print(batch[3].shape)
- print(batch[4])
- print(batch[2][0, 0, 0, :])
- break
diff --git a/train_model.py b/train_model.py
index 4f57ca24..e5dfd528 100644
--- a/train_model.py
+++ b/train_model.py
@@ -62,7 +62,7 @@ def main():
"--seed", type=int, default=42, help="random seed (default: 42)"
)
parser.add_argument(
- "--n_workers",
+ "--num_workers",
type=int,
default=4,
help="Number of workers in data loader (default: 4)",
@@ -235,7 +235,7 @@ def main():
# Create datamodule
data_module = WeatherDataModule(
batch_size=args.batch_size,
- num_workers=args.n_workers,
+ num_workers=args.num_workers,
)
# Instantiate model + trainer
From 500f2fbeafd386844c813a7bc20a3bf65aed86f0 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 19:24:09 +0200
Subject: [PATCH 017/273] Calculate stats and store in zarr archive Zarr is
registered to model buffer Normalization happens on device
on_after_batch_transfer
---
create_parameter_weights.py | 144 ++++++++++++++--------------------
neural_lam/models/ar_model.py | 33 ++++++++
neural_lam/utils.py | 18 ++++-
3 files changed, 108 insertions(+), 87 deletions(-)
diff --git a/create_parameter_weights.py b/create_parameter_weights.py
index 926d7741..1eda7a24 100644
--- a/create_parameter_weights.py
+++ b/create_parameter_weights.py
@@ -1,14 +1,13 @@
# Standard library
-import os
from argparse import ArgumentParser
# Third-party
-import numpy as np
import torch
+import xarray as xr
from tqdm import tqdm
# First-party
-from neural_lam.weather_dataset import WeatherDataset
+from neural_lam.weather_dataset import WeatherDataModule
def main():
@@ -16,12 +15,6 @@ def main():
Pre-compute parameter weights to be used in loss function
"""
parser = ArgumentParser(description="Training arguments")
- parser.add_argument(
- "--dataset",
- type=str,
- default="meps_example",
- help="Dataset to compute weights for (default: meps_example)",
- )
parser.add_argument(
"--batch_size",
type=int,
@@ -29,107 +22,77 @@ def main():
help="Batch size when iterating over the dataset",
)
parser.add_argument(
- "--step_length",
- type=int,
- default=3,
- help="Step length in hours to consider single time step (default: 3)",
- )
- parser.add_argument(
- "--n_workers",
+ "--num_workers",
type=int,
default=4,
help="Number of workers in data loader (default: 4)",
)
+ parser.add_argument(
+ "--zarr_path",
+ type=str,
+ default="normalization.zarr",
+ help="Directory where data is stored",
+ )
+
args = parser.parse_args()
- static_dir_path = os.path.join("data", args.dataset, "static")
-
- ds = WeatherDataset()
- # Create parameter weights based on height
- # based on fig A.1 in graph cast paper
- w_dict = {
- "2": 1.0,
- "0": 0.1,
- "65": 0.065,
- "1000": 0.1,
- "850": 0.05,
- "500": 0.03,
- }
- w_list = np.array(
- [w_dict[par.split("_")[-2]] for par in ds.config_loader.param_names()]
- )
- print("Saving parameter weights...")
- np.save(
- os.path.join(static_dir_path, "parameter_weights.npy"),
- w_list.astype("float32"),
+ data_module = WeatherDataModule(
+ batch_size=args.batch_size, num_workers=args.num_workers
)
+ data_module.setup()
+ loader = data_module.train_dataloader()
# Load dataset without any subsampling
- loader = torch.utils.data.DataLoader(
- ds, args.batch_size, shuffle=False, num_workers=args.n_workers
- )
- # Compute mean and std.-dev. of each parameter (+ flux forcing)
+ # Compute mean and std.-dev. of each parameter (+ forcing forcing)
# across full dataset
print("Computing mean and std.-dev. for parameters...")
means = []
squares = []
- flux_means = []
- flux_squares = []
- for init_batch, target_batch, forcing_batch in tqdm(loader):
+ fb_means = {"forcing": [], "boundary": []}
+ fb_squares = {"forcing": [], "boundary": []}
+
+ for init_batch, target_batch, forcing_batch, boundary_batch, _ in tqdm(
+ loader
+ ):
batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t, N_grid, d_features)
means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,)
- squares.append(
- torch.mean(batch**2, dim=(1, 2))
- ) # (N_batch, d_features,)
+ squares.append(torch.mean(batch**2, dim=(1, 2)))
- # Flux at 1st windowed position is index 1 in forcing
- flux_batch = forcing_batch[:, :, :, 1]
- flux_means.append(torch.mean(flux_batch)) # (,)
- flux_squares.append(torch.mean(flux_batch**2)) # (,)
+ for fb_type, fb_batch in zip(
+ ["forcing", "boundary"], [forcing_batch, boundary_batch]
+ ):
+ fb_batch = fb_batch[:, :, :, 1]
+ fb_means[fb_type].append(torch.mean(fb_batch)) # (,)
+ fb_squares[fb_type].append(torch.mean(fb_batch**2)) # (,)
mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features)
second_moment = torch.mean(torch.cat(squares, dim=0), dim=0)
std = torch.sqrt(second_moment - mean**2) # (d_features)
- flux_mean = torch.mean(torch.stack(flux_means)) # (,)
- flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,)
- flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
- flux_stats = torch.stack((flux_mean, flux_std))
-
- print("Saving mean, std.-dev, flux_stats...")
- torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
- torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))
- torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt"))
+ fb_stats = {}
+ for fb_type in ["forcing", "boundary"]:
+ fb_stats[f"{fb_type}_mean"] = torch.mean(
+ torch.stack(fb_means[fb_type])
+ ) # (,)
+ fb_second_moment = torch.mean(torch.stack(fb_squares[fb_type])) # (,)
+ fb_stats[f"{fb_type}_std"] = torch.sqrt(
+ fb_second_moment - fb_stats[f"{fb_type}_mean"] ** 2
+ ) # (,)
# Compute mean and std.-dev. of one-step differences across the dataset
print("Computing mean and std.-dev. for one-step differences...")
- ds_standard = WeatherDataset() # Re-load with standardization
- loader_standard = torch.utils.data.DataLoader(
- ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers
- )
- used_subsample_len = (65 // args.step_length) * args.step_length
-
diff_means = []
diff_squares = []
- for init_batch, target_batch, _ in tqdm(loader_standard):
- batch = torch.cat(
- (init_batch, target_batch), dim=1
- ) # (N_batch, N_t', N_grid, d_features)
- # Note: batch contains only 1h-steps
- stepped_batch = torch.cat(
- [
- batch[:, ss_i : used_subsample_len : args.step_length]
- for ss_i in range(args.step_length)
- ],
- dim=0,
- )
- # (N_batch', N_t, N_grid, d_features),
- # N_batch' = args.step_length*N_batch
-
- batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1]
- # (N_batch', N_t-1, N_grid, d_features)
+ for init_batch, target_batch, _, _, _ in tqdm(loader):
+ # normalize the batch
+ init_batch = (init_batch - mean) / std
+ target_batch = (target_batch - mean) / std
+
+ batch = torch.cat((init_batch, target_batch), dim=1)
+ batch_diffs = batch[:, 1:] - batch[:, :-1]
+ # (N_batch, N_t-1, N_grid, d_features)
diff_means.append(
torch.mean(batch_diffs, dim=(1, 2))
@@ -142,9 +105,20 @@ def main():
diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0)
diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features)
- print("Saving one-step difference mean and std.-dev...")
- torch.save(diff_mean, os.path.join(static_dir_path, "diff_mean.pt"))
- torch.save(diff_std, os.path.join(static_dir_path, "diff_std.pt"))
+ # Create xarray dataset
+ ds = xr.Dataset(
+ {
+ "mean": (["d_features"], mean),
+ "std": (["d_features"], std),
+ "diff_mean": (["d_features"], diff_mean),
+ "diff_std": (["d_features"], diff_std),
+ **fb_stats,
+ }
+ )
+
+ # Save dataset as Zarr
+ print("Saving dataset as Zarr...")
+ ds.to_zarr(args.zarr_path, mode="w")
if __name__ == "__main__":
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 8353327d..8976990b 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -91,6 +91,17 @@ def __init__(self, args):
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []
+ # Load normalization statistics
+ self.normalization_stats = self.config_loader.load_normalization_stats()
+ if self.normalization_stats is not None:
+ for (
+ var_name,
+ var_data,
+ ) in self.normalization_stats.data_vars.items():
+ self.register_buffer(
+ f"data_{var_name}", torch.tensor(var_data.values)
+ )
+
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.lr, betas=(0.9, 0.95)
@@ -195,6 +206,28 @@ def common_step(self, batch):
return prediction, target_states, pred_std
+ def on_after_batch_transfer(self, batch, dataloader_idx):
+ """Normalize Batch data after transferring to the device."""
+ if self.normalization_stats is not None:
+ init_states, target_states, forcing_features, boundary_features = (
+ batch
+ )
+ init_states = (init_states - self.data_mean) / self.data_std
+ target_states = (target_states - self.data_mean) / self.data_std
+ forcing_features = (
+ forcing_features - self.forcing_mean
+ ) / self.forcing_std
+ boundary_features = (
+ boundary_features - self.boundary_mean
+ ) / self.boundary_std
+ batch = (
+ init_states,
+ target_states,
+ forcing_features,
+ boundary_features,
+ )
+ return batch
+
def training_step(self, batch):
"""
Train on single batch
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 172cef95..c86418c8 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -276,6 +276,20 @@ def open_zarr(self, dataset_name):
dataset = xr.open_zarr(dataset_path, consolidated=True)
return dataset
+ def load_normalization_stats(self):
+ """Load normalization statistics from Zarr archive."""
+ normalization_path = "normalization.zarr"
+ if not os.path.exists(normalization_path):
+ print(
+ f"Normalization statistics not found at "
+ f"path: {normalization_path}"
+ )
+ return None
+ normalization_stats = xr.open_zarr(
+ normalization_path, consolidated=True
+ )
+ return normalization_stats
+
def process_dataset(self, dataset_name, split="train", stack=True):
"""Process a single dataset specified by the dataset name."""
@@ -338,8 +352,8 @@ def process_dataset(self, dataset_name, split="train", stack=True):
x=dataset[lon_name], y=dataset[lat_name]
)
- if stack:
- dataset = self.stack_grid(dataset)
+ if stack:
+ dataset = self.stack_grid(dataset)
return dataset
From 9293fe1b6e69f7960e4174e9d14a18e01cfa6521 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 21:55:42 +0200
Subject: [PATCH 018/273] latex support
---
neural_lam/data_config.yaml | 24 ++++++++++++------------
1 file changed, 12 insertions(+), 12 deletions(-)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index faaabd32..55e59a72 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -50,12 +50,12 @@ state: # Variables forecasted by the model
- V_10M
surface_units:
- "%"
- - Pa
- - Pa
- - K
- - kg/m^2
- - m/s
- - m/s
+ - r"$\mathrm{Pa}$"
+ - r"$\mathrm{Pa}$"
+ - r"$\mathrm{K}$"
+ - r"$\mathrm{kg}/\mathrm{m}^2$"
+ - r"$\mathrm{m}/\mathrm{s}$"
+ - r"$\mathrm{m}/\mathrm{s}$"
atmosphere: # Variables with vertical levels
- PP
- QV
@@ -65,13 +65,13 @@ state: # Variables forecasted by the model
- V
- W
atmosphere_units:
- - Pa
- - kg/kg
+ - r"$\mathrm{Pa}$"
+ - r"$\mathrm{kg}/\mathrm{kg}$"
- "%"
- - K
- - m/s
- - m/s
- - Pa/s
+ - r"$\mathrm{K}$"
+ - r"$\mathrm{m}/\mathrm{s}$"
+ - r"$\mathrm{m}/\mathrm{s}$"
+ - r"$\mathrm{Pa}/\mathrm{s}$"
levels: # Levels to use for atmosphere variables
- 0
- 5
From e80aa5899c3ff9cad63f09f364d48bf4780e1dfe Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 21:58:18 +0200
Subject: [PATCH 019/273] ar_steps for training and eval
---
neural_lam/weather_dataset.py | 29 +++++++++++++++----------
train_model.py | 41 +++++++++++++++--------------------
2 files changed, 35 insertions(+), 35 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index d51fb896..4b5da0a8 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -18,8 +18,8 @@ class WeatherDataset(torch.utils.data.Dataset):
def __init__(
self,
split="train",
- batch_size=4,
ar_steps=3,
+ batch_size=4,
control_only=False,
data_config="neural_lam/data_config.yaml",
):
@@ -47,9 +47,6 @@ def __init__(
self.state_times = self.state.time.values
self.forcing_window = self.config_loader.forcing.window
self.boundary_window = self.config_loader.boundary.window
- self.idx_max = max(
- (self.boundary_window - 1), (self.forcing_window - 1)
- )
if self.forcing is not None:
self.forcing_windowed = (
@@ -81,17 +78,16 @@ def __init__(
def __len__(self):
# Skip first and last time step
- return len(self.state.time) - self.ar_steps - self.idx_max
+ return len(self.state.time) - self.ar_steps
def __getitem__(self, idx):
- idx += self.idx_max // 2 # Skip first time step
sample = torch.tensor(
self.state.isel(time=slice(idx, idx + self.ar_steps)).values,
dtype=torch.float32,
)
forcing = (
- self.forcing_windowed.isel(time=slice(idx, idx + self.ar_steps))
+ self.forcing_windowed.isel(time=slice(idx + 2, idx + self.ar_steps))
.stack(variable_window=("variable", "window"))
.values
if self.forcing is not None
@@ -99,7 +95,9 @@ def __getitem__(self, idx):
)
boundary = (
- self.boundary_windowed.isel(time=slice(idx, idx + self.ar_steps))
+ self.boundary_windowed.isel(
+ time=slice(idx + 2, idx + self.ar_steps)
+ )
.stack(variable_window=("variable", "window"))
.values
if self.boundary is not None
@@ -110,16 +108,16 @@ def __getitem__(self, idx):
target_states = sample[2:]
batch_times = (
- self.state.isel(time=slice(idx, idx + self.ar_steps))
+ self.state.isel(time=slice(idx + 2, idx + self.ar_steps))
.time.values.astype(str)
.tolist()
)
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps-2, N_grid, d_features)
- # forcing: (ar_steps, N_grid, d_windowed_forcing)
- # boundary: (ar_steps, N_grid, d_windowed_boundary)
- # batch_times: (ar_steps,)
+ # forcing: (ar_steps-2, N_grid, d_windowed_forcing)
+ # boundary: (ar_steps-2, N_grid, d_windowed_boundary)
+ # batch_times: (ar_steps-2,)
return init_states, target_states, forcing, boundary, batch_times
@@ -128,10 +126,14 @@ class WeatherDataModule(pl.LightningDataModule):
def __init__(
self,
+ ar_steps_train=3,
+ ar_steps_eval=25,
batch_size=4,
num_workers=16,
):
super().__init__()
+ self.ar_steps_train = ar_steps_train
+ self.ar_steps_eval = ar_steps_eval
self.batch_size = batch_size
self.num_workers = num_workers
self.train_dataset = None
@@ -142,16 +144,19 @@ def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_dataset = WeatherDataset(
split="train",
+ ar_steps=self.ar_steps_train,
batch_size=self.batch_size,
)
self.val_dataset = WeatherDataset(
split="val",
+ ar_steps=self.ar_steps_eval,
batch_size=self.batch_size,
)
if stage == "test" or stage is None:
self.test_dataset = WeatherDataset(
split="test",
+ ar_steps=self.ar_steps_eval,
batch_size=self.batch_size,
)
diff --git a/train_model.py b/train_model.py
index e5dfd528..a8b02f58 100644
--- a/train_model.py
+++ b/train_model.py
@@ -31,14 +31,6 @@ def main():
description="Train or evaluate NeurWP models for LAM"
)
- # General options
- parser.add_argument(
- "--dataset",
- type=str,
- default="meps_example",
- help="Dataset, corresponding to name in data directory "
- "(default: meps_example)",
- )
parser.add_argument(
"--model",
type=str,
@@ -51,13 +43,6 @@ def main():
default="neural_lam/data_config.yaml",
help="Path to data config file (default: neural_lam/data_config.yaml)",
)
- parser.add_argument(
- "--subset_ds",
- type=int,
- default=0,
- help="Use only a small subset of the dataset, for debugging"
- "(default: 0=false)",
- )
parser.add_argument(
"--seed", type=int, default=42, help="random seed (default: 42)"
)
@@ -139,11 +124,11 @@ def main():
# Training options
parser.add_argument(
- "--ar_steps",
+ "--ar_steps_train",
type=int,
- default=1,
- help="Number of steps to unroll prediction for in loss (1-19) "
- "(default: 1)",
+ default=3,
+ help="Number of steps to unroll prediction for in loss function "
+ "(default: 3)",
)
parser.add_argument(
"--control_only",
@@ -161,9 +146,9 @@ def main():
parser.add_argument(
"--step_length",
type=int,
- default=3,
+ default=1,
help="Step length in hours to consider single time step 1-3 "
- "(default: 3)",
+ "(default: 1)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="learning rate (default: 0.001)"
@@ -183,6 +168,13 @@ def main():
help="Eval model on given data split (val/test) "
"(default: None (train model))",
)
+ parser.add_argument(
+ "--ar_steps_eval",
+ type=int,
+ default=25,
+ help="Number of steps to unroll prediction for in loss function "
+ "(default: 25)",
+ )
parser.add_argument(
"--n_example_pred",
type=int,
@@ -234,6 +226,8 @@ def main():
seed.seed_everything(args.seed)
# Create datamodule
data_module = WeatherDataModule(
+ ar_steps_train=args.ar_steps_train,
+ ar_steps_eval=args.ar_steps_eval,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
@@ -258,9 +252,10 @@ def main():
else:
model = model_class(args)
- prefix = "subset-" if args.subset_ds else ""
if args.eval:
- prefix = prefix + f"eval-{args.eval}-"
+ prefix = f"eval-{args.eval}-"
+ else:
+ prefix = "train-"
run_name = (
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}"
From a86fc0788c30a1f5f364f26ba4c68816b4af23f3 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 21:58:26 +0200
Subject: [PATCH 020/273] smaller ammendments
---
neural_lam/models/ar_model.py | 50 ++++++++++++---------------
neural_lam/models/base_graph_model.py | 4 +--
neural_lam/utils.py | 21 +++++++----
3 files changed, 38 insertions(+), 37 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 8976990b..0c0e5a55 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -24,13 +24,15 @@ class ARModel(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.save_hyperparameters()
- self.lr = args.lr
+ self.args = args
self.config_loader = utils.ConfigLoader(args.data_config)
# Load static features for grid/data
- static = self.config_loader.process_dataset("static", self.split)
+ static = self.config_loader.process_dataset("static")
self.register_buffer(
- "grid_static_features", torch.tensor(static.values)
+ "grid_static_features",
+ torch.tensor(static.values),
+ persistent=False,
)
# Double grid output dim. to also output std.-dev.
@@ -42,15 +44,6 @@ def __init__(self, args):
# Pred. dim. in grid cell
self.grid_output_dim = self.config_loader.num_data_vars("state")
- # Store constant per-variable std.-dev. weighting
- # Note that this is the inverse of the multiplicative weighting
- # in wMSE/wMAE
- self.register_buffer(
- "per_var_std",
- self.step_diff_std / torch.sqrt(self.param_weights),
- persistent=False,
- )
-
# grid_dim from data + static
(
self.num_grid_nodes,
@@ -60,11 +53,14 @@ def __init__(self, args):
2 * self.config_loader.num_data_vars("state")
+ grid_static_dim
+ self.config_loader.num_data_vars("forcing")
+ * self.config_loader.forcing.window
)
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
+ border_mask = torch.ones(self.num_grid_nodes, 1)
+ self.register_buffer("border_mask", border_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
"interior_mask", 1.0 - self.border_mask, persistent=False
@@ -99,12 +95,14 @@ def __init__(self, args):
var_data,
) in self.normalization_stats.data_vars.items():
self.register_buffer(
- f"data_{var_name}", torch.tensor(var_data.values)
+ f"{var_name}",
+ torch.tensor(var_data.values),
+ persistent=False,
)
def configure_optimizers(self):
opt = torch.optim.AdamW(
- self.parameters(), lr=self.lr, betas=(0.9, 0.95)
+ self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
)
if self.opt_state:
opt.load_state_dict(self.opt_state)
@@ -179,7 +177,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
pred_std_list, dim=1
) # (B, pred_steps, num_grid_nodes, d_f)
else:
- pred_std = self.per_var_std # (d_f,)
+ pred_std = self.diff_std # (d_f,)
return prediction, pred_std
@@ -209,22 +207,20 @@ def common_step(self, batch):
def on_after_batch_transfer(self, batch, dataloader_idx):
"""Normalize Batch data after transferring to the device."""
if self.normalization_stats is not None:
- init_states, target_states, forcing_features, boundary_features = (
- batch
- )
- init_states = (init_states - self.data_mean) / self.data_std
- target_states = (target_states - self.data_mean) / self.data_std
+ init_states, target_states, forcing_features, _, _ = batch
+ init_states = (init_states - self.mean) / self.std
+ target_states = (target_states - self.mean) / self.std
forcing_features = (
forcing_features - self.forcing_mean
) / self.forcing_std
- boundary_features = (
- boundary_features - self.boundary_mean
- ) / self.boundary_std
+ # boundary_features = (
+ # boundary_features - self.boundary_mean
+ # ) / self.boundary_std
batch = (
init_states,
target_states,
forcing_features,
- boundary_features,
+ # boundary_features,
)
return batch
@@ -392,8 +388,8 @@ def plot_examples(self, batch, n_examples, prediction=None):
target = batch[1]
# Rescale to original data scale
- prediction_rescaled = prediction * self.data_std + self.data_mean
- target_rescaled = target * self.data_std + self.data_mean
+ prediction_rescaled = prediction * self.std + self.mean
+ target_rescaled = target * self.std + self.mean
# Iterate over the examples
for pred_slice, target_slice in zip(
@@ -541,7 +537,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
metric_name = metric_name.replace("mse", "rmse")
# Note: we here assume rescaling for all metrics is linear
- metric_rescaled = metric_tensor_averaged * self.data_std
+ metric_rescaled = metric_tensor_averaged * self.std
# (pred_steps, d_f)
log_dict.update(
self.create_metric_log_dict(
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 256d4adc..fb5df62d 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -166,9 +166,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
pred_std = None
# Rescale with one-step difference statistics
- rescaled_delta_mean = (
- pred_delta_mean * self.step_diff_std + self.step_diff_mean
- )
+ rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean
# Residual connection for full state
return prev_state + rescaled_delta_mean, pred_std
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index c86418c8..71ef9512 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -242,16 +242,23 @@ def __contains__(self, key):
def param_names(self):
"""Return parameter names."""
- return (
- self.values["state"]["surface"] + self.values["state"]["atmosphere"]
- )
+ surface_names = self.values["state"]["surface"]
+ atmosphere_names = [
+ f"{var}_{level}"
+ for var in self.values["state"]["atmosphere"]
+ for level in self.values["state"]["levels"]
+ ]
+ return surface_names + atmosphere_names
def param_units(self):
"""Return parameter units."""
- return (
- self.values["state"]["surface_units"]
- + self.values["state"]["atmosphere_units"]
- )
+ surface_units = self.values["state"]["surface_units"]
+ atmosphere_units = [
+ unit
+ for unit in self.values["state"]["atmosphere_units"]
+ for _ in self.values["state"]["levels"]
+ ]
+ return surface_units + atmosphere_units
def num_data_vars(self, key):
"""Return the number of data variables for a given key."""
From 7ae9c872359b94283cb278f24f73bb7e050ae5bf Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 22:42:32 +0200
Subject: [PATCH 021/273] Dummy mask was inverted - fixed
---
neural_lam/models/ar_model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 0c0e5a55..f49eb094 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -59,7 +59,7 @@ def __init__(self, args):
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
- border_mask = torch.ones(self.num_grid_nodes, 1)
+ border_mask = torch.zeros(self.num_grid_nodes, 1)
self.register_buffer("border_mask", border_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
From 93674a2b437cf46681ddced3859f8978e5e03200 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 9 May 2024 22:54:54 +0200
Subject: [PATCH 022/273] replace hardcoded normalization path
---
neural_lam/data_config.yaml | 2 +-
neural_lam/utils.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 55e59a72..140eb9b7 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -187,7 +187,7 @@ projection:
pole_longitude: 10.0
pole_latitude: -43.0
normalization:
- zarr: /scratch/sadamov/norm.zarr
+ zarr: normalization.zarr
vars:
data_mean: data_mean
data_std: data_std
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 71ef9512..96e1549e 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -285,7 +285,7 @@ def open_zarr(self, dataset_name):
def load_normalization_stats(self):
"""Load normalization statistics from Zarr archive."""
- normalization_path = "normalization.zarr"
+ normalization_path = self.normalization.zarr
if not os.path.exists(normalization_path):
print(
f"Normalization statistics not found at "
From 0afdfee060a4c884797baea4779d9a4cc128c8c6 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 18:48:06 +0200
Subject: [PATCH 023/273] wip on simplifying pre-commit setup
---
.flake8 | 3 ++
.github/workflows/pre-commit.yml | 5 ++-
.pre-commit-config.yaml | 58 ++++++++++++++++----------------
requirements.txt | 6 +---
4 files changed, 35 insertions(+), 37 deletions(-)
create mode 100644 .flake8
diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..b02dd545
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 88
+ignore = E203, F811, I002, W503
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index a6ad84f1..4828afc9 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -28,6 +28,5 @@ jobs:
pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \
torch-cluster==1.6.1 torch-geometric==2.3.1 \
-f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html
- - name: Run pre-commit hooks
- run: |
- pre-commit run --all-files
+ - uses: pre-commit/action@v2.0.3
+
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f48eca67..106cb64b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,51 +1,51 @@
repos:
-- repo: https://github.com/pre-commit/pre-commit-hooks
+ - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- - id: check-ast
- - id: check-case-conflict
- - id: check-docstring-first
- - id: check-symlinks
- - id: check-toml
- - id: check-yaml
- - id: debug-statements
- - id: end-of-file-fixer
- - id: trailing-whitespace
-- repo: local
+ - id: check-ast
+ - id: check-case-conflict
+ - id: check-docstring-first
+ - id: check-symlinks
+ - id: check-toml
+ - id: check-yaml
+ - id: debug-statements
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+
+ - repo: https://github.com/codespell-project/codespell
+ rev: 2.0.0
hooks:
- - id: codespell
- name: codespell
+ - id: codespell
description: Check for spelling errors
language: system
- entry: codespell
-- repo: local
+
+ - repo: https://github.com/psf/black
+ rev: 22.3.0
hooks:
- - id: black
- name: black
+ - id: black
description: Format Python code
language: system
- entry: black
types_or: [python, pyi]
-- repo: local
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
hooks:
- - id: isort
- name: isort
+ - id: isort
description: Group and sort Python imports
language: system
- entry: isort
types_or: [python, pyi, cython]
-- repo: local
+
+ - repo: https://github.com/PyCQA/flake8
+ rev: 6.1.0
hooks:
- - id: flake8
- name: flake8
+ - id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
language: system
- entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503
types: [python]
-- repo: local
+
+ - repo: https://github.com/pylint-dev/pylint
hooks:
- - id: pylint
- name: pylint
+ - id: pylint
entry: pylint -rn -sn
language: system
types: [python]
diff --git a/requirements.txt b/requirements.txt
index 5a2111b2..f381d54f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,10 +10,6 @@ Cartopy>=0.22.0
pyproj>=3.4.1
tueplots>=0.0.8
plotly>=5.15.0
+
# for dev
-codespell>=2.0.0
-black>=21.9b0
-isort>=5.9.3
-flake8>=4.0.1
-pylint>=3.0.3
pre-commit>=2.15.0
From 28118a6e82775a4cb6d0d3d7e89e50c86541d320 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 18:51:46 +0200
Subject: [PATCH 024/273] setup pylint version
---
.pre-commit-config.yaml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 106cb64b..28cd91b9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -44,6 +44,7 @@ repos:
types: [python]
- repo: https://github.com/pylint-dev/pylint
+ rev: 2.0.0
hooks:
- id: pylint
entry: pylint -rn -sn
From 3da310860f209bcaec6387fbebc0468505aa6331 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 19:09:09 +0200
Subject: [PATCH 025/273] remove external deps install in cicd linting
---
.github/workflows/pre-commit.yml | 11 -----------
1 file changed, 11 deletions(-)
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index 4828afc9..0ff792a3 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -11,22 +11,11 @@ on:
jobs:
pre-commit-job:
runs-on: ubuntu-latest
- defaults:
- run:
- shell: bash -l {0}
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- - name: Install pre-commit hooks
- run: |
- pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \
- --index-url https://download.pytorch.org/whl/cpu
- pip install -r requirements.txt
- pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \
- torch-cluster==1.6.1 torch-geometric==2.3.1 \
- -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html
- uses: pre-commit/action@v2.0.3
From ea64309bf9215ba900fb3bcfdf4ea17219be7ef0 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 19:31:34 +0200
Subject: [PATCH 026/273] create project
---
.gitignore | 3 ++
README.md | 38 +++++++++----------
.../create_grid_features.py | 0
create_mesh.py => neural_lam/create_mesh.py | 0
.../create_parameter_weights.py | 0
train_model.py => neural_lam/train_model.py | 0
pyproject.toml | 36 ++++++++++++++++++
requirements.txt | 15 --------
tests/__init__.py | 5 +++
9 files changed, 63 insertions(+), 34 deletions(-)
rename create_grid_features.py => neural_lam/create_grid_features.py (100%)
rename create_mesh.py => neural_lam/create_mesh.py (100%)
rename create_parameter_weights.py => neural_lam/create_parameter_weights.py (100%)
rename train_model.py => neural_lam/train_model.py (100%)
delete mode 100644 requirements.txt
create mode 100644 tests/__init__.py
diff --git a/.gitignore b/.gitignore
index 7bb826a2..cd447c4e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -72,3 +72,6 @@ tags
# Coc configuration directory
.vim
+
+# pdm
+.pdm-python
diff --git a/README.md b/README.md
index 67d9d9b1..e59bf3f5 100644
--- a/README.md
+++ b/README.md
@@ -45,7 +45,7 @@ Still, some restrictions are inevitable:
## A note on the limited area setting
Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
There are still some parts of the code that is quite specific for the MEPS area use case.
-This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`).
+This is in particular true for the mesh graph creation (`neural_lam.create_mesh`) and some of the constants used (`neural_lam.constants`).
If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
We would be happy to support such enhancements.
See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
@@ -77,7 +77,7 @@ See the [repository format section](#format-of-data-directory) for details on th
The full MEPS dataset can be shared with other researchers on request, contact us for this.
A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX).
Download the file and unzip in the neural-lam directory.
-All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `create_mesh.py`).
+All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `python -m neural_lam.create_mesh`).
Note that this is far too little data to train any useful models, but all scripts can be ran with it.
It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues.
@@ -86,31 +86,31 @@ An overview of how the different scripts and files depend on each other is given
-In order to start training models at least three pre-processing scripts have to be ran:
+In order to start training models at least three pre-processing scripts have to be run:
-* `create_mesh.py`
-* `create_grid_features.py`
-* `create_parameter_weights.py`
+* `python -m neural_lam.create_mesh`
+* `python -m neural_lam.create_grid_features`
+* `python -m neural_lam.create_parameter_weights`
### Create graph
-Run `create_mesh.py` with suitable options to generate the graph you want to use (see `python create_mesh.py --help` for a list of options).
+Run `python -m neural_lam.create_mesh` with suitable options to generate the graph you want to use (see `python neural_lam.create_mesh --help` for a list of options).
The graphs used for the different models in the [paper](https://arxiv.org/abs/2309.17370) can be created as:
-* **GC-LAM**: `python create_mesh.py --graph multiscale`
-* **Hi-LAM**: `python create_mesh.py --graph hierarchical --hierarchical 1` (also works for Hi-LAM-Parallel)
-* **L1-LAM**: `python create_mesh.py --graph 1level --levels 1`
+* **GC-LAM**: `python -m neural_lam.create_mesh --graph multiscale`
+* **Hi-LAM**: `python -m neural_lam.create_mesh --graph hierarchical --hierarchical 1` (also works for Hi-LAM-Parallel)
+* **L1-LAM**: `python -m neural_lam.create_mesh --graph 1level --levels 1`
The graph-related files are stored in a directory called `graphs`.
### Create remaining static features
-To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`.
+To create the remaining static files run the scripts `python -m neural_lam.create_grid_features` and `python -m neural_lam.create_parameter_weights`.
The main option to set for these is just which dataset to use.
## Weights & Biases Integration
The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it.
When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface.
If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`.
-The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`.
+The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam.constants`.
See the [W&B documentation](https://docs.wandb.ai/) for details.
If you would like to login and use W&B, run:
@@ -123,8 +123,8 @@ wandb off
```
## Train Models
-Models can be trained using `train_model.py`.
-Run `python train_model.py --help` for a full list of training options.
+Models can be trained using `python -m neural_lam.train_model`.
+Run `python neural_lam.train_model --help` for a full list of training options.
A few of the key ones are outlined below:
* `--dataset`: Which data to train on
@@ -143,12 +143,12 @@ This model class is used both for the L1-LAM and GC-LAM models from the [paper](
To train 1L-LAM use
```
-python train_model.py --model graph_lam --graph 1level ...
+python -m neural_lam.train_model --model graph_lam --graph 1level ...
```
To train GC-LAM use
```
-python train_model.py --model graph_lam --graph multiscale ...
+python -m neural_lam.train_model --model graph_lam --graph multiscale ...
```
### Hi-LAM
@@ -156,7 +156,7 @@ A version of Graph-LAM that uses a hierarchical mesh graph and performs sequenti
To train Hi-LAM use
```
-python train_model.py --model hi_lam --graph hierarchical ...
+python -m neural_lam.train_model --model hi_lam --graph hierarchical ...
```
### Hi-LAM-Parallel
@@ -165,13 +165,13 @@ Not included in the paper as initial experiments showed worse results than Hi-LA
To train Hi-LAM-Parallel use
```
-python train_model.py --model hi_lam_parallel --graph hierarchical ...
+python -m neural_lam.train_model --model hi_lam_parallel --graph hierarchical ...
```
Checkpoint files for our models trained on the MEPS data are available upon request.
## Evaluate Models
-Evaluation is also done using `train_model.py`, but using the `--eval` option.
+Evaluation is also done using `python -m neural_lam.train_model`, but using the `--eval` option.
Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data.
Most of the training options are also relevant for evaluation (not `ar_steps`, evaluation always unrolls full forecasts).
Some options specifically important for evaluation are:
diff --git a/create_grid_features.py b/neural_lam/create_grid_features.py
similarity index 100%
rename from create_grid_features.py
rename to neural_lam/create_grid_features.py
diff --git a/create_mesh.py b/neural_lam/create_mesh.py
similarity index 100%
rename from create_mesh.py
rename to neural_lam/create_mesh.py
diff --git a/create_parameter_weights.py b/neural_lam/create_parameter_weights.py
similarity index 100%
rename from create_parameter_weights.py
rename to neural_lam/create_parameter_weights.py
diff --git a/train_model.py b/neural_lam/train_model.py
similarity index 100%
rename from train_model.py
rename to neural_lam/train_model.py
diff --git a/pyproject.toml b/pyproject.toml
index b513a258..ccc8953f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,32 @@
+[project]
+# PEP 621 project metadata
+# See https://www.python.org/dev/peps/pep-0621/
+dependencies = [
+ "numpy>=1.24.2",
+ "wandb>=0.13.10",
+ "matplotlib>=3.7.0",
+ "scipy>=1.10.0",
+ "pytorch-lightning>=2.0.3",
+ "shapely>=2.0.1",
+ "networkx>=3.0",
+ "Cartopy>=0.22.0",
+ "pyproj>=3.4.1",
+ "tueplots>=0.0.8",
+ "plotly>=5.15.0",
+ "pre-commit>=2.15.0",
+]
+requires-python = ">=3.10"
+name = "neural-lam"
+version = "0.1.0"
+description = "Neural Weather Prediction for Limited Area Modeling"
+authors = [
+ {name = "Joel Oskarsson", email = "joel.oskarsson@liu.se"},
+ {name = "Simon Adamov", email = "simon.adamov@meteoswiss.ch"},
+ {name = "Leif Denby", email = "lcd@dmi.dk"},
+]
+readme = "README.md"
+license = {text = "MIT"}
+
[tool.black]
line-length = 80
@@ -63,3 +92,10 @@ max-statements=100 # Allow for some more involved functions
allow-any-import-level="neural_lam"
[tool.pylint.SIMILARITIES]
min-similarity-lines=10
+
+
+[tool.pdm]
+distribution = true
+[build-system]
+requires = ["pdm-backend"]
+build-backend = "pdm.backend"
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index f381d54f..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,15 +0,0 @@
-# for all
-numpy>=1.24.2
-wandb>=0.13.10
-matplotlib>=3.7.0
-scipy>=1.10.0
-pytorch-lightning>=2.0.3
-shapely>=2.0.1
-networkx>=3.0
-Cartopy>=0.22.0
-pyproj>=3.4.1
-tueplots>=0.0.8
-plotly>=5.15.0
-
-# for dev
-pre-commit>=2.15.0
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..2f88fa16
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,5 @@
+import neural_lam
+
+
+def test_import():
+ assert neural_lam is not None
From 0c68537b1bb8f2e73334fe87e3a4cb220091250f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 19:34:48 +0200
Subject: [PATCH 027/273] replace absolute imports with relative
---
neural_lam/create_parameter_weights.py | 4 ++--
neural_lam/interaction_net.py | 2 +-
neural_lam/models/ar_model.py | 2 +-
neural_lam/models/base_graph_model.py | 6 +++---
neural_lam/models/base_hi_graph_model.py | 6 +++---
neural_lam/models/graph_lam.py | 6 +++---
neural_lam/models/hi_lam.py | 4 ++--
neural_lam/models/hi_lam_parallel.py | 4 ++--
neural_lam/train_model.py | 10 +++++-----
neural_lam/utils.py | 2 +-
neural_lam/vis.py | 2 +-
neural_lam/weather_dataset.py | 2 +-
12 files changed, 25 insertions(+), 25 deletions(-)
diff --git a/neural_lam/create_parameter_weights.py b/neural_lam/create_parameter_weights.py
index 494a5e81..ff420adf 100644
--- a/neural_lam/create_parameter_weights.py
+++ b/neural_lam/create_parameter_weights.py
@@ -8,8 +8,8 @@
from tqdm import tqdm
# First-party
-from neural_lam import constants
-from neural_lam.weather_dataset import WeatherDataset
+from . import constants
+from .weather_dataset import WeatherDataset
def main():
diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py
index 663f27e4..a81a3ab4 100644
--- a/neural_lam/interaction_net.py
+++ b/neural_lam/interaction_net.py
@@ -4,7 +4,7 @@
from torch import nn
# First-party
-from neural_lam import utils
+from . import utils
class InteractionNet(pyg.nn.MessagePassing):
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 7d0a8320..3da353a8 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -9,7 +9,7 @@
import wandb
# First-party
-from neural_lam import constants, metrics, utils, vis
+from .. import constants, metrics, utils, vis
class ARModel(pl.LightningModule):
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 256d4adc..77be82eb 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -2,9 +2,9 @@
import torch
# First-party
-from neural_lam import utils
-from neural_lam.interaction_net import InteractionNet
-from neural_lam.models.ar_model import ARModel
+from .. import utils
+from ..interaction_net import InteractionNet
+from .ar_model import ARModel
class BaseGraphModel(ARModel):
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index 8ce87030..92533f86 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -2,9 +2,9 @@
from torch import nn
# First-party
-from neural_lam import utils
-from neural_lam.interaction_net import InteractionNet
-from neural_lam.models.base_graph_model import BaseGraphModel
+from .. import utils
+from ..interaction_net import InteractionNet
+from .base_graph_model import BaseGraphModel
class BaseHiGraphModel(BaseGraphModel):
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index f767fba0..6bbe83cc 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -2,9 +2,9 @@
import torch_geometric as pyg
# First-party
-from neural_lam import utils
-from neural_lam.interaction_net import InteractionNet
-from neural_lam.models.base_graph_model import BaseGraphModel
+from . import utils
+from ..interaction_net import InteractionNet
+from .base_graph_model import BaseGraphModel
class GraphLAM(BaseGraphModel):
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index 4d7eb94c..df9d3cbb 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -2,8 +2,8 @@
from torch import nn
# First-party
-from neural_lam.interaction_net import InteractionNet
-from neural_lam.models.base_hi_graph_model import BaseHiGraphModel
+from ..interaction_net import InteractionNet
+from .base_hi_graph_model import BaseHiGraphModel
class HiLAM(BaseHiGraphModel):
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index 740824e1..d6dc27ee 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -3,8 +3,8 @@
import torch_geometric as pyg
# First-party
-from neural_lam.interaction_net import InteractionNet
-from neural_lam.models.base_hi_graph_model import BaseHiGraphModel
+from ..interaction_net import InteractionNet
+from .base_hi_graph_model import BaseHiGraphModel
class HiLAMParallel(BaseHiGraphModel):
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 96d21a3f..9120cd4b 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -9,11 +9,11 @@
from lightning_fabric.utilities import seed
# First-party
-from neural_lam import constants, utils
-from neural_lam.models.graph_lam import GraphLAM
-from neural_lam.models.hi_lam import HiLAM
-from neural_lam.models.hi_lam_parallel import HiLAMParallel
-from neural_lam.weather_dataset import WeatherDataset
+from . import constants, utils
+from .models.graph_lam import GraphLAM
+from .models.hi_lam import HiLAM
+from .models.hi_lam_parallel import HiLAMParallel
+from .weather_dataset import WeatherDataset
MODELS = {
"graph_lam": GraphLAM,
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 31715502..29d638fa 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -8,7 +8,7 @@
from tueplots import bundles, figsizes
# First-party
-from neural_lam import constants
+from . import constants
def load_dataset_stats(dataset_name, device="cpu"):
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index cef34a84..80616e3b 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -4,7 +4,7 @@
import numpy as np
# First-party
-from neural_lam import constants, utils
+from . import constants, utils
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index eeefc313..1c2f8fde 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -8,7 +8,7 @@
import torch
# First-party
-from neural_lam import constants, utils
+from . import constants, utils
class WeatherDataset(torch.utils.data.Dataset):
From 4b77be6cf0a74084c6f6f76d1e168ee421240292 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 19:35:45 +0200
Subject: [PATCH 028/273] simplify black config
---
pyproject.toml | 18 ------------------
1 file changed, 18 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index ccc8953f..0224dbeb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,25 +31,7 @@ license = {text = "MIT"}
line-length = 80
[tool.isort]
-default_section = "THIRDPARTY"
profile = "black"
-# Headings
-import_heading_stdlib = "Standard library"
-import_heading_thirdparty = "Third-party"
-import_heading_firstparty = "First-party"
-import_heading_localfolder = "Local"
-# Known modules to avoid misclassification
-known_standard_library = [
- # Add standard library modules that may be misclassified by isort
-]
-known_third_party = [
- # Add third-party modules that may be misclassified by isort
- "wandb",
-]
-known_first_party = [
- # Add first-party modules that may be misclassified by isort
- "neural_lam",
-]
[tool.flake8]
max-line-length = 80
From f2bae03e3a1ffb537aad45aa1cc678b86573daea Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 19:39:32 +0200
Subject: [PATCH 029/273] headers for import sections no longer needed
---
pyproject.toml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 0224dbeb..02213b0c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,13 +52,13 @@ skip = "requirements/*"
ignore = [
"create_mesh.py", # Disable linting for now, as major rework is planned/expected
]
-# Temporary fix for import neural_lam statements until set up as proper package
-init-hook='import sys; sys.path.append(".")'
+
[tool.pylint.TYPECHECK]
generated-members = [
"numpy.*",
"torch.*",
]
+
[tool.pylint.'MESSAGES CONTROL']
disable = [
"C0114", # 'missing-module-docstring', Do not require module docstrings
From 1d12b0d917e87cf0d834dc603da00a727f72c332 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 21:38:49 +0200
Subject: [PATCH 030/273] minor fixes
---
.github/workflows/pre-commit.yml | 1 -
.pre-commit-config.yaml | 17 +----------------
neural_lam/models/base_hi_graph_model.py | 12 ++++++------
plot_graph.py | 6 +-----
4 files changed, 8 insertions(+), 28 deletions(-)
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index 0ff792a3..29203557 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -18,4 +18,3 @@ jobs:
with:
python-version: 3.9
- uses: pre-commit/action@v2.0.3
-
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 28cd91b9..0547c6b9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -13,40 +13,25 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/codespell-project/codespell
- rev: 2.0.0
+ rev: v2.2.6
hooks:
- id: codespell
description: Check for spelling errors
- language: system
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
description: Format Python code
- language: system
- types_or: [python, pyi]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
description: Group and sort Python imports
- language: system
- types_or: [python, pyi, cython]
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
- language: system
- types: [python]
-
- - repo: https://github.com/pylint-dev/pylint
- rev: 2.0.0
- hooks:
- - id: pylint
- entry: pylint -rn -sn
- language: system
- types: [python]
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index 8ce87030..d9a4c676 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -179,9 +179,9 @@ def process_step(self, mesh_rep):
)
# Update node and edge vectors in lists
- mesh_rep_levels[level_l] = (
- new_node_rep # (B, num_mesh_nodes[l], d_h)
- )
+ mesh_rep_levels[
+ level_l
+ ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h)
# - PROCESSOR -
@@ -207,9 +207,9 @@ def process_step(self, mesh_rep):
new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep)
# Update node and edge vectors in lists
- mesh_rep_levels[level_l] = (
- new_node_rep # (B, num_mesh_nodes[l], d_h)
- )
+ mesh_rep_levels[
+ level_l
+ ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
# Return only bottom level representation
return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h)
diff --git a/plot_graph.py b/plot_graph.py
index 48427d5c..27e230e7 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -47,11 +47,7 @@ def main():
# Load graph data
hierarchical, graph_ldict = utils.load_graph(args.graph)
- (
- g2m_edge_index,
- m2g_edge_index,
- m2m_edge_index,
- ) = (
+ (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
graph_ldict["g2m_edge_index"],
graph_ldict["m2g_edge_index"],
graph_ldict["m2m_edge_index"],
From ad0accc6dafc1bcdbbb7833654b707d997553071 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 21:41:49 +0200
Subject: [PATCH 031/273] run on all branch pushes
---
.github/workflows/pre-commit.yml | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index 29203557..2f665631 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -1,15 +1,17 @@
name: Run pre-commit job
on:
- push:
+ # trigger on pushes to any branch, but not main
+ push:
+ branches-ignore:
+ - master
+ # and also on PRs to main
+ pull_request:
branches:
- - main
- pull_request:
- branches:
- - main
+ - main
jobs:
- pre-commit-job:
+ pre-commit-job:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
From 3e69502eb3c5b890b61552c9a6c1e8fe49f5149b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 21:42:54 +0200
Subject: [PATCH 032/273] rename action to "lint"
---
.github/workflows/pre-commit.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index 2f665631..14ab3c3b 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -1,4 +1,4 @@
-name: Run pre-commit job
+name: lint
on:
# trigger on pushes to any branch, but not main
From 681c7b1caf4345cf850736534154311c0ea0c661 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:32:27 +0200
Subject: [PATCH 033/273] add ci/cd test for imports
---
.github/workflows/ci-tests.yml | 37 ++++++++++++++++++++++++++++++++++
tests/test_base.py | 7 +++++++
2 files changed, 44 insertions(+)
create mode 100644 .github/workflows/ci-tests.yml
create mode 100644 tests/test_base.py
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
new file mode 100644
index 00000000..4b43f60d
--- /dev/null
+++ b/.github/workflows/ci-tests.yml
@@ -0,0 +1,37 @@
+# cicd workflow for running tests with pytest
+# needs to first install pdm, then install torch cpu manually and then install the package
+# then run the tests
+
+name: tests
+
+on: [push, pull_request]
+
+jobs:
+ tests:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+ - name: Set up Python 3.10
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.10
+
+ - name: Install pdm
+ run: |
+ python -m pip install pdm
+ pdm --version
+
+ - name: Install torch
+ run: |
+ python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+
+ - name: Install package (including dev dependencies)
+ run: |
+ pdm install
+ pdm install --dev
+
+ - name: Run tests
+ run: |
+ pdm run pytest
diff --git a/tests/test_base.py b/tests/test_base.py
new file mode 100644
index 00000000..c3a89171
--- /dev/null
+++ b/tests/test_base.py
@@ -0,0 +1,7 @@
+import neural_lam
+import neural_lam.train_model
+
+
+def test_import():
+ assert neural_lam is not None
+ assert neural_lam.train_model is not None
From 5ad02302de98267950aac271c91bf2b46a0f9ce5 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:36:34 +0200
Subject: [PATCH 034/273] py version must be quoted
---
.github/workflows/ci-tests.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
index 4b43f60d..09f049c3 100644
--- a/.github/workflows/ci-tests.yml
+++ b/.github/workflows/ci-tests.yml
@@ -16,7 +16,7 @@ jobs:
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
- python-version: 3.10
+ python-version: "3.10"
- name: Install pdm
run: |
From 35987e577be70107192c5c0da4f483d64acb13c3 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:38:43 +0200
Subject: [PATCH 035/273] fix torch install url
---
.github/workflows/ci-tests.yml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
index 09f049c3..47d15f5e 100644
--- a/.github/workflows/ci-tests.yml
+++ b/.github/workflows/ci-tests.yml
@@ -23,9 +23,9 @@ jobs:
python -m pip install pdm
pdm --version
- - name: Install torch
+ - name: Install torch (CPU)
run: |
- python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Install package (including dev dependencies)
run: |
From 148d7f6a6421a27e0e3865aa1886a318634ff4eb Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:40:45 +0200
Subject: [PATCH 036/273] use pdm in ci/cd
---
.github/workflows/ci-tests.yml | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
index 47d15f5e..34ae6d93 100644
--- a/.github/workflows/ci-tests.yml
+++ b/.github/workflows/ci-tests.yml
@@ -13,15 +13,11 @@ jobs:
- name: Checkout
uses: actions/checkout@v2
- - name: Set up Python 3.10
- uses: actions/setup-python@v2
+ - name: Install pdm
+ uses: pdm-project/setup-pdm@v4
with:
python-version: "3.10"
-
- - name: Install pdm
- run: |
- python -m pip install pdm
- pdm --version
+ cache: true
- name: Install torch (CPU)
run: |
From b912d1a62fbbd8c4b7fc08be8d690176aa047e6d Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:42:51 +0200
Subject: [PATCH 037/273] disable cache for now
---
.github/workflows/ci-tests.yml | 1 -
1 file changed, 1 deletion(-)
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
index 34ae6d93..0fff9f06 100644
--- a/.github/workflows/ci-tests.yml
+++ b/.github/workflows/ci-tests.yml
@@ -17,7 +17,6 @@ jobs:
uses: pdm-project/setup-pdm@v4
with:
python-version: "3.10"
- cache: true
- name: Install torch (CPU)
run: |
From b656445b4fea9ccf6c777da7599ed0b6dae5e006 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:45:57 +0200
Subject: [PATCH 038/273] check in lock file
---
pdm.lock | 1739 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 1739 insertions(+)
create mode 100644 pdm.lock
diff --git a/pdm.lock b/pdm.lock
new file mode 100644
index 00000000..65b43d2e
--- /dev/null
+++ b/pdm.lock
@@ -0,0 +1,1739 @@
+# This file is @generated by PDM.
+# It is not intended for manual editing.
+
+[metadata]
+groups = ["default"]
+strategy = ["cross_platform", "inherit_metadata"]
+lock_version = "4.4.1"
+content_hash = "sha256:1f465dd9fc7cac951a6e5f120e295be967d1be97e98a75f750adeae06040a64f"
+
+[[package]]
+name = "aiohttp"
+version = "3.9.5"
+requires_python = ">=3.8"
+summary = "Async http client/server framework (asyncio)"
+groups = ["default"]
+dependencies = [
+ "aiosignal>=1.1.2",
+ "async-timeout<5.0,>=4.0; python_version < \"3.11\"",
+ "attrs>=17.3.0",
+ "frozenlist>=1.1.1",
+ "multidict<7.0,>=4.5",
+ "yarl<2.0,>=1.0",
+]
+files = [
+ {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"},
+ {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"},
+ {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"},
+ {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"},
+ {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"},
+ {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"},
+ {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"},
+ {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"},
+ {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"},
+ {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"},
+ {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"},
+ {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"},
+ {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"},
+ {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"},
+ {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"},
+ {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"},
+]
+
+[[package]]
+name = "aiosignal"
+version = "1.3.1"
+requires_python = ">=3.7"
+summary = "aiosignal: a list of registered asynchronous callbacks"
+groups = ["default"]
+dependencies = [
+ "frozenlist>=1.1.0",
+]
+files = [
+ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
+ {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
+]
+
+[[package]]
+name = "async-timeout"
+version = "4.0.3"
+requires_python = ">=3.7"
+summary = "Timeout context manager for asyncio programs"
+groups = ["default"]
+marker = "python_version < \"3.11\""
+files = [
+ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
+ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
+]
+
+[[package]]
+name = "attrs"
+version = "23.2.0"
+requires_python = ">=3.7"
+summary = "Classes Without Boilerplate"
+groups = ["default"]
+files = [
+ {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
+ {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
+]
+
+[[package]]
+name = "cartopy"
+version = "0.23.0"
+requires_python = ">=3.9"
+summary = "A Python library for cartographic visualizations with Matplotlib"
+groups = ["default"]
+dependencies = [
+ "matplotlib>=3.5",
+ "numpy>=1.21",
+ "packaging>=20",
+ "pyproj>=3.3.1",
+ "pyshp>=2.3",
+ "shapely>=1.7",
+]
+files = [
+ {file = "Cartopy-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:374e66f816c3bafa48ffdbf6abaefa67063b405fac5f425f9be241cdf3498352"},
+ {file = "Cartopy-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2bae450c4c913796cad0b7ce05aa2fa78d1788de47989f0a03183397648e24be"},
+ {file = "Cartopy-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40437596e8ac5e74575eab822c661f4e725bd995cfd9e445069695fe9086b42"},
+ {file = "Cartopy-0.23.0-cp310-cp310-win_amd64.whl", hash = "sha256:3292d6d403137eed80d32014c2f28de6282bed8824213f4b4c2170f388b24a1b"},
+ {file = "Cartopy-0.23.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86b07b6794b616674e4e485b8574e9197bca54a4467d28dd01ae0bf178f8dc2b"},
+ {file = "Cartopy-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8dece2aa8d5ff7bf989ded6b5f07c980fb5bb772952bc7cdeab469738abdecee"},
+ {file = "Cartopy-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9dfd28352dc83d6b4e4cf85d84cb50fc4886d4c1510d61f4c7cf22477d1156f"},
+ {file = "Cartopy-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:b2671b5354e43220f8e1074e7fe30a8b9f71cb38407c78e51db9c97772f0320b"},
+ {file = "Cartopy-0.23.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:80b9fd666fd47f6370d29f7ad4e352828d54aaf688a03d0b83b51e141cfd77fa"},
+ {file = "Cartopy-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:43e36b8b7e7e373a5698757458fd28fafbbbf5f3ebbe2d378f6a5ec3993d6dc0"},
+ {file = "Cartopy-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550173b91155d4d81cd14b4892cb6cabe3dd32bd34feacaa1ec78c0e56287832"},
+ {file = "Cartopy-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:55219ee0fb069cc3254426e87382cde03546e86c3f7c6759f076823b1e3a44d9"},
+ {file = "Cartopy-0.23.0.tar.gz", hash = "sha256:231f37b35701f2ba31d94959cca75e6da04c2eea3a7f14ce1c75ee3b0eae7676"},
+]
+
+[[package]]
+name = "certifi"
+version = "2024.2.2"
+requires_python = ">=3.6"
+summary = "Python package for providing Mozilla's CA Bundle."
+groups = ["default"]
+files = [
+ {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"},
+ {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"},
+]
+
+[[package]]
+name = "cfgv"
+version = "3.4.0"
+requires_python = ">=3.8"
+summary = "Validate configuration and produce human readable error messages."
+groups = ["default"]
+files = [
+ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
+ {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.3.2"
+requires_python = ">=3.7.0"
+summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+groups = ["default"]
+files = [
+ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
+ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
+]
+
+[[package]]
+name = "click"
+version = "8.1.7"
+requires_python = ">=3.7"
+summary = "Composable command line interface toolkit"
+groups = ["default"]
+dependencies = [
+ "colorama; platform_system == \"Windows\"",
+]
+files = [
+ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
+ {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
+]
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+summary = "Cross-platform colored terminal text."
+groups = ["default"]
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+
+[[package]]
+name = "contourpy"
+version = "1.2.1"
+requires_python = ">=3.9"
+summary = "Python library for calculating contours of 2D quadrilateral grids"
+groups = ["default"]
+dependencies = [
+ "numpy>=1.20",
+]
+files = [
+ {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"},
+ {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"},
+ {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"},
+ {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"},
+ {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"},
+ {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"},
+ {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"},
+ {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"},
+ {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"},
+ {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"},
+ {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"},
+ {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"},
+ {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"},
+ {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"},
+ {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"},
+ {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"},
+ {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"},
+ {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"},
+ {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"},
+ {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"},
+ {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"},
+ {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"},
+]
+
+[[package]]
+name = "cycler"
+version = "0.12.1"
+requires_python = ">=3.8"
+summary = "Composable style cycles"
+groups = ["default"]
+files = [
+ {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"},
+ {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"},
+]
+
+[[package]]
+name = "distlib"
+version = "0.3.8"
+summary = "Distribution utilities"
+groups = ["default"]
+files = [
+ {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
+ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
+]
+
+[[package]]
+name = "docker-pycreds"
+version = "0.4.0"
+summary = "Python bindings for the docker credentials store API"
+groups = ["default"]
+dependencies = [
+ "six>=1.4.0",
+]
+files = [
+ {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"},
+ {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"},
+]
+
+[[package]]
+name = "filelock"
+version = "3.14.0"
+requires_python = ">=3.8"
+summary = "A platform independent file lock."
+groups = ["default"]
+files = [
+ {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
+ {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
+]
+
+[[package]]
+name = "fonttools"
+version = "4.51.0"
+requires_python = ">=3.8"
+summary = "Tools to manipulate font files"
+groups = ["default"]
+files = [
+ {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74"},
+ {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308"},
+ {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037"},
+ {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716"},
+ {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438"},
+ {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039"},
+ {file = "fonttools-4.51.0-cp310-cp310-win32.whl", hash = "sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77"},
+ {file = "fonttools-4.51.0-cp310-cp310-win_amd64.whl", hash = "sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b"},
+ {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74"},
+ {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2"},
+ {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f"},
+ {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097"},
+ {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0"},
+ {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1"},
+ {file = "fonttools-4.51.0-cp311-cp311-win32.whl", hash = "sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034"},
+ {file = "fonttools-4.51.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1"},
+ {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba"},
+ {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc"},
+ {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a"},
+ {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2"},
+ {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671"},
+ {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5"},
+ {file = "fonttools-4.51.0-cp312-cp312-win32.whl", hash = "sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15"},
+ {file = "fonttools-4.51.0-cp312-cp312-win_amd64.whl", hash = "sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e"},
+ {file = "fonttools-4.51.0-py3-none-any.whl", hash = "sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f"},
+ {file = "fonttools-4.51.0.tar.gz", hash = "sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68"},
+]
+
+[[package]]
+name = "frozenlist"
+version = "1.4.1"
+requires_python = ">=3.8"
+summary = "A list-like structure which implements collections.abc.MutableSequence"
+groups = ["default"]
+files = [
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"},
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"},
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"},
+ {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"},
+ {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"},
+ {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"},
+ {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"},
+ {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"},
+ {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"},
+ {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"},
+ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
+]
+
+[[package]]
+name = "fsspec"
+version = "2024.3.1"
+requires_python = ">=3.8"
+summary = "File-system specification"
+groups = ["default"]
+files = [
+ {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
+ {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
+]
+
+[[package]]
+name = "fsspec"
+version = "2024.3.1"
+extras = ["http"]
+requires_python = ">=3.8"
+summary = "File-system specification"
+groups = ["default"]
+dependencies = [
+ "aiohttp!=4.0.0a0,!=4.0.0a1",
+ "fsspec==2024.3.1",
+]
+files = [
+ {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
+ {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
+]
+
+[[package]]
+name = "gitdb"
+version = "4.0.11"
+requires_python = ">=3.7"
+summary = "Git Object Database"
+groups = ["default"]
+dependencies = [
+ "smmap<6,>=3.0.1",
+]
+files = [
+ {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"},
+ {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"},
+]
+
+[[package]]
+name = "gitpython"
+version = "3.1.43"
+requires_python = ">=3.7"
+summary = "GitPython is a Python library used to interact with Git repositories"
+groups = ["default"]
+dependencies = [
+ "gitdb<5,>=4.0.1",
+]
+files = [
+ {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"},
+ {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"},
+]
+
+[[package]]
+name = "identify"
+version = "2.5.36"
+requires_python = ">=3.8"
+summary = "File identification library for Python"
+groups = ["default"]
+files = [
+ {file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"},
+ {file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"},
+]
+
+[[package]]
+name = "idna"
+version = "3.7"
+requires_python = ">=3.5"
+summary = "Internationalized Domain Names in Applications (IDNA)"
+groups = ["default"]
+files = [
+ {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"},
+ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
+]
+
+[[package]]
+name = "intel-openmp"
+version = "2021.4.0"
+summary = "Intel® OpenMP* Runtime Library"
+groups = ["default"]
+marker = "platform_system == \"Windows\""
+files = [
+ {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
+]
+
+[[package]]
+name = "jinja2"
+version = "3.1.4"
+requires_python = ">=3.7"
+summary = "A very fast and expressive template engine."
+groups = ["default"]
+dependencies = [
+ "MarkupSafe>=2.0",
+]
+files = [
+ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"},
+ {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"},
+]
+
+[[package]]
+name = "kiwisolver"
+version = "1.4.5"
+requires_python = ">=3.7"
+summary = "A fast implementation of the Cassowary constraint solver"
+groups = ["default"]
+files = [
+ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"},
+ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
+]
+
+[[package]]
+name = "lightning-utilities"
+version = "0.11.2"
+requires_python = ">=3.8"
+summary = "Lightning toolbox for across the our ecosystem."
+groups = ["default"]
+dependencies = [
+ "packaging>=17.1",
+ "setuptools",
+ "typing-extensions",
+]
+files = [
+ {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"},
+ {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"},
+]
+
+[[package]]
+name = "markupsafe"
+version = "2.1.5"
+requires_python = ">=3.7"
+summary = "Safely add untrusted strings to HTML/XML markup."
+groups = ["default"]
+files = [
+ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"},
+ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
+]
+
+[[package]]
+name = "matplotlib"
+version = "3.8.4"
+requires_python = ">=3.9"
+summary = "Python plotting package"
+groups = ["default"]
+dependencies = [
+ "contourpy>=1.0.1",
+ "cycler>=0.10",
+ "fonttools>=4.22.0",
+ "kiwisolver>=1.3.1",
+ "numpy>=1.21",
+ "packaging>=20.0",
+ "pillow>=8",
+ "pyparsing>=2.3.1",
+ "python-dateutil>=2.7",
+]
+files = [
+ {file = "matplotlib-3.8.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014"},
+ {file = "matplotlib-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106"},
+ {file = "matplotlib-3.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10"},
+ {file = "matplotlib-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0"},
+ {file = "matplotlib-3.8.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef"},
+ {file = "matplotlib-3.8.4-cp310-cp310-win_amd64.whl", hash = "sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338"},
+ {file = "matplotlib-3.8.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661"},
+ {file = "matplotlib-3.8.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c"},
+ {file = "matplotlib-3.8.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa"},
+ {file = "matplotlib-3.8.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71"},
+ {file = "matplotlib-3.8.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b"},
+ {file = "matplotlib-3.8.4-cp311-cp311-win_amd64.whl", hash = "sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae"},
+ {file = "matplotlib-3.8.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616"},
+ {file = "matplotlib-3.8.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732"},
+ {file = "matplotlib-3.8.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb"},
+ {file = "matplotlib-3.8.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30"},
+ {file = "matplotlib-3.8.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25"},
+ {file = "matplotlib-3.8.4-cp312-cp312-win_amd64.whl", hash = "sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a"},
+ {file = "matplotlib-3.8.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35"},
+ {file = "matplotlib-3.8.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f"},
+ {file = "matplotlib-3.8.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94"},
+ {file = "matplotlib-3.8.4.tar.gz", hash = "sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea"},
+]
+
+[[package]]
+name = "mkl"
+version = "2021.4.0"
+summary = "Intel® oneAPI Math Kernel Library"
+groups = ["default"]
+marker = "platform_system == \"Windows\""
+dependencies = [
+ "intel-openmp==2021.*",
+ "tbb==2021.*",
+]
+files = [
+ {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"},
+ {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"},
+ {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"},
+ {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"},
+ {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"},
+]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+summary = "Python library for arbitrary-precision floating-point arithmetic"
+groups = ["default"]
+files = [
+ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
+ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
+]
+
+[[package]]
+name = "multidict"
+version = "6.0.5"
+requires_python = ">=3.7"
+summary = "multidict implementation"
+groups = ["default"]
+files = [
+ {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"},
+ {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"},
+ {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"},
+ {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"},
+ {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"},
+ {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"},
+ {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"},
+ {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"},
+ {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"},
+ {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"},
+ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
+]
+
+[[package]]
+name = "networkx"
+version = "3.3"
+requires_python = ">=3.10"
+summary = "Python package for creating and manipulating graphs and networks"
+groups = ["default"]
+files = [
+ {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"},
+ {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"},
+]
+
+[[package]]
+name = "nodeenv"
+version = "1.8.0"
+requires_python = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
+summary = "Node.js virtual environment builder"
+groups = ["default"]
+dependencies = [
+ "setuptools",
+]
+files = [
+ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
+ {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
+]
+
+[[package]]
+name = "numpy"
+version = "1.26.4"
+requires_python = ">=3.9"
+summary = "Fundamental package for array computing in Python"
+groups = ["default"]
+files = [
+ {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
+ {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
+ {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
+ {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
+ {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
+ {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
+ {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
+ {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
+ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
+]
+
+[[package]]
+name = "nvidia-cublas-cu12"
+version = "12.1.3.1"
+requires_python = ">=3"
+summary = "CUBLAS native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"},
+ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"},
+]
+
+[[package]]
+name = "nvidia-cuda-cupti-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "CUDA profiling tools runtime libs."
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"},
+ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"},
+]
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "NVRTC native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"},
+]
+
+[[package]]
+name = "nvidia-cuda-runtime-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "CUDA Runtime native Libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"},
+ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"},
+]
+
+[[package]]
+name = "nvidia-cudnn-cu12"
+version = "8.9.2.26"
+requires_python = ">=3"
+summary = "cuDNN runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+dependencies = [
+ "nvidia-cublas-cu12",
+]
+files = [
+ {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"},
+]
+
+[[package]]
+name = "nvidia-cufft-cu12"
+version = "11.0.2.54"
+requires_python = ">=3"
+summary = "CUFFT native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"},
+ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"},
+]
+
+[[package]]
+name = "nvidia-curand-cu12"
+version = "10.3.2.106"
+requires_python = ">=3"
+summary = "CURAND native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"},
+ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"},
+]
+
+[[package]]
+name = "nvidia-cusolver-cu12"
+version = "11.4.5.107"
+requires_python = ">=3"
+summary = "CUDA solver native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+dependencies = [
+ "nvidia-cublas-cu12",
+ "nvidia-cusparse-cu12",
+ "nvidia-nvjitlink-cu12",
+]
+files = [
+ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"},
+ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"},
+]
+
+[[package]]
+name = "nvidia-cusparse-cu12"
+version = "12.1.0.106"
+requires_python = ">=3"
+summary = "CUSPARSE native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+dependencies = [
+ "nvidia-nvjitlink-cu12",
+]
+files = [
+ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"},
+ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"},
+]
+
+[[package]]
+name = "nvidia-nccl-cu12"
+version = "2.20.5"
+requires_python = ">=3"
+summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"},
+ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"},
+]
+
+[[package]]
+name = "nvidia-nvjitlink-cu12"
+version = "12.4.127"
+requires_python = ">=3"
+summary = "Nvidia JIT LTO Library"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
+]
+
+[[package]]
+name = "nvidia-nvtx-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "NVIDIA Tools Extension"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"},
+ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
+]
+
+[[package]]
+name = "packaging"
+version = "24.0"
+requires_python = ">=3.7"
+summary = "Core utilities for Python packages"
+groups = ["default"]
+files = [
+ {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
+ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
+]
+
+[[package]]
+name = "pillow"
+version = "10.3.0"
+requires_python = ">=3.8"
+summary = "Python Imaging Library (Fork)"
+groups = ["default"]
+files = [
+ {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"},
+ {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"},
+ {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"},
+ {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"},
+ {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"},
+ {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"},
+ {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"},
+ {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"},
+ {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"},
+ {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"},
+ {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"},
+ {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"},
+ {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"},
+ {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"},
+ {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"},
+ {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"},
+ {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"},
+ {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"},
+ {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"},
+ {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"},
+ {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"},
+ {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"},
+]
+
+[[package]]
+name = "platformdirs"
+version = "4.2.1"
+requires_python = ">=3.8"
+summary = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
+groups = ["default"]
+files = [
+ {file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"},
+ {file = "platformdirs-4.2.1.tar.gz", hash = "sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf"},
+]
+
+[[package]]
+name = "plotly"
+version = "5.22.0"
+requires_python = ">=3.8"
+summary = "An open-source, interactive data visualization library for Python"
+groups = ["default"]
+dependencies = [
+ "packaging",
+ "tenacity>=6.2.0",
+]
+files = [
+ {file = "plotly-5.22.0-py3-none-any.whl", hash = "sha256:68fc1901f098daeb233cc3dd44ec9dc31fb3ca4f4e53189344199c43496ed006"},
+ {file = "plotly-5.22.0.tar.gz", hash = "sha256:859fdadbd86b5770ae2466e542b761b247d1c6b49daed765b95bb8c7063e7469"},
+]
+
+[[package]]
+name = "pre-commit"
+version = "3.7.1"
+requires_python = ">=3.9"
+summary = "A framework for managing and maintaining multi-language pre-commit hooks."
+groups = ["default"]
+dependencies = [
+ "cfgv>=2.0.0",
+ "identify>=1.0.0",
+ "nodeenv>=0.11.1",
+ "pyyaml>=5.1",
+ "virtualenv>=20.10.0",
+]
+files = [
+ {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"},
+ {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"},
+]
+
+[[package]]
+name = "pretty-errors"
+version = "1.2.25"
+summary = "Prettifies Python exception output to make it legible."
+groups = ["default"]
+dependencies = [
+ "colorama",
+]
+files = [
+ {file = "pretty_errors-1.2.25-py3-none-any.whl", hash = "sha256:8ce68ccd99e0f2a099265c8c1f1c23b7c60a15d69bb08816cb336e237d5dc983"},
+ {file = "pretty_errors-1.2.25.tar.gz", hash = "sha256:a16ba5c752c87c263bf92f8b4b58624e3b1e29271a9391f564f12b86e93c6755"},
+]
+
+[[package]]
+name = "protobuf"
+version = "4.25.3"
+requires_python = ">=3.8"
+summary = ""
+groups = ["default"]
+marker = "python_version > \"3.9\" or sys_platform != \"linux\""
+files = [
+ {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"},
+ {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"},
+ {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"},
+ {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"},
+ {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
+]
+
+[[package]]
+name = "psutil"
+version = "5.9.8"
+requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
+summary = "Cross-platform lib for process and system monitoring in Python."
+groups = ["default"]
+files = [
+ {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"},
+ {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"},
+ {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"},
+ {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"},
+ {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"},
+ {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"},
+ {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"},
+]
+
+[[package]]
+name = "pyparsing"
+version = "3.1.2"
+requires_python = ">=3.6.8"
+summary = "pyparsing module - Classes and methods to define and execute parsing grammars"
+groups = ["default"]
+files = [
+ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
+ {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
+]
+
+[[package]]
+name = "pyproj"
+version = "3.6.1"
+requires_python = ">=3.9"
+summary = "Python interface to PROJ (cartographic projections and coordinate transformations library)"
+groups = ["default"]
+dependencies = [
+ "certifi",
+]
+files = [
+ {file = "pyproj-3.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ab7aa4d9ff3c3acf60d4b285ccec134167a948df02347585fdd934ebad8811b4"},
+ {file = "pyproj-3.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bc0472302919e59114aa140fd7213c2370d848a7249d09704f10f5b062031fe"},
+ {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5279586013b8d6582e22b6f9e30c49796966770389a9d5b85e25a4223286cd3f"},
+ {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80fafd1f3eb421694857f254a9bdbacd1eb22fc6c24ca74b136679f376f97d35"},
+ {file = "pyproj-3.6.1-cp310-cp310-win32.whl", hash = "sha256:c41e80ddee130450dcb8829af7118f1ab69eaf8169c4bf0ee8d52b72f098dc2f"},
+ {file = "pyproj-3.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:db3aedd458e7f7f21d8176f0a1d924f1ae06d725228302b872885a1c34f3119e"},
+ {file = "pyproj-3.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ebfbdbd0936e178091309f6cd4fcb4decd9eab12aa513cdd9add89efa3ec2882"},
+ {file = "pyproj-3.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:447db19c7efad70ff161e5e46a54ab9cc2399acebb656b6ccf63e4bc4a04b97a"},
+ {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7e13c40183884ec7f94eb8e0f622f08f1d5716150b8d7a134de48c6110fee85"},
+ {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65ad699e0c830e2b8565afe42bd58cc972b47d829b2e0e48ad9638386d994915"},
+ {file = "pyproj-3.6.1-cp311-cp311-win32.whl", hash = "sha256:8b8acc31fb8702c54625f4d5a2a6543557bec3c28a0ef638778b7ab1d1772132"},
+ {file = "pyproj-3.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:38a3361941eb72b82bd9a18f60c78b0df8408416f9340521df442cebfc4306e2"},
+ {file = "pyproj-3.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1e9fbaf920f0f9b4ee62aab832be3ae3968f33f24e2e3f7fbb8c6728ef1d9746"},
+ {file = "pyproj-3.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d227a865356f225591b6732430b1d1781e946893789a609bb34f59d09b8b0f8"},
+ {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83039e5ae04e5afc974f7d25ee0870a80a6bd6b7957c3aca5613ccbe0d3e72bf"},
+ {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb059ba3bced6f6725961ba758649261d85ed6ce670d3e3b0a26e81cf1aa8d"},
+ {file = "pyproj-3.6.1-cp312-cp312-win32.whl", hash = "sha256:2d6ff73cc6dbbce3766b6c0bce70ce070193105d8de17aa2470009463682a8eb"},
+ {file = "pyproj-3.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:7a27151ddad8e1439ba70c9b4b2b617b290c39395fa9ddb7411ebb0eb86d6fb0"},
+ {file = "pyproj-3.6.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd93c1a0c6c4aedc77c0fe275a9f2aba4d59b8acf88cebfc19fe3c430cfabf4f"},
+ {file = "pyproj-3.6.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6420ea8e7d2a88cb148b124429fba8cd2e0fae700a2d96eab7083c0928a85110"},
+ {file = "pyproj-3.6.1.tar.gz", hash = "sha256:44aa7c704c2b7d8fb3d483bbf75af6cb2350d30a63b144279a09b75fead501bf"},
+]
+
+[[package]]
+name = "pyshp"
+version = "2.3.1"
+requires_python = ">=2.7"
+summary = "Pure Python read/write support for ESRI Shapefile format"
+groups = ["default"]
+files = [
+ {file = "pyshp-2.3.1-py2.py3-none-any.whl", hash = "sha256:67024c0ccdc352ba5db777c4e968483782dfa78f8e200672a90d2d30fd8b7b49"},
+ {file = "pyshp-2.3.1.tar.gz", hash = "sha256:4caec82fd8dd096feba8217858068bacb2a3b5950f43c048c6dc32a3489d5af1"},
+]
+
+[[package]]
+name = "python-dateutil"
+version = "2.9.0.post0"
+requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
+summary = "Extensions to the standard Python datetime module"
+groups = ["default"]
+dependencies = [
+ "six>=1.5",
+]
+files = [
+ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
+ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
+]
+
+[[package]]
+name = "pytorch-lightning"
+version = "2.2.4"
+requires_python = ">=3.8"
+summary = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
+groups = ["default"]
+dependencies = [
+ "PyYAML>=5.4",
+ "fsspec[http]>=2022.5.0",
+ "lightning-utilities>=0.8.0",
+ "numpy>=1.17.2",
+ "packaging>=20.0",
+ "torch>=1.13.0",
+ "torchmetrics>=0.7.0",
+ "tqdm>=4.57.0",
+ "typing-extensions>=4.4.0",
+]
+files = [
+ {file = "pytorch-lightning-2.2.4.tar.gz", hash = "sha256:525b04ebad9900c3e3c2a12b3b462fe4f61ebe11fdb694716c3209f05b9b0fa8"},
+ {file = "pytorch_lightning-2.2.4-py3-none-any.whl", hash = "sha256:fd91d47e983a2cd743c5c8c3c3795bbd0f3b69d24be2172a2f9012d930701ff2"},
+]
+
+[[package]]
+name = "pyyaml"
+version = "6.0.1"
+requires_python = ">=3.6"
+summary = "YAML parser and emitter for Python"
+groups = ["default"]
+files = [
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+ {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+ {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
+ {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
+ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
+]
+
+[[package]]
+name = "requests"
+version = "2.31.0"
+requires_python = ">=3.7"
+summary = "Python HTTP for Humans."
+groups = ["default"]
+dependencies = [
+ "certifi>=2017.4.17",
+ "charset-normalizer<4,>=2",
+ "idna<4,>=2.5",
+ "urllib3<3,>=1.21.1",
+]
+files = [
+ {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
+ {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
+]
+
+[[package]]
+name = "scipy"
+version = "1.13.0"
+requires_python = ">=3.9"
+summary = "Fundamental algorithms for scientific computing in Python"
+groups = ["default"]
+dependencies = [
+ "numpy<2.3,>=1.22.4",
+]
+files = [
+ {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
+ {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
+ {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
+ {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
+ {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
+ {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
+ {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
+ {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
+ {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
+ {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
+ {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
+ {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
+ {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
+ {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
+ {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
+ {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
+ {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
+ {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
+ {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
+]
+
+[[package]]
+name = "sentry-sdk"
+version = "2.1.1"
+requires_python = ">=3.6"
+summary = "Python client for Sentry (https://sentry.io)"
+groups = ["default"]
+dependencies = [
+ "certifi",
+ "urllib3>=1.26.11",
+]
+files = [
+ {file = "sentry_sdk-2.1.1-py2.py3-none-any.whl", hash = "sha256:99aeb78fb76771513bd3b2829d12613130152620768d00cd3e45ac00cb17950f"},
+ {file = "sentry_sdk-2.1.1.tar.gz", hash = "sha256:95d8c0bb41c8b0bc37ab202c2c4a295bb84398ee05f4cdce55051cd75b926ec1"},
+]
+
+[[package]]
+name = "setproctitle"
+version = "1.3.3"
+requires_python = ">=3.7"
+summary = "A Python module to customize the process title"
+groups = ["default"]
+files = [
+ {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"},
+ {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbbd6c7de0771c84b4aa30e70b409565eb1fc13627a723ca6be774ed6b9d9fa3"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c05ac48ef16ee013b8a326c63e4610e2430dbec037ec5c5b58fcced550382b74"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1342f4fdb37f89d3e3c1c0a59d6ddbedbde838fff5c51178a7982993d238fe4f"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc74e84fdfa96821580fb5e9c0b0777c1c4779434ce16d3d62a9c4d8c710df39"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9617b676b95adb412bb69645d5b077d664b6882bb0d37bfdafbbb1b999568d85"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6a249415f5bb88b5e9e8c4db47f609e0bf0e20a75e8d744ea787f3092ba1f2d0"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:38da436a0aaace9add67b999eb6abe4b84397edf4a78ec28f264e5b4c9d53cd5"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:da0d57edd4c95bf221b2ebbaa061e65b1788f1544977288bdf95831b6e44e44d"},
+ {file = "setproctitle-1.3.3-cp310-cp310-win32.whl", hash = "sha256:a1fcac43918b836ace25f69b1dca8c9395253ad8152b625064415b1d2f9be4fb"},
+ {file = "setproctitle-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:200620c3b15388d7f3f97e0ae26599c0c378fdf07ae9ac5a13616e933cbd2086"},
+ {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:334f7ed39895d692f753a443102dd5fed180c571eb6a48b2a5b7f5b3564908c8"},
+ {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:950f6476d56ff7817a8fed4ab207727fc5260af83481b2a4b125f32844df513a"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:195c961f54a09eb2acabbfc90c413955cf16c6e2f8caa2adbf2237d1019c7dd8"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f05e66746bf9fe6a3397ec246fe481096664a9c97eb3fea6004735a4daf867fd"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5901a31012a40ec913265b64e48c2a4059278d9f4e6be628441482dd13fb8b5"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64286f8a995f2cd934082b398fc63fca7d5ffe31f0e27e75b3ca6b4efda4e353"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:184239903bbc6b813b1a8fc86394dc6ca7d20e2ebe6f69f716bec301e4b0199d"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:664698ae0013f986118064b6676d7dcd28fefd0d7d5a5ae9497cbc10cba48fa5"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e5119a211c2e98ff18b9908ba62a3bd0e3fabb02a29277a7232a6fb4b2560aa0"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:417de6b2e214e837827067048f61841f5d7fc27926f2e43954567094051aff18"},
+ {file = "setproctitle-1.3.3-cp311-cp311-win32.whl", hash = "sha256:6a143b31d758296dc2f440175f6c8e0b5301ced3b0f477b84ca43cdcf7f2f476"},
+ {file = "setproctitle-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a680d62c399fa4b44899094027ec9a1bdaf6f31c650e44183b50d4c4d0ccc085"},
+ {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4460795a8a7a391e3567b902ec5bdf6c60a47d791c3b1d27080fc203d11c9dc"},
+ {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bdfd7254745bb737ca1384dee57e6523651892f0ea2a7344490e9caefcc35e64"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477d3da48e216d7fc04bddab67b0dcde633e19f484a146fd2a34bb0e9dbb4a1e"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ab2900d111e93aff5df9fddc64cf51ca4ef2c9f98702ce26524f1acc5a786ae7"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:088b9efc62d5aa5d6edf6cba1cf0c81f4488b5ce1c0342a8b67ae39d64001120"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6d50252377db62d6a0bb82cc898089916457f2db2041e1d03ce7fadd4a07381"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:87e668f9561fd3a457ba189edfc9e37709261287b52293c115ae3487a24b92f6"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:287490eb90e7a0ddd22e74c89a92cc922389daa95babc833c08cf80c84c4df0a"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:4fe1c49486109f72d502f8be569972e27f385fe632bd8895f4730df3c87d5ac8"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4a6ba2494a6449b1f477bd3e67935c2b7b0274f2f6dcd0f7c6aceae10c6c6ba3"},
+ {file = "setproctitle-1.3.3-cp312-cp312-win32.whl", hash = "sha256:2df2b67e4b1d7498632e18c56722851ba4db5d6a0c91aaf0fd395111e51cdcf4"},
+ {file = "setproctitle-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:f38d48abc121263f3b62943f84cbaede05749047e428409c2c199664feb6abc7"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6b9e62ddb3db4b5205c0321dd69a406d8af9ee1693529d144e86bd43bcb4b6c0"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e3b99b338598de0bd6b2643bf8c343cf5ff70db3627af3ca427a5e1a1a90dd9"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ae9a02766dad331deb06855fb7a6ca15daea333b3967e214de12cfae8f0ef5"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:200ede6fd11233085ba9b764eb055a2a191fb4ffb950c68675ac53c874c22e20"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0d3a953c50776751e80fe755a380a64cb14d61e8762bd43041ab3f8cc436092f"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5e08e232b78ba3ac6bc0d23ce9e2bee8fad2be391b7e2da834fc9a45129eb87"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1da82c3e11284da4fcbf54957dafbf0655d2389cd3d54e4eaba636faf6d117a"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:aeaa71fb9568ebe9b911ddb490c644fbd2006e8c940f21cb9a1e9425bd709574"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:59335d000c6250c35989394661eb6287187854e94ac79ea22315469ee4f4c244"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3ba57029c9c50ecaf0c92bb127224cc2ea9fda057b5d99d3f348c9ec2855ad3"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d876d355c53d975c2ef9c4f2487c8f83dad6aeaaee1b6571453cb0ee992f55f6"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:224602f0939e6fb9d5dd881be1229d485f3257b540f8a900d4271a2c2aa4e5f4"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d7f27e0268af2d7503386e0e6be87fb9b6657afd96f5726b733837121146750d"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5e7266498cd31a4572378c61920af9f6b4676a73c299fce8ba93afd694f8ae7"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33c5609ad51cd99d388e55651b19148ea99727516132fb44680e1f28dd0d1de9"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:eae8988e78192fd1a3245a6f4f382390b61bce6cfcc93f3809726e4c885fa68d"},
+ {file = "setproctitle-1.3.3.tar.gz", hash = "sha256:c913e151e7ea01567837ff037a23ca8740192880198b7fbb90b16d181607caae"},
+]
+
+[[package]]
+name = "setuptools"
+version = "69.5.1"
+requires_python = ">=3.8"
+summary = "Easily download, build, install, upgrade, and uninstall Python packages"
+groups = ["default"]
+files = [
+ {file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"},
+ {file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"},
+]
+
+[[package]]
+name = "shapely"
+version = "2.0.4"
+requires_python = ">=3.7"
+summary = "Manipulation and analysis of geometric objects"
+groups = ["default"]
+dependencies = [
+ "numpy<3,>=1.14",
+]
+files = [
+ {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:011b77153906030b795791f2fdfa2d68f1a8d7e40bce78b029782ade3afe4f2f"},
+ {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9831816a5d34d5170aa9ed32a64982c3d6f4332e7ecfe62dc97767e163cb0b17"},
+ {file = "shapely-2.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5c4849916f71dc44e19ed370421518c0d86cf73b26e8656192fcfcda08218fbd"},
+ {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:841f93a0e31e4c64d62ea570d81c35de0f6cea224568b2430d832967536308e6"},
+ {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b4431f522b277c79c34b65da128029a9955e4481462cbf7ebec23aab61fc58"},
+ {file = "shapely-2.0.4-cp310-cp310-win32.whl", hash = "sha256:92a41d936f7d6743f343be265ace93b7c57f5b231e21b9605716f5a47c2879e7"},
+ {file = "shapely-2.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:30982f79f21bb0ff7d7d4a4e531e3fcaa39b778584c2ce81a147f95be1cd58c9"},
+ {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de0205cb21ad5ddaef607cda9a3191eadd1e7a62a756ea3a356369675230ac35"},
+ {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7d56ce3e2a6a556b59a288771cf9d091470116867e578bebced8bfc4147fbfd7"},
+ {file = "shapely-2.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:58b0ecc505bbe49a99551eea3f2e8a9b3b24b3edd2a4de1ac0dc17bc75c9ec07"},
+ {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:790a168a808bd00ee42786b8ba883307c0e3684ebb292e0e20009588c426da47"},
+ {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4310b5494271e18580d61022c0857eb85d30510d88606fa3b8314790df7f367d"},
+ {file = "shapely-2.0.4-cp311-cp311-win32.whl", hash = "sha256:63f3a80daf4f867bd80f5c97fbe03314348ac1b3b70fb1c0ad255a69e3749879"},
+ {file = "shapely-2.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:c52ed79f683f721b69a10fb9e3d940a468203f5054927215586c5d49a072de8d"},
+ {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5bbd974193e2cc274312da16b189b38f5f128410f3377721cadb76b1e8ca5328"},
+ {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:41388321a73ba1a84edd90d86ecc8bfed55e6a1e51882eafb019f45895ec0f65"},
+ {file = "shapely-2.0.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0776c92d584f72f1e584d2e43cfc5542c2f3dd19d53f70df0900fda643f4bae6"},
+ {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c75c98380b1ede1cae9a252c6dc247e6279403fae38c77060a5e6186c95073ac"},
+ {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3e700abf4a37b7b8b90532fa6ed5c38a9bfc777098bc9fbae5ec8e618ac8f30"},
+ {file = "shapely-2.0.4-cp312-cp312-win32.whl", hash = "sha256:4f2ab0faf8188b9f99e6a273b24b97662194160cc8ca17cf9d1fb6f18d7fb93f"},
+ {file = "shapely-2.0.4-cp312-cp312-win_amd64.whl", hash = "sha256:03152442d311a5e85ac73b39680dd64a9892fa42bb08fd83b3bab4fe6999bfa0"},
+ {file = "shapely-2.0.4.tar.gz", hash = "sha256:5dc736127fac70009b8d309a0eeb74f3e08979e530cf7017f2f507ef62e6cfb8"},
+]
+
+[[package]]
+name = "six"
+version = "1.16.0"
+requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
+summary = "Python 2 and 3 compatibility utilities"
+groups = ["default"]
+files = [
+ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
+ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
+]
+
+[[package]]
+name = "smmap"
+version = "5.0.1"
+requires_python = ">=3.7"
+summary = "A pure Python implementation of a sliding window memory map manager"
+groups = ["default"]
+files = [
+ {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"},
+ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"},
+]
+
+[[package]]
+name = "sympy"
+version = "1.12"
+requires_python = ">=3.8"
+summary = "Computer algebra system (CAS) in Python"
+groups = ["default"]
+dependencies = [
+ "mpmath>=0.19",
+]
+files = [
+ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
+ {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
+]
+
+[[package]]
+name = "tbb"
+version = "2021.12.0"
+summary = "Intel® oneAPI Threading Building Blocks (oneTBB)"
+groups = ["default"]
+marker = "platform_system == \"Windows\""
+files = [
+ {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"},
+ {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"},
+ {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"},
+ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
+]
+
+[[package]]
+name = "tenacity"
+version = "8.3.0"
+requires_python = ">=3.8"
+summary = "Retry code until it succeeds"
+groups = ["default"]
+files = [
+ {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"},
+ {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"},
+]
+
+[[package]]
+name = "torch"
+version = "2.3.0"
+requires_python = ">=3.8.0"
+summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
+groups = ["default"]
+dependencies = [
+ "filelock",
+ "fsspec",
+ "jinja2",
+ "mkl<=2021.4.0,>=2021.1.1; platform_system == \"Windows\"",
+ "networkx",
+ "nvidia-cublas-cu12==12.1.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cuda-cupti-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cuda-runtime-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cudnn-cu12==8.9.2.26; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cufft-cu12==11.0.2.54; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-curand-cu12==10.3.2.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cusolver-cu12==11.4.5.107; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cusparse-cu12==12.1.0.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-nccl-cu12==2.20.5; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-nvtx-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "sympy",
+ "triton==2.3.0; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"",
+ "typing-extensions>=4.8.0",
+]
+files = [
+ {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"},
+ {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"},
+ {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"},
+ {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"},
+ {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"},
+ {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"},
+ {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"},
+ {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"},
+ {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"},
+ {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"},
+ {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"},
+ {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"},
+]
+
+[[package]]
+name = "torchmetrics"
+version = "1.4.0"
+requires_python = ">=3.8"
+summary = "PyTorch native Metrics"
+groups = ["default"]
+dependencies = [
+ "lightning-utilities>=0.8.0",
+ "numpy>1.20.0",
+ "packaging>17.1",
+ "pretty-errors==1.2.25",
+ "torch>=1.10.0",
+]
+files = [
+ {file = "torchmetrics-1.4.0-py3-none-any.whl", hash = "sha256:18599929a0fff7d4b840a3f9a7700054121850c378caaf7206f4161c0a5dc93c"},
+ {file = "torchmetrics-1.4.0.tar.gz", hash = "sha256:0b1e5acdcc9beb05bfe369d3d56cfa5b143f060ebfd6079d19ccc59ba46465b3"},
+]
+
+[[package]]
+name = "tqdm"
+version = "4.66.4"
+requires_python = ">=3.7"
+summary = "Fast, Extensible Progress Meter"
+groups = ["default"]
+dependencies = [
+ "colorama; platform_system == \"Windows\"",
+]
+files = [
+ {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"},
+ {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"},
+]
+
+[[package]]
+name = "triton"
+version = "2.3.0"
+summary = "A language and compiler for custom Deep Learning operations"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""
+dependencies = [
+ "filelock",
+]
+files = [
+ {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"},
+ {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"},
+ {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"},
+]
+
+[[package]]
+name = "tueplots"
+version = "0.0.15"
+requires_python = ">=3.9"
+summary = "Scientific plotting made easy."
+groups = ["default"]
+dependencies = [
+ "matplotlib",
+ "numpy",
+]
+files = [
+ {file = "tueplots-0.0.15-py3-none-any.whl", hash = "sha256:f63e020af88328c78618f3d912612c75c3c91d21004a88fd12cf79dbd9b6d78a"},
+]
+
+[[package]]
+name = "typing-extensions"
+version = "4.11.0"
+requires_python = ">=3.8"
+summary = "Backported and Experimental Type Hints for Python 3.8+"
+groups = ["default"]
+files = [
+ {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
+ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
+]
+
+[[package]]
+name = "urllib3"
+version = "2.2.1"
+requires_python = ">=3.8"
+summary = "HTTP library with thread-safe connection pooling, file post, and more."
+groups = ["default"]
+files = [
+ {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
+ {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},
+]
+
+[[package]]
+name = "virtualenv"
+version = "20.26.1"
+requires_python = ">=3.7"
+summary = "Virtual Python Environment builder"
+groups = ["default"]
+dependencies = [
+ "distlib<1,>=0.3.7",
+ "filelock<4,>=3.12.2",
+ "platformdirs<5,>=3.9.1",
+]
+files = [
+ {file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"},
+ {file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"},
+]
+
+[[package]]
+name = "wandb"
+version = "0.17.0"
+requires_python = ">=3.7"
+summary = "A CLI and library for interacting with the Weights & Biases API."
+groups = ["default"]
+dependencies = [
+ "click!=8.0.0,>=7.1",
+ "docker-pycreds>=0.4.0",
+ "gitpython!=3.1.29,>=1.0.0",
+ "platformdirs",
+ "protobuf!=4.21.0,<5,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"",
+ "protobuf!=4.21.0,<5,>=3.19.0; sys_platform != \"linux\"",
+ "psutil>=5.0.0",
+ "pyyaml",
+ "requests<3,>=2.0.0",
+ "sentry-sdk>=1.0.0",
+ "setproctitle",
+ "setuptools",
+]
+files = [
+ {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"},
+ {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"},
+ {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"},
+ {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"},
+ {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"},
+ {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"},
+ {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"},
+]
+
+[[package]]
+name = "yarl"
+version = "1.9.4"
+requires_python = ">=3.7"
+summary = "Yet another URL library"
+groups = ["default"]
+dependencies = [
+ "idna>=2.0",
+ "multidict>=4.0",
+]
+files = [
+ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"},
+ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"},
+ {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"},
+ {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"},
+ {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"},
+ {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"},
+ {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"},
+ {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"},
+ {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"},
+ {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"},
+ {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"},
+]
From 248196f41ce627926624e43af954db64deb24e3f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:47:26 +0200
Subject: [PATCH 039/273] add pytest
---
pdm.lock | 73 +++++++++++++++++++++++++++++++++++++++++++++++---
pyproject.toml | 5 ++++
2 files changed, 74 insertions(+), 4 deletions(-)
diff --git a/pdm.lock b/pdm.lock
index 65b43d2e..6f5e16bf 100644
--- a/pdm.lock
+++ b/pdm.lock
@@ -2,10 +2,10 @@
# It is not intended for manual editing.
[metadata]
-groups = ["default"]
+groups = ["default", "dev"]
strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.4.1"
-content_hash = "sha256:1f465dd9fc7cac951a6e5f120e295be967d1be97e98a75f750adeae06040a64f"
+content_hash = "sha256:042137fa24b1870b761e5e8242a9d8fcad3d4b1a95b3b89d5e280667b9ca2069"
[[package]]
name = "aiohttp"
@@ -234,7 +234,7 @@ name = "colorama"
version = "0.4.6"
requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
summary = "Cross-platform colored terminal text."
-groups = ["default"]
+groups = ["default", "dev"]
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
@@ -320,6 +320,18 @@ files = [
{file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"},
]
+[[package]]
+name = "exceptiongroup"
+version = "1.2.1"
+requires_python = ">=3.7"
+summary = "Backport of PEP 654 (exception groups)"
+groups = ["dev"]
+marker = "python_version < \"3.11\""
+files = [
+ {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"},
+ {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"},
+]
+
[[package]]
name = "filelock"
version = "3.14.0"
@@ -499,6 +511,17 @@ files = [
{file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
]
+[[package]]
+name = "iniconfig"
+version = "2.0.0"
+requires_python = ">=3.7"
+summary = "brain-dead simple config-ini parsing"
+groups = ["dev"]
+files = [
+ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
+ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+]
+
[[package]]
name = "intel-openmp"
version = "2021.4.0"
@@ -1000,7 +1023,7 @@ name = "packaging"
version = "24.0"
requires_python = ">=3.7"
summary = "Core utilities for Python packages"
-groups = ["default"]
+groups = ["default", "dev"]
files = [
{file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
{file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
@@ -1089,6 +1112,17 @@ files = [
{file = "plotly-5.22.0.tar.gz", hash = "sha256:859fdadbd86b5770ae2466e542b761b247d1c6b49daed765b95bb8c7063e7469"},
]
+[[package]]
+name = "pluggy"
+version = "1.5.0"
+requires_python = ">=3.8"
+summary = "plugin and hook calling mechanisms for python"
+groups = ["dev"]
+files = [
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
+]
+
[[package]]
name = "pre-commit"
version = "3.7.1"
@@ -1208,6 +1242,25 @@ files = [
{file = "pyshp-2.3.1.tar.gz", hash = "sha256:4caec82fd8dd096feba8217858068bacb2a3b5950f43c048c6dc32a3489d5af1"},
]
+[[package]]
+name = "pytest"
+version = "8.2.0"
+requires_python = ">=3.8"
+summary = "pytest: simple powerful testing with Python"
+groups = ["dev"]
+dependencies = [
+ "colorama; sys_platform == \"win32\"",
+ "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"",
+ "iniconfig",
+ "packaging",
+ "pluggy<2.0,>=1.5",
+ "tomli>=1; python_version < \"3.11\"",
+]
+files = [
+ {file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"},
+ {file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"},
+]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -1507,6 +1560,18 @@ files = [
{file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"},
]
+[[package]]
+name = "tomli"
+version = "2.0.1"
+requires_python = ">=3.7"
+summary = "A lil' TOML parser"
+groups = ["dev"]
+marker = "python_version < \"3.11\""
+files = [
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+]
+
[[package]]
name = "torch"
version = "2.3.0"
diff --git a/pyproject.toml b/pyproject.toml
index 02213b0c..d9eb14dd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,6 +78,11 @@ min-similarity-lines=10
[tool.pdm]
distribution = true
+
+[tool.pdm.dev-dependencies]
+dev = [
+ "pytest>=8.2.0",
+]
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
From 1af15760b01ad37b52494ebae0bf2e0c63b1f8ae Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:47:33 +0200
Subject: [PATCH 040/273] cache in cicd
---
.github/workflows/ci-tests.yml | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
index 0fff9f06..9b73f298 100644
--- a/.github/workflows/ci-tests.yml
+++ b/.github/workflows/ci-tests.yml
@@ -2,7 +2,7 @@
# needs to first install pdm, then install torch cpu manually and then install the package
# then run the tests
-name: tests
+name: tests (cpu)
on: [push, pull_request]
@@ -17,6 +17,7 @@ jobs:
uses: pdm-project/setup-pdm@v4
with:
python-version: "3.10"
+ cache: true
- name: Install torch (CPU)
run: |
From 7ed7c9765f451f7645cb201ecc2b4afd89c8c1be Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 13 May 2024 23:50:22 +0200
Subject: [PATCH 041/273] add torch-geometric to deps
---
pdm.lock | 78 +++++++++++++++++++++++++++++++++++++++++++++++++-
pyproject.toml | 1 +
2 files changed, 78 insertions(+), 1 deletion(-)
diff --git a/pdm.lock b/pdm.lock
index 6f5e16bf..21467c0d 100644
--- a/pdm.lock
+++ b/pdm.lock
@@ -5,7 +5,7 @@
groups = ["default", "dev"]
strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.4.1"
-content_hash = "sha256:042137fa24b1870b761e5e8242a9d8fcad3d4b1a95b3b89d5e280667b9ca2069"
+content_hash = "sha256:c6c346f14a001266b5cc8a2eafb2081b9bcba755c41eb0f44525436548a09fde"
[[package]]
name = "aiohttp"
@@ -550,6 +550,17 @@ files = [
{file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"},
]
+[[package]]
+name = "joblib"
+version = "1.4.2"
+requires_python = ">=3.8"
+summary = "Lightweight pipelining with Python functions"
+groups = ["default"]
+files = [
+ {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
+ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
+]
+
[[package]]
name = "kiwisolver"
version = "1.4.5"
@@ -1347,6 +1358,37 @@ files = [
{file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
]
+[[package]]
+name = "scikit-learn"
+version = "1.4.2"
+requires_python = ">=3.9"
+summary = "A set of python modules for machine learning and data mining"
+groups = ["default"]
+dependencies = [
+ "joblib>=1.2.0",
+ "numpy>=1.19.5",
+ "scipy>=1.6.0",
+ "threadpoolctl>=2.0.0",
+]
+files = [
+ {file = "scikit-learn-1.4.2.tar.gz", hash = "sha256:daa1c471d95bad080c6e44b4946c9390a4842adc3082572c20e4f8884e39e959"},
+ {file = "scikit_learn-1.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8539a41b3d6d1af82eb629f9c57f37428ff1481c1e34dddb3b9d7af8ede67ac5"},
+ {file = "scikit_learn-1.4.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:68b8404841f944a4a1459b07198fa2edd41a82f189b44f3e1d55c104dbc2e40c"},
+ {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81bf5d8bbe87643103334032dd82f7419bc8c8d02a763643a6b9a5c7288c5054"},
+ {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f0ea5d0f693cb247a073d21a4123bdf4172e470e6d163c12b74cbb1536cf38"},
+ {file = "scikit_learn-1.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:87440e2e188c87db80ea4023440923dccbd56fbc2d557b18ced00fef79da0727"},
+ {file = "scikit_learn-1.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:45dee87ac5309bb82e3ea633955030df9bbcb8d2cdb30383c6cd483691c546cc"},
+ {file = "scikit_learn-1.4.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1d0b25d9c651fd050555aadd57431b53d4cf664e749069da77f3d52c5ad14b3b"},
+ {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0203c368058ab92efc6168a1507d388d41469c873e96ec220ca8e74079bf62e"},
+ {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44c62f2b124848a28fd695db5bc4da019287abf390bfce602ddc8aa1ec186aae"},
+ {file = "scikit_learn-1.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:5cd7b524115499b18b63f0c96f4224eb885564937a0b3477531b2b63ce331904"},
+ {file = "scikit_learn-1.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90378e1747949f90c8f385898fff35d73193dfcaec3dd75d6b542f90c4e89755"},
+ {file = "scikit_learn-1.4.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ff4effe5a1d4e8fed260a83a163f7dbf4f6087b54528d8880bab1d1377bd78be"},
+ {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:671e2f0c3f2c15409dae4f282a3a619601fa824d2c820e5b608d9d775f91780c"},
+ {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d36d0bc983336bbc1be22f9b686b50c964f593c8a9a913a792442af9bf4f5e68"},
+ {file = "scikit_learn-1.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:d762070980c17ba3e9a4a1e043ba0518ce4c55152032f1af0ca6f39b376b5928"},
+]
+
[[package]]
name = "scipy"
version = "1.13.0"
@@ -1560,6 +1602,17 @@ files = [
{file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"},
]
+[[package]]
+name = "threadpoolctl"
+version = "3.5.0"
+requires_python = ">=3.8"
+summary = "threadpoolctl"
+groups = ["default"]
+files = [
+ {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"},
+ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
+]
+
[[package]]
name = "tomli"
version = "2.0.1"
@@ -1614,6 +1667,29 @@ files = [
{file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"},
]
+[[package]]
+name = "torch-geometric"
+version = "2.5.3"
+requires_python = ">=3.8"
+summary = "Graph Neural Network Library for PyTorch"
+groups = ["default"]
+dependencies = [
+ "aiohttp",
+ "fsspec",
+ "jinja2",
+ "numpy",
+ "psutil>=5.8.0",
+ "pyparsing",
+ "requests",
+ "scikit-learn",
+ "scipy",
+ "tqdm",
+]
+files = [
+ {file = "torch_geometric-2.5.3-py3-none-any.whl", hash = "sha256:8277abfc12600b0e8047e0c3ea2d55cc43f08c1448e73e924de827c15d0b5f85"},
+ {file = "torch_geometric-2.5.3.tar.gz", hash = "sha256:ad0761650c8fa56cdc46ee61c564fd4995f07f079965fe732b3a76d109fd3edc"},
+]
+
[[package]]
name = "torchmetrics"
version = "1.4.0"
diff --git a/pyproject.toml b/pyproject.toml
index d9eb14dd..b2461c42 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,7 @@ dependencies = [
"tueplots>=0.0.8",
"plotly>=5.15.0",
"pre-commit>=2.15.0",
+ "torch-geometric>=2.5.3",
]
requires-python = ">=3.10"
name = "neural-lam"
From 286995202ac9534eb2ef3bbb621ed8c798fa8e9a Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 14 May 2024 00:08:58 +0200
Subject: [PATCH 042/273] fix import and more tests
---
neural_lam/models/graph_lam.py | 2 +-
tests/test_base.py | 6 ++++++
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index 6bbe83cc..ff641c20 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -2,7 +2,7 @@
import torch_geometric as pyg
# First-party
-from . import utils
+from .. import utils
from ..interaction_net import InteractionNet
from .base_graph_model import BaseGraphModel
diff --git a/tests/test_base.py b/tests/test_base.py
index c3a89171..df442614 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -1,7 +1,13 @@
import neural_lam
+import neural_lam.create_mesh
+import neural_lam.create_grid_features
+import neural_lam.create_parameter_weights
import neural_lam.train_model
def test_import():
assert neural_lam is not None
+ assert neural_lam.create_mesh is not None
+ assert neural_lam.create_grid_features is not None
+ assert neural_lam.create_parameter_weights is not None
assert neural_lam.train_model is not None
From 358c8d6965a6b1093c4d44244ff652f5835d733b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 14 May 2024 00:12:11 +0200
Subject: [PATCH 043/273] pdm to sync to requirements.txt
---
.pre-commit-config.yaml | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0547c6b9..47754c21 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -35,3 +35,12 @@ repos:
hooks:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
+
+ # export python requirements
+ - repo: https://github.com/pdm-project/pdm
+ rev: 2.12.4 # a PDM release exposing the hook
+ hooks:
+ - id: pdm-export
+ # command arguments, e.g.:
+ args: ['-o', 'requirements.txt', '--without-hashes']
+ files: ^pdm.lock$
From 6c3bdce0ca2e645220db4c19be6f20f742074079 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 14 May 2024 00:13:31 +0200
Subject: [PATCH 044/273] update requirements.txt
---
requirements.txt | 93 ++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 93 insertions(+)
create mode 100644 requirements.txt
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 00000000..1a87e6be
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,93 @@
+# This file is @generated by PDM.
+# Please do not edit it manually.
+
+aiohttp==3.9.5
+aiosignal==1.3.1
+async-timeout==4.0.3; python_version < "3.11"
+attrs==23.2.0
+cartopy==0.23.0
+certifi==2024.2.2
+cfgv==3.4.0
+charset-normalizer==3.3.2
+click==8.1.7
+colorama==0.4.6
+contourpy==1.2.1
+cycler==0.12.1
+distlib==0.3.8
+docker-pycreds==0.4.0
+exceptiongroup==1.2.1; python_version < "3.11"
+filelock==3.14.0
+fonttools==4.51.0
+frozenlist==1.4.1
+fsspec==2024.3.1
+gitdb==4.0.11
+gitpython==3.1.43
+identify==2.5.36
+idna==3.7
+iniconfig==2.0.0
+intel-openmp==2021.4.0; platform_system == "Windows"
+jinja2==3.1.4
+joblib==1.4.2
+kiwisolver==1.4.5
+lightning-utilities==0.11.2
+markupsafe==2.1.5
+matplotlib==3.8.4
+mkl==2021.4.0; platform_system == "Windows"
+mpmath==1.3.0
+multidict==6.0.5
+networkx==3.3
+nodeenv==1.8.0
+numpy==1.26.4
+nvidia-cublas-cu12==12.1.3.1; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cuda-cupti-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cuda-runtime-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cudnn-cu12==8.9.2.26; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cufft-cu12==11.0.2.54; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-curand-cu12==10.3.2.106; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cusolver-cu12==11.4.5.107; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cusparse-cu12==12.1.0.106; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-nccl-cu12==2.20.5; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-nvjitlink-cu12==12.4.127; platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-nvtx-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
+packaging==24.0
+pillow==10.3.0
+platformdirs==4.2.1
+plotly==5.22.0
+pluggy==1.5.0
+pre-commit==3.7.1
+pretty-errors==1.2.25
+protobuf==4.25.3; python_version > "3.9" or sys_platform != "linux"
+psutil==5.9.8
+pyparsing==3.1.2
+pyproj==3.6.1
+pyshp==2.3.1
+pytest==8.2.0
+python-dateutil==2.9.0.post0
+pytorch-lightning==2.2.4
+pyyaml==6.0.1
+requests==2.31.0
+scikit-learn==1.4.2
+scipy==1.13.0
+sentry-sdk==2.1.1
+setproctitle==1.3.3
+setuptools==69.5.1
+shapely==2.0.4
+six==1.16.0
+smmap==5.0.1
+sympy==1.12
+tbb==2021.12.0; platform_system == "Windows"
+tenacity==8.3.0
+threadpoolctl==3.5.0
+tomli==2.0.1; python_version < "3.11"
+torch==2.3.0
+torch-geometric==2.5.3
+torchmetrics==1.4.0
+tqdm==4.66.4
+triton==2.3.0; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12"
+tueplots==0.0.15
+typing-extensions==4.11.0
+urllib3==2.2.1
+virtualenv==20.26.1
+wandb==0.17.0
+yarl==1.9.4
From fbd6a2b6351d1fc7a8f00a6990a6a1bb9e20c2a1 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 14 May 2024 00:13:43 +0200
Subject: [PATCH 045/273] more import tests
---
tests/test_base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_base.py b/tests/test_base.py
index df442614..27228cfb 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -1,6 +1,6 @@
import neural_lam
-import neural_lam.create_mesh
import neural_lam.create_grid_features
+import neural_lam.create_mesh
import neural_lam.create_parameter_weights
import neural_lam.train_model
From 93190deba0cfde97d31752ebf58c38417d713e3c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 22 May 2024 14:09:44 +0200
Subject: [PATCH 046/273] move deps to projects and add import tests
---
.gitignore | 3 +
pdm.lock | 2015 +++++++++++++++++++++++++++++++++++++++++
pyproject.toml | 37 +
tests/test_imports.py | 8 +
4 files changed, 2063 insertions(+)
create mode 100644 pdm.lock
create mode 100644 tests/test_imports.py
diff --git a/.gitignore b/.gitignore
index c9d914c2..ede00bca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -73,3 +73,6 @@ tags
# Coc configuration directory
.vim
+
+# pdm (https://pdm-project.org/en/stable/)
+.pdm-python
diff --git a/pdm.lock b/pdm.lock
new file mode 100644
index 00000000..6ea24bcf
--- /dev/null
+++ b/pdm.lock
@@ -0,0 +1,2015 @@
+# This file is @generated by PDM.
+# It is not intended for manual editing.
+
+[metadata]
+groups = ["default", "dev"]
+strategy = ["cross_platform", "inherit_metadata"]
+lock_version = "4.4.1"
+content_hash = "sha256:c4f5df1487409a1cd6d45a6155c3aff846c7deca9787b9e0003e2d850a4f27c8"
+
+[[package]]
+name = "aiohttp"
+version = "3.9.5"
+requires_python = ">=3.8"
+summary = "Async http client/server framework (asyncio)"
+groups = ["default"]
+dependencies = [
+ "aiosignal>=1.1.2",
+ "async-timeout<5.0,>=4.0; python_version < \"3.11\"",
+ "attrs>=17.3.0",
+ "frozenlist>=1.1.1",
+ "multidict<7.0,>=4.5",
+ "yarl<2.0,>=1.0",
+]
+files = [
+ {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"},
+ {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"},
+ {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"},
+ {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"},
+ {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"},
+ {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"},
+ {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"},
+ {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"},
+ {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"},
+ {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"},
+ {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"},
+ {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"},
+ {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"},
+ {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"},
+ {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"},
+ {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"},
+ {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"},
+ {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"},
+ {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"},
+ {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"},
+ {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"},
+ {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"},
+ {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"},
+ {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"},
+ {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"},
+ {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"},
+ {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"},
+ {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"},
+ {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"},
+ {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"},
+ {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"},
+ {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"},
+ {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"},
+ {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"},
+ {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"},
+ {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"},
+ {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"},
+]
+
+[[package]]
+name = "aiosignal"
+version = "1.3.1"
+requires_python = ">=3.7"
+summary = "aiosignal: a list of registered asynchronous callbacks"
+groups = ["default"]
+dependencies = [
+ "frozenlist>=1.1.0",
+]
+files = [
+ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
+ {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
+]
+
+[[package]]
+name = "async-timeout"
+version = "4.0.3"
+requires_python = ">=3.7"
+summary = "Timeout context manager for asyncio programs"
+groups = ["default"]
+marker = "python_version < \"3.11\""
+files = [
+ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
+ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
+]
+
+[[package]]
+name = "attrs"
+version = "23.2.0"
+requires_python = ">=3.7"
+summary = "Classes Without Boilerplate"
+groups = ["default"]
+files = [
+ {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
+ {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
+]
+
+[[package]]
+name = "cartopy"
+version = "0.23.0"
+requires_python = ">=3.9"
+summary = "A Python library for cartographic visualizations with Matplotlib"
+groups = ["default"]
+dependencies = [
+ "matplotlib>=3.5",
+ "numpy>=1.21",
+ "packaging>=20",
+ "pyproj>=3.3.1",
+ "pyshp>=2.3",
+ "shapely>=1.7",
+]
+files = [
+ {file = "Cartopy-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:374e66f816c3bafa48ffdbf6abaefa67063b405fac5f425f9be241cdf3498352"},
+ {file = "Cartopy-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2bae450c4c913796cad0b7ce05aa2fa78d1788de47989f0a03183397648e24be"},
+ {file = "Cartopy-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40437596e8ac5e74575eab822c661f4e725bd995cfd9e445069695fe9086b42"},
+ {file = "Cartopy-0.23.0-cp310-cp310-win_amd64.whl", hash = "sha256:3292d6d403137eed80d32014c2f28de6282bed8824213f4b4c2170f388b24a1b"},
+ {file = "Cartopy-0.23.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86b07b6794b616674e4e485b8574e9197bca54a4467d28dd01ae0bf178f8dc2b"},
+ {file = "Cartopy-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8dece2aa8d5ff7bf989ded6b5f07c980fb5bb772952bc7cdeab469738abdecee"},
+ {file = "Cartopy-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9dfd28352dc83d6b4e4cf85d84cb50fc4886d4c1510d61f4c7cf22477d1156f"},
+ {file = "Cartopy-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:b2671b5354e43220f8e1074e7fe30a8b9f71cb38407c78e51db9c97772f0320b"},
+ {file = "Cartopy-0.23.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:80b9fd666fd47f6370d29f7ad4e352828d54aaf688a03d0b83b51e141cfd77fa"},
+ {file = "Cartopy-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:43e36b8b7e7e373a5698757458fd28fafbbbf5f3ebbe2d378f6a5ec3993d6dc0"},
+ {file = "Cartopy-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550173b91155d4d81cd14b4892cb6cabe3dd32bd34feacaa1ec78c0e56287832"},
+ {file = "Cartopy-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:55219ee0fb069cc3254426e87382cde03546e86c3f7c6759f076823b1e3a44d9"},
+ {file = "Cartopy-0.23.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6279af846bf77d9817ab8792a8e38ca561878f048bba1afdae3e3a30c5432bfd"},
+ {file = "Cartopy-0.23.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843bf9dc0a18e1a8eed872c49e8092e8a8109e4dce285ad96752841e21e8161e"},
+ {file = "Cartopy-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:350ff8802e2bc617c09bd6148aeb46e841775a846bfaa6e635a212d1eaf5ab66"},
+ {file = "Cartopy-0.23.0-cp39-cp39-win_amd64.whl", hash = "sha256:b52ab2274ad7504955854ef8d6f603e41f5d7163d02b29d369cecdbd29c2fda1"},
+ {file = "Cartopy-0.23.0.tar.gz", hash = "sha256:231f37b35701f2ba31d94959cca75e6da04c2eea3a7f14ce1c75ee3b0eae7676"},
+]
+
+[[package]]
+name = "certifi"
+version = "2024.2.2"
+requires_python = ">=3.6"
+summary = "Python package for providing Mozilla's CA Bundle."
+groups = ["default"]
+files = [
+ {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"},
+ {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"},
+]
+
+[[package]]
+name = "cfgv"
+version = "3.4.0"
+requires_python = ">=3.8"
+summary = "Validate configuration and produce human readable error messages."
+groups = ["dev"]
+files = [
+ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
+ {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.3.2"
+requires_python = ">=3.7.0"
+summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+groups = ["default"]
+files = [
+ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
+ {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
+ {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
+ {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"},
+ {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"},
+ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
+]
+
+[[package]]
+name = "click"
+version = "8.1.7"
+requires_python = ">=3.7"
+summary = "Composable command line interface toolkit"
+groups = ["default"]
+dependencies = [
+ "colorama; platform_system == \"Windows\"",
+]
+files = [
+ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
+ {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
+]
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+summary = "Cross-platform colored terminal text."
+groups = ["default", "dev"]
+marker = "sys_platform == \"win32\" or platform_system == \"Windows\""
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+
+[[package]]
+name = "contourpy"
+version = "1.2.1"
+requires_python = ">=3.9"
+summary = "Python library for calculating contours of 2D quadrilateral grids"
+groups = ["default"]
+dependencies = [
+ "numpy>=1.20",
+]
+files = [
+ {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"},
+ {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"},
+ {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"},
+ {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"},
+ {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"},
+ {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"},
+ {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"},
+ {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"},
+ {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"},
+ {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"},
+ {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"},
+ {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"},
+ {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"},
+ {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"},
+ {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"},
+ {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"},
+ {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"},
+ {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"},
+ {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"},
+ {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"},
+ {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"},
+ {file = "contourpy-1.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b"},
+ {file = "contourpy-1.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f"},
+ {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364"},
+ {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe"},
+ {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985"},
+ {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445"},
+ {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02"},
+ {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083"},
+ {file = "contourpy-1.2.1-cp39-cp39-win32.whl", hash = "sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba"},
+ {file = "contourpy-1.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9"},
+ {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"},
+ {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"},
+ {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"},
+ {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"},
+]
+
+[[package]]
+name = "cycler"
+version = "0.12.1"
+requires_python = ">=3.8"
+summary = "Composable style cycles"
+groups = ["default"]
+files = [
+ {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"},
+ {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"},
+]
+
+[[package]]
+name = "distlib"
+version = "0.3.8"
+summary = "Distribution utilities"
+groups = ["dev"]
+files = [
+ {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
+ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
+]
+
+[[package]]
+name = "docker-pycreds"
+version = "0.4.0"
+summary = "Python bindings for the docker credentials store API"
+groups = ["default"]
+dependencies = [
+ "six>=1.4.0",
+]
+files = [
+ {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"},
+ {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"},
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.2.1"
+requires_python = ">=3.7"
+summary = "Backport of PEP 654 (exception groups)"
+groups = ["dev"]
+marker = "python_version < \"3.11\""
+files = [
+ {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"},
+ {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"},
+]
+
+[[package]]
+name = "filelock"
+version = "3.14.0"
+requires_python = ">=3.8"
+summary = "A platform independent file lock."
+groups = ["default", "dev"]
+files = [
+ {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
+ {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
+]
+
+[[package]]
+name = "fonttools"
+version = "4.51.0"
+requires_python = ">=3.8"
+summary = "Tools to manipulate font files"
+groups = ["default"]
+files = [
+ {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74"},
+ {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308"},
+ {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037"},
+ {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716"},
+ {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438"},
+ {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039"},
+ {file = "fonttools-4.51.0-cp310-cp310-win32.whl", hash = "sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77"},
+ {file = "fonttools-4.51.0-cp310-cp310-win_amd64.whl", hash = "sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b"},
+ {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74"},
+ {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2"},
+ {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f"},
+ {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097"},
+ {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0"},
+ {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1"},
+ {file = "fonttools-4.51.0-cp311-cp311-win32.whl", hash = "sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034"},
+ {file = "fonttools-4.51.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1"},
+ {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba"},
+ {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc"},
+ {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a"},
+ {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2"},
+ {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671"},
+ {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5"},
+ {file = "fonttools-4.51.0-cp312-cp312-win32.whl", hash = "sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15"},
+ {file = "fonttools-4.51.0-cp312-cp312-win_amd64.whl", hash = "sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e"},
+ {file = "fonttools-4.51.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b"},
+ {file = "fonttools-4.51.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936"},
+ {file = "fonttools-4.51.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55"},
+ {file = "fonttools-4.51.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce"},
+ {file = "fonttools-4.51.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051"},
+ {file = "fonttools-4.51.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7"},
+ {file = "fonttools-4.51.0-cp39-cp39-win32.whl", hash = "sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636"},
+ {file = "fonttools-4.51.0-cp39-cp39-win_amd64.whl", hash = "sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a"},
+ {file = "fonttools-4.51.0-py3-none-any.whl", hash = "sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f"},
+ {file = "fonttools-4.51.0.tar.gz", hash = "sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68"},
+]
+
+[[package]]
+name = "frozenlist"
+version = "1.4.1"
+requires_python = ">=3.8"
+summary = "A list-like structure which implements collections.abc.MutableSequence"
+groups = ["default"]
+files = [
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"},
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"},
+ {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"},
+ {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"},
+ {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"},
+ {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"},
+ {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"},
+ {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"},
+ {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"},
+ {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"},
+ {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"},
+ {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"},
+ {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"},
+ {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"},
+ {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"},
+ {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"},
+ {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"},
+ {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"},
+ {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"},
+ {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"},
+ {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"},
+ {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"},
+ {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"},
+ {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"},
+ {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"},
+ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
+]
+
+[[package]]
+name = "fsspec"
+version = "2024.5.0"
+requires_python = ">=3.8"
+summary = "File-system specification"
+groups = ["default"]
+files = [
+ {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"},
+ {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"},
+]
+
+[[package]]
+name = "fsspec"
+version = "2024.5.0"
+extras = ["http"]
+requires_python = ">=3.8"
+summary = "File-system specification"
+groups = ["default"]
+dependencies = [
+ "aiohttp!=4.0.0a0,!=4.0.0a1",
+ "fsspec==2024.5.0",
+]
+files = [
+ {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"},
+ {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"},
+]
+
+[[package]]
+name = "gitdb"
+version = "4.0.11"
+requires_python = ">=3.7"
+summary = "Git Object Database"
+groups = ["default"]
+dependencies = [
+ "smmap<6,>=3.0.1",
+]
+files = [
+ {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"},
+ {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"},
+]
+
+[[package]]
+name = "gitpython"
+version = "3.1.43"
+requires_python = ">=3.7"
+summary = "GitPython is a Python library used to interact with Git repositories"
+groups = ["default"]
+dependencies = [
+ "gitdb<5,>=4.0.1",
+]
+files = [
+ {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"},
+ {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"},
+]
+
+[[package]]
+name = "identify"
+version = "2.5.36"
+requires_python = ">=3.8"
+summary = "File identification library for Python"
+groups = ["dev"]
+files = [
+ {file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"},
+ {file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"},
+]
+
+[[package]]
+name = "idna"
+version = "3.7"
+requires_python = ">=3.5"
+summary = "Internationalized Domain Names in Applications (IDNA)"
+groups = ["default"]
+files = [
+ {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"},
+ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
+]
+
+[[package]]
+name = "importlib-resources"
+version = "6.4.0"
+requires_python = ">=3.8"
+summary = "Read resources from Python packages"
+groups = ["default"]
+marker = "python_version < \"3.10\""
+dependencies = [
+ "zipp>=3.1.0; python_version < \"3.10\"",
+]
+files = [
+ {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"},
+ {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"},
+]
+
+[[package]]
+name = "iniconfig"
+version = "2.0.0"
+requires_python = ">=3.7"
+summary = "brain-dead simple config-ini parsing"
+groups = ["dev"]
+files = [
+ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
+ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+]
+
+[[package]]
+name = "intel-openmp"
+version = "2021.4.0"
+summary = "Intel® OpenMP* Runtime Library"
+groups = ["default"]
+marker = "platform_system == \"Windows\""
+files = [
+ {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"},
+ {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
+]
+
+[[package]]
+name = "jinja2"
+version = "3.1.4"
+requires_python = ">=3.7"
+summary = "A very fast and expressive template engine."
+groups = ["default"]
+dependencies = [
+ "MarkupSafe>=2.0",
+]
+files = [
+ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"},
+ {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"},
+]
+
+[[package]]
+name = "kiwisolver"
+version = "1.4.5"
+requires_python = ">=3.7"
+summary = "A fast implementation of the Cassowary constraint solver"
+groups = ["default"]
+files = [
+ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"},
+ {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"},
+ {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"},
+ {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-win32.whl", hash = "sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f"},
+ {file = "kiwisolver-1.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"},
+ {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"},
+ {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"},
+ {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"},
+ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
+]
+
+[[package]]
+name = "lightning-utilities"
+version = "0.11.2"
+requires_python = ">=3.8"
+summary = "Lightning toolbox for across the our ecosystem."
+groups = ["default"]
+dependencies = [
+ "packaging>=17.1",
+ "setuptools",
+ "typing-extensions",
+]
+files = [
+ {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"},
+ {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"},
+]
+
+[[package]]
+name = "markupsafe"
+version = "2.1.5"
+requires_python = ">=3.7"
+summary = "Safely add untrusted strings to HTML/XML markup."
+groups = ["default"]
+files = [
+ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"},
+ {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"},
+ {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"},
+ {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"},
+ {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"},
+ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
+]
+
+[[package]]
+name = "matplotlib"
+version = "3.9.0"
+requires_python = ">=3.9"
+summary = "Python plotting package"
+groups = ["default"]
+dependencies = [
+ "contourpy>=1.0.1",
+ "cycler>=0.10",
+ "fonttools>=4.22.0",
+ "importlib-resources>=3.2.0; python_version < \"3.10\"",
+ "kiwisolver>=1.3.1",
+ "numpy>=1.23",
+ "packaging>=20.0",
+ "pillow>=8",
+ "pyparsing>=2.3.1",
+ "python-dateutil>=2.7",
+]
+files = [
+ {file = "matplotlib-3.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56"},
+ {file = "matplotlib-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b"},
+ {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241"},
+ {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d"},
+ {file = "matplotlib-3.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4"},
+ {file = "matplotlib-3.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463"},
+ {file = "matplotlib-3.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38"},
+ {file = "matplotlib-3.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152"},
+ {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85"},
+ {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb"},
+ {file = "matplotlib-3.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674"},
+ {file = "matplotlib-3.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be"},
+ {file = "matplotlib-3.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382"},
+ {file = "matplotlib-3.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84"},
+ {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5"},
+ {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db"},
+ {file = "matplotlib-3.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7"},
+ {file = "matplotlib-3.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf"},
+ {file = "matplotlib-3.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956"},
+ {file = "matplotlib-3.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a"},
+ {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321"},
+ {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89"},
+ {file = "matplotlib-3.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b"},
+ {file = "matplotlib-3.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888"},
+ {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0"},
+ {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03"},
+ {file = "matplotlib-3.9.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd"},
+ {file = "matplotlib-3.9.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e"},
+ {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"},
+]
+
+[[package]]
+name = "mkl"
+version = "2021.4.0"
+summary = "Intel® oneAPI Math Kernel Library"
+groups = ["default"]
+marker = "platform_system == \"Windows\""
+dependencies = [
+ "intel-openmp==2021.*",
+ "tbb==2021.*",
+]
+files = [
+ {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"},
+ {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"},
+ {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"},
+ {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"},
+ {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"},
+]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+summary = "Python library for arbitrary-precision floating-point arithmetic"
+groups = ["default"]
+files = [
+ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
+ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
+]
+
+[[package]]
+name = "multidict"
+version = "6.0.5"
+requires_python = ">=3.7"
+summary = "multidict implementation"
+groups = ["default"]
+files = [
+ {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"},
+ {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"},
+ {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"},
+ {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"},
+ {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"},
+ {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"},
+ {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"},
+ {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"},
+ {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"},
+ {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"},
+ {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"},
+ {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"},
+ {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"},
+ {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"},
+ {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"},
+ {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"},
+ {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"},
+ {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"},
+ {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"},
+ {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"},
+ {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"},
+ {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"},
+ {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"},
+ {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"},
+ {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"},
+ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
+]
+
+[[package]]
+name = "networkx"
+version = "3.2.1"
+requires_python = ">=3.9"
+summary = "Python package for creating and manipulating graphs and networks"
+groups = ["default"]
+files = [
+ {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"},
+ {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"},
+]
+
+[[package]]
+name = "nodeenv"
+version = "1.8.0"
+requires_python = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
+summary = "Node.js virtual environment builder"
+groups = ["dev"]
+dependencies = [
+ "setuptools",
+]
+files = [
+ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
+ {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
+]
+
+[[package]]
+name = "numpy"
+version = "1.26.4"
+requires_python = ">=3.9"
+summary = "Fundamental package for array computing in Python"
+groups = ["default"]
+files = [
+ {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
+ {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
+ {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
+ {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
+ {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
+ {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
+ {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
+ {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
+ {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
+ {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
+ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
+]
+
+[[package]]
+name = "nvidia-cublas-cu12"
+version = "12.1.3.1"
+requires_python = ">=3"
+summary = "CUBLAS native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"},
+ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"},
+]
+
+[[package]]
+name = "nvidia-cuda-cupti-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "CUDA profiling tools runtime libs."
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"},
+ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"},
+]
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "NVRTC native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"},
+]
+
+[[package]]
+name = "nvidia-cuda-runtime-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "CUDA Runtime native Libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"},
+ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"},
+]
+
+[[package]]
+name = "nvidia-cudnn-cu12"
+version = "8.9.2.26"
+requires_python = ">=3"
+summary = "cuDNN runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+dependencies = [
+ "nvidia-cublas-cu12",
+]
+files = [
+ {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"},
+]
+
+[[package]]
+name = "nvidia-cufft-cu12"
+version = "11.0.2.54"
+requires_python = ">=3"
+summary = "CUFFT native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"},
+ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"},
+]
+
+[[package]]
+name = "nvidia-curand-cu12"
+version = "10.3.2.106"
+requires_python = ">=3"
+summary = "CURAND native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"},
+ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"},
+]
+
+[[package]]
+name = "nvidia-cusolver-cu12"
+version = "11.4.5.107"
+requires_python = ">=3"
+summary = "CUDA solver native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+dependencies = [
+ "nvidia-cublas-cu12",
+ "nvidia-cusparse-cu12",
+ "nvidia-nvjitlink-cu12",
+]
+files = [
+ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"},
+ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"},
+]
+
+[[package]]
+name = "nvidia-cusparse-cu12"
+version = "12.1.0.106"
+requires_python = ">=3"
+summary = "CUSPARSE native runtime libraries"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+dependencies = [
+ "nvidia-nvjitlink-cu12",
+]
+files = [
+ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"},
+ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"},
+]
+
+[[package]]
+name = "nvidia-nccl-cu12"
+version = "2.20.5"
+requires_python = ">=3"
+summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"},
+ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"},
+]
+
+[[package]]
+name = "nvidia-nvjitlink-cu12"
+version = "12.5.40"
+requires_python = ">=3"
+summary = "Nvidia JIT LTO Library"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"},
+ {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"},
+]
+
+[[package]]
+name = "nvidia-nvtx-cu12"
+version = "12.1.105"
+requires_python = ">=3"
+summary = "NVIDIA Tools Extension"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"},
+ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
+]
+
+[[package]]
+name = "packaging"
+version = "24.0"
+requires_python = ">=3.7"
+summary = "Core utilities for Python packages"
+groups = ["default", "dev"]
+files = [
+ {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
+ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
+]
+
+[[package]]
+name = "pillow"
+version = "10.3.0"
+requires_python = ">=3.8"
+summary = "Python Imaging Library (Fork)"
+groups = ["default"]
+files = [
+ {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"},
+ {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"},
+ {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"},
+ {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"},
+ {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"},
+ {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"},
+ {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"},
+ {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"},
+ {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"},
+ {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"},
+ {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"},
+ {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"},
+ {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"},
+ {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"},
+ {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"},
+ {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"},
+ {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"},
+ {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"},
+ {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"},
+ {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"},
+ {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"},
+ {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"},
+ {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"},
+ {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"},
+ {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"},
+ {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"},
+ {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"},
+ {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"},
+ {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"},
+ {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"},
+ {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"},
+ {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"},
+ {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"},
+ {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"},
+ {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"},
+]
+
+[[package]]
+name = "platformdirs"
+version = "4.2.2"
+requires_python = ">=3.8"
+summary = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
+groups = ["default", "dev"]
+files = [
+ {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"},
+ {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"},
+]
+
+[[package]]
+name = "plotly"
+version = "5.22.0"
+requires_python = ">=3.8"
+summary = "An open-source, interactive data visualization library for Python"
+groups = ["default"]
+dependencies = [
+ "packaging",
+ "tenacity>=6.2.0",
+]
+files = [
+ {file = "plotly-5.22.0-py3-none-any.whl", hash = "sha256:68fc1901f098daeb233cc3dd44ec9dc31fb3ca4f4e53189344199c43496ed006"},
+ {file = "plotly-5.22.0.tar.gz", hash = "sha256:859fdadbd86b5770ae2466e542b761b247d1c6b49daed765b95bb8c7063e7469"},
+]
+
+[[package]]
+name = "pluggy"
+version = "1.5.0"
+requires_python = ">=3.8"
+summary = "plugin and hook calling mechanisms for python"
+groups = ["dev"]
+files = [
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
+]
+
+[[package]]
+name = "pre-commit"
+version = "3.7.1"
+requires_python = ">=3.9"
+summary = "A framework for managing and maintaining multi-language pre-commit hooks."
+groups = ["dev"]
+dependencies = [
+ "cfgv>=2.0.0",
+ "identify>=1.0.0",
+ "nodeenv>=0.11.1",
+ "pyyaml>=5.1",
+ "virtualenv>=20.10.0",
+]
+files = [
+ {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"},
+ {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"},
+]
+
+[[package]]
+name = "protobuf"
+version = "4.25.3"
+requires_python = ">=3.8"
+summary = ""
+groups = ["default"]
+marker = "python_version >= \"3.9\" or sys_platform != \"linux\""
+files = [
+ {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"},
+ {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"},
+ {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"},
+ {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"},
+ {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"},
+ {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"},
+ {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"},
+ {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
+]
+
+[[package]]
+name = "psutil"
+version = "5.9.8"
+requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
+summary = "Cross-platform lib for process and system monitoring in Python."
+groups = ["default"]
+files = [
+ {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"},
+ {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"},
+ {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"},
+ {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"},
+ {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"},
+ {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"},
+ {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"},
+]
+
+[[package]]
+name = "pyparsing"
+version = "3.1.2"
+requires_python = ">=3.6.8"
+summary = "pyparsing module - Classes and methods to define and execute parsing grammars"
+groups = ["default"]
+files = [
+ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
+ {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
+]
+
+[[package]]
+name = "pyproj"
+version = "3.6.1"
+requires_python = ">=3.9"
+summary = "Python interface to PROJ (cartographic projections and coordinate transformations library)"
+groups = ["default"]
+dependencies = [
+ "certifi",
+]
+files = [
+ {file = "pyproj-3.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ab7aa4d9ff3c3acf60d4b285ccec134167a948df02347585fdd934ebad8811b4"},
+ {file = "pyproj-3.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bc0472302919e59114aa140fd7213c2370d848a7249d09704f10f5b062031fe"},
+ {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5279586013b8d6582e22b6f9e30c49796966770389a9d5b85e25a4223286cd3f"},
+ {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80fafd1f3eb421694857f254a9bdbacd1eb22fc6c24ca74b136679f376f97d35"},
+ {file = "pyproj-3.6.1-cp310-cp310-win32.whl", hash = "sha256:c41e80ddee130450dcb8829af7118f1ab69eaf8169c4bf0ee8d52b72f098dc2f"},
+ {file = "pyproj-3.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:db3aedd458e7f7f21d8176f0a1d924f1ae06d725228302b872885a1c34f3119e"},
+ {file = "pyproj-3.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ebfbdbd0936e178091309f6cd4fcb4decd9eab12aa513cdd9add89efa3ec2882"},
+ {file = "pyproj-3.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:447db19c7efad70ff161e5e46a54ab9cc2399acebb656b6ccf63e4bc4a04b97a"},
+ {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7e13c40183884ec7f94eb8e0f622f08f1d5716150b8d7a134de48c6110fee85"},
+ {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65ad699e0c830e2b8565afe42bd58cc972b47d829b2e0e48ad9638386d994915"},
+ {file = "pyproj-3.6.1-cp311-cp311-win32.whl", hash = "sha256:8b8acc31fb8702c54625f4d5a2a6543557bec3c28a0ef638778b7ab1d1772132"},
+ {file = "pyproj-3.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:38a3361941eb72b82bd9a18f60c78b0df8408416f9340521df442cebfc4306e2"},
+ {file = "pyproj-3.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1e9fbaf920f0f9b4ee62aab832be3ae3968f33f24e2e3f7fbb8c6728ef1d9746"},
+ {file = "pyproj-3.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d227a865356f225591b6732430b1d1781e946893789a609bb34f59d09b8b0f8"},
+ {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83039e5ae04e5afc974f7d25ee0870a80a6bd6b7957c3aca5613ccbe0d3e72bf"},
+ {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb059ba3bced6f6725961ba758649261d85ed6ce670d3e3b0a26e81cf1aa8d"},
+ {file = "pyproj-3.6.1-cp312-cp312-win32.whl", hash = "sha256:2d6ff73cc6dbbce3766b6c0bce70ce070193105d8de17aa2470009463682a8eb"},
+ {file = "pyproj-3.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:7a27151ddad8e1439ba70c9b4b2b617b290c39395fa9ddb7411ebb0eb86d6fb0"},
+ {file = "pyproj-3.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4ba1f9b03d04d8cab24d6375609070580a26ce76eaed54631f03bab00a9c737b"},
+ {file = "pyproj-3.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18faa54a3ca475bfe6255156f2f2874e9a1c8917b0004eee9f664b86ccc513d3"},
+ {file = "pyproj-3.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd43bd9a9b9239805f406fd82ba6b106bf4838d9ef37c167d3ed70383943ade1"},
+ {file = "pyproj-3.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50100b2726a3ca946906cbaa789dd0749f213abf0cbb877e6de72ca7aa50e1ae"},
+ {file = "pyproj-3.6.1-cp39-cp39-win32.whl", hash = "sha256:9274880263256f6292ff644ca92c46d96aa7e57a75c6df3f11d636ce845a1877"},
+ {file = "pyproj-3.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:36b64c2cb6ea1cc091f329c5bd34f9c01bb5da8c8e4492c709bda6a09f96808f"},
+ {file = "pyproj-3.6.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd93c1a0c6c4aedc77c0fe275a9f2aba4d59b8acf88cebfc19fe3c430cfabf4f"},
+ {file = "pyproj-3.6.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6420ea8e7d2a88cb148b124429fba8cd2e0fae700a2d96eab7083c0928a85110"},
+ {file = "pyproj-3.6.1.tar.gz", hash = "sha256:44aa7c704c2b7d8fb3d483bbf75af6cb2350d30a63b144279a09b75fead501bf"},
+]
+
+[[package]]
+name = "pyshp"
+version = "2.3.1"
+requires_python = ">=2.7"
+summary = "Pure Python read/write support for ESRI Shapefile format"
+groups = ["default"]
+files = [
+ {file = "pyshp-2.3.1-py2.py3-none-any.whl", hash = "sha256:67024c0ccdc352ba5db777c4e968483782dfa78f8e200672a90d2d30fd8b7b49"},
+ {file = "pyshp-2.3.1.tar.gz", hash = "sha256:4caec82fd8dd096feba8217858068bacb2a3b5950f43c048c6dc32a3489d5af1"},
+]
+
+[[package]]
+name = "pytest"
+version = "8.2.1"
+requires_python = ">=3.8"
+summary = "pytest: simple powerful testing with Python"
+groups = ["dev"]
+dependencies = [
+ "colorama; sys_platform == \"win32\"",
+ "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"",
+ "iniconfig",
+ "packaging",
+ "pluggy<2.0,>=1.5",
+ "tomli>=1; python_version < \"3.11\"",
+]
+files = [
+ {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"},
+ {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"},
+]
+
+[[package]]
+name = "python-dateutil"
+version = "2.9.0.post0"
+requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
+summary = "Extensions to the standard Python datetime module"
+groups = ["default"]
+dependencies = [
+ "six>=1.5",
+]
+files = [
+ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
+ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
+]
+
+[[package]]
+name = "pytorch-lightning"
+version = "2.2.4"
+requires_python = ">=3.8"
+summary = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
+groups = ["default"]
+dependencies = [
+ "PyYAML>=5.4",
+ "fsspec[http]>=2022.5.0",
+ "lightning-utilities>=0.8.0",
+ "numpy>=1.17.2",
+ "packaging>=20.0",
+ "torch>=1.13.0",
+ "torchmetrics>=0.7.0",
+ "tqdm>=4.57.0",
+ "typing-extensions>=4.4.0",
+]
+files = [
+ {file = "pytorch-lightning-2.2.4.tar.gz", hash = "sha256:525b04ebad9900c3e3c2a12b3b462fe4f61ebe11fdb694716c3209f05b9b0fa8"},
+ {file = "pytorch_lightning-2.2.4-py3-none-any.whl", hash = "sha256:fd91d47e983a2cd743c5c8c3c3795bbd0f3b69d24be2172a2f9012d930701ff2"},
+]
+
+[[package]]
+name = "pyyaml"
+version = "6.0.1"
+requires_python = ">=3.6"
+summary = "YAML parser and emitter for Python"
+groups = ["default", "dev"]
+files = [
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
+ {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
+ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+ {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
+ {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
+ {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
+ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+ {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
+ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
+ {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
+ {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
+ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
+ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
+]
+
+[[package]]
+name = "requests"
+version = "2.32.2"
+requires_python = ">=3.8"
+summary = "Python HTTP for Humans."
+groups = ["default"]
+dependencies = [
+ "certifi>=2017.4.17",
+ "charset-normalizer<4,>=2",
+ "idna<4,>=2.5",
+ "urllib3<3,>=1.21.1",
+]
+files = [
+ {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
+ {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
+]
+
+[[package]]
+name = "scipy"
+version = "1.13.0"
+requires_python = ">=3.9"
+summary = "Fundamental algorithms for scientific computing in Python"
+groups = ["default"]
+dependencies = [
+ "numpy<2.3,>=1.22.4",
+]
+files = [
+ {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
+ {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
+ {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
+ {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
+ {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
+ {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
+ {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
+ {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
+ {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
+ {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
+ {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
+ {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
+ {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
+ {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
+ {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
+ {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
+ {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
+ {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
+ {file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"},
+ {file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"},
+ {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"},
+ {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"},
+ {file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"},
+ {file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"},
+ {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
+]
+
+[[package]]
+name = "sentry-sdk"
+version = "2.2.1"
+requires_python = ">=3.6"
+summary = "Python client for Sentry (https://sentry.io)"
+groups = ["default"]
+dependencies = [
+ "certifi",
+ "urllib3>=1.26.11",
+]
+files = [
+ {file = "sentry_sdk-2.2.1-py2.py3-none-any.whl", hash = "sha256:7d617a1b30e80c41f3b542347651fcf90bb0a36f3a398be58b4f06b79c8d85bc"},
+ {file = "sentry_sdk-2.2.1.tar.gz", hash = "sha256:8aa2ec825724d8d9d645cab68e6034928b1a6a148503af3e361db3fa6401183f"},
+]
+
+[[package]]
+name = "setproctitle"
+version = "1.3.3"
+requires_python = ">=3.7"
+summary = "A Python module to customize the process title"
+groups = ["default"]
+files = [
+ {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"},
+ {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbbd6c7de0771c84b4aa30e70b409565eb1fc13627a723ca6be774ed6b9d9fa3"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c05ac48ef16ee013b8a326c63e4610e2430dbec037ec5c5b58fcced550382b74"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1342f4fdb37f89d3e3c1c0a59d6ddbedbde838fff5c51178a7982993d238fe4f"},
+ {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc74e84fdfa96821580fb5e9c0b0777c1c4779434ce16d3d62a9c4d8c710df39"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9617b676b95adb412bb69645d5b077d664b6882bb0d37bfdafbbb1b999568d85"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6a249415f5bb88b5e9e8c4db47f609e0bf0e20a75e8d744ea787f3092ba1f2d0"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:38da436a0aaace9add67b999eb6abe4b84397edf4a78ec28f264e5b4c9d53cd5"},
+ {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:da0d57edd4c95bf221b2ebbaa061e65b1788f1544977288bdf95831b6e44e44d"},
+ {file = "setproctitle-1.3.3-cp310-cp310-win32.whl", hash = "sha256:a1fcac43918b836ace25f69b1dca8c9395253ad8152b625064415b1d2f9be4fb"},
+ {file = "setproctitle-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:200620c3b15388d7f3f97e0ae26599c0c378fdf07ae9ac5a13616e933cbd2086"},
+ {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:334f7ed39895d692f753a443102dd5fed180c571eb6a48b2a5b7f5b3564908c8"},
+ {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:950f6476d56ff7817a8fed4ab207727fc5260af83481b2a4b125f32844df513a"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:195c961f54a09eb2acabbfc90c413955cf16c6e2f8caa2adbf2237d1019c7dd8"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f05e66746bf9fe6a3397ec246fe481096664a9c97eb3fea6004735a4daf867fd"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5901a31012a40ec913265b64e48c2a4059278d9f4e6be628441482dd13fb8b5"},
+ {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64286f8a995f2cd934082b398fc63fca7d5ffe31f0e27e75b3ca6b4efda4e353"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:184239903bbc6b813b1a8fc86394dc6ca7d20e2ebe6f69f716bec301e4b0199d"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:664698ae0013f986118064b6676d7dcd28fefd0d7d5a5ae9497cbc10cba48fa5"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e5119a211c2e98ff18b9908ba62a3bd0e3fabb02a29277a7232a6fb4b2560aa0"},
+ {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:417de6b2e214e837827067048f61841f5d7fc27926f2e43954567094051aff18"},
+ {file = "setproctitle-1.3.3-cp311-cp311-win32.whl", hash = "sha256:6a143b31d758296dc2f440175f6c8e0b5301ced3b0f477b84ca43cdcf7f2f476"},
+ {file = "setproctitle-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a680d62c399fa4b44899094027ec9a1bdaf6f31c650e44183b50d4c4d0ccc085"},
+ {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4460795a8a7a391e3567b902ec5bdf6c60a47d791c3b1d27080fc203d11c9dc"},
+ {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bdfd7254745bb737ca1384dee57e6523651892f0ea2a7344490e9caefcc35e64"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477d3da48e216d7fc04bddab67b0dcde633e19f484a146fd2a34bb0e9dbb4a1e"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ab2900d111e93aff5df9fddc64cf51ca4ef2c9f98702ce26524f1acc5a786ae7"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:088b9efc62d5aa5d6edf6cba1cf0c81f4488b5ce1c0342a8b67ae39d64001120"},
+ {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6d50252377db62d6a0bb82cc898089916457f2db2041e1d03ce7fadd4a07381"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:87e668f9561fd3a457ba189edfc9e37709261287b52293c115ae3487a24b92f6"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:287490eb90e7a0ddd22e74c89a92cc922389daa95babc833c08cf80c84c4df0a"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:4fe1c49486109f72d502f8be569972e27f385fe632bd8895f4730df3c87d5ac8"},
+ {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4a6ba2494a6449b1f477bd3e67935c2b7b0274f2f6dcd0f7c6aceae10c6c6ba3"},
+ {file = "setproctitle-1.3.3-cp312-cp312-win32.whl", hash = "sha256:2df2b67e4b1d7498632e18c56722851ba4db5d6a0c91aaf0fd395111e51cdcf4"},
+ {file = "setproctitle-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:f38d48abc121263f3b62943f84cbaede05749047e428409c2c199664feb6abc7"},
+ {file = "setproctitle-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c7951820b77abe03d88b114b998867c0f99da03859e5ab2623d94690848d3e45"},
+ {file = "setproctitle-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5bc94cf128676e8fac6503b37763adb378e2b6be1249d207630f83fc325d9b11"},
+ {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f5d9027eeda64d353cf21a3ceb74bb1760bd534526c9214e19f052424b37e42"},
+ {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e4a8104db15d3462e29d9946f26bed817a5b1d7a47eabca2d9dc2b995991503"},
+ {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c32c41ace41f344d317399efff4cffb133e709cec2ef09c99e7a13e9f3b9483c"},
+ {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf16381c7bf7f963b58fb4daaa65684e10966ee14d26f5cc90f07049bfd8c1e"},
+ {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e18b7bd0898398cc97ce2dfc83bb192a13a087ef6b2d5a8a36460311cb09e775"},
+ {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:69d565d20efe527bd8a9b92e7f299ae5e73b6c0470f3719bd66f3cd821e0d5bd"},
+ {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:ddedd300cd690a3b06e7eac90ed4452348b1348635777ce23d460d913b5b63c3"},
+ {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:415bfcfd01d1fbf5cbd75004599ef167a533395955305f42220a585f64036081"},
+ {file = "setproctitle-1.3.3-cp39-cp39-win32.whl", hash = "sha256:21112fcd2195d48f25760f0eafa7a76510871bbb3b750219310cf88b04456ae3"},
+ {file = "setproctitle-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:5a740f05d0968a5a17da3d676ce6afefebeeeb5ce137510901bf6306ba8ee002"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6b9e62ddb3db4b5205c0321dd69a406d8af9ee1693529d144e86bd43bcb4b6c0"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e3b99b338598de0bd6b2643bf8c343cf5ff70db3627af3ca427a5e1a1a90dd9"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ae9a02766dad331deb06855fb7a6ca15daea333b3967e214de12cfae8f0ef5"},
+ {file = "setproctitle-1.3.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:200ede6fd11233085ba9b764eb055a2a191fb4ffb950c68675ac53c874c22e20"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0d3a953c50776751e80fe755a380a64cb14d61e8762bd43041ab3f8cc436092f"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5e08e232b78ba3ac6bc0d23ce9e2bee8fad2be391b7e2da834fc9a45129eb87"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1da82c3e11284da4fcbf54957dafbf0655d2389cd3d54e4eaba636faf6d117a"},
+ {file = "setproctitle-1.3.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:aeaa71fb9568ebe9b911ddb490c644fbd2006e8c940f21cb9a1e9425bd709574"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:59335d000c6250c35989394661eb6287187854e94ac79ea22315469ee4f4c244"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3ba57029c9c50ecaf0c92bb127224cc2ea9fda057b5d99d3f348c9ec2855ad3"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d876d355c53d975c2ef9c4f2487c8f83dad6aeaaee1b6571453cb0ee992f55f6"},
+ {file = "setproctitle-1.3.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:224602f0939e6fb9d5dd881be1229d485f3257b540f8a900d4271a2c2aa4e5f4"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d7f27e0268af2d7503386e0e6be87fb9b6657afd96f5726b733837121146750d"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5e7266498cd31a4572378c61920af9f6b4676a73c299fce8ba93afd694f8ae7"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33c5609ad51cd99d388e55651b19148ea99727516132fb44680e1f28dd0d1de9"},
+ {file = "setproctitle-1.3.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:eae8988e78192fd1a3245a6f4f382390b61bce6cfcc93f3809726e4c885fa68d"},
+ {file = "setproctitle-1.3.3.tar.gz", hash = "sha256:c913e151e7ea01567837ff037a23ca8740192880198b7fbb90b16d181607caae"},
+]
+
+[[package]]
+name = "setuptools"
+version = "70.0.0"
+requires_python = ">=3.8"
+summary = "Easily download, build, install, upgrade, and uninstall Python packages"
+groups = ["default", "dev"]
+files = [
+ {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"},
+ {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"},
+]
+
+[[package]]
+name = "shapely"
+version = "2.0.4"
+requires_python = ">=3.7"
+summary = "Manipulation and analysis of geometric objects"
+groups = ["default"]
+dependencies = [
+ "numpy<3,>=1.14",
+]
+files = [
+ {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:011b77153906030b795791f2fdfa2d68f1a8d7e40bce78b029782ade3afe4f2f"},
+ {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9831816a5d34d5170aa9ed32a64982c3d6f4332e7ecfe62dc97767e163cb0b17"},
+ {file = "shapely-2.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5c4849916f71dc44e19ed370421518c0d86cf73b26e8656192fcfcda08218fbd"},
+ {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:841f93a0e31e4c64d62ea570d81c35de0f6cea224568b2430d832967536308e6"},
+ {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b4431f522b277c79c34b65da128029a9955e4481462cbf7ebec23aab61fc58"},
+ {file = "shapely-2.0.4-cp310-cp310-win32.whl", hash = "sha256:92a41d936f7d6743f343be265ace93b7c57f5b231e21b9605716f5a47c2879e7"},
+ {file = "shapely-2.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:30982f79f21bb0ff7d7d4a4e531e3fcaa39b778584c2ce81a147f95be1cd58c9"},
+ {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de0205cb21ad5ddaef607cda9a3191eadd1e7a62a756ea3a356369675230ac35"},
+ {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7d56ce3e2a6a556b59a288771cf9d091470116867e578bebced8bfc4147fbfd7"},
+ {file = "shapely-2.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:58b0ecc505bbe49a99551eea3f2e8a9b3b24b3edd2a4de1ac0dc17bc75c9ec07"},
+ {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:790a168a808bd00ee42786b8ba883307c0e3684ebb292e0e20009588c426da47"},
+ {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4310b5494271e18580d61022c0857eb85d30510d88606fa3b8314790df7f367d"},
+ {file = "shapely-2.0.4-cp311-cp311-win32.whl", hash = "sha256:63f3a80daf4f867bd80f5c97fbe03314348ac1b3b70fb1c0ad255a69e3749879"},
+ {file = "shapely-2.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:c52ed79f683f721b69a10fb9e3d940a468203f5054927215586c5d49a072de8d"},
+ {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5bbd974193e2cc274312da16b189b38f5f128410f3377721cadb76b1e8ca5328"},
+ {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:41388321a73ba1a84edd90d86ecc8bfed55e6a1e51882eafb019f45895ec0f65"},
+ {file = "shapely-2.0.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0776c92d584f72f1e584d2e43cfc5542c2f3dd19d53f70df0900fda643f4bae6"},
+ {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c75c98380b1ede1cae9a252c6dc247e6279403fae38c77060a5e6186c95073ac"},
+ {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3e700abf4a37b7b8b90532fa6ed5c38a9bfc777098bc9fbae5ec8e618ac8f30"},
+ {file = "shapely-2.0.4-cp312-cp312-win32.whl", hash = "sha256:4f2ab0faf8188b9f99e6a273b24b97662194160cc8ca17cf9d1fb6f18d7fb93f"},
+ {file = "shapely-2.0.4-cp312-cp312-win_amd64.whl", hash = "sha256:03152442d311a5e85ac73b39680dd64a9892fa42bb08fd83b3bab4fe6999bfa0"},
+ {file = "shapely-2.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9103abd1678cb1b5f7e8e1af565a652e036844166c91ec031eeb25c5ca8af0"},
+ {file = "shapely-2.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:263bcf0c24d7a57c80991e64ab57cba7a3906e31d2e21b455f493d4aab534aaa"},
+ {file = "shapely-2.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddf4a9bfaac643e62702ed662afc36f6abed2a88a21270e891038f9a19bc08fc"},
+ {file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:485246fcdb93336105c29a5cfbff8a226949db37b7473c89caa26c9bae52a242"},
+ {file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8de4578e838a9409b5b134a18ee820730e507b2d21700c14b71a2b0757396acc"},
+ {file = "shapely-2.0.4-cp39-cp39-win32.whl", hash = "sha256:9dab4c98acfb5fb85f5a20548b5c0abe9b163ad3525ee28822ffecb5c40e724c"},
+ {file = "shapely-2.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:31c19a668b5a1eadab82ff070b5a260478ac6ddad3a5b62295095174a8d26398"},
+ {file = "shapely-2.0.4.tar.gz", hash = "sha256:5dc736127fac70009b8d309a0eeb74f3e08979e530cf7017f2f507ef62e6cfb8"},
+]
+
+[[package]]
+name = "six"
+version = "1.16.0"
+requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
+summary = "Python 2 and 3 compatibility utilities"
+groups = ["default"]
+files = [
+ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
+ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
+]
+
+[[package]]
+name = "smmap"
+version = "5.0.1"
+requires_python = ">=3.7"
+summary = "A pure Python implementation of a sliding window memory map manager"
+groups = ["default"]
+files = [
+ {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"},
+ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"},
+]
+
+[[package]]
+name = "sympy"
+version = "1.12"
+requires_python = ">=3.8"
+summary = "Computer algebra system (CAS) in Python"
+groups = ["default"]
+dependencies = [
+ "mpmath>=0.19",
+]
+files = [
+ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
+ {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
+]
+
+[[package]]
+name = "tbb"
+version = "2021.12.0"
+summary = "Intel® oneAPI Threading Building Blocks (oneTBB)"
+groups = ["default"]
+marker = "platform_system == \"Windows\""
+files = [
+ {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"},
+ {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"},
+ {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"},
+ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
+]
+
+[[package]]
+name = "tenacity"
+version = "8.3.0"
+requires_python = ">=3.8"
+summary = "Retry code until it succeeds"
+groups = ["default"]
+files = [
+ {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"},
+ {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"},
+]
+
+[[package]]
+name = "tomli"
+version = "2.0.1"
+requires_python = ">=3.7"
+summary = "A lil' TOML parser"
+groups = ["dev"]
+marker = "python_version < \"3.11\""
+files = [
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+]
+
+[[package]]
+name = "torch"
+version = "2.3.0"
+requires_python = ">=3.8.0"
+summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
+groups = ["default"]
+dependencies = [
+ "filelock",
+ "fsspec",
+ "jinja2",
+ "mkl<=2021.4.0,>=2021.1.1; platform_system == \"Windows\"",
+ "networkx",
+ "nvidia-cublas-cu12==12.1.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cuda-cupti-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cuda-runtime-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cudnn-cu12==8.9.2.26; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cufft-cu12==11.0.2.54; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-curand-cu12==10.3.2.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cusolver-cu12==11.4.5.107; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-cusparse-cu12==12.1.0.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-nccl-cu12==2.20.5; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "nvidia-nvtx-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
+ "sympy",
+ "triton==2.3.0; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"",
+ "typing-extensions>=4.8.0",
+]
+files = [
+ {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"},
+ {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"},
+ {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"},
+ {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"},
+ {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"},
+ {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"},
+ {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"},
+ {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"},
+ {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"},
+ {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"},
+ {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"},
+ {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"},
+ {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"},
+ {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"},
+ {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"},
+ {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"},
+]
+
+[[package]]
+name = "torchmetrics"
+version = "1.4.0.post0"
+requires_python = ">=3.8"
+summary = "PyTorch native Metrics"
+groups = ["default"]
+dependencies = [
+ "lightning-utilities>=0.8.0",
+ "numpy>1.20.0",
+ "packaging>17.1",
+ "torch>=1.10.0",
+]
+files = [
+ {file = "torchmetrics-1.4.0.post0-py3-none-any.whl", hash = "sha256:ab234216598e3fbd8d62ee4541a0e74e7e8fc935d099683af5b8da50f745b3c8"},
+ {file = "torchmetrics-1.4.0.post0.tar.gz", hash = "sha256:ab9bcfe80e65dbabbddb6cecd9be21f1f1d5207bb74051ef95260740f2762358"},
+]
+
+[[package]]
+name = "tqdm"
+version = "4.66.4"
+requires_python = ">=3.7"
+summary = "Fast, Extensible Progress Meter"
+groups = ["default"]
+dependencies = [
+ "colorama; platform_system == \"Windows\"",
+]
+files = [
+ {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"},
+ {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"},
+]
+
+[[package]]
+name = "triton"
+version = "2.3.0"
+summary = "A language and compiler for custom Deep Learning operations"
+groups = ["default"]
+marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""
+dependencies = [
+ "filelock",
+]
+files = [
+ {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"},
+ {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"},
+ {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"},
+ {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"},
+]
+
+[[package]]
+name = "tueplots"
+version = "0.0.15"
+requires_python = ">=3.9"
+summary = "Scientific plotting made easy."
+groups = ["default"]
+dependencies = [
+ "matplotlib",
+ "numpy",
+]
+files = [
+ {file = "tueplots-0.0.15-py3-none-any.whl", hash = "sha256:f63e020af88328c78618f3d912612c75c3c91d21004a88fd12cf79dbd9b6d78a"},
+]
+
+[[package]]
+name = "typing-extensions"
+version = "4.11.0"
+requires_python = ">=3.8"
+summary = "Backported and Experimental Type Hints for Python 3.8+"
+groups = ["default"]
+files = [
+ {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
+ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
+]
+
+[[package]]
+name = "urllib3"
+version = "2.2.1"
+requires_python = ">=3.8"
+summary = "HTTP library with thread-safe connection pooling, file post, and more."
+groups = ["default"]
+files = [
+ {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
+ {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},
+]
+
+[[package]]
+name = "virtualenv"
+version = "20.26.2"
+requires_python = ">=3.7"
+summary = "Virtual Python Environment builder"
+groups = ["dev"]
+dependencies = [
+ "distlib<1,>=0.3.7",
+ "filelock<4,>=3.12.2",
+ "platformdirs<5,>=3.9.1",
+]
+files = [
+ {file = "virtualenv-20.26.2-py3-none-any.whl", hash = "sha256:a624db5e94f01ad993d476b9ee5346fdf7b9de43ccaee0e0197012dc838a0e9b"},
+ {file = "virtualenv-20.26.2.tar.gz", hash = "sha256:82bf0f4eebbb78d36ddaee0283d43fe5736b53880b8a8cdcd37390a07ac3741c"},
+]
+
+[[package]]
+name = "wandb"
+version = "0.17.0"
+requires_python = ">=3.7"
+summary = "A CLI and library for interacting with the Weights & Biases API."
+groups = ["default"]
+dependencies = [
+ "click!=8.0.0,>=7.1",
+ "docker-pycreds>=0.4.0",
+ "gitpython!=3.1.29,>=1.0.0",
+ "platformdirs",
+ "protobuf!=4.21.0,<5,>=3.15.0; python_version == \"3.9\" and sys_platform == \"linux\"",
+ "protobuf!=4.21.0,<5,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"",
+ "protobuf!=4.21.0,<5,>=3.19.0; sys_platform != \"linux\"",
+ "psutil>=5.0.0",
+ "pyyaml",
+ "requests<3,>=2.0.0",
+ "sentry-sdk>=1.0.0",
+ "setproctitle",
+ "setuptools",
+ "typing-extensions; python_version < \"3.10\"",
+]
+files = [
+ {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"},
+ {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"},
+ {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"},
+ {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"},
+ {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"},
+ {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"},
+ {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"},
+]
+
+[[package]]
+name = "yarl"
+version = "1.9.4"
+requires_python = ">=3.7"
+summary = "Yet another URL library"
+groups = ["default"]
+dependencies = [
+ "idna>=2.0",
+ "multidict>=4.0",
+]
+files = [
+ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"},
+ {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"},
+ {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"},
+ {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"},
+ {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"},
+ {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"},
+ {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"},
+ {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"},
+ {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"},
+ {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"},
+ {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"},
+ {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"},
+ {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"},
+ {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"},
+ {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"},
+ {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"},
+ {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"},
+ {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"},
+ {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"},
+ {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"},
+ {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"},
+ {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"},
+ {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"},
+ {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"},
+ {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"},
+ {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"},
+]
+
+[[package]]
+name = "zipp"
+version = "3.18.2"
+requires_python = ">=3.8"
+summary = "Backport of pathlib-compatible object wrapper for zip files"
+groups = ["default"]
+marker = "python_version < \"3.10\""
+files = [
+ {file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
+ {file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
+]
diff --git a/pyproject.toml b/pyproject.toml
index b513a258..0a25868c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,36 @@
+[project]
+name = "neural-lam"
+version = "0.1.0"
+description = "LAM-based data-driven forecasting"
+authors = [
+ {name = "Joel Oskarsson", email = "joel.oskarsson@liu.se"},
+ {name = "Simon Adamov", email = "Simon.Adamov@meteoswiss.ch"},
+ {name = "Leif Denby", email = "lcd@dmi.dk"},
+]
+
+# PEP 621 project metadata
+# See https://www.python.org/dev/peps/pep-0621/
+dependencies = [
+ "numpy>=1.24.2",
+ "wandb>=0.13.10",
+ "scipy>=1.10.0",
+ "pytorch-lightning>=2.0.3",
+ "shapely>=2.0.1",
+ "networkx>=3.0",
+ "Cartopy>=0.22.0",
+ "pyproj>=3.4.1",
+ "tueplots>=0.0.8",
+ "matplotlib>=3.7.0",
+ "plotly>=5.15.0",
+]
+requires-python = ">=3.9"
+
+[tool.pdm.dev-dependencies]
+dev = [
+ "pre-commit>=2.15.0",
+ "pytest>=8.2.1",
+]
+
[tool.black]
line-length = 80
@@ -63,3 +96,7 @@ max-statements=100 # Allow for some more involved functions
allow-any-import-level="neural_lam"
[tool.pylint.SIMILARITIES]
min-similarity-lines=10
+
+[build-system]
+requires = ["pdm-backend"]
+build-backend = "pdm.backend"
diff --git a/tests/test_imports.py b/tests/test_imports.py
new file mode 100644
index 00000000..e7bbd356
--- /dev/null
+++ b/tests/test_imports.py
@@ -0,0 +1,8 @@
+# First-party
+import neural_lam
+import neural_lam.vis
+
+
+def test_import():
+ assert neural_lam is not None
+ assert neural_lam.vis is not None
From afd6012731a3f30abaaf29f97cc4710c6bf1ba1a Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 22 May 2024 14:11:21 +0200
Subject: [PATCH 047/273] add cicd testing workflow
---
.github/workflows/ci-tests.yml | 33 +++++++++++++++++++++++++++++++++
1 file changed, 33 insertions(+)
create mode 100644 .github/workflows/ci-tests.yml
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
new file mode 100644
index 00000000..9b73f298
--- /dev/null
+++ b/.github/workflows/ci-tests.yml
@@ -0,0 +1,33 @@
+# cicd workflow for running tests with pytest
+# needs to first install pdm, then install torch cpu manually and then install the package
+# then run the tests
+
+name: tests (cpu)
+
+on: [push, pull_request]
+
+jobs:
+ tests:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+ - name: Install pdm
+ uses: pdm-project/setup-pdm@v4
+ with:
+ python-version: "3.10"
+ cache: true
+
+ - name: Install torch (CPU)
+ run: |
+ python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
+
+ - name: Install package (including dev dependencies)
+ run: |
+ pdm install
+ pdm install --dev
+
+ - name: Run tests
+ run: |
+ pdm run pytest
From 4013796b2fc4a7e3ad1eec21f122d1c1af1170c4 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 22 May 2024 14:19:53 +0200
Subject: [PATCH 048/273] test both with pdm and pip install
---
.github/workflows/ci-pdm-install-and-test.yml | 33 +++++++++++++++++++
.github/workflows/ci-pip-install-and-test.yml | 27 +++++++++++++++
2 files changed, 60 insertions(+)
create mode 100644 .github/workflows/ci-pdm-install-and-test.yml
create mode 100644 .github/workflows/ci-pip-install-and-test.yml
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
new file mode 100644
index 00000000..a85d7fae
--- /dev/null
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -0,0 +1,33 @@
+# cicd workflow for running tests with pytest
+# needs to first install pdm, then install torch cpu manually and then install the package
+# then run the tests
+
+name: tests (pdm install, cpu)
+
+on: [push, pull_request]
+
+jobs:
+ tests:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+ - name: Install pdm
+ uses: pdm-project/setup-pdm@v4
+ with:
+ python-version: "3.10"
+ cache: true
+
+ - name: Install torch (CPU)
+ run: |
+ python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
+
+ - name: Install package (including dev dependencies)
+ run: |
+ pdm install
+ pdm install --dev
+
+ - name: Run tests
+ run: |
+ pdm run pytest
diff --git a/.github/workflows/ci-pip-install-and-test.yml b/.github/workflows/ci-pip-install-and-test.yml
new file mode 100644
index 00000000..66ac95ac
--- /dev/null
+++ b/.github/workflows/ci-pip-install-and-test.yml
@@ -0,0 +1,27 @@
+# cicd workflow for running tests with pytest
+# needs to first install pdm, then install torch cpu manually and then install the package
+# then run the tests
+
+name: test (pip install, cpu)
+
+on: [push, pull_request]
+
+jobs:
+ tests:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+ - name: Install torch (CPU)
+ run: |
+ python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
+
+ - name: Install package (including dev dependencies)
+ run: |
+ python -m pip install .
+ python -m pip install pytest
+
+ - name: Run tests
+ run: |
+ python -m pytest
From de72b95ad88bf952f3711685f2183c9f2f41ed60 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 22 May 2024 14:20:48 +0200
Subject: [PATCH 049/273] clean up test cicd
---
.github/workflows/ci-pdm-install-and-test.yml | 2 +-
.github/workflows/ci-tests.yml | 33 -------------------
2 files changed, 1 insertion(+), 34 deletions(-)
delete mode 100644 .github/workflows/ci-tests.yml
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index a85d7fae..20b5fc14 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -2,7 +2,7 @@
# needs to first install pdm, then install torch cpu manually and then install the package
# then run the tests
-name: tests (pdm install, cpu)
+name: test (pdm install, cpu)
on: [push, pull_request]
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
deleted file mode 100644
index 9b73f298..00000000
--- a/.github/workflows/ci-tests.yml
+++ /dev/null
@@ -1,33 +0,0 @@
-# cicd workflow for running tests with pytest
-# needs to first install pdm, then install torch cpu manually and then install the package
-# then run the tests
-
-name: tests (cpu)
-
-on: [push, pull_request]
-
-jobs:
- tests:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Install pdm
- uses: pdm-project/setup-pdm@v4
- with:
- python-version: "3.10"
- cache: true
-
- - name: Install torch (CPU)
- run: |
- python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
-
- - name: Install package (including dev dependencies)
- run: |
- pdm install
- pdm install --dev
-
- - name: Run tests
- run: |
- pdm run pytest
From 4d78c681c9e9bc7c32e0a12c4c5cf862d865a667 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 23 May 2024 11:40:57 +0200
Subject: [PATCH 050/273] remove requirements.txt
---
requirements.txt | 15 ---------------
1 file changed, 15 deletions(-)
delete mode 100644 requirements.txt
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index f381d54f..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,15 +0,0 @@
-# for all
-numpy>=1.24.2
-wandb>=0.13.10
-matplotlib>=3.7.0
-scipy>=1.10.0
-pytorch-lightning>=2.0.3
-shapely>=2.0.1
-networkx>=3.0
-Cartopy>=0.22.0
-pyproj>=3.4.1
-tueplots>=0.0.8
-plotly>=5.15.0
-
-# for dev
-pre-commit>=2.15.0
From f2cbc44e9e7661fe7fa43afba27204721ac38368 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 23 May 2024 14:23:27 +0200
Subject: [PATCH 051/273] create 3D mesh objects for schol AR
---
plot_graph.py | 112 ++++++++++++++++++++++++++++++++++----------------
1 file changed, 77 insertions(+), 35 deletions(-)
diff --git a/plot_graph.py b/plot_graph.py
index e246200d..50c54e06 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -5,19 +5,78 @@
import numpy as np
import plotly.graph_objects as go
import torch_geometric as pyg
+import trimesh
+from tqdm import tqdm
+from trimesh.primitives import Box
# First-party
from neural_lam import utils
MESH_HEIGHT = 0.1
-MESH_LEVEL_DIST = 0.2
+MESH_LEVEL_DIST = 0.05
GRID_HEIGHT = 0
+def create_cubes_for_nodes(nodes, size=0.002):
+ """Create cubes for each node in the graph."""
+ cube_meshes = []
+ for node in tqdm(nodes, desc="Creating cubes"):
+ cube = Box(extents=[size, size, size])
+ cube.apply_translation(node)
+ cube_meshes.append(cube)
+ return cube_meshes
+
+
+def export_to_3d_model(node_pos, edge_plot_list, filename):
+ """Export the graph to a 3D model."""
+ paths = []
+ filtered_edge_plot_list = [
+ item for item in edge_plot_list if item[3] not in ["M2G", "G2M"]
+ ]
+
+ unique_node_indices = set()
+ for ei, _, _, _ in filtered_edge_plot_list:
+ unique_node_indices.update(ei.flatten())
+
+ filtered_node_positions = node_pos[np.array(list(unique_node_indices))]
+
+ for ei, _, _, _ in filtered_edge_plot_list:
+ edge_start = filtered_node_positions[ei[0]]
+ edge_end = filtered_node_positions[ei[1]]
+ for start, end in zip(edge_start, edge_end):
+ if not (np.isnan(start).any() or np.isnan(end).any()):
+ paths.append((start, end))
+
+ meshes = []
+ for start, end in tqdm(paths, desc="Creating meshes"):
+ offset_xyz = np.array([2e-4, 2e-4, 2e-4])
+ dummy_vertex = end + offset_xyz
+ vertices = [start, end, dummy_vertex]
+ faces = [[0, 1, 2]]
+ color_vertices = [[255, 179, 71], [255, 179, 71], [255, 179, 71]]
+ # color_faces = [[0, 0, 0]]
+
+ mesh = trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ # face_colors=color_faces,
+ vertex_colors=color_vertices,
+ )
+ meshes.append(mesh)
+
+ node_spheres = create_cubes_for_nodes(filtered_node_positions)
+
+ scene = trimesh.Scene()
+ for mesh in meshes:
+ scene.add_geometry(mesh)
+ for sphere in node_spheres:
+ scene.add_geometry(sphere)
+
+ scene.export(filename, file_type="ply")
+
+
def main():
- """
- Plot graph structure in 3D using plotly
- """
+ """Plot the graph."""
parser = ArgumentParser(description="Plot graph")
parser.add_argument(
"--graph",
@@ -42,16 +101,16 @@ def main():
default="neural_lam/data_config.yaml",
help="Path to data config file (default: neural_lam/data_config.yaml)",
)
+ parser.add_argument(
+ "--export",
+ type=str,
+ help="Name of .obj file to export 3D model to (default: None)",
+ )
args = parser.parse_args()
- # Load graph data
hierarchical, graph_ldict = utils.load_graph(args.graph)
- (
- g2m_edge_index,
- m2g_edge_index,
- m2m_edge_index,
- ) = (
+ g2m_edge_index, m2g_edge_index, m2m_edge_index = (
graph_ldict["g2m_edge_index"],
graph_ldict["m2g_edge_index"],
graph_ldict["m2m_edge_index"],
@@ -64,23 +123,20 @@ def main():
config_loader = utils.ConfigLoader(args.data_config)
xy = config_loader.get_nwp_xy()
- grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2) # (N_grid, 2)
+ grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2)
pos_max = np.max(np.abs(grid_xy))
- grid_pos = grid_xy / pos_max # Divide by maximum coordinate
+ grid_pos = grid_xy / pos_max
- # Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
grid_pos = np.concatenate(
(grid_pos, np.expand_dims(z_grid, axis=1)), axis=1
)
- # List of edges to plot, (edge_index, color, line_width, label)
edge_plot_list = [
(m2g_edge_index.numpy(), "black", 0.4, "M2G"),
(g2m_edge_index.numpy(), "black", 0.4, "G2M"),
]
- # Mesh positioning and edges to plot differ if we have a hierarchical graph
if hierarchical:
mesh_level_pos = [
np.concatenate(
@@ -99,13 +155,11 @@ def main():
]
mesh_pos = np.concatenate(mesh_level_pos, axis=0)
- # Add inter-level mesh edges
edge_plot_list += [
(level_ei.numpy(), "blue", 1, f"M2M Level {level}")
for level, level_ei in enumerate(m2m_edge_index)
]
- # Add intra-level mesh edges
up_edges_ei = np.concatenate(
[level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1
)
@@ -119,30 +173,20 @@ def main():
mesh_node_size = 2.5
else:
mesh_pos = mesh_static_features.numpy()
-
mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy()
z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees
mesh_node_size = mesh_degrees / 2
-
mesh_pos = np.concatenate(
(mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1
)
-
edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M"))
- # All node positions in one array
node_pos = np.concatenate((mesh_pos, grid_pos), axis=0)
- # Add edges
data_objs = []
- for (
- ei,
- col,
- width,
- label,
- ) in edge_plot_list:
- edge_start = node_pos[ei[0]] # (M, 2)
- edge_end = node_pos[ei[1]] # (M, 2)
+ for ei, col, width, label in edge_plot_list:
+ edge_start = node_pos[ei[0]]
+ edge_end = node_pos[ei[1]]
n_edges = edge_start.shape[0]
x_edges = np.stack(
@@ -165,8 +209,6 @@ def main():
)
data_objs.append(scatter_obj)
- # Add node objects
-
data_objs.append(
go.Scatter3d(
x=grid_pos[:, 0],
@@ -194,7 +236,6 @@ def main():
fig.update_traces(connectgaps=False)
if not args.show_axis:
- # Hide axis
fig.update_layout(
scene={
"xaxis": {"visible": False},
@@ -205,8 +246,9 @@ def main():
if args.save:
fig.write_html(args.save, include_plotlyjs="cdn")
- else:
- fig.show()
+
+ if args.export:
+ export_to_3d_model(node_pos, edge_plot_list, args.export)
if __name__ == "__main__":
From c0e7529d64ed5ebf46c1c9205c05037fb90e22da Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 28 May 2024 14:22:22 +0200
Subject: [PATCH 052/273] fixed math writing
---
neural_lam/data_config.yaml | 24 ++++++++++++------------
1 file changed, 12 insertions(+), 12 deletions(-)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 140eb9b7..a4417a65 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -50,12 +50,12 @@ state: # Variables forecasted by the model
- V_10M
surface_units:
- "%"
- - r"$\mathrm{Pa}$"
- - r"$\mathrm{Pa}$"
- - r"$\mathrm{K}$"
- - r"$\mathrm{kg}/\mathrm{m}^2$"
- - r"$\mathrm{m}/\mathrm{s}$"
- - r"$\mathrm{m}/\mathrm{s}$"
+ - $\mathrm{Pa}$
+ - $\mathrm{Pa}$
+ - $\mathrm{K}$
+ - $\mathrm{kg}/\mathrm{m}^2$
+ - $\mathrm{m}/\mathrm{s}$
+ - $\mathrm{m}/\mathrm{s}$
atmosphere: # Variables with vertical levels
- PP
- QV
@@ -65,13 +65,13 @@ state: # Variables forecasted by the model
- V
- W
atmosphere_units:
- - r"$\mathrm{Pa}$"
- - r"$\mathrm{kg}/\mathrm{kg}$"
+ - $\mathrm{Pa}$
+ - $\mathrm{kg}/\mathrm{kg}$
- "%"
- - r"$\mathrm{K}$"
- - r"$\mathrm{m}/\mathrm{s}$"
- - r"$\mathrm{m}/\mathrm{s}$"
- - r"$\mathrm{Pa}/\mathrm{s}$"
+ - $\mathrm{K}$
+ - $\mathrm{m}/\mathrm{s}$
+ - $\mathrm{m}/\mathrm{s}$
+ - $\mathrm{Pa}/\mathrm{s}$
levels: # Levels to use for atmosphere variables
- 0
- 5
From 5f538f909854b61e332a5938905f686a72a88349 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 28 May 2024 15:43:10 +0200
Subject: [PATCH 053/273] cherry-pick with main
---
.github/workflows/pre-commit.yml | 36 +++---
.pre-commit-config.yaml | 66 +++++------
CHANGELOG.md | 72 ++++++++++++
README.md | 5 +-
create_mesh.py | 19 ++-
neural_lam/config.py | 192 +++++++++++++++++++++++++++++++
neural_lam/models/ar_model.py | 5 +-
neural_lam/utils.py | 189 ------------------------------
neural_lam/weather_dataset.py | 4 +-
plot_graph.py | 18 +--
requirements.txt | 5 -
11 files changed, 329 insertions(+), 282 deletions(-)
create mode 100644 CHANGELOG.md
create mode 100644 neural_lam/config.py
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index a6ad84f1..dc519e5b 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -1,33 +1,25 @@
-name: Run pre-commit job
+name: lint
on:
- push:
+ # trigger on pushes to any branch, but not main
+ push:
+ branches-ignore:
+ - main
+ # and also on PRs to main
+ pull_request:
branches:
- - main
- pull_request:
- branches:
- - main
+ - main
jobs:
- pre-commit-job:
+ pre-commit-job:
runs-on: ubuntu-latest
- defaults:
- run:
- shell: bash -l {0}
+ strategy:
+ matrix:
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
- python-version: 3.9
- - name: Install pre-commit hooks
- run: |
- pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \
- --index-url https://download.pytorch.org/whl/cpu
- pip install -r requirements.txt
- pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \
- torch-cluster==1.6.1 torch-geometric==2.3.1 \
- -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html
- - name: Run pre-commit hooks
- run: |
- pre-commit run --all-files
+ python-version: ${{ matrix.python-version }}
+ - uses: pre-commit/action@v2.0.3
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f48eca67..815a92e1 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,51 +1,37 @@
repos:
-- repo: https://github.com/pre-commit/pre-commit-hooks
+ - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- - id: check-ast
- - id: check-case-conflict
- - id: check-docstring-first
- - id: check-symlinks
- - id: check-toml
- - id: check-yaml
- - id: debug-statements
- - id: end-of-file-fixer
- - id: trailing-whitespace
-- repo: local
+ - id: check-ast
+ - id: check-case-conflict
+ - id: check-docstring-first
+ - id: check-symlinks
+ - id: check-toml
+ - id: check-yaml
+ - id: debug-statements
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.2.6
hooks:
- - id: codespell
- name: codespell
+ - id: codespell
description: Check for spelling errors
- language: system
- entry: codespell
-- repo: local
+
+ - repo: https://github.com/psf/black
+ rev: 22.3.0
hooks:
- - id: black
- name: black
+ - id: black
description: Format Python code
- language: system
- entry: black
- types_or: [python, pyi]
-- repo: local
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
hooks:
- - id: isort
- name: isort
+ - id: isort
description: Group and sort Python imports
- language: system
- entry: isort
- types_or: [python, pyi, cython]
-- repo: local
+
+ - repo: https://github.com/PyCQA/flake8
+ rev: 7.0.0
hooks:
- - id: flake8
- name: flake8
+ - id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
- language: system
- entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503
- types: [python]
-- repo: local
- hooks:
- - id: pylint
- name: pylint
- entry: pylint -rn -sn
- language: system
- types: [python]
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 00000000..63feff96
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,72 @@
+# Changelog
+
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
+and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD)
+
+### Added
+
+- Replaced `constants.py` with `data_config.yaml` for data configuration management
+ [\#31](https://github.com/joeloskarsson/neural-lam/pull/31)
+ @sadamov
+
+- new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option
+ [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a)
+ @joeloskarsson
+
+- ability to "watch" metrics and log
+ [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a)
+ @joeloskarsson
+
+- pre-commit setup for linting and formatting
+ [\#6](https://github.com/joeloskarsson/neural-lam/pull/6), [\#8](https://github.com/joeloskarsson/neural-lam/pull/8)
+ @sadamov, @joeloskarsson
+
+### Changed
+
+- Updated scripts and modules to use `data_config.yaml` instead of `constants.py`
+ [\#31](https://github.com/joeloskarsson/neural-lam/pull/31)
+ @sadamov
+
+- Added new flags in `train_model.py` for configuration previously in `constants.py`
+ [\#31](https://github.com/joeloskarsson/neural-lam/pull/31)
+ @sadamov
+
+- moved batch-static features ("water cover") into forcing component return by `WeatherDataset`
+ [\#13](https://github.com/joeloskarsson/neural-lam/pull/13)
+ @joeloskarsson
+
+- change validation metric from `mae` to `rmse`
+ [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a)
+ @joeloskarsson
+
+- change RMSE definition to compute sqrt after all averaging
+ [\#10](https://github.com/joeloskarsson/neural-lam/pull/10)
+ @joeloskarsson
+
+### Removed
+
+- `WeatherDataset(torch.Dataset)` no longer returns "batch-static" component of
+ training item (only `prev_state`, `target_state` and `forcing`), the batch static features are
+ instead included in forcing
+ [\#13](https://github.com/joeloskarsson/neural-lam/pull/13)
+ @joeloskarsson
+
+### Maintenance
+
+- simplify pre-commit setup by 1) reducing linting to only cover static
+ analysis excluding imports from external dependencies (this will be handled
+ in build/test cicd action introduced later), 2) pinning versions of linting
+ tools in pre-commit config (and remove from `requirements.txt`) and 3) using
+ github action to run pre-commit.
+ [\#29](https://github.com/mllam/neural-lam/pull/29)
+ @leifdenby
+
+
+## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0)
+
+First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication
+()
diff --git a/README.md b/README.md
index 67d9d9b1..ba0bb3fe 100644
--- a/README.md
+++ b/README.md
@@ -45,7 +45,7 @@ Still, some restrictions are inevitable:
## A note on the limited area setting
Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
There are still some parts of the code that is quite specific for the MEPS area use case.
-This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`).
+This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ).
If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
We would be happy to support such enhancements.
See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
@@ -104,13 +104,12 @@ The graph-related files are stored in a directory called `graphs`.
### Create remaining static features
To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`.
-The main option to set for these is just which dataset to use.
## Weights & Biases Integration
The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it.
When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface.
If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`.
-The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`.
+The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse).
See the [W&B documentation](https://docs.wandb.ai/) for details.
If you would like to login and use W&B, run:
diff --git a/create_mesh.py b/create_mesh.py
index 2b6af9fd..da881594 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -13,9 +13,7 @@
from torch_geometric.utils.convert import from_networkx
# First-party
-from neural_lam import utils
-
-# matplotlib.use('TkAgg')
+from neural_lam import config
def plot_graph(graph, title=None):
@@ -157,6 +155,12 @@ def prepend_node_index(graph, new_index):
def main():
parser = ArgumentParser(description="Graph generation arguments")
+ parser.add_argument(
+ "--data_config",
+ type=str,
+ default="neural_lam/data_config.yaml",
+ help="Path to data config file (default: neural_lam/data_config.yaml)",
+ )
parser.add_argument(
"--graph",
type=str,
@@ -182,20 +186,13 @@ def main():
default=0,
help="Generate hierarchical mesh graph (default: 0, no)",
)
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
-
args = parser.parse_args()
# Load grid positions
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)
- config_loader = utils.ConfigLoader(args.data_config)
+ config_loader = config.Config(args.data_config)
xy = config_loader.get_nwp_xy()
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
diff --git a/neural_lam/config.py b/neural_lam/config.py
new file mode 100644
index 00000000..819ce2aa
--- /dev/null
+++ b/neural_lam/config.py
@@ -0,0 +1,192 @@
+
+import os
+
+import cartopy.crs as ccrs
+import numpy as np
+import xarray as xr
+import yaml
+
+
+class Config:
+ """
+ Class for loading configuration files.
+
+ This class loads a YAML configuration file and provides a way to access
+ its values as attributes.
+ """
+
+ def __init__(self, config_path, values=None):
+ self.config_path = config_path
+ if values is None:
+ self.values = self.load_config()
+ else:
+ self.values = values
+
+ def load_config(self):
+ """Load configuration file."""
+ with open(self.config_path, encoding="utf-8", mode="r") as file:
+ return yaml.safe_load(file)
+
+ def __getattr__(self, name):
+ keys = name.split(".")
+ value = self.values
+ for key in keys:
+ if key in value:
+ value = value[key]
+ else:
+ return None
+ if isinstance(value, dict):
+ return Config(None, values=value)
+ return value
+
+ def __getitem__(self, key):
+ value = self.values[key]
+ if isinstance(value, dict):
+ return Config(None, values=value)
+ return value
+
+ def __contains__(self, key):
+ return key in self.values
+
+ def param_names(self):
+ """Return parameter names."""
+ surface_names = self.values["state"]["surface"]
+ atmosphere_names = [
+ f"{var}_{level}"
+ for var in self.values["state"]["atmosphere"]
+ for level in self.values["state"]["levels"]
+ ]
+ return surface_names + atmosphere_names
+
+ def param_units(self):
+ """Return parameter units."""
+ surface_units = self.values["state"]["surface_units"]
+ atmosphere_units = [
+ unit
+ for unit in self.values["state"]["atmosphere_units"]
+ for _ in self.values["state"]["levels"]
+ ]
+ return surface_units + atmosphere_units
+
+ def num_data_vars(self, key):
+ """Return the number of data variables for a given key."""
+ surface_vars = len(self.values[key]["surface"])
+ atmosphere_vars = len(self.values[key]["atmosphere"])
+ levels = len(self.values[key]["levels"])
+ return surface_vars + atmosphere_vars * levels
+
+ def projection(self):
+ """Return the projection."""
+ proj_config = self.values["projections"]["class"]
+ proj_class = getattr(ccrs, proj_config["proj_class"])
+ proj_params = proj_config["proj_params"]
+ return proj_class(**proj_params)
+
+ def open_zarr(self, dataset_name):
+ """Open a dataset specified by the dataset name."""
+ dataset_path = self.zarrs[dataset_name].path
+ if dataset_path is None or not os.path.exists(dataset_path):
+ print(f"Dataset '{dataset_name}' not found at path: {dataset_path}")
+ return None
+ dataset = xr.open_zarr(dataset_path, consolidated=True)
+ return dataset
+
+ def load_normalization_stats(self):
+ """Load normalization statistics from Zarr archive."""
+ normalization_path = self.normalization.zarr
+ if not os.path.exists(normalization_path):
+ print(
+ f"Normalization statistics not found at "
+ f"path: {normalization_path}"
+ )
+ return None
+ normalization_stats = xr.open_zarr(
+ normalization_path, consolidated=True
+ )
+ return normalization_stats
+
+ def process_dataset(self, dataset_name, split="train", stack=True):
+ """Process a single dataset specified by the dataset name."""
+
+ dataset = self.open_zarr(dataset_name)
+ if dataset is None:
+ return None
+
+ start, end = (
+ self.splits[split].start,
+ self.splits[split].end,
+ )
+ dataset = dataset.sel(time=slice(start, end))
+ dataset = dataset.rename_dims(
+ {
+ v: k
+ for k, v in self.zarrs[dataset_name].dims.values.items()
+ if k not in dataset.dims
+ }
+ )
+
+ vars_surface = []
+ if self[dataset_name].surface:
+ vars_surface = dataset[self[dataset_name].surface]
+
+ vars_atmosphere = []
+ if self[dataset_name].atmosphere:
+ vars_atmosphere = xr.merge(
+ [
+ dataset[var]
+ .sel(level=level, drop=True)
+ .rename(f"{var}_{level}")
+ for var in self[dataset_name].atmosphere
+ for level in self[dataset_name].levels
+ ]
+ )
+
+ if vars_surface and vars_atmosphere:
+ dataset = xr.merge([vars_surface, vars_atmosphere])
+ elif vars_surface:
+ dataset = vars_surface
+ elif vars_atmosphere:
+ dataset = vars_atmosphere
+ else:
+ print(f"No variables found in dataset {dataset_name}")
+ return None
+
+ if not all(
+ lat_lon in self.zarrs[dataset_name].dims.values.values()
+ for lat_lon in self.zarrs[
+ dataset_name
+ ].lat_lon_names.values.values()
+ ):
+ lat_name = self.zarrs[dataset_name].lat_lon_names.lat
+ lon_name = self.zarrs[dataset_name].lat_lon_names.lon
+ if dataset[lat_name].ndim == 2:
+ dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
+ if dataset[lon_name].ndim == 2:
+ dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True)
+ dataset = dataset.assign_coords(
+ x=dataset[lon_name], y=dataset[lat_name]
+ )
+
+ if stack:
+ dataset = self.stack_grid(dataset)
+
+ return dataset
+
+ def stack_grid(self, dataset):
+ """Stack grid dimensions."""
+ dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
+
+ if "time" in dataset.dims:
+ dataset = dataset.transpose("time", "grid", "variable")
+ else:
+ dataset = dataset.transpose("grid", "variable")
+ return dataset
+
+ def get_nwp_xy(self):
+ """Get the x and y coordinates for the NWP grid."""
+ x = self.process_dataset("static", stack=False).x.values
+ y = self.process_dataset("static", stack=False).y.values
+ xx, yy = np.meshgrid(y, x)
+ xy = np.stack((xx, yy), axis=0)
+
+ return xy
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index f49eb094..fff28632 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -6,10 +6,11 @@
import numpy as np
import pytorch_lightning as pl
import torch
+
import wandb
# First-party
-from neural_lam import metrics, utils, vis
+from neural_lam import metrics, vis
class ARModel(pl.LightningModule):
@@ -25,7 +26,7 @@ def __init__(self, args):
super().__init__()
self.save_hyperparameters()
self.args = args
- self.config_loader = utils.ConfigLoader(args.data_config)
+ self.config_loader = config.Config(args.data_config)
# Load static features for grid/data
static = self.config_loader.process_dataset("static")
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 96e1549e..18584d2e 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -2,11 +2,7 @@
import os
# Third-party
-import cartopy.crs as ccrs
-import numpy as np
import torch
-import xarray as xr
-import yaml
from torch import nn
from tueplots import bundles, figsizes
@@ -197,188 +193,3 @@ def init_wandb_metrics(wandb_logger, val_steps):
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")
-
-
-class ConfigLoader:
- """
- Class for loading configuration files.
-
- This class loads a YAML configuration file and provides a way to access
- its values as attributes.
- """
-
- def __init__(self, config_path, values=None):
- self.config_path = config_path
- if values is None:
- self.values = self.load_config()
- else:
- self.values = values
-
- def load_config(self):
- """Load configuration file."""
- with open(self.config_path, encoding="utf-8", mode="r") as file:
- return yaml.safe_load(file)
-
- def __getattr__(self, name):
- keys = name.split(".")
- value = self.values
- for key in keys:
- if key in value:
- value = value[key]
- else:
- return None
- if isinstance(value, dict):
- return ConfigLoader(None, values=value)
- return value
-
- def __getitem__(self, key):
- value = self.values[key]
- if isinstance(value, dict):
- return ConfigLoader(None, values=value)
- return value
-
- def __contains__(self, key):
- return key in self.values
-
- def param_names(self):
- """Return parameter names."""
- surface_names = self.values["state"]["surface"]
- atmosphere_names = [
- f"{var}_{level}"
- for var in self.values["state"]["atmosphere"]
- for level in self.values["state"]["levels"]
- ]
- return surface_names + atmosphere_names
-
- def param_units(self):
- """Return parameter units."""
- surface_units = self.values["state"]["surface_units"]
- atmosphere_units = [
- unit
- for unit in self.values["state"]["atmosphere_units"]
- for _ in self.values["state"]["levels"]
- ]
- return surface_units + atmosphere_units
-
- def num_data_vars(self, key):
- """Return the number of data variables for a given key."""
- surface_vars = len(self.values[key]["surface"])
- atmosphere_vars = len(self.values[key]["atmosphere"])
- levels = len(self.values[key]["levels"])
- return surface_vars + atmosphere_vars * levels
-
- def projection(self):
- """Return the projection."""
- proj_config = self.values["projections"]["class"]
- proj_class = getattr(ccrs, proj_config["proj_class"])
- proj_params = proj_config["proj_params"]
- return proj_class(**proj_params)
-
- def open_zarr(self, dataset_name):
- """Open a dataset specified by the dataset name."""
- dataset_path = self.zarrs[dataset_name].path
- if dataset_path is None or not os.path.exists(dataset_path):
- print(f"Dataset '{dataset_name}' not found at path: {dataset_path}")
- return None
- dataset = xr.open_zarr(dataset_path, consolidated=True)
- return dataset
-
- def load_normalization_stats(self):
- """Load normalization statistics from Zarr archive."""
- normalization_path = self.normalization.zarr
- if not os.path.exists(normalization_path):
- print(
- f"Normalization statistics not found at "
- f"path: {normalization_path}"
- )
- return None
- normalization_stats = xr.open_zarr(
- normalization_path, consolidated=True
- )
- return normalization_stats
-
- def process_dataset(self, dataset_name, split="train", stack=True):
- """Process a single dataset specified by the dataset name."""
-
- dataset = self.open_zarr(dataset_name)
- if dataset is None:
- return None
-
- start, end = (
- self.splits[split].start,
- self.splits[split].end,
- )
- dataset = dataset.sel(time=slice(start, end))
- dataset = dataset.rename_dims(
- {
- v: k
- for k, v in self.zarrs[dataset_name].dims.values.items()
- if k not in dataset.dims
- }
- )
-
- vars_surface = []
- if self[dataset_name].surface:
- vars_surface = dataset[self[dataset_name].surface]
-
- vars_atmosphere = []
- if self[dataset_name].atmosphere:
- vars_atmosphere = xr.merge(
- [
- dataset[var]
- .sel(level=level, drop=True)
- .rename(f"{var}_{level}")
- for var in self[dataset_name].atmosphere
- for level in self[dataset_name].levels
- ]
- )
-
- if vars_surface and vars_atmosphere:
- dataset = xr.merge([vars_surface, vars_atmosphere])
- elif vars_surface:
- dataset = vars_surface
- elif vars_atmosphere:
- dataset = vars_atmosphere
- else:
- print(f"No variables found in dataset {dataset_name}")
- return None
-
- if not all(
- lat_lon in self.zarrs[dataset_name].dims.values.values()
- for lat_lon in self.zarrs[
- dataset_name
- ].lat_lon_names.values.values()
- ):
- lat_name = self.zarrs[dataset_name].lat_lon_names.lat
- lon_name = self.zarrs[dataset_name].lat_lon_names.lon
- if dataset[lat_name].ndim == 2:
- dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
- if dataset[lon_name].ndim == 2:
- dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True)
- dataset = dataset.assign_coords(
- x=dataset[lon_name], y=dataset[lat_name]
- )
-
- if stack:
- dataset = self.stack_grid(dataset)
-
- return dataset
-
- def stack_grid(self, dataset):
- """Stack grid dimensions."""
- dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
-
- if "time" in dataset.dims:
- dataset = dataset.transpose("time", "grid", "variable")
- else:
- dataset = dataset.transpose("grid", "variable")
- return dataset
-
- def get_nwp_xy(self):
- """Get the x and y coordinates for the NWP grid."""
- x = self.process_dataset("static", stack=False).x.values
- y = self.process_dataset("static", stack=False).y.values
- xx, yy = np.meshgrid(y, x)
- xy = np.stack((xx, yy), axis=0)
-
- return xy
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 4b5da0a8..6ce630c7 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -3,7 +3,7 @@
import torch
# First-party
-from neural_lam import utils
+from neural_lam import config
class WeatherDataset(torch.utils.data.Dataset):
@@ -35,7 +35,7 @@ def __init__(
self.batch_size = batch_size
self.ar_steps = ar_steps
self.control_only = control_only
- self.config_loader = utils.ConfigLoader(data_config)
+ self.config_loader = config.Config(data_config)
self.state = self.config_loader.process_dataset("state", self.split)
assert self.state is not None, "State dataset not found"
diff --git a/plot_graph.py b/plot_graph.py
index 50c54e06..9b465fd4 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -76,8 +76,16 @@ def export_to_3d_model(node_pos, edge_plot_list, filename):
def main():
- """Plot the graph."""
+ """
+ Plot graph structure in 3D using plotly
+ """
parser = ArgumentParser(description="Plot graph")
+ parser.add_argument(
+ "--data_config",
+ type=str,
+ default="neural_lam/data_config.yaml",
+ help="Path to data config file (default: neural_lam/data_config.yaml)",
+ )
parser.add_argument(
"--graph",
type=str,
@@ -95,12 +103,6 @@ def main():
default=0,
help="If the axis should be displayed (default: 0 (No))",
)
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
parser.add_argument(
"--export",
type=str,
@@ -121,7 +123,7 @@ def main():
)
mesh_static_features = graph_ldict["mesh_static_features"]
- config_loader = utils.ConfigLoader(args.data_config)
+ config_loader = config.Config(args.data_config)
xy = config_loader.get_nwp_xy()
grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2)
pos_max = np.max(np.abs(grid_xy))
diff --git a/requirements.txt b/requirements.txt
index cb9bd425..70b97330 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,9 +14,4 @@ xarray>=0.20.1
zarr>=2.10.0
dask>=2022.0.0
# for dev
-codespell>=2.0.0
-black>=21.9b0
-isort>=5.9.3
-flake8>=4.0.1
-pylint>=3.0.3
pre-commit>=2.15.0
From 6685e94fea1fa84631cbadb884d16f4472451437 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 28 May 2024 16:10:57 +0200
Subject: [PATCH 054/273] bugfixes
---
neural_lam/config.py | 8 ++-
neural_lam/models/ar_model.py | 80 +++++++++++-------------
neural_lam/models/base_hi_graph_model.py | 12 ++--
plot_graph.py | 2 +-
4 files changed, 50 insertions(+), 52 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 819ce2aa..048241df 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -1,6 +1,7 @@
-
+# Standard library
import os
+# Third-party
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
@@ -86,7 +87,10 @@ def open_zarr(self, dataset_name):
"""Open a dataset specified by the dataset name."""
dataset_path = self.zarrs[dataset_name].path
if dataset_path is None or not os.path.exists(dataset_path):
- print(f"Dataset '{dataset_name}' not found at path: {dataset_path}")
+ print(
+ f"Dataset '{dataset_name}' "
+ f"not found at path: {dataset_path}"
+ )
return None
dataset = xr.open_zarr(dataset_path, consolidated=True)
return dataset
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index fff28632..fc78e638 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -6,21 +6,19 @@
import numpy as np
import pytorch_lightning as pl
import torch
-
import wandb
# First-party
-from neural_lam import metrics, vis
+from neural_lam import config, metrics, vis
class ARModel(pl.LightningModule):
"""
- Generic auto-regressive weather model.
- Abstract class that can be extended.
+ Generic auto-regressive weather model. Abstract class that can be extended.
"""
- # pylint: disable=arguments-differ
- # Disable to override args/kwargs from superclass
+ # pylint: disable=arguments-differ Disable to override args/kwargs from
+ # superclass
def __init__(self, args):
super().__init__()
@@ -127,18 +125,18 @@ def expand_to_batch(x, batch_size):
def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
- prev_state: (B, num_grid_nodes, feature_dim), X_t
- prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
- forcing: (B, num_grid_nodes, forcing_dim)
+ prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B,
+ num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes,
+ forcing_dim)
"""
raise NotImplementedError("No prediction step implemented")
def unroll_prediction(self, init_states, forcing_features, true_states):
"""
Roll out prediction taking multiple autoregressive steps with model
- init_states: (B, 2, num_grid_nodes, d_f)
- forcing_features: (B, pred_steps, num_grid_nodes, d_static_f)
- true_states: (B, pred_steps, num_grid_nodes, d_f)
+ init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B,
+ pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps,
+ num_grid_nodes, d_f)
"""
prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
@@ -153,8 +151,8 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing
)
- # state: (B, num_grid_nodes, d_f)
- # pred_std: (B, num_grid_nodes, d_f) or None
+ # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes,
+ # d_f) or None
# Overwrite border with true state
new_state = (
@@ -184,11 +182,10 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
def common_step(self, batch):
"""
- Predict on single batch
- batch consists of:
- init_states: (B, 2, num_grid_nodes, d_features)
- target_states: (B, pred_steps, num_grid_nodes, d_features)
- forcing_features: (B, pred_steps, num_grid_nodes, d_forcing),
+ Predict on single batch batch consists of: init_states: (B, 2,
+ num_grid_nodes, d_features) target_states: (B, pred_steps,
+ num_grid_nodes, d_features) forcing_features: (B, pred_steps,
+ num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
"""
(
@@ -200,8 +197,8 @@ def common_step(self, batch):
prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
) # (B, pred_steps, num_grid_nodes, d_f)
- # prediction: (B, pred_steps, num_grid_nodes, d_f)
- # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
+ # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
+ # pred_steps, num_grid_nodes, d_f) or (d_f,)
return prediction, target_states, pred_std
@@ -214,9 +211,8 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
forcing_features = (
forcing_features - self.forcing_mean
) / self.forcing_std
- # boundary_features = (
- # boundary_features - self.boundary_mean
- # ) / self.boundary_std
+ # boundary_features = ( boundary_features - self.boundary_mean ) /
+ # self.boundary_std
batch = (
init_states,
target_states,
@@ -246,8 +242,8 @@ def training_step(self, batch):
def all_gather_cat(self, tensor_to_gather):
"""
- Gather tensors across all ranks, and concatenate across dim. 0
- (instead of stacking in new dim. 0)
+ Gather tensors across all ranks, and concatenate across dim. 0 (instead
+ of stacking in new dim. 0)
tensor_to_gather: (d1, d2, ...), distributed over K ranks
@@ -308,8 +304,8 @@ def test_step(self, batch, batch_idx):
Run test on single batch
"""
prediction, target, pred_std = self.common_step(batch)
- # prediction: (B, pred_steps, num_grid_nodes, d_f)
- # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
+ # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
+ # pred_steps, num_grid_nodes, d_f) or (d_f,)
time_step_loss = torch.mean(
self.loss(
@@ -330,10 +326,9 @@ def test_step(self, batch, batch_idx):
test_log_dict, on_step=False, on_epoch=True, sync_dist=True
)
- # Compute all evaluation metrics for error maps
- # Note: explicitly list metrics here, as test_metrics can contain
- # additional ones, computed differently, but that should be aggregated
- # on_test_epoch_end
+ # Compute all evaluation metrics for error maps Note: explicitly list
+ # metrics here, as test_metrics can contain additional ones, computed
+ # differently, but that should be aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
@@ -378,9 +373,9 @@ def plot_examples(self, batch, n_examples, prediction=None):
"""
Plot the first n_examples forecasts from batch
- batch: batch with data to plot corresponding forecasts for
- n_examples: number of forecasts to plot
- prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction.
+ batch: batch with data to plot corresponding forecasts for n_examples:
+ number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes,
+ d_f), existing prediction.
Generate if None.
"""
if prediction is None:
@@ -470,15 +465,14 @@ def plot_examples(self, batch, n_examples, prediction=None):
def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
- Put together a dict with everything to log for one metric.
- Also saves plots as pdf and csv if using test prefix.
+ Put together a dict with everything to log for one metric. Also saves
+ plots as pdf and csv if using test prefix.
metric_tensor: (pred_steps, d_f), metric values per time and variable
- prefix: string, prefix to use for logging
- metric_name: string, name of the metric
+ prefix: string, prefix to use for logging metric_name: string, name of
+ the metric
- Return:
- log_dict: dict with everything to log for given metric
+ Return: log_dict: dict with everything to log for given metric
"""
log_dict = {}
metric_fig = vis.plot_error_map(
@@ -552,8 +546,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
def on_test_epoch_end(self):
"""
- Compute test metrics and make plots at the end of test epoch.
- Will gather stored tensors and perform plotting and logging on rank 0.
+ Compute test metrics and make plots at the end of test epoch. Will
+ gather stored tensors and perform plotting and logging on rank 0.
"""
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index 8ce87030..d9a4c676 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -179,9 +179,9 @@ def process_step(self, mesh_rep):
)
# Update node and edge vectors in lists
- mesh_rep_levels[level_l] = (
- new_node_rep # (B, num_mesh_nodes[l], d_h)
- )
+ mesh_rep_levels[
+ level_l
+ ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h)
# - PROCESSOR -
@@ -207,9 +207,9 @@ def process_step(self, mesh_rep):
new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep)
# Update node and edge vectors in lists
- mesh_rep_levels[level_l] = (
- new_node_rep # (B, num_mesh_nodes[l], d_h)
- )
+ mesh_rep_levels[
+ level_l
+ ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
# Return only bottom level representation
return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h)
diff --git a/plot_graph.py b/plot_graph.py
index 9b465fd4..d6f297d0 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -10,7 +10,7 @@
from trimesh.primitives import Box
# First-party
-from neural_lam import utils
+from neural_lam import config, utils
MESH_HEIGHT = 0.1
MESH_LEVEL_DIST = 0.05
From 6423fdf5d6cd3e2e76def5b1f309f12c0872e4b4 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 28 May 2024 16:31:11 +0200
Subject: [PATCH 055/273] pre_commits
---
create_mesh.py | 10 ++++++++--
neural_lam/models/ar_model.py | 25 +++++++++++++++++++------
neural_lam/models/base_graph_model.py | 9 ++++-----
neural_lam/models/graph_lam.py | 4 +++-
neural_lam/models/hi_lam.py | 19 +++++++++++--------
neural_lam/models/hi_lam_parallel.py | 11 +++++++----
neural_lam/utils.py | 7 +++++--
neural_lam/weather_dataset.py | 13 ++++++++++---
pyproject.toml | 4 ++--
9 files changed, 69 insertions(+), 33 deletions(-)
diff --git a/create_mesh.py b/create_mesh.py
index da881594..04d7468b 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -125,7 +125,11 @@ def mk_2d_graph(xy, nx, ny):
# add diagonal edges
g.add_edges_from(
- [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
+ [
+ ((x, y), (x + 1, y + 1))
+ for x in range(nx - 1)
+ for y in range(ny - 1)
+ ]
+ [
((x + 1, y), (x, y + 1))
for x in range(nx - 1)
@@ -343,7 +347,9 @@ def main():
.reshape(int(n / nx) ** 2, 2)
)
ij = [tuple(x) for x in ij]
- G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij)))
+ G[lev] = networkx.relabel_nodes(
+ G[lev], dict(zip(G[lev].nodes, ij))
+ )
G_tot = networkx.compose(G_tot, G[lev])
# Relabel mesh nodes to start with 0
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index fc78e638..d29f84ec 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -38,7 +38,9 @@ def __init__(self, args):
self.output_std = bool(args.output_std)
if self.output_std:
# Pred. dim. in grid cell
- self.grid_output_dim = 2 * self.config_loader.num_data_vars("state")
+ self.grid_output_dim = 2 * self.config_loader.num_data_vars(
+ "state"
+ )
else:
# Pred. dim. in grid cell
self.grid_output_dim = self.config_loader.num_data_vars("state")
@@ -87,7 +89,9 @@ def __init__(self, args):
self.spatial_loss_maps = []
# Load normalization statistics
- self.normalization_stats = self.config_loader.load_normalization_stats()
+ self.normalization_stats = (
+ self.config_loader.load_normalization_stats()
+ )
if self.normalization_stats is not None:
for (
var_name,
@@ -236,7 +240,11 @@ def training_step(self, batch):
log_dict = {"train_loss": batch_loss}
self.log_dict(
- log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True
+ log_dict,
+ prog_bar=True,
+ on_step=True,
+ on_epoch=True,
+ sync_dist=True,
)
return batch_loss
@@ -362,7 +370,8 @@ def test_step(self, batch, batch_idx):
):
# Need to plot more example predictions
n_additional_examples = min(
- prediction.shape[0], self.n_example_pred - self.plotted_examples
+ prediction.shape[0],
+ self.n_example_pred - self.plotted_examples,
)
self.plot_examples(
@@ -584,10 +593,14 @@ def on_test_epoch_end(self):
)
for loss_map in mean_spatial_loss
]
- pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
+ pdf_loss_maps_dir = os.path.join(
+ wandb.run.dir, "spatial_loss_maps"
+ )
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs):
- fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
+ fig.savefig(
+ os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")
+ )
# save mean spatial loss as .pt file also
torch.save(
mean_spatial_loss.cpu(),
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index fb5df62d..723a3f3c 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -118,8 +118,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
dim=-1,
)
- # Embed all features
- grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h)
+ # Embed all features # (B, num_grid_nodes, d_h)
+ grid_emb = self.grid_embedder(grid_features)
g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h)
m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h)
mesh_emb = self.embedd_mesh_nodes()
@@ -149,9 +149,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
) # (B, num_grid_nodes, d_h)
# Map to output dimension, only for grid
- net_output = self.output_map(
- grid_rep
- ) # (B, num_grid_nodes, d_grid_out)
+ # (B, num_grid_nodes, d_grid_out)
+ net_output = self.output_map(grid_rep)
if self.output_std:
pred_delta_mean, pred_std_raw = net_output.chunk(
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index f767fba0..e4dc74ac 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -32,7 +32,9 @@ def __init__(self, args):
# Define sub-models
# Feature embedders for mesh
- self.mesh_embedder = utils.make_mlp([mesh_dim] + self.mlp_blueprint_end)
+ self.mesh_embedder = utils.make_mlp(
+ [mesh_dim] + self.mlp_blueprint_end
+ )
self.m2m_embedder = utils.make_mlp([m2m_dim] + self.mlp_blueprint_end)
# GNNs
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index 4d7eb94c..335ea8c7 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -101,9 +101,8 @@ def mesh_down_step(
reversed(same_gnns[:-1]),
):
# Extract representations
- send_node_rep = mesh_rep_levels[
- level_l + 1
- ] # (B, N_mesh[l+1], d_h)
+ # (B, N_mesh[l+1], d_h)
+ send_node_rep = mesh_rep_levels[level_l + 1]
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
down_edge_rep = mesh_down_rep[level_l]
same_edge_rep = mesh_same_rep[level_l]
@@ -139,9 +138,8 @@ def mesh_up_step(
zip(up_gnns, same_gnns[1:]), start=1
):
# Extract representations
- send_node_rep = mesh_rep_levels[
- level_l - 1
- ] # (B, N_mesh[l-1], d_h)
+ # (B, N_mesh[l-1], d_h)
+ send_node_rep = mesh_rep_levels[level_l - 1]
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
up_edge_rep = mesh_up_rep[level_l - 1]
same_edge_rep = mesh_same_rep[level_l]
@@ -183,7 +181,11 @@ def hi_processor_step(
self.mesh_up_same_gnns,
):
# Down
- mesh_rep_levels, mesh_same_rep, mesh_down_rep = self.mesh_down_step(
+ (
+ mesh_rep_levels,
+ mesh_same_rep,
+ mesh_down_rep,
+ ) = self.mesh_down_step(
mesh_rep_levels,
mesh_same_rep,
mesh_down_rep,
@@ -200,5 +202,6 @@ def hi_processor_step(
up_same_gnns,
)
- # Note: We return all, even though only down edges really are used later
+ # Note: We return all, even though only down edges really are used
+ # later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index 740824e1..b6f619d1 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -27,7 +27,9 @@ def __init__(self, args):
+ list(self.mesh_down_edge_index)
)
total_edge_index = torch.cat(total_edge_index_list, dim=1)
- self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list]
+ self.edge_split_sections = [
+ ei.shape[1] for ei in total_edge_index_list
+ ]
if args.processor_layers == 0:
self.processor = lambda x, edge_attr: (x, edge_attr)
@@ -86,11 +88,12 @@ def hi_processor_step(
mesh_same_rep = mesh_edge_rep_sections[: self.num_levels]
mesh_up_rep = mesh_edge_rep_sections[
- self.num_levels : self.num_levels + (self.num_levels - 1)
+ self.num_levels : self.num_levels + (self.num_levels - 1) # noqa
]
mesh_down_rep = mesh_edge_rep_sections[
- self.num_levels + (self.num_levels - 1) :
+ self.num_levels + (self.num_levels - 1) : # noqa
] # Last are down edges
- # Note: We return all, even though only down edges really are used later
+ # Note: We return all, even though only down edges really are used
+ # later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 18584d2e..f7ecafb3 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -40,7 +40,9 @@ def load_graph(graph_name, device="cpu"):
graph_dir_path = os.path.join("graphs", graph_name)
def loads_file(fn):
- return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
+ return torch.load(
+ os.path.join(graph_dir_path, fn), map_location=device
+ )
# Load edges (edge_index)
m2m_edge_index = BufferList(
@@ -53,7 +55,8 @@ def loads_file(fn):
hierarchical = n_levels > 1 # Nor just single level mesh graph
# Load static edge features
- m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f)
+ # List of (M_m2m[l], d_edge_f)
+ m2m_features = loads_file("m2m_features.pt")
g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f)
m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 6ce630c7..6762a450 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -39,7 +39,9 @@ def __init__(
self.state = self.config_loader.process_dataset("state", self.split)
assert self.state is not None, "State dataset not found"
- self.forcing = self.config_loader.process_dataset("forcing", self.split)
+ self.forcing = self.config_loader.process_dataset(
+ "forcing", self.split
+ )
self.boundary = self.config_loader.process_dataset(
"boundary", self.split
)
@@ -69,7 +71,10 @@ def __init__(
method="nearest",
)
.pad(
- time=(self.boundary_window // 2, self.boundary_window // 2),
+ time=(
+ self.boundary_window // 2,
+ self.boundary_window // 2,
+ ),
mode="edge",
)
.rolling(time=self.boundary_window, center=True)
@@ -87,7 +92,9 @@ def __getitem__(self, idx):
)
forcing = (
- self.forcing_windowed.isel(time=slice(idx + 2, idx + self.ar_steps))
+ self.forcing_windowed.isel(
+ time=slice(idx + 2, idx + self.ar_steps)
+ )
.stack(variable_window=("variable", "window"))
.values
if self.forcing is not None
diff --git a/pyproject.toml b/pyproject.toml
index 619f444f..192afbc7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,10 +6,10 @@ version = "0.1.0"
packages = ["neural_lam"]
[tool.black]
-line-length = 80
+line-length = 79
[tool.isort]
-default_section = "THIRDPARTY"
+default_section = "THIRDPARTY" #codespell:ignore
profile = "black"
# Headings
import_heading_stdlib = "Standard library"
From 4e457ed83ab283472e94a75ce05a8edf3590782c Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Fri, 31 May 2024 22:23:50 +0200
Subject: [PATCH 056/273] config.py is ready for danra
---
neural_lam/config.py | 207 +++++++++++++++++++++------------
neural_lam/data_config.yaml | 225 +++++++++++-------------------------
2 files changed, 200 insertions(+), 232 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index aa20030c..7df993d0 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -60,52 +60,109 @@ def coords_projection(self):
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)
+ @functools.cached_property
def param_names(self):
"""Return parameter names."""
- surface_names = self.values["state"]["surface"]
- atmosphere_names = [
+ surface_vars_names = self.values["state"]["surface_vars"]
+ atmosphere_vars_names = [
f"{var}_{level}"
- for var in self.values["state"]["atmosphere"]
+ for var in self.values["state"]["atmosphere_vars"]
for level in self.values["state"]["levels"]
]
- return surface_names + atmosphere_names
+ return surface_vars_names + atmosphere_vars_names
+ @functools.cached_property
def param_units(self):
"""Return parameter units."""
- surface_units = self.values["state"]["surface_units"]
- atmosphere_units = [
+ surface_vars_units = self.values["state"]["surface_vars_units"]
+ atmosphere_vars_units = [
unit
- for unit in self.values["state"]["atmosphere_units"]
+ for unit in self.values["state"]["atmosphere_vars_units"]
for _ in self.values["state"]["levels"]
]
- return surface_units + atmosphere_units
+ return surface_vars_units + atmosphere_vars_units
- def num_data_vars(self, key):
- """Return the number of data variables for a given key."""
- surface_vars = len(self.values[key]["surface"])
- atmosphere_vars = len(self.values[key]["atmosphere"])
- levels = len(self.values[key]["levels"])
- return surface_vars + atmosphere_vars * levels
+ @functools.lru_cache()
+ def num_data_vars(self, category):
+ """Return the number of data variables for a given category."""
+ surface_vars = self.values[category].get("surface_vars", [])
+ atmosphere_vars = self.values[category].get("atmosphere_vars", [])
+ levels = self.values[category].get("levels", [])
- def projection(self):
- """Return the projection."""
- proj_config = self.values["projections"]["class"]
- proj_class = getattr(ccrs, proj_config["proj_class"])
- proj_params = proj_config["proj_params"]
- return proj_class(**proj_params)
+ surface_vars_count = (
+ len(surface_vars) if surface_vars is not None else 0
+ )
+ atmosphere_vars_count = (
+ len(atmosphere_vars) if atmosphere_vars is not None else 0
+ )
+ levels_count = len(levels) if levels is not None else 0
- def open_zarr(self, dataset_name):
- """Open a dataset specified by the dataset name."""
- dataset_path = self.zarrs[dataset_name].path
- if dataset_path is None or not os.path.exists(dataset_path):
- print(
- f"Dataset '{dataset_name}' "
- f"not found at path: {dataset_path}"
- )
- return None
- dataset = xr.open_zarr(dataset_path, consolidated=True)
+ return surface_vars_count + atmosphere_vars_count * levels_count
+
+ @functools.lru_cache(maxsize=None)
+ def open_zarr(self, category):
+ """Open a dataset specified by the category."""
+ zarr_config = self.zarrs[category]
+
+ if isinstance(zarr_config, list):
+ try:
+ datasets = []
+ for config in zarr_config:
+ dataset_path = config["path"]
+ dataset = xr.open_zarr(dataset_path, consolidated=True)
+ datasets.append(dataset)
+ return xr.merge(datasets)
+ except Exception:
+ print(f"Invalid zarr configuration for category: {category}")
+ return None
+
+ else:
+ try:
+ dataset_path = zarr_config["path"]
+ return xr.open_zarr(dataset_path, consolidated=True)
+ except Exception:
+ print(f"Invalid zarr configuration for category: {category}")
+ return None
+
+ def stack_grid(self, dataset):
+ """Stack grid dimensions."""
+ dims = dataset.to_array().dims
+
+ if "grid" not in dims and "x" in dims and "y" in dims:
+ dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
+ else:
+ try:
+ dataset = dataset.squeeze().to_array()
+ except ValueError:
+ print("Failed to stack grid dimensions.")
+ return None
+
+ if "time" in dataset.dims:
+ dataset = dataset.transpose("time", "grid", "variable")
+ else:
+ dataset = dataset.transpose("grid", "variable")
return dataset
+ @functools.lru_cache()
+ def get_nwp_xy(self, category):
+ """Get the x and y coordinates for the NWP grid."""
+ dataset = self.open_zarr(category)
+ lon_name = self.zarrs[category].lat_lon_names.lon
+ lat_name = self.zarrs[category].lat_lon_names.lat
+ if lon_name in dataset and lat_name in dataset:
+ lon = dataset[lon_name].values
+ lat = dataset[lat_name].values
+ else:
+ raise ValueError(
+ f"Dataset does not contain " f"{lon_name} or {lat_name}"
+ )
+ if lon.ndim == 1:
+ lon, lat = np.meshgrid(lat, lon)
+ lonlat = np.stack((lon, lat), axis=0)
+
+ return lonlat
+
+ @functools.cached_property
def load_normalization_stats(self):
"""Load normalization statistics from Zarr archive."""
normalization_path = self.normalization.zarr
@@ -120,10 +177,11 @@ def load_normalization_stats(self):
)
return normalization_stats
- def process_dataset(self, dataset_name, split="train", stack=True):
+ @functools.lru_cache(maxsize=None)
+ def process_dataset(self, category, split="train"):
"""Process a single dataset specified by the dataset name."""
- dataset = self.open_zarr(dataset_name)
+ dataset = self.open_zarr(category)
if dataset is None:
return None
@@ -132,48 +190,64 @@ def process_dataset(self, dataset_name, split="train", stack=True):
self.splits[split].end,
)
dataset = dataset.sel(time=slice(start, end))
+
+ dims_mapping = {}
+ zarr_configs = self.zarrs[category]
+ if isinstance(zarr_configs, list):
+ for zarr_config in zarr_configs:
+ dims_mapping.update(zarr_config["dims"])
+ else:
+ dims_mapping.update(zarr_configs["dims"].values)
+
dataset = dataset.rename_dims(
{
v: k
- for k, v in self.zarrs[dataset_name].dims.values.items()
- if k not in dataset.dims
+ for k, v in dims_mapping.items()
+ if k not in dataset.dims and v in dataset.dims
}
)
+ dataset = dataset.rename_vars(
+ {v: k for k, v in dims_mapping.items() if v in dataset.coords}
+ )
- vars_surface = []
- if self[dataset_name].surface:
- vars_surface = dataset[self[dataset_name].surface]
+ surface_vars = []
+ if self[category].surface_vars:
+ surface_vars = dataset[self[category].surface_vars]
- vars_atmosphere = []
- if self[dataset_name].atmosphere:
- vars_atmosphere = xr.merge(
+ atmosphere_vars = []
+ if self[category].atmosphere_vars:
+ atmosphere_vars = xr.merge(
[
dataset[var]
.sel(level=level, drop=True)
.rename(f"{var}_{level}")
- for var in self[dataset_name].atmosphere
- for level in self[dataset_name].levels
+ for var in self[category].atmosphere_vars
+ for level in self[category].levels
]
)
- if vars_surface and vars_atmosphere:
- dataset = xr.merge([vars_surface, vars_atmosphere])
- elif vars_surface:
- dataset = vars_surface
- elif vars_atmosphere:
- dataset = vars_atmosphere
+ if surface_vars and atmosphere_vars:
+ dataset = xr.merge([surface_vars, atmosphere_vars])
+ elif surface_vars:
+ dataset = surface_vars
+ elif atmosphere_vars:
+ dataset = atmosphere_vars
else:
- print(f"No variables found in dataset {dataset_name}")
+ print(f"No variables found in dataset {category}")
return None
+ zarr_configs = self.zarrs[category]
+ lat_lon_names = {}
+ if isinstance(self.zarrs[category], list):
+ for zarr_configs in self.zarrs[category]:
+ lat_lon_names.update(zarr_configs["lat_lon_names"])
+ else:
+ lat_lon_names.update(self.zarrs[category]["lat_lon_names"].values)
+
if not all(
- lat_lon in self.zarrs[dataset_name].dims.values.values()
- for lat_lon in self.zarrs[
- dataset_name
- ].lat_lon_names.values.values()
+ lat_lon in lat_lon_names.values() for lat_lon in lat_lon_names
):
- lat_name = self.zarrs[dataset_name].lat_lon_names.lat
- lon_name = self.zarrs[dataset_name].lat_lon_names.lon
+ lat_name, lon_name = lat_lon_names[:2]
if dataset[lat_name].ndim == 2:
dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
if dataset[lon_name].ndim == 2:
@@ -182,26 +256,15 @@ def process_dataset(self, dataset_name, split="train", stack=True):
x=dataset[lon_name], y=dataset[lat_name]
)
- if stack:
- dataset = self.stack_grid(dataset)
-
+ dataset = dataset.rename(
+ {v: k for k, v in dims_mapping.items() if v in dataset.coords}
+ )
+ dataset = self.stack_grid(dataset)
return dataset
- def stack_grid(self, dataset):
- """Stack grid dimensions."""
- dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
+ dataset = self.stack_grid(dataset)
- if "time" in dataset.dims:
- dataset = dataset.transpose("time", "grid", "variable")
- else:
- dataset = dataset.transpose("grid", "variable")
return dataset
- def get_nwp_xy(self):
- """Get the x and y coordinates for the NWP grid."""
- x = self.process_dataset("static", stack=False).x.values
- y = self.process_dataset("static", stack=False).y.values
- xx, yy = np.meshgrid(y, x)
- xy = np.stack((xx, yy), axis=0)
- return xy
+config = Config.from_file("neural_lam/data_config.yaml")
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index a4417a65..ff14a231 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -1,191 +1,96 @@
-zarrs: # List of zarrs containing fields related to state
+name: danra
+zarrs:
state:
- path: /scratch/sadamov/template.zarr # Path to zarr
- dims: # Name of dimensions in zarr, to be used for indexing
- time: time
- level: z
- x: x # Either give "grid" (flattened) dimension or "x" and "y"
- y: y
- lat_lon_names:
- lon: lon
- lat: lat
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ dims:
+ time: time
+ level: null
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr"
+ dims:
+ time: time
+ level: altitude
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
static:
- path: /scratch/sadamov/template.zarr
+ path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
- level: z
+ level: null
x: x
y: y
+ grid: null
lat_lon_names:
lon: lon
lat: lat
forcing:
- path: /scratch/sadamov/template.zarr
+ path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
time: time
- level: z
+ level: null
x: x
y: y
+ grid: null
lat_lon_names:
lon: lon
lat: lat
- boundary:
- path: /scratch/sadamov/era5_template.zarr
- dims:
- time: time
- level: level
- x: longitude
- y: latitude
- lat_lon_names:
- lon: longitude
- lat: latitude
- mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary.
-state: # Variables forecasted by the model
- surface: # Single-field variables
- - CLCT
- - PMSL
- - PS
- - T_2M
- - TOT_PREC
- - U_10M
- - V_10M
+state:
+ surface_vars:
+ - u10m
+ - v10m
+ - t2m
surface_units:
- - "%"
- - $\mathrm{Pa}$
- - $\mathrm{Pa}$
- - $\mathrm{K}$
- - $\mathrm{kg}/\mathrm{m}^2$
- - $\mathrm{m}/\mathrm{s}$
- - $\mathrm{m}/\mathrm{s}$
- atmosphere: # Variables with vertical levels
- - PP
- - QV
- - RELHUM
- - T
- - U
- - V
- - W
+ - m/s
+ - m/s
+ - K
+ atmosphere_vars:
+ - u
+ - v
+ - t
atmosphere_units:
- - $\mathrm{Pa}$
- - $\mathrm{kg}/\mathrm{kg}$
- - "%"
- - $\mathrm{K}$
- - $\mathrm{m}/\mathrm{s}$
- - $\mathrm{m}/\mathrm{s}$
- - $\mathrm{Pa}/\mathrm{s}$
- levels: # Levels to use for atmosphere variables
- - 0
- - 5
- - 8
- - 11
- - 13
- - 15
- - 19
- - 22
- - 26
- - 30
- - 38
- - 44
- - 59
-static: # Static inputs
- surface:
- - HSURF
- atmosphere:
- - FI
- levels:
- - 0
- - 5
- - 8
- - 11
- - 13
- - 15
- - 19
- - 22
- - 26
- - 30
- - 38
- - 44
- - 59
-forcing: # Forcing variables, dynamic inputs to the model
- surface:
- - ASOB_S
- atmosphere:
- - T
+ - m/s
+ - m/s
+ - K
levels:
- - 0
- - 5
- - 8
- - 11
- - 13
- - 15
- - 19
- - 22
- - 26
- - 30
- - 38
- - 44
- - 59
- window: 3 # Number of time steps to use for forcing (odd)
-boundary: # Boundary conditions
- surface:
- - 10m_u_component_of_wind
- # - 10m_v_component_of_wind
- # - 2m_dewpoint_temperature
- # - 2m_temperature
- # - mean_sea_level_pressure
- # - mean_surface_latent_heat_flux
- # - mean_surface_net_long_wave_radiation_flux
- # - mean_surface_net_short_wave_radiation_flux
- # - mean_surface_sensible_heat_flux
- # - surface_pressure
- # - total_cloud_cover
- # - total_column_water_vapour
- # - total_precipitation_12hr
- # - total_precipitation_24hr
- # - total_precipitation_6hr
- # - geopotential_at_surface
- atmosphere:
- - divergence
- # - geopotential
- # - relative_humidity
- # - specific_humidity
- # - temperature
- # - u_component_of_wind
- # - v_component_of_wind
- # - vertical_velocity
- # - vorticity
- levels:
- - 50
- 100
- - 150
- - 200
- - 250
- - 300
- - 400
- - 500
- - 600
- - 700
- - 850
- - 925
- - 1000
- window: 3 # Number of time steps to use for boundary (odd)
+static:
+ surface_vars:
+ - pres0m # just as a technical test
+ atmosphere_vars: null
+ levels: null
+forcing:
+ surface_vars:
+ - cape_column # just as a technical test
+ atmosphere_vars: null
+ levels: null
+ window: 3 # Number of time steps to use for forcing (odd)
grid_shape_state:
x: 582
y: 390
splits:
train:
- start: 2015-01-01T00
- end: 2024-12-31T23
+ start: 1990-09-01T00
+ end: 1990-09-11T00
val:
- start: 2015-01-01T00
- end: 2024-12-31T23
+ start: 1990-09-11T03
+ end: 1990-09-13T09
test:
- start: 2015-01-01T00
- end: 2024-12-31T23
+ start: 1990-09-11T03
+ end: 1990-09-13T09
projection:
- class: RotatedPole # Name of class in cartopy.crs
- kwargs: # Parsed and used directly as kwargs to projection-class above
- pole_longitude: 10.0
- pole_latitude: -43.0
+ class: LambertConformal # Name of class in cartopy.crs
+ kwargs:
+ central_longitude: 6.22
+ central_latitude: 56.0
+ standard_parallels: [47.6, 64.4]
normalization:
zarr: normalization.zarr
vars:
From adc592f3edb9e38105eba4d15b71dd847c6fd07d Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 1 Jun 2024 10:53:04 +0200
Subject: [PATCH 057/273] streamlined multi-zarr workflow
---
neural_lam/config.py | 110 ++++++++++++------------------------
neural_lam/data_config.yaml | 71 +++++++++++------------
2 files changed, 73 insertions(+), 108 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 7df993d0..32dd884f 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -11,19 +11,11 @@
class Config:
- """
- Class for loading configuration files.
-
- This class loads a configuration file and provides a way to access its
- values as attributes.
- """
-
def __init__(self, values):
self.values = values
@classmethod
def from_file(cls, filepath):
- """Load a configuration file."""
if filepath.endswith(".yaml"):
with open(filepath, encoding="utf-8", mode="r") as file:
return cls(values=yaml.safe_load(file))
@@ -53,7 +45,6 @@ def __contains__(self, key):
@functools.cached_property
def coords_projection(self):
- """Return the projection."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
@@ -62,7 +53,6 @@ def coords_projection(self):
@functools.cached_property
def param_names(self):
- """Return parameter names."""
surface_vars_names = self.values["state"]["surface_vars"]
atmosphere_vars_names = [
f"{var}_{level}"
@@ -73,18 +63,16 @@ def param_names(self):
@functools.cached_property
def param_units(self):
- """Return parameter units."""
- surface_vars_units = self.values["state"]["surface_vars_units"]
+ surface_vars_units = self.values["state"]["surface_units"]
atmosphere_vars_units = [
unit
- for unit in self.values["state"]["atmosphere_vars_units"]
+ for unit in self.values["state"]["atmosphere_units"]
for _ in self.values["state"]["levels"]
]
return surface_vars_units + atmosphere_vars_units
@functools.lru_cache()
def num_data_vars(self, category):
- """Return the number of data variables for a given category."""
surface_vars = self.values[category].get("surface_vars", [])
atmosphere_vars = self.values[category].get("atmosphere_vars", [])
levels = self.values[category].get("levels", [])
@@ -101,31 +89,20 @@ def num_data_vars(self, category):
@functools.lru_cache(maxsize=None)
def open_zarr(self, category):
- """Open a dataset specified by the category."""
- zarr_config = self.zarrs[category]
-
- if isinstance(zarr_config, list):
- try:
- datasets = []
- for config in zarr_config:
- dataset_path = config["path"]
- dataset = xr.open_zarr(dataset_path, consolidated=True)
- datasets.append(dataset)
- return xr.merge(datasets)
- except Exception:
- print(f"Invalid zarr configuration for category: {category}")
- return None
-
- else:
- try:
- dataset_path = zarr_config["path"]
- return xr.open_zarr(dataset_path, consolidated=True)
- except Exception:
- print(f"Invalid zarr configuration for category: {category}")
- return None
+ zarr_configs = self.values[category]["zarrs"]
+
+ try:
+ datasets = []
+ for config in zarr_configs:
+ dataset_path = config["path"]
+ dataset = xr.open_zarr(dataset_path, consolidated=True)
+ datasets.append(dataset)
+ return xr.merge(datasets)
+ except Exception:
+ print(f"Invalid zarr configuration for category: {category}")
+ return None
def stack_grid(self, dataset):
- """Stack grid dimensions."""
dims = dataset.to_array().dims
if "grid" not in dims and "x" in dims and "y" in dims:
@@ -145,10 +122,9 @@ def stack_grid(self, dataset):
@functools.lru_cache()
def get_nwp_xy(self, category):
- """Get the x and y coordinates for the NWP grid."""
dataset = self.open_zarr(category)
- lon_name = self.zarrs[category].lat_lon_names.lon
- lat_name = self.zarrs[category].lat_lon_names.lat
+ lon_name = self.values[category]["zarrs"][0]["lat_lon_names"]["lon"]
+ lat_name = self.values[category]["zarrs"][0]["lat_lon_names"]["lat"]
if lon_name in dataset and lat_name in dataset:
lon = dataset[lon_name].values
lat = dataset[lat_name].values
@@ -158,46 +134,42 @@ def get_nwp_xy(self, category):
)
if lon.ndim == 1:
lon, lat = np.meshgrid(lat, lon)
- lonlat = np.stack((lon, lat), axis=0)
+ lonlat = np.stack((lon.T, lat.T), axis=0)
return lonlat
@functools.cached_property
def load_normalization_stats(self):
- """Load normalization statistics from Zarr archive."""
- normalization_path = self.normalization.zarr
- if not os.path.exists(normalization_path):
- print(
- f"Normalization statistics not found at "
- f"path: {normalization_path}"
- )
- return None
- normalization_stats = xr.open_zarr(
- normalization_path, consolidated=True
- )
+ normalization_stats = {}
+ for zarr_config in self.values["normalization"]["zarrs"]:
+ normalization_path = zarr_config["path"]
+ if not os.path.exists(normalization_path):
+ print(
+ f"Normalization statistics not found at path: "
+ f"{normalization_path}"
+ )
+ return None
+ stats = xr.open_zarr(normalization_path, consolidated=True)
+ for var_name, var_path in zarr_config["stats_vars"].items():
+ normalization_stats[var_name] = stats[var_path]
return normalization_stats
@functools.lru_cache(maxsize=None)
def process_dataset(self, category, split="train"):
- """Process a single dataset specified by the dataset name."""
-
dataset = self.open_zarr(category)
if dataset is None:
return None
start, end = (
- self.splits[split].start,
- self.splits[split].end,
+ self.values["splits"][split]["start"],
+ self.values["splits"][split]["end"],
)
dataset = dataset.sel(time=slice(start, end))
dims_mapping = {}
- zarr_configs = self.zarrs[category]
- if isinstance(zarr_configs, list):
- for zarr_config in zarr_configs:
- dims_mapping.update(zarr_config["dims"])
- else:
- dims_mapping.update(zarr_configs["dims"].values)
+ zarr_configs = self.values[category]["zarrs"]
+ for zarr_config in zarr_configs:
+ dims_mapping.update(zarr_config["dims"])
dataset = dataset.rename_dims(
{
@@ -236,18 +208,14 @@ def process_dataset(self, category, split="train"):
print(f"No variables found in dataset {category}")
return None
- zarr_configs = self.zarrs[category]
lat_lon_names = {}
- if isinstance(self.zarrs[category], list):
- for zarr_configs in self.zarrs[category]:
- lat_lon_names.update(zarr_configs["lat_lon_names"])
- else:
- lat_lon_names.update(self.zarrs[category]["lat_lon_names"].values)
+ for zarr_config in self.values[category]["zarrs"]:
+ lat_lon_names.update(zarr_config["lat_lon_names"])
if not all(
lat_lon in lat_lon_names.values() for lat_lon in lat_lon_names
):
- lat_name, lon_name = lat_lon_names[:2]
+ lat_name, lon_name = list(lat_lon_names.values())[:2]
if dataset[lat_name].ndim == 2:
dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
if dataset[lon_name].ndim == 2:
@@ -262,9 +230,5 @@ def process_dataset(self, category, split="train"):
dataset = self.stack_grid(dataset)
return dataset
- dataset = self.stack_grid(dataset)
-
- return dataset
-
config = Config.from_file("neural_lam/data_config.yaml")
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index ff14a231..41ddbcb8 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -1,6 +1,6 @@
name: danra
-zarrs:
- state:
+state:
+ zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
time: time
@@ -21,28 +21,6 @@ zarrs:
lat_lon_names:
lon: lon
lat: lat
- static:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
- dims:
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- forcing:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
- dims:
- time: time
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
-state:
surface_vars:
- u10m
- v10m
@@ -62,16 +40,50 @@ state:
levels:
- 100
static:
+ zarrs:
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ dims:
+ level: null
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
surface_vars:
- pres0m # just as a technical test
atmosphere_vars: null
levels: null
forcing:
+ zarrs:
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ dims:
+ time: time
+ level: null
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
surface_vars:
- cape_column # just as a technical test
atmosphere_vars: null
levels: null
window: 3 # Number of time steps to use for forcing (odd)
+normalization:
+ zarrs:
+ - path: "normalization.zarr"
+ stats_vars:
+ data_mean: data_mean
+ data_std: data_std
+ forcing_mean: forcing_mean
+ forcing_std: forcing_std
+ boundary_mean: boundary_mean
+ boundary_std: boundary_std
+ diff_mean: diff_mean
+ diff_std: diff_std
+
grid_shape_state:
x: 582
y: 390
@@ -91,14 +103,3 @@ projection:
central_longitude: 6.22
central_latitude: 56.0
standard_parallels: [47.6, 64.4]
-normalization:
- zarr: normalization.zarr
- vars:
- data_mean: data_mean
- data_std: data_std
- forcing_mean: forcing_mean
- forcing_std: forcing_std
- boundary_mean: boundary_mean
- boundary_std: boundary_std
- diff_mean: diff_mean
- diff_std: diff_std
From a7bea6bbbf605316d4baf5078e7bed3a9fff6eef Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 1 Jun 2024 17:55:37 +0200
Subject: [PATCH 058/273] xarray zarr based data normalization
---
create_parameter_weights.py | 163 +++++++++---------------------------
neural_lam/config.py | 13 ++-
neural_lam/data_config.yaml | 7 +-
3 files changed, 47 insertions(+), 136 deletions(-)
diff --git a/create_parameter_weights.py b/create_parameter_weights.py
index 9c004535..94267b31 100644
--- a/create_parameter_weights.py
+++ b/create_parameter_weights.py
@@ -1,16 +1,17 @@
# Standard library
-import os
from argparse import ArgumentParser
# Third-party
-import numpy as np
-import torch
import xarray as xr
-from tqdm import tqdm
# First-party
from neural_lam import config
-from neural_lam.weather_dataset import WeatherDataModule, WeatherDataset
+
+
+def compute_stats(data_array):
+ mean = data_array.mean(dim=("time", "grid"))
+ std = data_array.std(dim=("time", "grid"))
+ return mean, std
def main():
@@ -24,144 +25,58 @@ def main():
default="neural_lam/data_config.yaml",
help="Path to data config file (default: neural_lam/data_config.yaml)",
)
- parser.add_argument(
- "--batch_size",
- type=int,
- default=32,
- help="Batch size when iterating over the dataset",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=4,
- help="Number of workers in data loader (default: 4)",
- )
parser.add_argument(
"--zarr_path",
type=str,
default="normalization.zarr",
help="Directory where data is stored",
)
+ parser.add_argument(
+ "--combined_forcings",
+ action="store_true",
+ help="Whether to compute combined stats forcing variables",
+ )
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
- static_dir_path = os.path.join(
- "data", config_loader.dataset.name, "static"
- )
- # Create parameter weights based on height
- # based on fig A.1 in graph cast paper
- w_dict = {
- "2": 1.0,
- "0": 0.1,
- "65": 0.065,
- "1000": 0.1,
- "850": 0.05,
- "500": 0.03,
- }
- w_list = np.array(
- [
- w_dict[par.split("_")[-2]]
- for par in config_loader.dataset.var_longnames
- ]
- )
- print("Saving parameter weights...")
- np.save(
- os.path.join(static_dir_path, "parameter_weights.npy"),
- w_list.astype("float32"),
- )
- data_module = WeatherDataModule(
- batch_size=args.batch_size, num_workers=args.num_workers
- )
- data_module.setup()
- loader = data_module.train_dataloader()
-
- # Load dataset without any subsampling
- ds = WeatherDataset(
- config_loader.dataset.name,
- split="train",
- subsample_step=1,
- pred_length=63,
- standardize=False,
- ) # Without standardization
- loader = torch.utils.data.DataLoader(
- ds, args.batch_size, shuffle=False, num_workers=args.n_workers
- )
- # Compute mean and std.-dev. of each parameter (+ flux forcing)
- # Compute mean and std.-dev. of each parameter (+ forcing forcing)
- # across full dataset
- print("Computing mean and std.-dev. for parameters...")
- means = []
- squares = []
- fb_means = {"forcing": [], "boundary": []}
- fb_squares = {"forcing": [], "boundary": []}
-
- for init_batch, target_batch, forcing_batch, boundary_batch, _ in tqdm(
- loader
- ):
- batch = torch.cat(
- (init_batch, target_batch), dim=1
- ) # (N_batch, N_t, N_grid, d_features)
- means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,)
- squares.append(torch.mean(batch**2, dim=(1, 2)))
-
- for fb_type, fb_batch in zip(
- ["forcing", "boundary"], [forcing_batch, boundary_batch]
- ):
- fb_batch = fb_batch[:, :, :, 1]
- fb_means[fb_type].append(torch.mean(fb_batch)) # (,)
- fb_squares[fb_type].append(torch.mean(fb_batch**2)) # (,)
+ state_data = config_loader.process_dataset("state", split="train")
+ forcing_data = config_loader.process_dataset("forcing", split="train")
- mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features)
- second_moment = torch.mean(torch.cat(squares, dim=0), dim=0)
- std = torch.sqrt(second_moment - mean**2) # (d_features)
+ print("Computing mean and std.-dev. for parameters...", flush=True)
+ state_mean, state_std = compute_stats(state_data)
- fb_stats = {}
- for fb_type in ["forcing", "boundary"]:
- fb_stats[f"{fb_type}_mean"] = torch.mean(
- torch.stack(fb_means[fb_type])
- ) # (,)
- fb_second_moment = torch.mean(torch.stack(fb_squares[fb_type])) # (,)
- fb_stats[f"{fb_type}_std"] = torch.sqrt(
- fb_second_moment - fb_stats[f"{fb_type}_mean"] ** 2
- ) # (,)
+ if forcing_data is not None:
+ forcing_mean, forcing_std = compute_stats(forcing_data)
+ if args.combined_forcings:
+ forcing_mean = forcing_mean.mean(dim="variable")
+ forcing_std = forcing_std.mean(dim="variable")
- # Compute mean and std.-dev. of one-step differences across the dataset
- print("Computing mean and std.-dev. for one-step differences...")
- diff_means = []
- diff_squares = []
- for init_batch, target_batch, _, _, _ in tqdm(loader):
- # normalize the batch
- init_batch = (init_batch - mean) / std
- target_batch = (target_batch - mean) / std
-
- batch = torch.cat((init_batch, target_batch), dim=1)
- batch_diffs = batch[:, 1:] - batch[:, :-1]
- # (N_batch, N_t-1, N_grid, d_features)
-
- diff_means.append(
- torch.mean(batch_diffs, dim=(1, 2))
- ) # (N_batch', d_features,)
- diff_squares.append(
- torch.mean(batch_diffs**2, dim=(1, 2))
- ) # (N_batch', d_features,)
-
- diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features)
- diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0)
- diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features)
+ print(
+ "Computing mean and std.-dev. for one-step differences...", flush=True
+ )
+ state_data_diff = state_data.diff(dim="time")
+ diff_mean, diff_std = compute_stats(state_data_diff)
- # Create xarray dataset
ds = xr.Dataset(
{
- "mean": (["d_features"], mean),
- "std": (["d_features"], std),
- "diff_mean": (["d_features"], diff_mean),
- "diff_std": (["d_features"], diff_std),
- **fb_stats,
+ "state_mean": (["d_features"], state_mean.data),
+ "state_std": (["d_features"], state_std.data),
+ "diff_mean": (["d_features"], diff_mean.data),
+ "diff_std": (["d_features"], diff_std.data),
}
)
-
+ if forcing_data is not None:
+ dsf = xr.Dataset(
+ {
+ "forcing_mean": (["d_forcings"], forcing_mean.data),
+ "forcing_std": (["d_forcings"], forcing_std.data),
+ }
+ )
+ ds = xr.merge(
+ [ds, dsf],
+ )
# Save dataset as Zarr
print("Saving dataset as Zarr...")
ds.to_zarr(args.zarr_path, mode="w")
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 32dd884f..8dd76c2d 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -140,8 +140,7 @@ def get_nwp_xy(self, category):
@functools.cached_property
def load_normalization_stats(self):
- normalization_stats = {}
- for zarr_config in self.values["normalization"]["zarrs"]:
+ for i, zarr_config in enumerate(self.values["normalization"]["zarrs"]):
normalization_path = zarr_config["path"]
if not os.path.exists(normalization_path):
print(
@@ -150,8 +149,11 @@ def load_normalization_stats(self):
)
return None
stats = xr.open_zarr(normalization_path, consolidated=True)
- for var_name, var_path in zarr_config["stats_vars"].items():
- normalization_stats[var_name] = stats[var_path]
+ if i == 0:
+ normalization_stats = stats
+ else:
+ stats = xr.merge([stats, normalization_stats])
+ normalization_stats = stats
return normalization_stats
@functools.lru_cache(maxsize=None)
@@ -229,6 +231,3 @@ def process_dataset(self, category, split="train"):
)
dataset = self.stack_grid(dataset)
return dataset
-
-
-config = Config.from_file("neural_lam/data_config.yaml")
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 41ddbcb8..6c5bdbdd 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -75,15 +75,12 @@ normalization:
zarrs:
- path: "normalization.zarr"
stats_vars:
- data_mean: data_mean
- data_std: data_std
+ state_mean: state_mean
+ state_std: state_std
forcing_mean: forcing_mean
forcing_std: forcing_std
- boundary_mean: boundary_mean
- boundary_std: boundary_std
diff_mean: diff_mean
diff_std: diff_std
-
grid_shape_state:
x: 582
y: 390
From 1f7cbe84da2925ba1986404943abf7f8b0b401b1 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 2 Jun 2024 14:35:46 +0200
Subject: [PATCH 059/273] adjusted pre-processing scripts to new data config
workflow
---
...eter_weights.py => calculate_statistics.py | 0
create_boundary_mask.py | 50 ++++++++++++++
create_grid_features.py | 67 -------------------
3 files changed, 50 insertions(+), 67 deletions(-)
rename create_parameter_weights.py => calculate_statistics.py (100%)
create mode 100644 create_boundary_mask.py
delete mode 100644 create_grid_features.py
diff --git a/create_parameter_weights.py b/calculate_statistics.py
similarity index 100%
rename from create_parameter_weights.py
rename to calculate_statistics.py
diff --git a/create_boundary_mask.py b/create_boundary_mask.py
new file mode 100644
index 00000000..78443df0
--- /dev/null
+++ b/create_boundary_mask.py
@@ -0,0 +1,50 @@
+# Standard library
+from argparse import ArgumentParser
+
+# Third-party
+import numpy as np
+import xarray as xr
+
+# First-party
+from neural_lam import config
+
+
+def main():
+ parser = ArgumentParser(description="Training arguments")
+ parser.add_argument(
+ "--data_config",
+ type=str,
+ default="neural_lam/data_config.yaml",
+ help="Path to data config file (default: neural_lam/data_config.yaml)",
+ )
+ parser.add_argument(
+ "--zarr_path",
+ type=str,
+ default="boundary_mask.zarr",
+ help="Path to save the Zarr archive "
+ "(default: same directory as border_mask.npy)",
+ )
+ parser.add_argument(
+ "--boundaries",
+ type=int,
+ default=30,
+ help="Number of grid-cells to set to True along each boundary",
+ )
+ args = parser.parse_args()
+ config_loader = config.Config.from_file(args.data_config)
+ mask = np.zeros(list(config_loader.grid_shape_state.values.values()))
+
+ # Set the args.boundaries grid-cells closest to each boundary to True
+ mask[: args.boundaries, :] = True # top boundary
+ mask[-args.boundaries :, :] = True # noqa bottom boundary
+ mask[:, : args.boundaries] = True # left boundary
+ mask[:, -args.boundaries :] = True # noqa right boundary
+
+ mask = xr.Dataset({"mask": (["x", "y"], mask)})
+
+ print(f"Saving mask to {args.zarr_path}...")
+ mask.to_zarr(args.zarr_path, mode="w")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/create_grid_features.py b/create_grid_features.py
deleted file mode 100644
index eabf9de8..00000000
--- a/create_grid_features.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Standard library
-import os
-from argparse import ArgumentParser
-
-# Third-party
-import numpy as np
-import torch
-
-# First-party
-from neural_lam import config
-
-
-def main():
- """
- Pre-compute all static features related to the grid nodes
- """
- parser = ArgumentParser(description="Training arguments")
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
- args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
-
- static_dir_path = os.path.join(
- "data", config_loader.dataset.name, "static"
- )
-
- # -- Static grid node features --
- grid_xy = torch.tensor(
- np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
- ) # (2, N_x, N_y)
- grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
- pos_max = torch.max(torch.abs(grid_xy))
- grid_xy = grid_xy / pos_max # Divide by maximum coordinate
-
- geopotential = torch.tensor(
- np.load(os.path.join(static_dir_path, "surface_geopotential.npy"))
- ) # (N_x, N_y)
- geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1)
- gp_min = torch.min(geopotential)
- gp_max = torch.max(geopotential)
- # Rescale geopotential to [0,1]
- geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1)
-
- grid_border_mask = torch.tensor(
- np.load(os.path.join(static_dir_path, "border_mask.npy")),
- dtype=torch.int64,
- ) # (N_x, N_y)
- grid_border_mask = (
- grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
- ) # (N_grid, 1)
-
- # Concatenate grid features
- grid_features = torch.cat(
- (grid_xy, geopotential, grid_border_mask), dim=1
- ) # (N_grid, 4)
-
- torch.save(
- grid_features, os.path.join(static_dir_path, "grid_features.pt")
- )
-
-
-if __name__ == "__main__":
- main()
From e3281522f3eee073cc05e15e4260f47593098eee Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 2 Jun 2024 14:36:36 +0200
Subject: [PATCH 060/273] plotting update with latest get_xy() function
---
create_mesh.py | 4 +-
plot_graph.py | 111 +++++++++++++------------------------------------
2 files changed, 31 insertions(+), 84 deletions(-)
diff --git a/create_mesh.py b/create_mesh.py
index 21561644..8b547166 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -160,7 +160,6 @@ def prepend_node_index(graph, new_index):
def main():
parser = ArgumentParser(description="Graph generation arguments")
parser.add_argument(
- "--data_config",
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
@@ -198,8 +197,7 @@ def main():
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)
- config_loader = config.Config(args.data_config)
- xy = config_loader.get_nwp_xy()
+ xy = config_loader.get_xy("static")
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
diff --git a/plot_graph.py b/plot_graph.py
index 476b5a68..2c3f6238 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -5,83 +5,21 @@
import numpy as np
import plotly.graph_objects as go
import torch_geometric as pyg
-import trimesh
-from tqdm import tqdm
-from trimesh.primitives import Box
# First-party
from neural_lam import config, utils
MESH_HEIGHT = 0.1
-MESH_LEVEL_DIST = 0.05
+MESH_LEVEL_DIST = 0.2
GRID_HEIGHT = 0
-def create_cubes_for_nodes(nodes, size=0.002):
- """Create cubes for each node in the graph."""
- cube_meshes = []
- for node in tqdm(nodes, desc="Creating cubes"):
- cube = Box(extents=[size, size, size])
- cube.apply_translation(node)
- cube_meshes.append(cube)
- return cube_meshes
-
-
-def export_to_3d_model(node_pos, edge_plot_list, filename):
- """Export the graph to a 3D model."""
- paths = []
- filtered_edge_plot_list = [
- item for item in edge_plot_list if item[3] not in ["M2G", "G2M"]
- ]
-
- unique_node_indices = set()
- for ei, _, _, _ in filtered_edge_plot_list:
- unique_node_indices.update(ei.flatten())
-
- filtered_node_positions = node_pos[np.array(list(unique_node_indices))]
-
- for ei, _, _, _ in filtered_edge_plot_list:
- edge_start = filtered_node_positions[ei[0]]
- edge_end = filtered_node_positions[ei[1]]
- for start, end in zip(edge_start, edge_end):
- if not (np.isnan(start).any() or np.isnan(end).any()):
- paths.append((start, end))
-
- meshes = []
- for start, end in tqdm(paths, desc="Creating meshes"):
- offset_xyz = np.array([2e-4, 2e-4, 2e-4])
- dummy_vertex = end + offset_xyz
- vertices = [start, end, dummy_vertex]
- faces = [[0, 1, 2]]
- color_vertices = [[255, 179, 71], [255, 179, 71], [255, 179, 71]]
- # color_faces = [[0, 0, 0]]
-
- mesh = trimesh.Trimesh(
- vertices=vertices,
- faces=faces,
- # face_colors=color_faces,
- vertex_colors=color_vertices,
- )
- meshes.append(mesh)
-
- node_spheres = create_cubes_for_nodes(filtered_node_positions)
-
- scene = trimesh.Scene()
- for mesh in meshes:
- scene.add_geometry(mesh)
- for sphere in node_spheres:
- scene.add_geometry(sphere)
-
- scene.export(filename, file_type="ply")
-
-
def main():
"""
Plot graph structure in 3D using plotly
"""
parser = ArgumentParser(description="Plot graph")
parser.add_argument(
- "--data_config",
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
@@ -104,17 +42,17 @@ def main():
default=0,
help="If the axis should be displayed (default: 0 (No))",
)
- parser.add_argument(
- "--export",
- type=str,
- help="Name of .obj file to export 3D model to (default: None)",
- )
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
+ xy = config_loader.get_xy("state") # (2, N_x, N_y)
+ xy = xy.reshape(2, -1).T # (N_grid, 2)
+ pos_max = np.max(np.abs(xy))
+ grid_pos = xy / pos_max # Divide by maximum coordinate
+ # Load graph data
hierarchical, graph_ldict = utils.load_graph(args.graph)
- g2m_edge_index, m2g_edge_index, m2m_edge_index = (
+ (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
graph_ldict["g2m_edge_index"],
graph_ldict["m2g_edge_index"],
graph_ldict["m2m_edge_index"],
@@ -125,22 +63,19 @@ def main():
)
mesh_static_features = graph_ldict["mesh_static_features"]
- config_loader = config.Config(args.data_config)
- xy = config_loader.get_nwp_xy()
- grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2)
- pos_max = np.max(np.abs(grid_xy))
- grid_pos = grid_xy / pos_max
-
+ # Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
grid_pos = np.concatenate(
(grid_pos, np.expand_dims(z_grid, axis=1)), axis=1
)
+ # List of edges to plot, (edge_index, color, line_width, label)
edge_plot_list = [
(m2g_edge_index.numpy(), "black", 0.4, "M2G"),
(g2m_edge_index.numpy(), "black", 0.4, "G2M"),
]
+ # Mesh positioning and edges to plot differ if we have a hierarchical graph
if hierarchical:
mesh_level_pos = [
np.concatenate(
@@ -159,11 +94,13 @@ def main():
]
mesh_pos = np.concatenate(mesh_level_pos, axis=0)
+ # Add inter-level mesh edges
edge_plot_list += [
(level_ei.numpy(), "blue", 1, f"M2M Level {level}")
for level, level_ei in enumerate(m2m_edge_index)
]
+ # Add intra-level mesh edges
up_edges_ei = np.concatenate(
[level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1
)
@@ -177,20 +114,30 @@ def main():
mesh_node_size = 2.5
else:
mesh_pos = mesh_static_features.numpy()
+
mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy()
z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees
mesh_node_size = mesh_degrees / 2
+
mesh_pos = np.concatenate(
(mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1
)
+
edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M"))
+ # All node positions in one array
node_pos = np.concatenate((mesh_pos, grid_pos), axis=0)
+ # Add edges
data_objs = []
- for ei, col, width, label in edge_plot_list:
- edge_start = node_pos[ei[0]]
- edge_end = node_pos[ei[1]]
+ for (
+ ei,
+ col,
+ width,
+ label,
+ ) in edge_plot_list:
+ edge_start = node_pos[ei[0]] # (M, 2)
+ edge_end = node_pos[ei[1]] # (M, 2)
n_edges = edge_start.shape[0]
x_edges = np.stack(
@@ -213,6 +160,8 @@ def main():
)
data_objs.append(scatter_obj)
+ # Add node objects
+
data_objs.append(
go.Scatter3d(
x=grid_pos[:, 0],
@@ -240,6 +189,7 @@ def main():
fig.update_traces(connectgaps=False)
if not args.show_axis:
+ # Hide axis
fig.update_layout(
scene={
"xaxis": {"visible": False},
@@ -250,9 +200,8 @@ def main():
if args.save:
fig.write_html(args.save, include_plotlyjs="cdn")
-
- if args.export:
- export_to_3d_model(node_pos, edge_plot_list, args.export)
+ else:
+ fig.show()
if __name__ == "__main__":
From cb85cda967c4d108bb18071dd3aec761a02f3d1e Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 2 Jun 2024 14:37:20 +0200
Subject: [PATCH 061/273] making data config more modular
---
neural_lam/config.py | 299 ++++++++++++++++++++++++++----------
neural_lam/data_config.yaml | 40 +++--
2 files changed, 251 insertions(+), 88 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 8dd76c2d..c81cb02b 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -6,11 +6,14 @@
# Third-party
import cartopy.crs as ccrs
import numpy as np
+import pandas as pd
import xarray as xr
import yaml
class Config:
+ DIMS_TO_KEEP = {"time", "grid", "variable"}
+
def __init__(self, values):
self.values = values
@@ -45,6 +48,7 @@ def __contains__(self, key):
@functools.cached_property
def coords_projection(self):
+ """Return the projection object for the coordinates."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
@@ -52,7 +56,8 @@ def coords_projection(self):
return proj_class(**proj_params)
@functools.cached_property
- def param_names(self):
+ def vars_names(self):
+ """Return the names of the variables in the dataset."""
surface_vars_names = self.values["state"]["surface_vars"]
atmosphere_vars_names = [
f"{var}_{level}"
@@ -62,7 +67,8 @@ def param_names(self):
return surface_vars_names + atmosphere_vars_names
@functools.cached_property
- def param_units(self):
+ def vars_units(self):
+ """Return the units of the variables in the dataset."""
surface_vars_units = self.values["state"]["surface_units"]
atmosphere_vars_units = [
unit
@@ -73,6 +79,7 @@ def param_units(self):
@functools.lru_cache()
def num_data_vars(self, category):
+ """Return the number of data variables in the dataset."""
surface_vars = self.values[category].get("surface_vars", [])
atmosphere_vars = self.values[category].get("atmosphere_vars", [])
levels = self.values[category].get("levels", [])
@@ -87,8 +94,8 @@ def num_data_vars(self, category):
return surface_vars_count + atmosphere_vars_count * levels_count
- @functools.lru_cache(maxsize=None)
def open_zarr(self, category):
+ """Open the zarr dataset for the given category."""
zarr_configs = self.values[category]["zarrs"]
try:
@@ -97,50 +104,148 @@ def open_zarr(self, category):
dataset_path = config["path"]
dataset = xr.open_zarr(dataset_path, consolidated=True)
datasets.append(dataset)
- return xr.merge(datasets)
+ merged_dataset = xr.merge(datasets)
+ merged_dataset.attrs["category"] = category
+ return merged_dataset
except Exception:
print(f"Invalid zarr configuration for category: {category}")
return None
def stack_grid(self, dataset):
+ """Stack the grid dimensions of the dataset."""
+ if dataset is None:
+ return None
dims = dataset.to_array().dims
- if "grid" not in dims and "x" in dims and "y" in dims:
- dataset = dataset.squeeze().stack(grid=("x", "y")).to_array()
+ if "grid" in dims:
+ print("\033[94mGrid dimensions already stacked.\033[0m")
+ return dataset.squeeze()
else:
- try:
- dataset = dataset.squeeze().to_array()
- except ValueError:
- print("Failed to stack grid dimensions.")
+ if "x" not in dims or "y" not in dims:
+ self.rename_dataset_dims_and_vars(dataset=dataset)
+ dataset = dataset.squeeze().stack(grid=("x", "y"))
+ return dataset
+
+ def convert_dataset_to_dataarray(self, dataset):
+ """Convert the Dataset to a Dataarray."""
+ if isinstance(dataset, xr.Dataset):
+ dataset = dataset.to_array()
+ print(
+ "\033[92mSuccessfully converted Dataset to Dataarray.\033[0m"
+ )
+ return dataset.to_array()
+
+ def filter_dimensions(self, dataset, transpose_array=False):
+ """Filter the dimensions of the dataset."""
+ dims_to_keep = self.DIMS_TO_KEEP
+ dataset_dims = set(dataset.to_array().dims)
+ min_req_dims = dims_to_keep.copy()
+ min_req_dims.discard("time")
+ if not min_req_dims.issubset(dataset_dims):
+ missing_dims = min_req_dims - dataset_dims
+ print(
+ f"\033[91mMissing required dimensions in dataset: "
+ f"{missing_dims}\033[0m"
+ )
+ print(
+ "\033[91mAttempting to update dims and "
+ "vars based on zarr config...\033[0m"
+ )
+ dataset = self.rename_dataset_dims_and_vars(
+ dataset.attrs["category"], dataset=dataset
+ )
+ dataset = self.stack_grid(dataset)
+ dataset_dims = set(dataset.to_array().dims)
+ if min_req_dims.issubset(dataset_dims):
+ print(
+ "\033[92mSuccessfully updated dims and "
+ "vars based on zarr config.\033[0m"
+ )
+ else:
+ print(
+ "\033[91mFailed to update dims and "
+ "vars based on zarr config.\033[0m"
+ )
return None
- if "time" in dataset.dims:
- dataset = dataset.transpose("time", "grid", "variable")
- else:
- dataset = dataset.transpose("grid", "variable")
+ dataset_dims = set(dataset.to_array().dims)
+ dims_to_drop = dataset_dims - dims_to_keep
+ dataset = dataset.drop_dims(dims_to_drop)
+ if dims_to_drop:
+ print(
+ "\033[91mDropped dimensions: --",
+ dims_to_drop,
+ "-- from dataset.\033[0m",
+ )
+ print(
+ "\033[91mAny data vars still dependent "
+ "on these variables were dropped!\033[0m"
+ )
+
+ if transpose_array:
+ dataset = self.convert_dataset_to_dataarray(dataset)
+
+ if "time" in dataset.dims:
+ dataset = dataset.transpose("time", "grid", "variable")
+ else:
+ dataset = dataset.transpose("grid", "variable")
+ dataset_vars = (
+ dataset["variable"].values.tolist()
+ if transpose_array
+ else list(dataset.data_vars)
+ )
+
+ print(
+ "\033[94mYour Dataarray has the following dimensions: ",
+ dataset.to_array().dims,
+ "\033[0m",
+ )
+ print(
+ "\033[94mYour Dataarray has the following variables: ",
+ dataset_vars,
+ "\033[0m",
+ )
+
return dataset
+ def reshape_grid_to_2d(self, dataset, grid_shape=None):
+ """Reshape the grid to 2D."""
+ if grid_shape is None:
+ grid_shape = dict(self.grid_shape_state.values.items())
+ x_dim, y_dim = (grid_shape["x"], grid_shape["y"])
+
+ x_coords = np.arange(x_dim)
+ y_coords = np.arange(y_dim)
+ multi_index = pd.MultiIndex.from_product(
+ [x_coords, y_coords], names=["x", "y"]
+ )
+
+ mindex_coords = xr.Coordinates.from_pandas_multiindex(
+ multi_index, "grid"
+ )
+ dataset = dataset.drop_vars(["grid", "x", "y"], errors="ignore")
+ dataset = dataset.assign_coords(mindex_coords)
+ reshaped_data = dataset.unstack("grid")
+
+ return reshaped_data
+
@functools.lru_cache()
- def get_nwp_xy(self, category):
+ def get_xy(self, category):
+ """Return the x, y coordinates of the dataset."""
dataset = self.open_zarr(category)
- lon_name = self.values[category]["zarrs"][0]["lat_lon_names"]["lon"]
- lat_name = self.values[category]["zarrs"][0]["lat_lon_names"]["lat"]
- if lon_name in dataset and lat_name in dataset:
- lon = dataset[lon_name].values
- lat = dataset[lat_name].values
- else:
- raise ValueError(
- f"Dataset does not contain " f"{lon_name} or {lat_name}"
- )
- if lon.ndim == 1:
- lon, lat = np.meshgrid(lat, lon)
- lonlat = np.stack((lon.T, lat.T), axis=0)
+ x, y = dataset.x.values, dataset.y.values
+ if x.ndim == 1:
+ x, y = np.meshgrid(y, x)
+ xy = np.stack((x, y), axis=0)
- return lonlat
+ return xy
@functools.cached_property
def load_normalization_stats(self):
- for i, zarr_config in enumerate(self.values["normalization"]["zarrs"]):
+ """Load the normalization statistics for the dataset."""
+ for i, zarr_config in enumerate(
+ self.values["utilities"]["normalization"]["zarrs"]
+ ):
normalization_path = zarr_config["path"]
if not os.path.exists(normalization_path):
print(
@@ -156,18 +261,69 @@ def load_normalization_stats(self):
normalization_stats = stats
return normalization_stats
- @functools.lru_cache(maxsize=None)
- def process_dataset(self, category, split="train"):
- dataset = self.open_zarr(category)
+ # def assign_lat_lon_coords(self, category, dataset=None):
+ # """Process the latitude and longitude names of the dataset."""
+ # if dataset is None:
+ # dataset = self.open_zarr(category)
+ # lat_lon_names = {}
+ # for zarr_config in self.values[category]["zarrs"]:
+ # lat_lon_names.update(zarr_config["lat_lon_names"])
+ # lat_name, lon_name = (lat_lon_names["lat"], lat_lon_names["lon"])
+
+ # if "x" not in dataset.dims or "y" in dataset.dims:
+ # dataset = self.reshape_grid_to_2d(dataset)
+ # if not set(lat_lon_names).issubset(dataset.to_array().dims):
+ # dataset = dataset.assign_coords(
+ # x=dataset[lon_name], y=dataset[lat_name]
+ # )
+ # return dataset
+
+ def extract_vars(self, category, dataset=None):
+ """Extract the variables from the dataset."""
if dataset is None:
- return None
+ dataset = self.open_zarr(category)
+ surface_vars = (
+ dataset[self[category].surface_vars]
+ if self[category].surface_vars
+ else []
+ )
- start, end = (
- self.values["splits"][split]["start"],
- self.values["splits"][split]["end"],
+ if (
+ "level" not in dataset.to_array().dims
+ and self[category].atmosphere_vars
+ ):
+ dataset = self.rename_dataset_dims_and_vars(
+ dataset.attrs["category"], dataset=dataset
+ )
+
+ atmosphere_vars = (
+ xr.merge(
+ [
+ dataset[var]
+ .sel(level=level, drop=True)
+ .rename(f"{var}_{level}")
+ for var in self[category].atmosphere_vars
+ for level in self[category].levels
+ ]
+ )
+ if self[category].atmosphere_vars
+ else []
)
- dataset = dataset.sel(time=slice(start, end))
+ if surface_vars and atmosphere_vars:
+ return xr.merge([surface_vars, atmosphere_vars])
+ elif surface_vars:
+ return surface_vars
+ elif atmosphere_vars:
+ return atmosphere_vars
+ else:
+ print(f"No variables found in dataset {category}")
+ return None
+
+ def rename_dataset_dims_and_vars(self, category, dataset=None):
+ """Rename the dimensions and variables of the dataset."""
+ if dataset is None:
+ dataset = self.open_zarr(category)
dims_mapping = {}
zarr_configs = self.values[category]["zarrs"]
for zarr_config in zarr_configs:
@@ -183,51 +339,40 @@ def process_dataset(self, category, split="train"):
dataset = dataset.rename_vars(
{v: k for k, v in dims_mapping.items() if v in dataset.coords}
)
+ return dataset
- surface_vars = []
- if self[category].surface_vars:
- surface_vars = dataset[self[category].surface_vars]
+ def filter_dataset_by_time(self, dataset, split="train"):
+ """Filter the dataset by the time split."""
+ start, end = (
+ self.values["splits"][split]["start"],
+ self.values["splits"][split]["end"],
+ )
+ return dataset.sel(time=slice(start, end))
- atmosphere_vars = []
- if self[category].atmosphere_vars:
- atmosphere_vars = xr.merge(
- [
- dataset[var]
- .sel(level=level, drop=True)
- .rename(f"{var}_{level}")
- for var in self[category].atmosphere_vars
- for level in self[category].levels
- ]
- )
+ def process_dataset(self, category, split="train"):
+ """Process the dataset for the given category."""
+ print(f"Opening zarr dataset for category: {category}")
+ dataset = self.open_zarr(category)
- if surface_vars and atmosphere_vars:
- dataset = xr.merge([surface_vars, atmosphere_vars])
- elif surface_vars:
- dataset = surface_vars
- elif atmosphere_vars:
- dataset = atmosphere_vars
- else:
- print(f"No variables found in dataset {category}")
- return None
+ print(f"Extracting variables for category: {category}")
+ dataset = self.extract_vars(category, dataset)
- lat_lon_names = {}
- for zarr_config in self.values[category]["zarrs"]:
- lat_lon_names.update(zarr_config["lat_lon_names"])
+ print(f"Filtering dataset by time for split: {split}")
+ dataset = self.filter_dataset_by_time(dataset, split)
- if not all(
- lat_lon in lat_lon_names.values() for lat_lon in lat_lon_names
- ):
- lat_name, lon_name = list(lat_lon_names.values())[:2]
- if dataset[lat_name].ndim == 2:
- dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
- if dataset[lon_name].ndim == 2:
- dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True)
- dataset = dataset.assign_coords(
- x=dataset[lon_name], y=dataset[lat_name]
- )
+ print("Stacking grid dimensions of the dataset")
+ dataset = self.stack_grid(dataset)
- dataset = dataset.rename(
- {v: k for k, v in dims_mapping.items() if v in dataset.coords}
+ print("Filtering dimensions of the dataset")
+ dataset = self.filter_dimensions(dataset)
+
+ print(
+ "Renaming dataset dimensions and "
+ "variables for category: {category}"
)
- dataset = self.stack_grid(dataset)
+ dataset = self.rename_dataset_dims_and_vars(category, dataset)
+
+ print("Converting dataset to data array")
+ dataset = self.convert_dataset_to_dataarray(dataset)
+
return dataset
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 6c5bdbdd..1b7c5ebe 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -71,19 +71,37 @@ forcing:
atmosphere_vars: null
levels: null
window: 3 # Number of time steps to use for forcing (odd)
-normalization:
+boundary:
zarrs:
- - path: "normalization.zarr"
- stats_vars:
- state_mean: state_mean
- state_std: state_std
- forcing_mean: forcing_mean
- forcing_std: forcing_std
- diff_mean: diff_mean
- diff_std: diff_std
+ - path: /scratch/sadamov/era5_template.zarr
+ dims:
+ time: time
+ level: level
+ x: longitude
+ y: latitude
+ lat_lon_names:
+ lon: longitude
+ lat: latitude
+ mask: boundary_mask
+utilities:
+ normalization:
+ zarrs:
+ - path: "normalization.zarr"
+ stats_vars:
+ state_mean: state_mean
+ state_std: state_std
+ forcing_mean: forcing_mean
+ forcing_std: forcing_std
+ diff_mean: diff_mean
+ diff_std: diff_std
+ boundary_mask:
+ zarrs:
+ - path: "boundary.zarr"
+ boundary_vars:
+ boundary_mask: boundary_mask
grid_shape_state:
- x: 582
- y: 390
+ x: 789
+ y: 589
splits:
train:
start: 1990-09-01T00
From eb8c6fbec717ec8840b4aac5a4afed1353f33162 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 2 Jun 2024 14:37:42 +0200
Subject: [PATCH 062/273] removing boundaries for now
---
neural_lam/weather_dataset.py | 33 +--------------------------------
1 file changed, 1 insertion(+), 32 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 6762a450..23036ebd 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -42,9 +42,6 @@ def __init__(
self.forcing = self.config_loader.process_dataset(
"forcing", self.split
)
- self.boundary = self.config_loader.process_dataset(
- "boundary", self.split
- )
self.state_times = self.state.time.values
self.forcing_window = self.config_loader.forcing.window
@@ -64,23 +61,6 @@ def __init__(
.construct("window")
)
- if self.boundary is not None:
- self.boundary_windowed = (
- self.boundary.sel(
- time=self.state.time,
- method="nearest",
- )
- .pad(
- time=(
- self.boundary_window // 2,
- self.boundary_window // 2,
- ),
- mode="edge",
- )
- .rolling(time=self.boundary_window, center=True)
- .construct("window")
- )
-
def __len__(self):
# Skip first and last time step
return len(self.state.time) - self.ar_steps
@@ -101,16 +81,6 @@ def __getitem__(self, idx):
else torch.tensor([])
)
- boundary = (
- self.boundary_windowed.isel(
- time=slice(idx + 2, idx + self.ar_steps)
- )
- .stack(variable_window=("variable", "window"))
- .values
- if self.boundary is not None
- else torch.tensor([])
- )
-
init_states = sample[:2]
target_states = sample[2:]
@@ -123,9 +93,8 @@ def __getitem__(self, idx):
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps-2, N_grid, d_features)
# forcing: (ar_steps-2, N_grid, d_windowed_forcing)
- # boundary: (ar_steps-2, N_grid, d_windowed_boundary)
# batch_times: (ar_steps-2,)
- return init_states, target_states, forcing, boundary, batch_times
+ return init_states, target_states, forcing, batch_times
class WeatherDataModule(pl.LightningDataModule):
From 0cfbb33202ecfc57b98b88f9b2fd09c6dc0b683b Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 2 Jun 2024 14:37:54 +0200
Subject: [PATCH 063/273] small updates
---
.gitignore | 6 ++----
requirements.txt | 1 +
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/.gitignore b/.gitignore
index bbe3571e..98ee425c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,15 +1,13 @@
### Project Specific ###
wandb
-slurm_log*
saved_models
data
graphs
*.sif
sweeps
-test_*.sh
.vscode
-cosmo_hilam.html
-normalization.zarr
+*.html
+*.zarr
### Python ###
# Byte-compiled / optimized / DLL files
diff --git a/requirements.txt b/requirements.txt
index 70b97330..ef86ab96 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,5 +13,6 @@ plotly>=5.15.0
xarray>=0.20.1
zarr>=2.10.0
dask>=2022.0.0
+pandas>=1.4.0
# for dev
pre-commit>=2.15.0
From 59d0c8a01fbac204b19abec39ee78f549e158dcc Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 2 Jun 2024 20:15:04 +0200
Subject: [PATCH 064/273] improved stats and units retrieval
---
calculate_statistics.py | 58 ++++++++++----------
neural_lam/config.py | 106 +++++++++++++++++++-----------------
neural_lam/data_config.yaml | 33 ++++++++---
3 files changed, 110 insertions(+), 87 deletions(-)
diff --git a/calculate_statistics.py b/calculate_statistics.py
index 94267b31..3765df2e 100644
--- a/calculate_statistics.py
+++ b/calculate_statistics.py
@@ -15,9 +15,6 @@ def compute_stats(data_array):
def main():
- """
- Pre-compute parameter weights to be used in loss function
- """
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--data_config",
@@ -31,16 +28,9 @@ def main():
default="normalization.zarr",
help="Directory where data is stored",
)
- parser.add_argument(
- "--combined_forcings",
- action="store_true",
- help="Whether to compute combined stats forcing variables",
- )
-
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
-
state_data = config_loader.process_dataset("state", split="train")
forcing_data = config_loader.process_dataset("forcing", split="train")
@@ -49,38 +39,48 @@ def main():
if forcing_data is not None:
forcing_mean, forcing_std = compute_stats(forcing_data)
- if args.combined_forcings:
- forcing_mean = forcing_mean.mean(dim="variable")
- forcing_std = forcing_std.mean(dim="variable")
+ combined_stats = config_loader["utilities"]["normalization"][
+ "combined_stats"
+ ]
+
+ if combined_stats is not None:
+ for group in combined_stats:
+ vars_to_combine = group["vars"]
+ means = forcing_mean.sel(variable=vars_to_combine)
+ stds = forcing_std.sel(variable=vars_to_combine)
+
+ combined_mean = means.mean(dim="variable")
+ combined_std = (stds**2).mean(dim="variable") ** 0.5
+
+ forcing_mean.loc[
+ dict(variable=vars_to_combine)
+ ] = combined_mean
+ forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
print(
"Computing mean and std.-dev. for one-step differences...", flush=True
)
- state_data_diff = state_data.diff(dim="time")
- diff_mean, diff_std = compute_stats(state_data_diff)
+ state_data_normalized = (state_data - state_mean) / state_std
+ state_data_diff_normalized = state_data_normalized.diff(dim="time")
+ diff_mean, diff_std = compute_stats(state_data_diff_normalized)
ds = xr.Dataset(
{
- "state_mean": (["d_features"], state_mean.data),
- "state_std": (["d_features"], state_std.data),
- "diff_mean": (["d_features"], diff_mean.data),
- "diff_std": (["d_features"], diff_std.data),
+ "state_mean": state_mean,
+ "state_std": state_std,
+ "diff_mean": diff_mean,
+ "diff_std": diff_std,
}
)
if forcing_data is not None:
dsf = xr.Dataset(
{
- "forcing_mean": (["d_forcings"], forcing_mean.data),
- "forcing_std": (["d_forcings"], forcing_std.data),
+ "forcing_mean": forcing_mean,
+ "forcing_std": forcing_std,
}
)
- ds = xr.merge(
- [ds, dsf],
- )
- # Save dataset as Zarr
+ ds = xr.merge([ds, dsf])
+
print("Saving dataset as Zarr...")
+ ds = ds.chunk({"variable": -1})
ds.to_zarr(args.zarr_path, mode="w")
-
-
-if __name__ == "__main__":
- main()
diff --git a/neural_lam/config.py b/neural_lam/config.py
index c81cb02b..8f565675 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -55,25 +55,25 @@ def coords_projection(self):
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)
- @functools.cached_property
- def vars_names(self):
+ @functools.lru_cache()
+ def vars_names(self, category):
"""Return the names of the variables in the dataset."""
- surface_vars_names = self.values["state"]["surface_vars"]
+ surface_vars_names = self.values[category].get("surface_vars") or []
atmosphere_vars_names = [
f"{var}_{level}"
- for var in self.values["state"]["atmosphere_vars"]
- for level in self.values["state"]["levels"]
+ for var in (self.values[category].get("atmosphere_vars") or [])
+ for level in (self.values[category].get("levels") or [])
]
return surface_vars_names + atmosphere_vars_names
- @functools.cached_property
- def vars_units(self):
+ @functools.lru_cache()
+ def vars_units(self, category):
"""Return the units of the variables in the dataset."""
- surface_vars_units = self.values["state"]["surface_units"]
+ surface_vars_units = self.values[category].get("surface_units") or []
atmosphere_vars_units = [
unit
- for unit in self.values["state"]["atmosphere_units"]
- for _ in self.values["state"]["levels"]
+ for unit in (self.values[category].get("atmosphere_units") or [])
+ for _ in (self.values[category].get("levels") or [])
]
return surface_vars_units + atmosphere_vars_units
@@ -130,12 +130,9 @@ def convert_dataset_to_dataarray(self, dataset):
"""Convert the Dataset to a Dataarray."""
if isinstance(dataset, xr.Dataset):
dataset = dataset.to_array()
- print(
- "\033[92mSuccessfully converted Dataset to Dataarray.\033[0m"
- )
- return dataset.to_array()
+ return dataset
- def filter_dimensions(self, dataset, transpose_array=False):
+ def filter_dimensions(self, dataset, transpose_array=True):
"""Filter the dimensions of the dataset."""
dims_to_keep = self.DIMS_TO_KEEP
dataset_dims = set(dataset.to_array().dims)
@@ -190,15 +187,9 @@ def filter_dimensions(self, dataset, transpose_array=False):
else:
dataset = dataset.transpose("grid", "variable")
dataset_vars = (
- dataset["variable"].values.tolist()
- if transpose_array
- else list(dataset.data_vars)
- )
-
- print(
- "\033[94mYour Dataarray has the following dimensions: ",
- dataset.to_array().dims,
- "\033[0m",
+ list(dataset.data_vars)
+ if isinstance(dataset, xr.Dataset)
+ else dataset["variable"].values.tolist()
)
print(
"\033[94mYour Dataarray has the following variables: ",
@@ -240,26 +231,49 @@ def get_xy(self, category):
return xy
- @functools.cached_property
- def load_normalization_stats(self):
+ @functools.lru_cache()
+ def load_normalization_stats(self, category):
"""Load the normalization statistics for the dataset."""
for i, zarr_config in enumerate(
self.values["utilities"]["normalization"]["zarrs"]
):
- normalization_path = zarr_config["path"]
- if not os.path.exists(normalization_path):
+ stats_path = zarr_config["path"]
+ if not os.path.exists(stats_path):
print(
f"Normalization statistics not found at path: "
- f"{normalization_path}"
+ f"{stats_path}"
)
return None
- stats = xr.open_zarr(normalization_path, consolidated=True)
+ stats = xr.open_zarr(stats_path, consolidated=True)
if i == 0:
- normalization_stats = stats
+ combined_stats = stats
else:
- stats = xr.merge([stats, normalization_stats])
- normalization_stats = stats
- return normalization_stats
+ stats = xr.merge([stats, combined_stats])
+ combined_stats = stats
+
+ # Rename data variables
+ vars_mapping = {}
+ zarr_configs = self.values["utilities"]["normalization"]["zarrs"]
+ for zarr_config in zarr_configs:
+ vars_mapping.update(zarr_config["stats_vars"])
+
+ combined_stats = combined_stats.rename_vars(
+ {
+ v: k
+ for k, v in vars_mapping.items()
+ if v in list(combined_stats.data_vars)
+ }
+ )
+
+ stats = combined_stats.loc[dict(variable=self.vars_names(category))]
+ if category == "state":
+ stats = stats.drop_vars(["forcing_mean", "forcing_std"])
+ elif category == "forcing":
+ stats = stats[["forcing_mean", "forcing_std"]]
+ else:
+ print(f"Invalid category: {category}")
+ return None
+ return stats
# def assign_lat_lon_coords(self, category, dataset=None):
# """Process the latitude and longitude names of the dataset."""
@@ -322,8 +336,12 @@ def extract_vars(self, category, dataset=None):
def rename_dataset_dims_and_vars(self, category, dataset=None):
"""Rename the dimensions and variables of the dataset."""
+ convert = False
if dataset is None:
dataset = self.open_zarr(category)
+ elif isinstance(dataset, xr.DataArray):
+ convert = True
+ dataset = dataset.to_dataset("variable")
dims_mapping = {}
zarr_configs = self.values[category]["zarrs"]
for zarr_config in zarr_configs:
@@ -339,6 +357,8 @@ def rename_dataset_dims_and_vars(self, category, dataset=None):
dataset = dataset.rename_vars(
{v: k for k, v in dims_mapping.items() if v in dataset.coords}
)
+ if convert:
+ dataset = dataset.to_array()
return dataset
def filter_dataset_by_time(self, dataset, split="train"):
@@ -351,28 +371,12 @@ def filter_dataset_by_time(self, dataset, split="train"):
def process_dataset(self, category, split="train"):
"""Process the dataset for the given category."""
- print(f"Opening zarr dataset for category: {category}")
dataset = self.open_zarr(category)
-
- print(f"Extracting variables for category: {category}")
dataset = self.extract_vars(category, dataset)
-
- print(f"Filtering dataset by time for split: {split}")
dataset = self.filter_dataset_by_time(dataset, split)
-
- print("Stacking grid dimensions of the dataset")
dataset = self.stack_grid(dataset)
-
- print("Filtering dimensions of the dataset")
- dataset = self.filter_dimensions(dataset)
-
- print(
- "Renaming dataset dimensions and "
- "variables for category: {category}"
- )
dataset = self.rename_dataset_dims_and_vars(category, dataset)
-
- print("Converting dataset to data array")
+ dataset = self.filter_dimensions(dataset)
dataset = self.convert_dataset_to_dataarray(dataset)
return dataset
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 1b7c5ebe..ae4562b7 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -39,10 +39,11 @@ state:
- K
levels:
- 100
-static:
+forcing:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
+ time: time
level: null
x: x
y: y
@@ -51,14 +52,23 @@ static:
lon: lon
lat: lat
surface_vars:
- - pres0m # just as a technical test
+ - cape_column # just as a technical test
+ - icei0m
+ - vis0m
+ - xhail0m
+ surface_units:
+ - J/kg
+ - kg/m^2 # just as a technical test :)
+ - m
+ - m
atmosphere_vars: null
+ atmosphere_units: null
levels: null
-forcing:
+ window: 3 # Number of time steps to use for forcing (odd)
+static:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
- time: time
level: null
x: x
y: y
@@ -67,13 +77,15 @@ forcing:
lon: lon
lat: lat
surface_vars:
- - cape_column # just as a technical test
+ - pres0m # just as a technical test
+ surface_units:
+ - Pa
atmosphere_vars: null
+ atmosphere_units: null
levels: null
- window: 3 # Number of time steps to use for forcing (odd)
boundary:
zarrs:
- - path: /scratch/sadamov/era5_template.zarr
+ - path: "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
dims:
time: time
level: level
@@ -94,6 +106,13 @@ utilities:
forcing_std: forcing_std
diff_mean: diff_mean
diff_std: diff_std
+ combined_stats:
+ - vars:
+ - icei0m
+ - vis0m
+ - vars:
+ - cape_column
+ - xhail0m
boundary_mask:
zarrs:
- path: "boundary.zarr"
From 2f6a87a3c4cc38df042d1d0be582be311d021036 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 11:33:40 +0100
Subject: [PATCH 065/273] add GPU-based runner on cirun.io
---
.cirun.yml | 14 ++++++++++
.../workflows/ci-pip-install-and-test-gpu.yml | 27 +++++++++++++++++++
2 files changed, 41 insertions(+)
create mode 100644 .cirun.yml
create mode 100644 .github/workflows/ci-pip-install-and-test-gpu.yml
diff --git a/.cirun.yml b/.cirun.yml
new file mode 100644
index 00000000..79d62f22
--- /dev/null
+++ b/.cirun.yml
@@ -0,0 +1,14 @@
+# setup for using github runners via https://cirun.io/
+runners:
+ - name: "aws-runner"
+ # Cloud Provider: AWS
+ cloud: "aws"
+ # https://aws.amazon.com/ec2/instance-types/g4/
+ instance_type: "g4dn.xlarge"
+ # Ubuntu-20.4, ami image
+ machine_image: "ami-06fd8a495a537da8b"
+ preemptible: false
+ # Add this label in the "runs-on" param in .github/workflows/.yml
+ # So that this runner is created for running the workflow
+ labels:
+ - "cirun-aws-runner"
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
new file mode 100644
index 00000000..2cc168f0
--- /dev/null
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -0,0 +1,27 @@
+# cicd workflow for running tests with pytest
+# needs to first install pdm, then install torch cpu manually and then install the package
+# then run the tests
+
+name: test (pip install, gpu)
+
+on: [push, pull_request]
+
+jobs:
+ tests:
+ runs-on: "cirun-aws-runner--${{ github.run_id }}"
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+ - name: Install torch (GPU CUDA 12.1)
+ run: |
+ python -m pip install torch --index-url https://download.pytorch.org/whl/cu121
+
+ - name: Install package (including dev dependencies)
+ run: |
+ python -m pip install .
+ python -m pip install pytest
+
+ - name: Run tests
+ run: |
+ python -m pytest
From 668dd8162c64c19b93f2d34bbdb1095dde249d88 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Mon, 3 Jun 2024 16:08:56 +0200
Subject: [PATCH 066/273] improved zarr-based normalization
---
calculate_statistics.py | 37 ++++++++++++++++++-----
neural_lam/config.py | 40 ++++++++++++++++++++++--
neural_lam/data_config.yaml | 1 +
neural_lam/weather_dataset.py | 57 +++++++++++++++++++++++------------
train_model.py | 1 +
5 files changed, 106 insertions(+), 30 deletions(-)
diff --git a/calculate_statistics.py b/calculate_statistics.py
index 3765df2e..90d3dbc0 100644
--- a/calculate_statistics.py
+++ b/calculate_statistics.py
@@ -32,7 +32,9 @@ def main():
config_loader = config.Config.from_file(args.data_config)
state_data = config_loader.process_dataset("state", split="train")
- forcing_data = config_loader.process_dataset("forcing", split="train")
+ forcing_data = config_loader.process_dataset(
+ "forcing", split="train", apply_windowing=False
+ )
print("Computing mean and std.-dev. for parameters...", flush=True)
state_mean, state_std = compute_stats(state_data)
@@ -56,6 +58,16 @@ def main():
dict(variable=vars_to_combine)
] = combined_mean
forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
+ window = config_loader["forcing"]["window"]
+ forcing_mean = xr.concat([forcing_mean] * window, dim="window").stack(
+ forcing_variable=("variable", "window")
+ )
+ forcing_std = xr.concat([forcing_std] * window, dim="window").stack(
+ forcing_variable=("variable", "window")
+ )
+ vars = forcing_data["variable"].values.tolist()
+ window = config_loader["forcing"]["window"]
+ forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
print(
"Computing mean and std.-dev. for one-step differences...", flush=True
@@ -73,14 +85,25 @@ def main():
}
)
if forcing_data is not None:
- dsf = xr.Dataset(
- {
- "forcing_mean": forcing_mean,
- "forcing_std": forcing_std,
- }
+ dsf = (
+ xr.Dataset(
+ {
+ "forcing_mean": forcing_mean,
+ "forcing_std": forcing_std,
+ }
+ )
+ .reset_index(["forcing_variable"])
+ .drop_vars(["variable", "window"])
+ .assign_coords(forcing_variable=forcing_vars)
)
ds = xr.merge([ds, dsf])
+ print(ds)
+
+ ds = ds.chunk({"variable": -1, "forcing_variable": -1})
print("Saving dataset as Zarr...")
- ds = ds.chunk({"variable": -1})
ds.to_zarr(args.zarr_path, mode="w")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 8f565675..d411bd1e 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -7,6 +7,7 @@
import cartopy.crs as ccrs
import numpy as np
import pandas as pd
+import torch
import xarray as xr
import yaml
@@ -232,7 +233,7 @@ def get_xy(self, category):
return xy
@functools.lru_cache()
- def load_normalization_stats(self, category):
+ def load_normalization_stats(self, category, datatype="torch"):
"""Load the normalization statistics for the dataset."""
for i, zarr_config in enumerate(
self.values["utilities"]["normalization"]["zarrs"]
@@ -265,14 +266,30 @@ def load_normalization_stats(self, category):
}
)
- stats = combined_stats.loc[dict(variable=self.vars_names(category))]
if category == "state":
+ stats = combined_stats.loc[
+ dict(variable=self.vars_names(category))
+ ]
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
elif category == "forcing":
+ vars = self.vars_names(category)
+ window = self["forcing"]["window"]
+ forcing_vars = [
+ f"{var}_{i}" for var in vars for i in range(window)
+ ]
+ stats = combined_stats.loc[dict(forcing_variable=forcing_vars)]
stats = stats[["forcing_mean", "forcing_std"]]
else:
print(f"Invalid category: {category}")
return None
+
+ if datatype == "torch":
+ stats_dict = {
+ var: torch.tensor(stats[var].values, dtype=torch.float32)
+ for var in stats.data_vars
+ }
+ return stats_dict
+
return stats
# def assign_lat_lon_coords(self, category, dataset=None):
@@ -369,7 +386,7 @@ def filter_dataset_by_time(self, dataset, split="train"):
)
return dataset.sel(time=slice(start, end))
- def process_dataset(self, category, split="train"):
+ def process_dataset(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category."""
dataset = self.open_zarr(category)
dataset = self.extract_vars(category, dataset)
@@ -378,5 +395,22 @@ def process_dataset(self, category, split="train"):
dataset = self.rename_dataset_dims_and_vars(category, dataset)
dataset = self.filter_dimensions(dataset)
dataset = self.convert_dataset_to_dataarray(dataset)
+ if "window" in self.values[category] and apply_windowing:
+ dataset = self.apply_window(category, dataset)
+
+ return dataset
+ def apply_window(self, category, dataset=None):
+ """Apply the forcing window to the forcing dataset."""
+ if dataset is None:
+ dataset = self.open_zarr(category)
+ state_time = self.open_zarr("state").time.values
+ window = self[category].window
+ dataset = (
+ dataset.sel(time=state_time, method="nearest")
+ .pad(time=(window // 2, window // 2), mode="edge")
+ .rolling(time=window, center=True)
+ .construct("window")
+ .stack(variable_window=("variable", "window"))
+ )
return dataset
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index ae4562b7..2f7261c0 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -95,6 +95,7 @@ boundary:
lon: longitude
lat: latitude
mask: boundary_mask
+ window: 3
utilities:
normalization:
zarrs:
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 23036ebd..c25b0452 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -20,6 +20,7 @@ def __init__(
split="train",
ar_steps=3,
batch_size=4,
+ standardize=True,
control_only=False,
data_config="neural_lam/data_config.yaml",
):
@@ -35,31 +36,35 @@ def __init__(
self.batch_size = batch_size
self.ar_steps = ar_steps
self.control_only = control_only
- self.config_loader = config.Config(data_config)
+ self.config_loader = config.Config.from_file(data_config)
self.state = self.config_loader.process_dataset("state", self.split)
assert self.state is not None, "State dataset not found"
self.forcing = self.config_loader.process_dataset(
"forcing", self.split
)
-
self.state_times = self.state.time.values
- self.forcing_window = self.config_loader.forcing.window
- self.boundary_window = self.config_loader.boundary.window
-
- if self.forcing is not None:
- self.forcing_windowed = (
- self.forcing.sel(
- time=self.state.time,
- method="nearest",
+
+ # Set up for standardization
+ # NOTE: This will become part of ar_model.py soon!
+ self.standardize = standardize
+ if standardize:
+ state_stats = self.config_loader.load_normalization_stats(
+ "state", datatype="torch"
+ )
+ self.state_mean, self.state_std = (
+ state_stats["state_mean"],
+ state_stats["state_std"],
+ )
+
+ if self.forcing is not None:
+ forcing_stats = self.config_loader.load_normalization_stats(
+ "forcing", datatype="torch"
)
- .pad(
- time=(self.forcing_window // 2, self.forcing_window // 2),
- mode="edge",
+ self.forcing_mean, self.forcing_std = (
+ forcing_stats["forcing_mean"],
+ forcing_stats["forcing_std"],
)
- .rolling(time=self.forcing_window, center=True)
- .construct("window")
- )
def __len__(self):
# Skip first and last time step
@@ -72,11 +77,11 @@ def __getitem__(self, idx):
)
forcing = (
- self.forcing_windowed.isel(
- time=slice(idx + 2, idx + self.ar_steps)
+ torch.tensor(
+ self.forcing.isel(
+ time=slice(idx + 2, idx + self.ar_steps)
+ ).values
)
- .stack(variable_window=("variable", "window"))
- .values
if self.forcing is not None
else torch.tensor([])
)
@@ -90,6 +95,13 @@ def __getitem__(self, idx):
.tolist()
)
+ if self.standardize:
+ init_states = (init_states - self.state_mean) / self.state_std
+ target_states = (target_states - self.state_mean) / self.state_std
+
+ if self.forcing is not None:
+ forcing = (forcing - self.forcing_mean) / self.forcing_std
+
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps-2, N_grid, d_features)
# forcing: (ar_steps-2, N_grid, d_windowed_forcing)
@@ -104,12 +116,14 @@ def __init__(
self,
ar_steps_train=3,
ar_steps_eval=25,
+ standardize=True,
batch_size=4,
num_workers=16,
):
super().__init__()
self.ar_steps_train = ar_steps_train
self.ar_steps_eval = ar_steps_eval
+ self.standardize = standardize
self.batch_size = batch_size
self.num_workers = num_workers
self.train_dataset = None
@@ -121,11 +135,13 @@ def setup(self, stage=None):
self.train_dataset = WeatherDataset(
split="train",
ar_steps=self.ar_steps_train,
+ standardize=self.standardize,
batch_size=self.batch_size,
)
self.val_dataset = WeatherDataset(
split="val",
ar_steps=self.ar_steps_eval,
+ standardize=self.standardize,
batch_size=self.batch_size,
)
@@ -133,6 +149,7 @@ def setup(self, stage=None):
self.test_dataset = WeatherDataset(
split="test",
ar_steps=self.ar_steps_eval,
+ standardize=self.standardize,
batch_size=self.batch_size,
)
diff --git a/train_model.py b/train_model.py
index 0cbb2e82..1b985ef0 100644
--- a/train_model.py
+++ b/train_model.py
@@ -238,6 +238,7 @@ def main():
data_module = WeatherDataModule(
ar_steps_train=args.ar_steps_train,
ar_steps_eval=args.ar_steps_eval,
+ standardize=True,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
From 143cf2aee531d4ea0baa3fbb0cdf4e2b2b7ffcf7 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 18:02:57 +0200
Subject: [PATCH 067/273] pdm install with cpu torch
---
.github/workflows/ci-pdm-install-and-test.yml | 13 +-
.gitignore | 2 +
pdm.lock | 2015 -----------------
pyproject.toml | 1 +
4 files changed, 12 insertions(+), 2019 deletions(-)
delete mode 100644 pdm.lock
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index 20b5fc14..f13552b9 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -14,14 +14,19 @@ jobs:
uses: actions/checkout@v2
- name: Install pdm
- uses: pdm-project/setup-pdm@v4
- with:
- python-version: "3.10"
- cache: true
+ run: |
+ python -m pip install pdm
+
+ - name: Create venv
+ run: |
+ pdm venv create --with-pip
+ pdm use --venv in-project
- name: Install torch (CPU)
run: |
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
+ # check that the CPU version is installed
+ python -c "import torch; assert torch.__version__.endswith('+gpu')"
- name: Install package (including dev dependencies)
run: |
diff --git a/.gitignore b/.gitignore
index ede00bca..2a12cf57 100644
--- a/.gitignore
+++ b/.gitignore
@@ -76,3 +76,5 @@ tags
# pdm (https://pdm-project.org/en/stable/)
.pdm-python
+# exclude pdm.lock file so that both cpu and gpu versions of torch will be accepted by pdm
+pdm.lock
diff --git a/pdm.lock b/pdm.lock
deleted file mode 100644
index 6ea24bcf..00000000
--- a/pdm.lock
+++ /dev/null
@@ -1,2015 +0,0 @@
-# This file is @generated by PDM.
-# It is not intended for manual editing.
-
-[metadata]
-groups = ["default", "dev"]
-strategy = ["cross_platform", "inherit_metadata"]
-lock_version = "4.4.1"
-content_hash = "sha256:c4f5df1487409a1cd6d45a6155c3aff846c7deca9787b9e0003e2d850a4f27c8"
-
-[[package]]
-name = "aiohttp"
-version = "3.9.5"
-requires_python = ">=3.8"
-summary = "Async http client/server framework (asyncio)"
-groups = ["default"]
-dependencies = [
- "aiosignal>=1.1.2",
- "async-timeout<5.0,>=4.0; python_version < \"3.11\"",
- "attrs>=17.3.0",
- "frozenlist>=1.1.1",
- "multidict<7.0,>=4.5",
- "yarl<2.0,>=1.0",
-]
-files = [
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"},
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"},
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"},
- {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"},
- {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"},
- {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"},
- {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"},
- {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"},
- {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"},
- {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"},
- {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"},
- {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"},
- {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"},
- {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"},
- {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"},
-]
-
-[[package]]
-name = "aiosignal"
-version = "1.3.1"
-requires_python = ">=3.7"
-summary = "aiosignal: a list of registered asynchronous callbacks"
-groups = ["default"]
-dependencies = [
- "frozenlist>=1.1.0",
-]
-files = [
- {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
- {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
-]
-
-[[package]]
-name = "async-timeout"
-version = "4.0.3"
-requires_python = ">=3.7"
-summary = "Timeout context manager for asyncio programs"
-groups = ["default"]
-marker = "python_version < \"3.11\""
-files = [
- {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
- {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
-]
-
-[[package]]
-name = "attrs"
-version = "23.2.0"
-requires_python = ">=3.7"
-summary = "Classes Without Boilerplate"
-groups = ["default"]
-files = [
- {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
- {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
-]
-
-[[package]]
-name = "cartopy"
-version = "0.23.0"
-requires_python = ">=3.9"
-summary = "A Python library for cartographic visualizations with Matplotlib"
-groups = ["default"]
-dependencies = [
- "matplotlib>=3.5",
- "numpy>=1.21",
- "packaging>=20",
- "pyproj>=3.3.1",
- "pyshp>=2.3",
- "shapely>=1.7",
-]
-files = [
- {file = "Cartopy-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:374e66f816c3bafa48ffdbf6abaefa67063b405fac5f425f9be241cdf3498352"},
- {file = "Cartopy-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2bae450c4c913796cad0b7ce05aa2fa78d1788de47989f0a03183397648e24be"},
- {file = "Cartopy-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40437596e8ac5e74575eab822c661f4e725bd995cfd9e445069695fe9086b42"},
- {file = "Cartopy-0.23.0-cp310-cp310-win_amd64.whl", hash = "sha256:3292d6d403137eed80d32014c2f28de6282bed8824213f4b4c2170f388b24a1b"},
- {file = "Cartopy-0.23.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86b07b6794b616674e4e485b8574e9197bca54a4467d28dd01ae0bf178f8dc2b"},
- {file = "Cartopy-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8dece2aa8d5ff7bf989ded6b5f07c980fb5bb772952bc7cdeab469738abdecee"},
- {file = "Cartopy-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9dfd28352dc83d6b4e4cf85d84cb50fc4886d4c1510d61f4c7cf22477d1156f"},
- {file = "Cartopy-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:b2671b5354e43220f8e1074e7fe30a8b9f71cb38407c78e51db9c97772f0320b"},
- {file = "Cartopy-0.23.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:80b9fd666fd47f6370d29f7ad4e352828d54aaf688a03d0b83b51e141cfd77fa"},
- {file = "Cartopy-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:43e36b8b7e7e373a5698757458fd28fafbbbf5f3ebbe2d378f6a5ec3993d6dc0"},
- {file = "Cartopy-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550173b91155d4d81cd14b4892cb6cabe3dd32bd34feacaa1ec78c0e56287832"},
- {file = "Cartopy-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:55219ee0fb069cc3254426e87382cde03546e86c3f7c6759f076823b1e3a44d9"},
- {file = "Cartopy-0.23.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6279af846bf77d9817ab8792a8e38ca561878f048bba1afdae3e3a30c5432bfd"},
- {file = "Cartopy-0.23.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843bf9dc0a18e1a8eed872c49e8092e8a8109e4dce285ad96752841e21e8161e"},
- {file = "Cartopy-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:350ff8802e2bc617c09bd6148aeb46e841775a846bfaa6e635a212d1eaf5ab66"},
- {file = "Cartopy-0.23.0-cp39-cp39-win_amd64.whl", hash = "sha256:b52ab2274ad7504955854ef8d6f603e41f5d7163d02b29d369cecdbd29c2fda1"},
- {file = "Cartopy-0.23.0.tar.gz", hash = "sha256:231f37b35701f2ba31d94959cca75e6da04c2eea3a7f14ce1c75ee3b0eae7676"},
-]
-
-[[package]]
-name = "certifi"
-version = "2024.2.2"
-requires_python = ">=3.6"
-summary = "Python package for providing Mozilla's CA Bundle."
-groups = ["default"]
-files = [
- {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"},
- {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"},
-]
-
-[[package]]
-name = "cfgv"
-version = "3.4.0"
-requires_python = ">=3.8"
-summary = "Validate configuration and produce human readable error messages."
-groups = ["dev"]
-files = [
- {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
- {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
-]
-
-[[package]]
-name = "charset-normalizer"
-version = "3.3.2"
-requires_python = ">=3.7.0"
-summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
-groups = ["default"]
-files = [
- {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"},
- {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"},
- {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
-]
-
-[[package]]
-name = "click"
-version = "8.1.7"
-requires_python = ">=3.7"
-summary = "Composable command line interface toolkit"
-groups = ["default"]
-dependencies = [
- "colorama; platform_system == \"Windows\"",
-]
-files = [
- {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
- {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
-]
-
-[[package]]
-name = "colorama"
-version = "0.4.6"
-requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
-summary = "Cross-platform colored terminal text."
-groups = ["default", "dev"]
-marker = "sys_platform == \"win32\" or platform_system == \"Windows\""
-files = [
- {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
- {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
-]
-
-[[package]]
-name = "contourpy"
-version = "1.2.1"
-requires_python = ">=3.9"
-summary = "Python library for calculating contours of 2D quadrilateral grids"
-groups = ["default"]
-dependencies = [
- "numpy>=1.20",
-]
-files = [
- {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"},
- {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"},
- {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"},
- {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"},
- {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"},
- {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"},
- {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"},
- {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"},
- {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"},
- {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"},
- {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"},
- {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"},
- {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"},
- {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"},
- {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"},
- {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"},
- {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"},
- {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"},
- {file = "contourpy-1.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b"},
- {file = "contourpy-1.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f"},
- {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364"},
- {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe"},
- {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985"},
- {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445"},
- {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02"},
- {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083"},
- {file = "contourpy-1.2.1-cp39-cp39-win32.whl", hash = "sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba"},
- {file = "contourpy-1.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9"},
- {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"},
- {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"},
- {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"},
- {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"},
-]
-
-[[package]]
-name = "cycler"
-version = "0.12.1"
-requires_python = ">=3.8"
-summary = "Composable style cycles"
-groups = ["default"]
-files = [
- {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"},
- {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"},
-]
-
-[[package]]
-name = "distlib"
-version = "0.3.8"
-summary = "Distribution utilities"
-groups = ["dev"]
-files = [
- {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
- {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
-]
-
-[[package]]
-name = "docker-pycreds"
-version = "0.4.0"
-summary = "Python bindings for the docker credentials store API"
-groups = ["default"]
-dependencies = [
- "six>=1.4.0",
-]
-files = [
- {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"},
- {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"},
-]
-
-[[package]]
-name = "exceptiongroup"
-version = "1.2.1"
-requires_python = ">=3.7"
-summary = "Backport of PEP 654 (exception groups)"
-groups = ["dev"]
-marker = "python_version < \"3.11\""
-files = [
- {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"},
- {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"},
-]
-
-[[package]]
-name = "filelock"
-version = "3.14.0"
-requires_python = ">=3.8"
-summary = "A platform independent file lock."
-groups = ["default", "dev"]
-files = [
- {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
- {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
-]
-
-[[package]]
-name = "fonttools"
-version = "4.51.0"
-requires_python = ">=3.8"
-summary = "Tools to manipulate font files"
-groups = ["default"]
-files = [
- {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74"},
- {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308"},
- {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037"},
- {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716"},
- {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438"},
- {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039"},
- {file = "fonttools-4.51.0-cp310-cp310-win32.whl", hash = "sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77"},
- {file = "fonttools-4.51.0-cp310-cp310-win_amd64.whl", hash = "sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b"},
- {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74"},
- {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2"},
- {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f"},
- {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097"},
- {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0"},
- {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1"},
- {file = "fonttools-4.51.0-cp311-cp311-win32.whl", hash = "sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034"},
- {file = "fonttools-4.51.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1"},
- {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba"},
- {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc"},
- {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a"},
- {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2"},
- {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671"},
- {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5"},
- {file = "fonttools-4.51.0-cp312-cp312-win32.whl", hash = "sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15"},
- {file = "fonttools-4.51.0-cp312-cp312-win_amd64.whl", hash = "sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e"},
- {file = "fonttools-4.51.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b"},
- {file = "fonttools-4.51.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936"},
- {file = "fonttools-4.51.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55"},
- {file = "fonttools-4.51.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce"},
- {file = "fonttools-4.51.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051"},
- {file = "fonttools-4.51.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7"},
- {file = "fonttools-4.51.0-cp39-cp39-win32.whl", hash = "sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636"},
- {file = "fonttools-4.51.0-cp39-cp39-win_amd64.whl", hash = "sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a"},
- {file = "fonttools-4.51.0-py3-none-any.whl", hash = "sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f"},
- {file = "fonttools-4.51.0.tar.gz", hash = "sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68"},
-]
-
-[[package]]
-name = "frozenlist"
-version = "1.4.1"
-requires_python = ">=3.8"
-summary = "A list-like structure which implements collections.abc.MutableSequence"
-groups = ["default"]
-files = [
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"},
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"},
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"},
- {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"},
- {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"},
- {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"},
- {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"},
- {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"},
- {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"},
- {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"},
- {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"},
- {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"},
- {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"},
- {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"},
- {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"},
- {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
-]
-
-[[package]]
-name = "fsspec"
-version = "2024.5.0"
-requires_python = ">=3.8"
-summary = "File-system specification"
-groups = ["default"]
-files = [
- {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"},
- {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"},
-]
-
-[[package]]
-name = "fsspec"
-version = "2024.5.0"
-extras = ["http"]
-requires_python = ">=3.8"
-summary = "File-system specification"
-groups = ["default"]
-dependencies = [
- "aiohttp!=4.0.0a0,!=4.0.0a1",
- "fsspec==2024.5.0",
-]
-files = [
- {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"},
- {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"},
-]
-
-[[package]]
-name = "gitdb"
-version = "4.0.11"
-requires_python = ">=3.7"
-summary = "Git Object Database"
-groups = ["default"]
-dependencies = [
- "smmap<6,>=3.0.1",
-]
-files = [
- {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"},
- {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"},
-]
-
-[[package]]
-name = "gitpython"
-version = "3.1.43"
-requires_python = ">=3.7"
-summary = "GitPython is a Python library used to interact with Git repositories"
-groups = ["default"]
-dependencies = [
- "gitdb<5,>=4.0.1",
-]
-files = [
- {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"},
- {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"},
-]
-
-[[package]]
-name = "identify"
-version = "2.5.36"
-requires_python = ">=3.8"
-summary = "File identification library for Python"
-groups = ["dev"]
-files = [
- {file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"},
- {file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"},
-]
-
-[[package]]
-name = "idna"
-version = "3.7"
-requires_python = ">=3.5"
-summary = "Internationalized Domain Names in Applications (IDNA)"
-groups = ["default"]
-files = [
- {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"},
- {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
-]
-
-[[package]]
-name = "importlib-resources"
-version = "6.4.0"
-requires_python = ">=3.8"
-summary = "Read resources from Python packages"
-groups = ["default"]
-marker = "python_version < \"3.10\""
-dependencies = [
- "zipp>=3.1.0; python_version < \"3.10\"",
-]
-files = [
- {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"},
- {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"},
-]
-
-[[package]]
-name = "iniconfig"
-version = "2.0.0"
-requires_python = ">=3.7"
-summary = "brain-dead simple config-ini parsing"
-groups = ["dev"]
-files = [
- {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
- {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
-]
-
-[[package]]
-name = "intel-openmp"
-version = "2021.4.0"
-summary = "Intel® OpenMP* Runtime Library"
-groups = ["default"]
-marker = "platform_system == \"Windows\""
-files = [
- {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
-]
-
-[[package]]
-name = "jinja2"
-version = "3.1.4"
-requires_python = ">=3.7"
-summary = "A very fast and expressive template engine."
-groups = ["default"]
-dependencies = [
- "MarkupSafe>=2.0",
-]
-files = [
- {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"},
- {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"},
-]
-
-[[package]]
-name = "kiwisolver"
-version = "1.4.5"
-requires_python = ">=3.7"
-summary = "A fast implementation of the Cassowary constraint solver"
-groups = ["default"]
-files = [
- {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"},
- {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"},
- {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"},
- {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"},
- {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"},
- {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"},
- {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"},
- {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"},
- {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"},
- {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"},
- {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"},
- {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"},
- {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"},
- {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"},
- {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"},
- {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d"},
- {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9"},
- {file = "kiwisolver-1.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046"},
- {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0"},
- {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff"},
- {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54"},
- {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958"},
- {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3"},
- {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf"},
- {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901"},
- {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9"},
- {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342"},
- {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77"},
- {file = "kiwisolver-1.4.5-cp39-cp39-win32.whl", hash = "sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f"},
- {file = "kiwisolver-1.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"},
- {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
-]
-
-[[package]]
-name = "lightning-utilities"
-version = "0.11.2"
-requires_python = ">=3.8"
-summary = "Lightning toolbox for across the our ecosystem."
-groups = ["default"]
-dependencies = [
- "packaging>=17.1",
- "setuptools",
- "typing-extensions",
-]
-files = [
- {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"},
- {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"},
-]
-
-[[package]]
-name = "markupsafe"
-version = "2.1.5"
-requires_python = ">=3.7"
-summary = "Safely add untrusted strings to HTML/XML markup."
-groups = ["default"]
-files = [
- {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"},
- {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"},
- {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
-]
-
-[[package]]
-name = "matplotlib"
-version = "3.9.0"
-requires_python = ">=3.9"
-summary = "Python plotting package"
-groups = ["default"]
-dependencies = [
- "contourpy>=1.0.1",
- "cycler>=0.10",
- "fonttools>=4.22.0",
- "importlib-resources>=3.2.0; python_version < \"3.10\"",
- "kiwisolver>=1.3.1",
- "numpy>=1.23",
- "packaging>=20.0",
- "pillow>=8",
- "pyparsing>=2.3.1",
- "python-dateutil>=2.7",
-]
-files = [
- {file = "matplotlib-3.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56"},
- {file = "matplotlib-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b"},
- {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241"},
- {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d"},
- {file = "matplotlib-3.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4"},
- {file = "matplotlib-3.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463"},
- {file = "matplotlib-3.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38"},
- {file = "matplotlib-3.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152"},
- {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85"},
- {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb"},
- {file = "matplotlib-3.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674"},
- {file = "matplotlib-3.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be"},
- {file = "matplotlib-3.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382"},
- {file = "matplotlib-3.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84"},
- {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5"},
- {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db"},
- {file = "matplotlib-3.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7"},
- {file = "matplotlib-3.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf"},
- {file = "matplotlib-3.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956"},
- {file = "matplotlib-3.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a"},
- {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321"},
- {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89"},
- {file = "matplotlib-3.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b"},
- {file = "matplotlib-3.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888"},
- {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0"},
- {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03"},
- {file = "matplotlib-3.9.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd"},
- {file = "matplotlib-3.9.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e"},
- {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"},
-]
-
-[[package]]
-name = "mkl"
-version = "2021.4.0"
-summary = "Intel® oneAPI Math Kernel Library"
-groups = ["default"]
-marker = "platform_system == \"Windows\""
-dependencies = [
- "intel-openmp==2021.*",
- "tbb==2021.*",
-]
-files = [
- {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"},
- {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"},
- {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"},
- {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"},
- {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"},
-]
-
-[[package]]
-name = "mpmath"
-version = "1.3.0"
-summary = "Python library for arbitrary-precision floating-point arithmetic"
-groups = ["default"]
-files = [
- {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
- {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
-]
-
-[[package]]
-name = "multidict"
-version = "6.0.5"
-requires_python = ">=3.7"
-summary = "multidict implementation"
-groups = ["default"]
-files = [
- {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"},
- {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"},
- {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"},
- {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"},
- {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"},
- {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"},
- {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"},
- {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"},
- {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"},
- {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"},
- {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"},
- {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"},
- {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"},
- {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"},
- {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"},
- {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
-]
-
-[[package]]
-name = "networkx"
-version = "3.2.1"
-requires_python = ">=3.9"
-summary = "Python package for creating and manipulating graphs and networks"
-groups = ["default"]
-files = [
- {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"},
- {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"},
-]
-
-[[package]]
-name = "nodeenv"
-version = "1.8.0"
-requires_python = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
-summary = "Node.js virtual environment builder"
-groups = ["dev"]
-dependencies = [
- "setuptools",
-]
-files = [
- {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
- {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
-]
-
-[[package]]
-name = "numpy"
-version = "1.26.4"
-requires_python = ">=3.9"
-summary = "Fundamental package for array computing in Python"
-groups = ["default"]
-files = [
- {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
- {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
- {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
- {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
- {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
- {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
- {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
- {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
- {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
- {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
- {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
- {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
- {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
- {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
- {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
- {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
- {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
- {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
- {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
- {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
- {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
- {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
- {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
- {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
- {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
- {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
- {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
- {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
- {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
- {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
- {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
- {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
- {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
- {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
- {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
- {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
-]
-
-[[package]]
-name = "nvidia-cublas-cu12"
-version = "12.1.3.1"
-requires_python = ">=3"
-summary = "CUBLAS native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"},
- {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"},
-]
-
-[[package]]
-name = "nvidia-cuda-cupti-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "CUDA profiling tools runtime libs."
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"},
- {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"},
-]
-
-[[package]]
-name = "nvidia-cuda-nvrtc-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "NVRTC native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"},
- {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"},
-]
-
-[[package]]
-name = "nvidia-cuda-runtime-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "CUDA Runtime native Libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"},
- {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"},
-]
-
-[[package]]
-name = "nvidia-cudnn-cu12"
-version = "8.9.2.26"
-requires_python = ">=3"
-summary = "cuDNN runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-dependencies = [
- "nvidia-cublas-cu12",
-]
-files = [
- {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"},
-]
-
-[[package]]
-name = "nvidia-cufft-cu12"
-version = "11.0.2.54"
-requires_python = ">=3"
-summary = "CUFFT native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"},
- {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"},
-]
-
-[[package]]
-name = "nvidia-curand-cu12"
-version = "10.3.2.106"
-requires_python = ">=3"
-summary = "CURAND native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"},
- {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"},
-]
-
-[[package]]
-name = "nvidia-cusolver-cu12"
-version = "11.4.5.107"
-requires_python = ">=3"
-summary = "CUDA solver native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-dependencies = [
- "nvidia-cublas-cu12",
- "nvidia-cusparse-cu12",
- "nvidia-nvjitlink-cu12",
-]
-files = [
- {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"},
- {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"},
-]
-
-[[package]]
-name = "nvidia-cusparse-cu12"
-version = "12.1.0.106"
-requires_python = ">=3"
-summary = "CUSPARSE native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-dependencies = [
- "nvidia-nvjitlink-cu12",
-]
-files = [
- {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"},
- {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"},
-]
-
-[[package]]
-name = "nvidia-nccl-cu12"
-version = "2.20.5"
-requires_python = ">=3"
-summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"},
- {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"},
-]
-
-[[package]]
-name = "nvidia-nvjitlink-cu12"
-version = "12.5.40"
-requires_python = ">=3"
-summary = "Nvidia JIT LTO Library"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"},
- {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"},
-]
-
-[[package]]
-name = "nvidia-nvtx-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "NVIDIA Tools Extension"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"},
- {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
-]
-
-[[package]]
-name = "packaging"
-version = "24.0"
-requires_python = ">=3.7"
-summary = "Core utilities for Python packages"
-groups = ["default", "dev"]
-files = [
- {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
- {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
-]
-
-[[package]]
-name = "pillow"
-version = "10.3.0"
-requires_python = ">=3.8"
-summary = "Python Imaging Library (Fork)"
-groups = ["default"]
-files = [
- {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"},
- {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"},
- {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"},
- {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"},
- {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"},
- {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"},
- {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"},
- {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"},
- {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"},
- {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"},
- {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"},
- {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"},
- {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"},
- {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"},
- {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"},
- {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"},
- {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"},
- {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"},
- {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"},
- {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"},
- {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"},
- {file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"},
- {file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"},
- {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"},
- {file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"},
- {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"},
- {file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"},
- {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"},
- {file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"},
- {file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"},
- {file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"},
- {file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"},
- {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"},
-]
-
-[[package]]
-name = "platformdirs"
-version = "4.2.2"
-requires_python = ">=3.8"
-summary = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
-groups = ["default", "dev"]
-files = [
- {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"},
- {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"},
-]
-
-[[package]]
-name = "plotly"
-version = "5.22.0"
-requires_python = ">=3.8"
-summary = "An open-source, interactive data visualization library for Python"
-groups = ["default"]
-dependencies = [
- "packaging",
- "tenacity>=6.2.0",
-]
-files = [
- {file = "plotly-5.22.0-py3-none-any.whl", hash = "sha256:68fc1901f098daeb233cc3dd44ec9dc31fb3ca4f4e53189344199c43496ed006"},
- {file = "plotly-5.22.0.tar.gz", hash = "sha256:859fdadbd86b5770ae2466e542b761b247d1c6b49daed765b95bb8c7063e7469"},
-]
-
-[[package]]
-name = "pluggy"
-version = "1.5.0"
-requires_python = ">=3.8"
-summary = "plugin and hook calling mechanisms for python"
-groups = ["dev"]
-files = [
- {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
- {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
-]
-
-[[package]]
-name = "pre-commit"
-version = "3.7.1"
-requires_python = ">=3.9"
-summary = "A framework for managing and maintaining multi-language pre-commit hooks."
-groups = ["dev"]
-dependencies = [
- "cfgv>=2.0.0",
- "identify>=1.0.0",
- "nodeenv>=0.11.1",
- "pyyaml>=5.1",
- "virtualenv>=20.10.0",
-]
-files = [
- {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"},
- {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"},
-]
-
-[[package]]
-name = "protobuf"
-version = "4.25.3"
-requires_python = ">=3.8"
-summary = ""
-groups = ["default"]
-marker = "python_version >= \"3.9\" or sys_platform != \"linux\""
-files = [
- {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"},
- {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"},
- {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"},
- {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"},
- {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"},
- {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"},
- {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"},
- {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"},
- {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
-]
-
-[[package]]
-name = "psutil"
-version = "5.9.8"
-requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
-summary = "Cross-platform lib for process and system monitoring in Python."
-groups = ["default"]
-files = [
- {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"},
- {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"},
- {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"},
- {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"},
- {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"},
- {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"},
- {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"},
-]
-
-[[package]]
-name = "pyparsing"
-version = "3.1.2"
-requires_python = ">=3.6.8"
-summary = "pyparsing module - Classes and methods to define and execute parsing grammars"
-groups = ["default"]
-files = [
- {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
- {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
-]
-
-[[package]]
-name = "pyproj"
-version = "3.6.1"
-requires_python = ">=3.9"
-summary = "Python interface to PROJ (cartographic projections and coordinate transformations library)"
-groups = ["default"]
-dependencies = [
- "certifi",
-]
-files = [
- {file = "pyproj-3.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ab7aa4d9ff3c3acf60d4b285ccec134167a948df02347585fdd934ebad8811b4"},
- {file = "pyproj-3.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bc0472302919e59114aa140fd7213c2370d848a7249d09704f10f5b062031fe"},
- {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5279586013b8d6582e22b6f9e30c49796966770389a9d5b85e25a4223286cd3f"},
- {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80fafd1f3eb421694857f254a9bdbacd1eb22fc6c24ca74b136679f376f97d35"},
- {file = "pyproj-3.6.1-cp310-cp310-win32.whl", hash = "sha256:c41e80ddee130450dcb8829af7118f1ab69eaf8169c4bf0ee8d52b72f098dc2f"},
- {file = "pyproj-3.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:db3aedd458e7f7f21d8176f0a1d924f1ae06d725228302b872885a1c34f3119e"},
- {file = "pyproj-3.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ebfbdbd0936e178091309f6cd4fcb4decd9eab12aa513cdd9add89efa3ec2882"},
- {file = "pyproj-3.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:447db19c7efad70ff161e5e46a54ab9cc2399acebb656b6ccf63e4bc4a04b97a"},
- {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7e13c40183884ec7f94eb8e0f622f08f1d5716150b8d7a134de48c6110fee85"},
- {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65ad699e0c830e2b8565afe42bd58cc972b47d829b2e0e48ad9638386d994915"},
- {file = "pyproj-3.6.1-cp311-cp311-win32.whl", hash = "sha256:8b8acc31fb8702c54625f4d5a2a6543557bec3c28a0ef638778b7ab1d1772132"},
- {file = "pyproj-3.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:38a3361941eb72b82bd9a18f60c78b0df8408416f9340521df442cebfc4306e2"},
- {file = "pyproj-3.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1e9fbaf920f0f9b4ee62aab832be3ae3968f33f24e2e3f7fbb8c6728ef1d9746"},
- {file = "pyproj-3.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d227a865356f225591b6732430b1d1781e946893789a609bb34f59d09b8b0f8"},
- {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83039e5ae04e5afc974f7d25ee0870a80a6bd6b7957c3aca5613ccbe0d3e72bf"},
- {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb059ba3bced6f6725961ba758649261d85ed6ce670d3e3b0a26e81cf1aa8d"},
- {file = "pyproj-3.6.1-cp312-cp312-win32.whl", hash = "sha256:2d6ff73cc6dbbce3766b6c0bce70ce070193105d8de17aa2470009463682a8eb"},
- {file = "pyproj-3.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:7a27151ddad8e1439ba70c9b4b2b617b290c39395fa9ddb7411ebb0eb86d6fb0"},
- {file = "pyproj-3.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4ba1f9b03d04d8cab24d6375609070580a26ce76eaed54631f03bab00a9c737b"},
- {file = "pyproj-3.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18faa54a3ca475bfe6255156f2f2874e9a1c8917b0004eee9f664b86ccc513d3"},
- {file = "pyproj-3.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd43bd9a9b9239805f406fd82ba6b106bf4838d9ef37c167d3ed70383943ade1"},
- {file = "pyproj-3.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50100b2726a3ca946906cbaa789dd0749f213abf0cbb877e6de72ca7aa50e1ae"},
- {file = "pyproj-3.6.1-cp39-cp39-win32.whl", hash = "sha256:9274880263256f6292ff644ca92c46d96aa7e57a75c6df3f11d636ce845a1877"},
- {file = "pyproj-3.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:36b64c2cb6ea1cc091f329c5bd34f9c01bb5da8c8e4492c709bda6a09f96808f"},
- {file = "pyproj-3.6.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd93c1a0c6c4aedc77c0fe275a9f2aba4d59b8acf88cebfc19fe3c430cfabf4f"},
- {file = "pyproj-3.6.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6420ea8e7d2a88cb148b124429fba8cd2e0fae700a2d96eab7083c0928a85110"},
- {file = "pyproj-3.6.1.tar.gz", hash = "sha256:44aa7c704c2b7d8fb3d483bbf75af6cb2350d30a63b144279a09b75fead501bf"},
-]
-
-[[package]]
-name = "pyshp"
-version = "2.3.1"
-requires_python = ">=2.7"
-summary = "Pure Python read/write support for ESRI Shapefile format"
-groups = ["default"]
-files = [
- {file = "pyshp-2.3.1-py2.py3-none-any.whl", hash = "sha256:67024c0ccdc352ba5db777c4e968483782dfa78f8e200672a90d2d30fd8b7b49"},
- {file = "pyshp-2.3.1.tar.gz", hash = "sha256:4caec82fd8dd096feba8217858068bacb2a3b5950f43c048c6dc32a3489d5af1"},
-]
-
-[[package]]
-name = "pytest"
-version = "8.2.1"
-requires_python = ">=3.8"
-summary = "pytest: simple powerful testing with Python"
-groups = ["dev"]
-dependencies = [
- "colorama; sys_platform == \"win32\"",
- "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"",
- "iniconfig",
- "packaging",
- "pluggy<2.0,>=1.5",
- "tomli>=1; python_version < \"3.11\"",
-]
-files = [
- {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"},
- {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"},
-]
-
-[[package]]
-name = "python-dateutil"
-version = "2.9.0.post0"
-requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
-summary = "Extensions to the standard Python datetime module"
-groups = ["default"]
-dependencies = [
- "six>=1.5",
-]
-files = [
- {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
- {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
-]
-
-[[package]]
-name = "pytorch-lightning"
-version = "2.2.4"
-requires_python = ">=3.8"
-summary = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
-groups = ["default"]
-dependencies = [
- "PyYAML>=5.4",
- "fsspec[http]>=2022.5.0",
- "lightning-utilities>=0.8.0",
- "numpy>=1.17.2",
- "packaging>=20.0",
- "torch>=1.13.0",
- "torchmetrics>=0.7.0",
- "tqdm>=4.57.0",
- "typing-extensions>=4.4.0",
-]
-files = [
- {file = "pytorch-lightning-2.2.4.tar.gz", hash = "sha256:525b04ebad9900c3e3c2a12b3b462fe4f61ebe11fdb694716c3209f05b9b0fa8"},
- {file = "pytorch_lightning-2.2.4-py3-none-any.whl", hash = "sha256:fd91d47e983a2cd743c5c8c3c3795bbd0f3b69d24be2172a2f9012d930701ff2"},
-]
-
-[[package]]
-name = "pyyaml"
-version = "6.0.1"
-requires_python = ">=3.6"
-summary = "YAML parser and emitter for Python"
-groups = ["default", "dev"]
-files = [
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
- {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
- {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
- {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
- {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
- {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
- {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
- {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
- {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
- {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
- {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
- {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
- {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
- {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
-]
-
-[[package]]
-name = "requests"
-version = "2.32.2"
-requires_python = ">=3.8"
-summary = "Python HTTP for Humans."
-groups = ["default"]
-dependencies = [
- "certifi>=2017.4.17",
- "charset-normalizer<4,>=2",
- "idna<4,>=2.5",
- "urllib3<3,>=1.21.1",
-]
-files = [
- {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
- {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
-]
-
-[[package]]
-name = "scipy"
-version = "1.13.0"
-requires_python = ">=3.9"
-summary = "Fundamental algorithms for scientific computing in Python"
-groups = ["default"]
-dependencies = [
- "numpy<2.3,>=1.22.4",
-]
-files = [
- {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
- {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
- {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
- {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
- {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
- {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
- {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
- {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
- {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
- {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
- {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
- {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
- {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
- {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
- {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
- {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
- {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
- {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
- {file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"},
- {file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"},
- {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"},
- {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"},
- {file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"},
- {file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"},
- {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
-]
-
-[[package]]
-name = "sentry-sdk"
-version = "2.2.1"
-requires_python = ">=3.6"
-summary = "Python client for Sentry (https://sentry.io)"
-groups = ["default"]
-dependencies = [
- "certifi",
- "urllib3>=1.26.11",
-]
-files = [
- {file = "sentry_sdk-2.2.1-py2.py3-none-any.whl", hash = "sha256:7d617a1b30e80c41f3b542347651fcf90bb0a36f3a398be58b4f06b79c8d85bc"},
- {file = "sentry_sdk-2.2.1.tar.gz", hash = "sha256:8aa2ec825724d8d9d645cab68e6034928b1a6a148503af3e361db3fa6401183f"},
-]
-
-[[package]]
-name = "setproctitle"
-version = "1.3.3"
-requires_python = ">=3.7"
-summary = "A Python module to customize the process title"
-groups = ["default"]
-files = [
- {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"},
- {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbbd6c7de0771c84b4aa30e70b409565eb1fc13627a723ca6be774ed6b9d9fa3"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c05ac48ef16ee013b8a326c63e4610e2430dbec037ec5c5b58fcced550382b74"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1342f4fdb37f89d3e3c1c0a59d6ddbedbde838fff5c51178a7982993d238fe4f"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc74e84fdfa96821580fb5e9c0b0777c1c4779434ce16d3d62a9c4d8c710df39"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9617b676b95adb412bb69645d5b077d664b6882bb0d37bfdafbbb1b999568d85"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6a249415f5bb88b5e9e8c4db47f609e0bf0e20a75e8d744ea787f3092ba1f2d0"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:38da436a0aaace9add67b999eb6abe4b84397edf4a78ec28f264e5b4c9d53cd5"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:da0d57edd4c95bf221b2ebbaa061e65b1788f1544977288bdf95831b6e44e44d"},
- {file = "setproctitle-1.3.3-cp310-cp310-win32.whl", hash = "sha256:a1fcac43918b836ace25f69b1dca8c9395253ad8152b625064415b1d2f9be4fb"},
- {file = "setproctitle-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:200620c3b15388d7f3f97e0ae26599c0c378fdf07ae9ac5a13616e933cbd2086"},
- {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:334f7ed39895d692f753a443102dd5fed180c571eb6a48b2a5b7f5b3564908c8"},
- {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:950f6476d56ff7817a8fed4ab207727fc5260af83481b2a4b125f32844df513a"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:195c961f54a09eb2acabbfc90c413955cf16c6e2f8caa2adbf2237d1019c7dd8"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f05e66746bf9fe6a3397ec246fe481096664a9c97eb3fea6004735a4daf867fd"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5901a31012a40ec913265b64e48c2a4059278d9f4e6be628441482dd13fb8b5"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64286f8a995f2cd934082b398fc63fca7d5ffe31f0e27e75b3ca6b4efda4e353"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:184239903bbc6b813b1a8fc86394dc6ca7d20e2ebe6f69f716bec301e4b0199d"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:664698ae0013f986118064b6676d7dcd28fefd0d7d5a5ae9497cbc10cba48fa5"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e5119a211c2e98ff18b9908ba62a3bd0e3fabb02a29277a7232a6fb4b2560aa0"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:417de6b2e214e837827067048f61841f5d7fc27926f2e43954567094051aff18"},
- {file = "setproctitle-1.3.3-cp311-cp311-win32.whl", hash = "sha256:6a143b31d758296dc2f440175f6c8e0b5301ced3b0f477b84ca43cdcf7f2f476"},
- {file = "setproctitle-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a680d62c399fa4b44899094027ec9a1bdaf6f31c650e44183b50d4c4d0ccc085"},
- {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4460795a8a7a391e3567b902ec5bdf6c60a47d791c3b1d27080fc203d11c9dc"},
- {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bdfd7254745bb737ca1384dee57e6523651892f0ea2a7344490e9caefcc35e64"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477d3da48e216d7fc04bddab67b0dcde633e19f484a146fd2a34bb0e9dbb4a1e"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ab2900d111e93aff5df9fddc64cf51ca4ef2c9f98702ce26524f1acc5a786ae7"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:088b9efc62d5aa5d6edf6cba1cf0c81f4488b5ce1c0342a8b67ae39d64001120"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6d50252377db62d6a0bb82cc898089916457f2db2041e1d03ce7fadd4a07381"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:87e668f9561fd3a457ba189edfc9e37709261287b52293c115ae3487a24b92f6"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:287490eb90e7a0ddd22e74c89a92cc922389daa95babc833c08cf80c84c4df0a"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:4fe1c49486109f72d502f8be569972e27f385fe632bd8895f4730df3c87d5ac8"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4a6ba2494a6449b1f477bd3e67935c2b7b0274f2f6dcd0f7c6aceae10c6c6ba3"},
- {file = "setproctitle-1.3.3-cp312-cp312-win32.whl", hash = "sha256:2df2b67e4b1d7498632e18c56722851ba4db5d6a0c91aaf0fd395111e51cdcf4"},
- {file = "setproctitle-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:f38d48abc121263f3b62943f84cbaede05749047e428409c2c199664feb6abc7"},
- {file = "setproctitle-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c7951820b77abe03d88b114b998867c0f99da03859e5ab2623d94690848d3e45"},
- {file = "setproctitle-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5bc94cf128676e8fac6503b37763adb378e2b6be1249d207630f83fc325d9b11"},
- {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f5d9027eeda64d353cf21a3ceb74bb1760bd534526c9214e19f052424b37e42"},
- {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e4a8104db15d3462e29d9946f26bed817a5b1d7a47eabca2d9dc2b995991503"},
- {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c32c41ace41f344d317399efff4cffb133e709cec2ef09c99e7a13e9f3b9483c"},
- {file = "setproctitle-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf16381c7bf7f963b58fb4daaa65684e10966ee14d26f5cc90f07049bfd8c1e"},
- {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e18b7bd0898398cc97ce2dfc83bb192a13a087ef6b2d5a8a36460311cb09e775"},
- {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:69d565d20efe527bd8a9b92e7f299ae5e73b6c0470f3719bd66f3cd821e0d5bd"},
- {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:ddedd300cd690a3b06e7eac90ed4452348b1348635777ce23d460d913b5b63c3"},
- {file = "setproctitle-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:415bfcfd01d1fbf5cbd75004599ef167a533395955305f42220a585f64036081"},
- {file = "setproctitle-1.3.3-cp39-cp39-win32.whl", hash = "sha256:21112fcd2195d48f25760f0eafa7a76510871bbb3b750219310cf88b04456ae3"},
- {file = "setproctitle-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:5a740f05d0968a5a17da3d676ce6afefebeeeb5ce137510901bf6306ba8ee002"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6b9e62ddb3db4b5205c0321dd69a406d8af9ee1693529d144e86bd43bcb4b6c0"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e3b99b338598de0bd6b2643bf8c343cf5ff70db3627af3ca427a5e1a1a90dd9"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ae9a02766dad331deb06855fb7a6ca15daea333b3967e214de12cfae8f0ef5"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:200ede6fd11233085ba9b764eb055a2a191fb4ffb950c68675ac53c874c22e20"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0d3a953c50776751e80fe755a380a64cb14d61e8762bd43041ab3f8cc436092f"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5e08e232b78ba3ac6bc0d23ce9e2bee8fad2be391b7e2da834fc9a45129eb87"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1da82c3e11284da4fcbf54957dafbf0655d2389cd3d54e4eaba636faf6d117a"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:aeaa71fb9568ebe9b911ddb490c644fbd2006e8c940f21cb9a1e9425bd709574"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:59335d000c6250c35989394661eb6287187854e94ac79ea22315469ee4f4c244"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3ba57029c9c50ecaf0c92bb127224cc2ea9fda057b5d99d3f348c9ec2855ad3"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d876d355c53d975c2ef9c4f2487c8f83dad6aeaaee1b6571453cb0ee992f55f6"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:224602f0939e6fb9d5dd881be1229d485f3257b540f8a900d4271a2c2aa4e5f4"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d7f27e0268af2d7503386e0e6be87fb9b6657afd96f5726b733837121146750d"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5e7266498cd31a4572378c61920af9f6b4676a73c299fce8ba93afd694f8ae7"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33c5609ad51cd99d388e55651b19148ea99727516132fb44680e1f28dd0d1de9"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:eae8988e78192fd1a3245a6f4f382390b61bce6cfcc93f3809726e4c885fa68d"},
- {file = "setproctitle-1.3.3.tar.gz", hash = "sha256:c913e151e7ea01567837ff037a23ca8740192880198b7fbb90b16d181607caae"},
-]
-
-[[package]]
-name = "setuptools"
-version = "70.0.0"
-requires_python = ">=3.8"
-summary = "Easily download, build, install, upgrade, and uninstall Python packages"
-groups = ["default", "dev"]
-files = [
- {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"},
- {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"},
-]
-
-[[package]]
-name = "shapely"
-version = "2.0.4"
-requires_python = ">=3.7"
-summary = "Manipulation and analysis of geometric objects"
-groups = ["default"]
-dependencies = [
- "numpy<3,>=1.14",
-]
-files = [
- {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:011b77153906030b795791f2fdfa2d68f1a8d7e40bce78b029782ade3afe4f2f"},
- {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9831816a5d34d5170aa9ed32a64982c3d6f4332e7ecfe62dc97767e163cb0b17"},
- {file = "shapely-2.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5c4849916f71dc44e19ed370421518c0d86cf73b26e8656192fcfcda08218fbd"},
- {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:841f93a0e31e4c64d62ea570d81c35de0f6cea224568b2430d832967536308e6"},
- {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b4431f522b277c79c34b65da128029a9955e4481462cbf7ebec23aab61fc58"},
- {file = "shapely-2.0.4-cp310-cp310-win32.whl", hash = "sha256:92a41d936f7d6743f343be265ace93b7c57f5b231e21b9605716f5a47c2879e7"},
- {file = "shapely-2.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:30982f79f21bb0ff7d7d4a4e531e3fcaa39b778584c2ce81a147f95be1cd58c9"},
- {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de0205cb21ad5ddaef607cda9a3191eadd1e7a62a756ea3a356369675230ac35"},
- {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7d56ce3e2a6a556b59a288771cf9d091470116867e578bebced8bfc4147fbfd7"},
- {file = "shapely-2.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:58b0ecc505bbe49a99551eea3f2e8a9b3b24b3edd2a4de1ac0dc17bc75c9ec07"},
- {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:790a168a808bd00ee42786b8ba883307c0e3684ebb292e0e20009588c426da47"},
- {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4310b5494271e18580d61022c0857eb85d30510d88606fa3b8314790df7f367d"},
- {file = "shapely-2.0.4-cp311-cp311-win32.whl", hash = "sha256:63f3a80daf4f867bd80f5c97fbe03314348ac1b3b70fb1c0ad255a69e3749879"},
- {file = "shapely-2.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:c52ed79f683f721b69a10fb9e3d940a468203f5054927215586c5d49a072de8d"},
- {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5bbd974193e2cc274312da16b189b38f5f128410f3377721cadb76b1e8ca5328"},
- {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:41388321a73ba1a84edd90d86ecc8bfed55e6a1e51882eafb019f45895ec0f65"},
- {file = "shapely-2.0.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0776c92d584f72f1e584d2e43cfc5542c2f3dd19d53f70df0900fda643f4bae6"},
- {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c75c98380b1ede1cae9a252c6dc247e6279403fae38c77060a5e6186c95073ac"},
- {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3e700abf4a37b7b8b90532fa6ed5c38a9bfc777098bc9fbae5ec8e618ac8f30"},
- {file = "shapely-2.0.4-cp312-cp312-win32.whl", hash = "sha256:4f2ab0faf8188b9f99e6a273b24b97662194160cc8ca17cf9d1fb6f18d7fb93f"},
- {file = "shapely-2.0.4-cp312-cp312-win_amd64.whl", hash = "sha256:03152442d311a5e85ac73b39680dd64a9892fa42bb08fd83b3bab4fe6999bfa0"},
- {file = "shapely-2.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9103abd1678cb1b5f7e8e1af565a652e036844166c91ec031eeb25c5ca8af0"},
- {file = "shapely-2.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:263bcf0c24d7a57c80991e64ab57cba7a3906e31d2e21b455f493d4aab534aaa"},
- {file = "shapely-2.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddf4a9bfaac643e62702ed662afc36f6abed2a88a21270e891038f9a19bc08fc"},
- {file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:485246fcdb93336105c29a5cfbff8a226949db37b7473c89caa26c9bae52a242"},
- {file = "shapely-2.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8de4578e838a9409b5b134a18ee820730e507b2d21700c14b71a2b0757396acc"},
- {file = "shapely-2.0.4-cp39-cp39-win32.whl", hash = "sha256:9dab4c98acfb5fb85f5a20548b5c0abe9b163ad3525ee28822ffecb5c40e724c"},
- {file = "shapely-2.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:31c19a668b5a1eadab82ff070b5a260478ac6ddad3a5b62295095174a8d26398"},
- {file = "shapely-2.0.4.tar.gz", hash = "sha256:5dc736127fac70009b8d309a0eeb74f3e08979e530cf7017f2f507ef62e6cfb8"},
-]
-
-[[package]]
-name = "six"
-version = "1.16.0"
-requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
-summary = "Python 2 and 3 compatibility utilities"
-groups = ["default"]
-files = [
- {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
- {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
-]
-
-[[package]]
-name = "smmap"
-version = "5.0.1"
-requires_python = ">=3.7"
-summary = "A pure Python implementation of a sliding window memory map manager"
-groups = ["default"]
-files = [
- {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"},
- {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"},
-]
-
-[[package]]
-name = "sympy"
-version = "1.12"
-requires_python = ">=3.8"
-summary = "Computer algebra system (CAS) in Python"
-groups = ["default"]
-dependencies = [
- "mpmath>=0.19",
-]
-files = [
- {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
- {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
-]
-
-[[package]]
-name = "tbb"
-version = "2021.12.0"
-summary = "Intel® oneAPI Threading Building Blocks (oneTBB)"
-groups = ["default"]
-marker = "platform_system == \"Windows\""
-files = [
- {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"},
- {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"},
- {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"},
- {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
-]
-
-[[package]]
-name = "tenacity"
-version = "8.3.0"
-requires_python = ">=3.8"
-summary = "Retry code until it succeeds"
-groups = ["default"]
-files = [
- {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"},
- {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"},
-]
-
-[[package]]
-name = "tomli"
-version = "2.0.1"
-requires_python = ">=3.7"
-summary = "A lil' TOML parser"
-groups = ["dev"]
-marker = "python_version < \"3.11\""
-files = [
- {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
- {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
-]
-
-[[package]]
-name = "torch"
-version = "2.3.0"
-requires_python = ">=3.8.0"
-summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
-groups = ["default"]
-dependencies = [
- "filelock",
- "fsspec",
- "jinja2",
- "mkl<=2021.4.0,>=2021.1.1; platform_system == \"Windows\"",
- "networkx",
- "nvidia-cublas-cu12==12.1.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cuda-cupti-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cuda-runtime-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cudnn-cu12==8.9.2.26; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cufft-cu12==11.0.2.54; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-curand-cu12==10.3.2.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cusolver-cu12==11.4.5.107; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cusparse-cu12==12.1.0.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-nccl-cu12==2.20.5; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-nvtx-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "sympy",
- "triton==2.3.0; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"",
- "typing-extensions>=4.8.0",
-]
-files = [
- {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"},
- {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"},
- {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"},
- {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"},
- {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"},
- {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"},
- {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"},
- {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"},
- {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"},
- {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"},
- {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"},
- {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"},
- {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"},
- {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"},
- {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"},
- {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"},
-]
-
-[[package]]
-name = "torchmetrics"
-version = "1.4.0.post0"
-requires_python = ">=3.8"
-summary = "PyTorch native Metrics"
-groups = ["default"]
-dependencies = [
- "lightning-utilities>=0.8.0",
- "numpy>1.20.0",
- "packaging>17.1",
- "torch>=1.10.0",
-]
-files = [
- {file = "torchmetrics-1.4.0.post0-py3-none-any.whl", hash = "sha256:ab234216598e3fbd8d62ee4541a0e74e7e8fc935d099683af5b8da50f745b3c8"},
- {file = "torchmetrics-1.4.0.post0.tar.gz", hash = "sha256:ab9bcfe80e65dbabbddb6cecd9be21f1f1d5207bb74051ef95260740f2762358"},
-]
-
-[[package]]
-name = "tqdm"
-version = "4.66.4"
-requires_python = ">=3.7"
-summary = "Fast, Extensible Progress Meter"
-groups = ["default"]
-dependencies = [
- "colorama; platform_system == \"Windows\"",
-]
-files = [
- {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"},
- {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"},
-]
-
-[[package]]
-name = "triton"
-version = "2.3.0"
-summary = "A language and compiler for custom Deep Learning operations"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""
-dependencies = [
- "filelock",
-]
-files = [
- {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"},
- {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"},
- {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"},
- {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"},
-]
-
-[[package]]
-name = "tueplots"
-version = "0.0.15"
-requires_python = ">=3.9"
-summary = "Scientific plotting made easy."
-groups = ["default"]
-dependencies = [
- "matplotlib",
- "numpy",
-]
-files = [
- {file = "tueplots-0.0.15-py3-none-any.whl", hash = "sha256:f63e020af88328c78618f3d912612c75c3c91d21004a88fd12cf79dbd9b6d78a"},
-]
-
-[[package]]
-name = "typing-extensions"
-version = "4.11.0"
-requires_python = ">=3.8"
-summary = "Backported and Experimental Type Hints for Python 3.8+"
-groups = ["default"]
-files = [
- {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
- {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
-]
-
-[[package]]
-name = "urllib3"
-version = "2.2.1"
-requires_python = ">=3.8"
-summary = "HTTP library with thread-safe connection pooling, file post, and more."
-groups = ["default"]
-files = [
- {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
- {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},
-]
-
-[[package]]
-name = "virtualenv"
-version = "20.26.2"
-requires_python = ">=3.7"
-summary = "Virtual Python Environment builder"
-groups = ["dev"]
-dependencies = [
- "distlib<1,>=0.3.7",
- "filelock<4,>=3.12.2",
- "platformdirs<5,>=3.9.1",
-]
-files = [
- {file = "virtualenv-20.26.2-py3-none-any.whl", hash = "sha256:a624db5e94f01ad993d476b9ee5346fdf7b9de43ccaee0e0197012dc838a0e9b"},
- {file = "virtualenv-20.26.2.tar.gz", hash = "sha256:82bf0f4eebbb78d36ddaee0283d43fe5736b53880b8a8cdcd37390a07ac3741c"},
-]
-
-[[package]]
-name = "wandb"
-version = "0.17.0"
-requires_python = ">=3.7"
-summary = "A CLI and library for interacting with the Weights & Biases API."
-groups = ["default"]
-dependencies = [
- "click!=8.0.0,>=7.1",
- "docker-pycreds>=0.4.0",
- "gitpython!=3.1.29,>=1.0.0",
- "platformdirs",
- "protobuf!=4.21.0,<5,>=3.15.0; python_version == \"3.9\" and sys_platform == \"linux\"",
- "protobuf!=4.21.0,<5,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"",
- "protobuf!=4.21.0,<5,>=3.19.0; sys_platform != \"linux\"",
- "psutil>=5.0.0",
- "pyyaml",
- "requests<3,>=2.0.0",
- "sentry-sdk>=1.0.0",
- "setproctitle",
- "setuptools",
- "typing-extensions; python_version < \"3.10\"",
-]
-files = [
- {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"},
- {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"},
- {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"},
- {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"},
- {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"},
- {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"},
- {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"},
-]
-
-[[package]]
-name = "yarl"
-version = "1.9.4"
-requires_python = ">=3.7"
-summary = "Yet another URL library"
-groups = ["default"]
-dependencies = [
- "idna>=2.0",
- "multidict>=4.0",
-]
-files = [
- {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"},
- {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"},
- {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"},
- {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"},
- {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"},
- {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"},
- {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"},
- {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"},
- {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"},
- {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"},
- {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"},
- {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"},
- {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"},
- {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"},
- {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"},
- {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"},
-]
-
-[[package]]
-name = "zipp"
-version = "3.18.2"
-requires_python = ">=3.8"
-summary = "Backport of pathlib-compatible object wrapper for zip files"
-groups = ["default"]
-marker = "python_version < \"3.10\""
-files = [
- {file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
- {file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
-]
diff --git a/pyproject.toml b/pyproject.toml
index 0a25868c..50ddca04 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,6 +22,7 @@ dependencies = [
"tueplots>=0.0.8",
"matplotlib>=3.7.0",
"plotly>=5.15.0",
+ "torch>=2.3.0",
]
requires-python = ">=3.9"
From b7609152fe4b12543859557da66831bfb0469c47 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 18:05:36 +0200
Subject: [PATCH 068/273] ensure exec in pdm venv
---
.github/workflows/ci-pdm-install-and-test.yml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index f13552b9..bdec19c3 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -24,9 +24,9 @@ jobs:
- name: Install torch (CPU)
run: |
- python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
+ pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# check that the CPU version is installed
- python -c "import torch; assert torch.__version__.endswith('+gpu')"
+ pdm run python -c "import torch; assert torch.__version__.endswith('+gpu')"
- name: Install package (including dev dependencies)
run: |
From 7797cef908a90400d29c63d5cf420887ef3e2d07 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 18:07:21 +0200
Subject: [PATCH 069/273] ensure exec in pdm venv
---
.github/workflows/ci-pdm-install-and-test.yml | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index bdec19c3..92194215 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -26,6 +26,10 @@ jobs:
run: |
pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# check that the CPU version is installed
+
+ - name: Print and check torch version
+ run: |
+ pdm run python -c "import torch; print(torch.__version__)"
pdm run python -c "import torch; assert torch.__version__.endswith('+gpu')"
- name: Install package (including dev dependencies)
From e68965001be4dde01305dfcd0daf6acb5967d837 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 18:09:05 +0200
Subject: [PATCH 070/273] check version #2
---
.github/workflows/ci-pdm-install-and-test.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index 92194215..0e2b1827 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -30,7 +30,7 @@ jobs:
- name: Print and check torch version
run: |
pdm run python -c "import torch; print(torch.__version__)"
- pdm run python -c "import torch; assert torch.__version__.endswith('+gpu')"
+ pdm run python -c "import torch; assert torch.__version__.endswith('+cpu')"
- name: Install package (including dev dependencies)
run: |
From fb8ef233d9ba78e4a0cdbe585645103cf9e840cd Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 18:10:57 +0200
Subject: [PATCH 071/273] check version no 3
---
.github/workflows/ci-pdm-install-and-test.yml | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index 0e2b1827..308cc49b 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -37,6 +37,11 @@ jobs:
pdm install
pdm install --dev
+ - name: Print and check torch version
+ run: |
+ pdm run python -c "import torch; print(torch.__version__)"
+ pdm run python -c "import torch; assert torch.__version__.endswith('+cpu')"
+
- name: Run tests
run: |
pdm run pytest
From 51b0a0bcf03d3950021ba31d790c53290058e437 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 18:15:34 +0200
Subject: [PATCH 072/273] check versions
---
.github/workflows/ci-pdm-install-and-test.yml | 5 -----
.github/workflows/ci-pip-install-and-test-gpu.yml | 5 +++++
.github/workflows/ci-pip-install-and-test.yml | 5 +++++
3 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index 308cc49b..69fc29d3 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -27,11 +27,6 @@ jobs:
pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# check that the CPU version is installed
- - name: Print and check torch version
- run: |
- pdm run python -c "import torch; print(torch.__version__)"
- pdm run python -c "import torch; assert torch.__version__.endswith('+cpu')"
-
- name: Install package (including dev dependencies)
run: |
pdm install
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index 2cc168f0..b05b2ecf 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -22,6 +22,11 @@ jobs:
python -m pip install .
python -m pip install pytest
+ - name: Print and check torch version
+ run: |
+ python -c "import torch; print(torch.__version__)"
+ python -c "import torch; assert not torch.__version__.endswith('+cpu')"
+
- name: Run tests
run: |
python -m pytest
diff --git a/.github/workflows/ci-pip-install-and-test.yml b/.github/workflows/ci-pip-install-and-test.yml
index 66ac95ac..b7d0afee 100644
--- a/.github/workflows/ci-pip-install-and-test.yml
+++ b/.github/workflows/ci-pip-install-and-test.yml
@@ -22,6 +22,11 @@ jobs:
python -m pip install .
python -m pip install pytest
+ - name: Print and check torch version
+ run: |
+ pdm run python -c "import torch; print(torch.__version__)"
+ pdm run python -c "import torch; assert torch.__version__.endswith('+cpu')"
+
- name: Run tests
run: |
python -m pytest
From 8fa3ca70148176a5a46a93eac3945bfa4a6bca94 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Mon, 3 Jun 2024 20:51:27 +0200
Subject: [PATCH 073/273] Introduced datetime forcing calculation as seperate
script
---
create_forcings.py | 78 +++++++++++++++++++++++++++++++++++++
neural_lam/data_config.yaml | 11 ++++++
2 files changed, 89 insertions(+)
create mode 100644 create_forcings.py
diff --git a/create_forcings.py b/create_forcings.py
new file mode 100644
index 00000000..deb36994
--- /dev/null
+++ b/create_forcings.py
@@ -0,0 +1,78 @@
+# Standard library
+import argparse
+
+# Third-party
+import numpy as np
+import pandas as pd
+import xarray as xr
+
+# First-party
+from neural_lam import config
+
+
+def get_seconds_in_year(year):
+ start_of_year = pd.Timestamp(f"{year}-01-01")
+ start_of_next_year = pd.Timestamp(f"{year + 1}-01-01")
+ return (start_of_next_year - start_of_year).total_seconds()
+
+
+def calculate_datetime_forcing(timesteps):
+ hours_of_day = xr.DataArray(timesteps.dt.hour, dims=["time"])
+ seconds_into_year = xr.DataArray(
+ [
+ (
+ pd.Timestamp(dt_obj)
+ - pd.Timestamp(f"{pd.Timestamp(dt_obj).year}-01-01")
+ ).total_seconds()
+ for dt_obj in timesteps.values
+ ],
+ dims=["time"],
+ )
+ year_seconds = xr.DataArray(
+ [
+ get_seconds_in_year(pd.Timestamp(dt_obj).year)
+ for dt_obj in timesteps.values
+ ],
+ dims=["time"],
+ )
+ hour_angle = (hours_of_day / 12) * np.pi
+ year_angle = (seconds_into_year / year_seconds) * 2 * np.pi
+ datetime_forcing = xr.Dataset(
+ {
+ "hour_sin": np.sin(hour_angle),
+ "hour_cos": np.cos(hour_angle),
+ "year_sin": np.sin(year_angle),
+ "year_cos": np.cos(year_angle),
+ },
+ coords={"time": timesteps},
+ )
+ datetime_forcing = (datetime_forcing + 1) / 2
+ return datetime_forcing
+
+
+def main():
+ """Main function for creating the datetime forcing and boundary mask."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_config", type=str, default="neural_lam/data_config.yaml"
+ )
+ parser.add_argument("--zarr_path", type=str, default="forcings.zarr")
+ args = parser.parse_args()
+
+ config_loader = config.Config.from_file(args.data_config)
+ dataset = config_loader.open_zarr("state")
+ datetime_forcing = calculate_datetime_forcing(timesteps=dataset.time)
+
+ # Expand dimensions to match the target dataset
+ datetime_forcing_expanded = datetime_forcing.expand_dims(
+ {"y": dataset.y, "x": dataset.x}
+ )
+
+ datetime_forcing_expanded.to_zarr(args.zarr_path, mode="w")
+ print(f"Datetime forcing saved to {args.zarr_path}")
+
+ dataset
+
+
+if __name__ == "__main__":
+ main()
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 2f7261c0..4f7de3f4 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -51,11 +51,22 @@ forcing:
lat_lon_names:
lon: lon
lat: lat
+ - path: "forcings.zarr"
+ dims:
+ time: time
+ level: null
+ x: x
+ y: y
+ grid: null
surface_vars:
- cape_column # just as a technical test
- icei0m
- vis0m
- xhail0m
+ - hour_cos
+ - hour_sin
+ - year_cos
+ - year_sin
surface_units:
- J/kg
- kg/m^2 # just as a technical test :)
From a748903b2c11244ca337a99210e8d3eb0f3eaab3 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Mon, 3 Jun 2024 20:52:57 +0200
Subject: [PATCH 074/273] Fixed order of y and x dims to adhere to #52
---
create_mesh.py | 2 +-
neural_lam/config.py | 11 +++++------
plot_graph.py | 2 +-
3 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/create_mesh.py b/create_mesh.py
index 8b547166..36b9b0b5 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -197,7 +197,7 @@ def main():
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)
- xy = config_loader.get_xy("static")
+ xy = config_loader.get_xy("static") # (2, N_y, N_x)
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
diff --git a/neural_lam/config.py b/neural_lam/config.py
index d411bd1e..20185563 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -113,7 +113,6 @@ def open_zarr(self, category):
return None
def stack_grid(self, dataset):
- """Stack the grid dimensions of the dataset."""
if dataset is None:
return None
dims = dataset.to_array().dims
@@ -124,7 +123,7 @@ def stack_grid(self, dataset):
else:
if "x" not in dims or "y" not in dims:
self.rename_dataset_dims_and_vars(dataset=dataset)
- dataset = dataset.squeeze().stack(grid=("x", "y"))
+ dataset = dataset.squeeze().stack(grid=("y", "x"))
return dataset
def convert_dataset_to_dataarray(self, dataset):
@@ -201,7 +200,7 @@ def filter_dimensions(self, dataset, transpose_array=True):
return dataset
def reshape_grid_to_2d(self, dataset, grid_shape=None):
- """Reshape the grid to 2D."""
+ """Reshape the grid to 2D for stacked data without multi-index."""
if grid_shape is None:
grid_shape = dict(self.grid_shape_state.values.items())
x_dim, y_dim = (grid_shape["x"], grid_shape["y"])
@@ -209,7 +208,7 @@ def reshape_grid_to_2d(self, dataset, grid_shape=None):
x_coords = np.arange(x_dim)
y_coords = np.arange(y_dim)
multi_index = pd.MultiIndex.from_product(
- [x_coords, y_coords], names=["x", "y"]
+ [y_coords, x_coords], names=["y", "x"]
)
mindex_coords = xr.Coordinates.from_pandas_multiindex(
@@ -227,8 +226,8 @@ def get_xy(self, category):
dataset = self.open_zarr(category)
x, y = dataset.x.values, dataset.y.values
if x.ndim == 1:
- x, y = np.meshgrid(y, x)
- xy = np.stack((x, y), axis=0)
+ x, y = np.meshgrid(x, y)
+ xy = np.stack((x, y), axis=0) # (2, N_y, N_x)
return xy
diff --git a/plot_graph.py b/plot_graph.py
index 2c3f6238..dc3682ff 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -45,7 +45,7 @@ def main():
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
- xy = config_loader.get_xy("state") # (2, N_x, N_y)
+ xy = config_loader.get_xy("state") # (2, N_y, N_x)
xy = xy.reshape(2, -1).T # (N_grid, 2)
pos_max = np.max(np.abs(xy))
grid_pos = xy / pos_max # Divide by maximum coordinate
From 70425eea56e88ecc3f54d98256c78a476881b548 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 20:51:42 +0100
Subject: [PATCH 075/273] fix for pip install
---
.github/workflows/ci-pip-install-and-test.yml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/ci-pip-install-and-test.yml b/.github/workflows/ci-pip-install-and-test.yml
index b7d0afee..307f1829 100644
--- a/.github/workflows/ci-pip-install-and-test.yml
+++ b/.github/workflows/ci-pip-install-and-test.yml
@@ -24,8 +24,8 @@ jobs:
- name: Print and check torch version
run: |
- pdm run python -c "import torch; print(torch.__version__)"
- pdm run python -c "import torch; assert torch.__version__.endswith('+cpu')"
+ python -c "import torch; print(torch.__version__)"
+ python -c "import torch; assert torch.__version__.endswith('+cpu')"
- name: Run tests
run: |
From 60110f611ee56895bb570b654b9254100032720c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 20:57:25 +0100
Subject: [PATCH 076/273] switch cirun instance type
---
.cirun.yml | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/.cirun.yml b/.cirun.yml
index 79d62f22..734e6786 100644
--- a/.cirun.yml
+++ b/.cirun.yml
@@ -4,7 +4,8 @@ runners:
# Cloud Provider: AWS
cloud: "aws"
# https://aws.amazon.com/ec2/instance-types/g4/
- instance_type: "g4dn.xlarge"
+ # instance_type: "g4dn.xlarge"
+ instance_type: "t2.nano"
# Ubuntu-20.4, ami image
machine_image: "ami-06fd8a495a537da8b"
preemptible: false
From 6fff3fc90438587f090b21eefdf5844e88e32265 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 3 Jun 2024 21:02:13 +0100
Subject: [PATCH 077/273] install py39 on cirun runner
---
.github/workflows/ci-pip-install-and-test-gpu.yml | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index b05b2ecf..dab7b060 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -13,6 +13,11 @@ jobs:
- name: Checkout
uses: actions/checkout@v2
+ - name: Set up Python 3.9
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
- name: Install torch (GPU CUDA 12.1)
run: |
python -m pip install torch --index-url https://download.pytorch.org/whl/cu121
From 74b4a101c94ad11244f859653605277a23a947cf Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 4 Jun 2024 06:24:34 +0200
Subject: [PATCH 078/273] cleanup: boundary_mask, zarr-opening, utils
---
create_forcings.py | 2 +-
neural_lam/config.py | 25 +++++++-----
neural_lam/data_config.yaml | 8 +++-
neural_lam/models/ar_model.py | 55 +++++--------------------
neural_lam/utils.py | 75 -----------------------------------
neural_lam/vis.py | 2 +-
train_model.py | 8 ----
7 files changed, 32 insertions(+), 143 deletions(-)
diff --git a/create_forcings.py b/create_forcings.py
index deb36994..459a3982 100644
--- a/create_forcings.py
+++ b/create_forcings.py
@@ -60,7 +60,7 @@ def main():
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
- dataset = config_loader.open_zarr("state")
+ dataset = config_loader.open_zarrs("state")
datetime_forcing = calculate_datetime_forcing(timesteps=dataset.time)
# Expand dimensions to match the target dataset
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 20185563..21a97018 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -95,7 +95,7 @@ def num_data_vars(self, category):
return surface_vars_count + atmosphere_vars_count * levels_count
- def open_zarr(self, category):
+ def open_zarrs(self, category):
"""Open the zarr dataset for the given category."""
zarr_configs = self.values[category]["zarrs"]
@@ -103,7 +103,7 @@ def open_zarr(self, category):
datasets = []
for config in zarr_configs:
dataset_path = config["path"]
- dataset = xr.open_zarr(dataset_path, consolidated=True)
+ dataset = xr.open_zarrs(dataset_path, consolidated=True)
datasets.append(dataset)
merged_dataset = xr.merge(datasets)
merged_dataset.attrs["category"] = category
@@ -223,7 +223,7 @@ def reshape_grid_to_2d(self, dataset, grid_shape=None):
@functools.lru_cache()
def get_xy(self, category):
"""Return the x, y coordinates of the dataset."""
- dataset = self.open_zarr(category)
+ dataset = self.open_zarrs(category)
x, y = dataset.x.values, dataset.y.values
if x.ndim == 1:
x, y = np.meshgrid(x, y)
@@ -244,7 +244,7 @@ def load_normalization_stats(self, category, datatype="torch"):
f"{stats_path}"
)
return None
- stats = xr.open_zarr(stats_path, consolidated=True)
+ stats = xr.open_zarrs(stats_path, consolidated=True)
if i == 0:
combined_stats = stats
else:
@@ -294,7 +294,7 @@ def load_normalization_stats(self, category, datatype="torch"):
# def assign_lat_lon_coords(self, category, dataset=None):
# """Process the latitude and longitude names of the dataset."""
# if dataset is None:
- # dataset = self.open_zarr(category)
+ # dataset = self.open_zarrs(category)
# lat_lon_names = {}
# for zarr_config in self.values[category]["zarrs"]:
# lat_lon_names.update(zarr_config["lat_lon_names"])
@@ -311,7 +311,7 @@ def load_normalization_stats(self, category, datatype="torch"):
def extract_vars(self, category, dataset=None):
"""Extract the variables from the dataset."""
if dataset is None:
- dataset = self.open_zarr(category)
+ dataset = self.open_zarrs(category)
surface_vars = (
dataset[self[category].surface_vars]
if self[category].surface_vars
@@ -354,7 +354,7 @@ def rename_dataset_dims_and_vars(self, category, dataset=None):
"""Rename the dimensions and variables of the dataset."""
convert = False
if dataset is None:
- dataset = self.open_zarr(category)
+ dataset = self.open_zarrs(category)
elif isinstance(dataset, xr.DataArray):
convert = True
dataset = dataset.to_dataset("variable")
@@ -387,7 +387,7 @@ def filter_dataset_by_time(self, dataset, split="train"):
def process_dataset(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category."""
- dataset = self.open_zarr(category)
+ dataset = self.open_zarrs(category)
dataset = self.extract_vars(category, dataset)
dataset = self.filter_dataset_by_time(dataset, split)
dataset = self.stack_grid(dataset)
@@ -402,8 +402,8 @@ def process_dataset(self, category, split="train", apply_windowing=True):
def apply_window(self, category, dataset=None):
"""Apply the forcing window to the forcing dataset."""
if dataset is None:
- dataset = self.open_zarr(category)
- state_time = self.open_zarr("state").time.values
+ dataset = self.open_zarrs(category)
+ state_time = self.open_zarrs("state").time.values
window = self[category].window
dataset = (
dataset.sel(time=state_time, method="nearest")
@@ -413,3 +413,8 @@ def apply_window(self, category, dataset=None):
.stack(variable_window=("variable", "window"))
)
return dataset
+
+ def load_boundary_mask(self):
+ """Load the boundary mask for the dataset."""
+ boundary_mask = xr.open_zarr(self.values["boundary"]["mask"]["path"])
+ return boundary_mask.to_array().values
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 4f7de3f4..bffedb77 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -95,7 +95,7 @@ static:
atmosphere_units: null
levels: null
boundary:
- zarrs:
+ zarrs: # This is not used currently, but soon ERA% boundaries will be used
- path: "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
dims:
time: time
@@ -105,7 +105,11 @@ boundary:
lat_lon_names:
lon: longitude
lat: latitude
- mask: boundary_mask
+ mask:
+ path: "boundary_mask.zarr"
+ dims:
+ x: x
+ y: y
window: 3
utilities:
normalization:
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index dd352c22..391f6b02 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -36,14 +36,10 @@ def __init__(self, args):
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
+ self.grid_output_dim = self.config_loader.num_data_vars("state")
if self.output_std:
# Pred. dim. in grid cell
- self.grid_output_dim = 2 * self.config_loader.num_data_vars(
- "state"
- )
- else:
- # Pred. dim. in grid cell
- self.grid_output_dim = self.config_loader.num_data_vars("state")
+ self.grid_output_dim = 2 * self.grid_output_dim
# grid_dim from data + static
(
@@ -51,7 +47,7 @@ def __init__(self, args):
grid_static_dim,
) = self.grid_static_features.shape
self.grid_dim = (
- 2 * self.config_loader.num_data_vars("state")
+ 2 * self.grid_output_dim
+ grid_static_dim
+ self.config_loader.num_data_vars("forcing")
* self.config_loader.forcing.window
@@ -60,14 +56,15 @@ def __init__(self, args):
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
- border_mask = torch.zeros(self.num_grid_nodes, 1)
- self.register_buffer("border_mask", border_mask, persistent=False)
+ boundary_mask = self.config_loader.load_boundary_mask()
+ self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
- "interior_mask", 1.0 - self.border_mask, persistent=False
+ "interior_mask", 1.0 - self.boundary_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border
- self.step_length = args.step_length # Number of hours per pred. step
+ # Number of hours per pred. step
+ self.step_length = self.config_loader.step_length
self.val_metrics = {
"mse": [],
}
@@ -88,21 +85,6 @@ def __init__(self, args):
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []
- # Load normalization statistics
- self.normalization_stats = (
- self.config_loader.load_normalization_stats()
- )
- if self.normalization_stats is not None:
- for (
- var_name,
- var_data,
- ) in self.normalization_stats.data_vars.items():
- self.register_buffer(
- f"{var_name}",
- torch.tensor(var_data.values),
- persistent=False,
- )
-
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
@@ -157,7 +139,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
# Overwrite border with true state
new_state = (
- self.border_mask * border_state
+ self.boundary_mask * border_state
+ self.interior_mask * pred_state
)
@@ -203,25 +185,6 @@ def common_step(self, batch):
return prediction, target_states, pred_std
- def on_after_batch_transfer(self, batch, dataloader_idx):
- """Normalize Batch data after transferring to the device."""
- if self.normalization_stats is not None:
- init_states, target_states, forcing_features, _, _ = batch
- init_states = (init_states - self.mean) / self.std
- target_states = (target_states - self.mean) / self.std
- forcing_features = (
- forcing_features - self.forcing_mean
- ) / self.forcing_std
- # boundary_features = ( boundary_features - self.boundary_mean ) /
- # self.boundary_std
- batch = (
- init_states,
- target_states,
- forcing_features,
- # boundary_features,
- )
- return batch
-
def training_step(self, batch):
"""
Train on single batch
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 59bd31e6..f7ecafb3 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -2,86 +2,11 @@
import os
# Third-party
-import numpy as np
import torch
from torch import nn
from tueplots import bundles, figsizes
-def load_dataset_stats(dataset_name, device="cpu"):
- """
- Load arrays with stored dataset statistics from pre-processing
- """
- static_dir_path = os.path.join("data", dataset_name, "static")
-
- def loads_file(fn):
- return torch.load(
- os.path.join(static_dir_path, fn), map_location=device
- )
-
- data_mean = loads_file("parameter_mean.pt") # (d_features,)
- data_std = loads_file("parameter_std.pt") # (d_features,)
-
- flux_stats = loads_file("flux_stats.pt") # (2,)
- flux_mean, flux_std = flux_stats
-
- return {
- "data_mean": data_mean,
- "data_std": data_std,
- "flux_mean": flux_mean,
- "flux_std": flux_std,
- }
-
-
-def load_static_data(dataset_name, device="cpu"):
- """
- Load static files related to dataset
- """
- static_dir_path = os.path.join("data", dataset_name, "static")
-
- def loads_file(fn):
- return torch.load(
- os.path.join(static_dir_path, fn), map_location=device
- )
-
- # Load border mask, 1. if node is part of border, else 0.
- border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy"))
- border_mask = (
- torch.tensor(border_mask_np, dtype=torch.float32, device=device)
- .flatten(0, 1)
- .unsqueeze(1)
- ) # (N_grid, 1)
-
- grid_static_features = loads_file(
- "grid_features.pt"
- ) # (N_grid, d_grid_static)
-
- # Load step diff stats
- step_diff_mean = loads_file("diff_mean.pt") # (d_f,)
- step_diff_std = loads_file("diff_std.pt") # (d_f,)
-
- # Load parameter std for computing validation errors in original data scale
- data_mean = loads_file("parameter_mean.pt") # (d_features,)
- data_std = loads_file("parameter_std.pt") # (d_features,)
-
- # Load loss weighting vectors
- param_weights = torch.tensor(
- np.load(os.path.join(static_dir_path, "parameter_weights.npy")),
- dtype=torch.float32,
- device=device,
- ) # (d_f,)
-
- return {
- "border_mask": border_mask,
- "grid_static_features": grid_static_features,
- "step_diff_mean": step_diff_mean,
- "step_diff_std": step_diff_std,
- "data_mean": data_mean,
- "data_std": data_std,
- "param_weights": param_weights,
- }
-
-
class BufferList(nn.Module):
"""
A list of torch buffer tensors that sit together as a Module with no
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 8c9ca77c..ca77e24e 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -8,7 +8,7 @@
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_error_map(errors, data_config, title=None, step_length=3):
+def plot_error_map(errors, data_config, title=None, step_length=1):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
diff --git a/train_model.py b/train_model.py
index 1b985ef0..fb4d8a5b 100644
--- a/train_model.py
+++ b/train_model.py
@@ -149,13 +149,6 @@ def main():
default="wmse",
help="Loss function to use, see metric.py (default: wmse)",
)
- parser.add_argument(
- "--step_length",
- type=int,
- default=1,
- help="Step length in hours to consider single time step 1-3 "
- "(default: 1)",
- )
parser.add_argument(
"--lr", type=float, default=1e-3, help="learning rate (default: 0.001)"
)
@@ -222,7 +215,6 @@ def main():
# Asserts for arguments
assert args.model in MODELS, f"Unknown model: {args.model}"
- assert args.step_length <= 3, "Too high step length"
assert args.eval in (
None,
"val",
From 8054e9e8dafd8b4c58b13648df7e079587a74e81 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 4 Jun 2024 16:09:29 +0100
Subject: [PATCH 079/273] change ami image to gpu
---
.cirun.yml | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/.cirun.yml b/.cirun.yml
index 734e6786..79d62f22 100644
--- a/.cirun.yml
+++ b/.cirun.yml
@@ -4,8 +4,7 @@ runners:
# Cloud Provider: AWS
cloud: "aws"
# https://aws.amazon.com/ec2/instance-types/g4/
- # instance_type: "g4dn.xlarge"
- instance_type: "t2.nano"
+ instance_type: "g4dn.xlarge"
# Ubuntu-20.4, ami image
machine_image: "ami-06fd8a495a537da8b"
preemptible: false
From 97aeb2e67ecefcb6488853938f2195098b671422 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 4 Jun 2024 16:13:39 +0100
Subject: [PATCH 080/273] use cheaper gpu instance
---
.cirun.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.cirun.yml b/.cirun.yml
index 79d62f22..b188d6dc 100644
--- a/.cirun.yml
+++ b/.cirun.yml
@@ -4,7 +4,7 @@ runners:
# Cloud Provider: AWS
cloud: "aws"
# https://aws.amazon.com/ec2/instance-types/g4/
- instance_type: "g4dn.xlarge"
+ instance_type: "g4ad.xlarge"
# Ubuntu-20.4, ami image
machine_image: "ami-06fd8a495a537da8b"
preemptible: false
From 425123c1938f8cdbf9d68ce7a2398a07f059a756 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 4 Jun 2024 17:41:22 +0200
Subject: [PATCH 081/273] adapted tests for zarr-analysis data
---
tests/test_analysis_dataset.py | 84 +++++++++++++++++++
...m_dataset.py => test_forecast_dataset.py_} | 0
2 files changed, 84 insertions(+)
create mode 100644 tests/test_analysis_dataset.py
rename tests/{test_mllam_dataset.py => test_forecast_dataset.py_} (100%)
diff --git a/tests/test_analysis_dataset.py b/tests/test_analysis_dataset.py
new file mode 100644
index 00000000..546921aa
--- /dev/null
+++ b/tests/test_analysis_dataset.py
@@ -0,0 +1,84 @@
+# Standard library
+import os
+
+# First-party
+from create_mesh import main as create_mesh
+from neural_lam.config import Config
+from neural_lam.weather_dataset import WeatherDataset
+from train_model import main as train_model
+
+# Disable weights and biases to avoid unnecessary logging
+# and to avoid having to deal with authentication
+os.environ["WANDB_DISABLED"] = "true"
+
+
+def test_load_analysis_dataset():
+ # The data_config.yaml file is downloaded and extracted in
+ # test_retrieve_data_ewc together with the dataset itself
+ data_config_file = "tests/data_config.yaml"
+ config = Config.from_file(data_config_file)
+
+ var_state_names = config.vars_names("state")
+ var_state_units = config.vars_units("state")
+ num_state_vars = config.num_data_vars("state")
+
+ assert len(var_state_names) == len(var_state_units) == num_state_vars
+
+ var_forcing_names = config.vars_names("forcing")
+ var_forcing_units = config.vars_units("forcing")
+ num_forcing_vars = config.num_data_vars("forcing")
+
+ assert len(var_forcing_names) == len(var_forcing_units) == num_forcing_vars
+
+ # Assert dataset can be loaded
+ ds = config.open_zarrs("state")
+ grid = ds.sizes["y"] * ds.sizes["x"]
+ dataset = WeatherDataset(split="train", ar_steps=3, standardize=False)
+ batch = dataset[0]
+ # return init_states, target_states, forcing, batch_times
+ # init_states: (2, N_grid, d_features)
+ # target_states: (ar_steps-2, N_grid, d_features)
+ # forcing: (ar_steps-2, N_grid, d_windowed_forcing)
+ # batch_times: (ar_steps-2,)
+ assert list(batch[0].shape) == [2, grid, num_state_vars]
+ assert list(batch[1].shape) == [dataset.ar_steps - 2, grid, num_state_vars]
+ assert list(batch[2].shape) == [
+ dataset.ar_steps - 2,
+ grid,
+ num_forcing_vars * config.forcing.window,
+ ]
+ assert isinstance(batch[3], list)
+
+ # Assert provided grid-shapes
+ assert config.get_xy("static")[0].shape == (
+ config.grid_shape_state.y,
+ config.grid_shape_state.x,
+ )
+ assert config.get_xy("static")[0].shape == (ds.sizes["y"], ds.sizes["x"])
+
+
+def test_create_graph_analysis_dataset():
+ args = [
+ "--graph=hierarchical",
+ "--hierarchical=1",
+ "--data_config=tests/data_config.yaml",
+ "--levels=2",
+ ]
+ create_mesh(args)
+
+
+def test_train_model_analysis_dataset():
+ args = [
+ "--model=hi_lam",
+ "--data_config=tests/data_config.yaml",
+ "--num_workers=4",
+ "--epochs=1",
+ "--graph=hierarchical",
+ "--hidden_dim=16",
+ "--hidden_layers=1",
+ "--processor_layers=1",
+ "--ar_steps_eval=1",
+ "--eval=val",
+ "--n_example_pred=0",
+ ]
+ train_model(args)
diff --git a/tests/test_mllam_dataset.py b/tests/test_forecast_dataset.py_
similarity index 100%
rename from tests/test_mllam_dataset.py
rename to tests/test_forecast_dataset.py_
From 4dcf6718e7a8b0f1cdc6595cd0c9299ba0a39f2a Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 4 Jun 2024 17:41:41 +0200
Subject: [PATCH 082/273] Readme adapted for yaml zarr analysis workflow
---
README.md | 166 ++++++++++++++++++++++++++++++++++--------------------
1 file changed, 105 insertions(+), 61 deletions(-)
diff --git a/README.md b/README.md
index 1bdc6602..272cd8a9 100644
--- a/README.md
+++ b/README.md
@@ -45,14 +45,6 @@ Still, some restrictions are inevitable:
-## A note on the limited area setting
-Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
-There are still some parts of the code that is quite specific for the MEPS area use case.
-This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ).
-If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
-We would be happy to support such enhancements.
-See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
-
# Using Neural-LAM
Below follows instructions on how to use Neural-LAM to train and evaluate models.
@@ -74,26 +66,25 @@ pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 torch-clust
You will have to adjust the `CUDA` variable to match the CUDA version on your system or to run on CPU. See the [installation webpage](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) for more information.
## Data
-Datasets should be stored in a directory called `data`.
-See the [repository format section](#format-of-data-directory) for details on the directory structure.
-
-The full MEPS dataset can be shared with other researchers on request, contact us for this.
-A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX).
-Download the file and unzip in the neural-lam directory.
-All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `create_mesh.py`).
-Note that this is far too little data to train any useful models, but all scripts can be ran with it.
-It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues.
+The repository is set up to work with `yaml` configuration files. These files are used to specify the dataset properties and location. An example of a dataset configuration file is stored in `neural_lam/data_config.yaml` and outlined below.
## Pre-processing
An overview of how the different scripts and files depend on each other is given in this figure:
-In order to start training models at least three pre-processing scripts have to be ran:
+In order to start training models at least one pre-processing script has to be ran:
* `create_mesh.py`
-* `create_grid_features.py`
-* `create_parameter_weights.py`
+
+If not provided directly by the user, the following scripts also has to be ran:
+
+* `calculate_statistics.py`
+* `create_boundary_mask.py`
+
+The following script is optional, but can be used to create additional features:
+
+* `create_forcing.py`
### Create graph
Run `create_mesh.py` with suitable options to generate the graph you want to use (see `python create_mesh.py --help` for a list of options).
@@ -105,9 +96,6 @@ The graphs used for the different models in the [paper](https://arxiv.org/abs/23
The graph-related files are stored in a directory called `graphs`.
-### Create remaining static features
-To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`.
-
## Weights & Biases Integration
The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it.
When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface.
@@ -129,7 +117,7 @@ Models can be trained using `train_model.py`.
Run `python train_model.py --help` for a full list of training options.
A few of the key ones are outlined below:
-* `--dataset`: Which data to train on
+* `--data_config`: Path to the data configuration file
* `--model`: Which model to train
* `--graph`: Which graph to use with the model
* `--processor_layers`: Number of GNN layers to use in the processing part of the model
@@ -186,46 +174,102 @@ Some options specifically important for evaluation are:
# Repository Structure
Except for training and pre-processing scripts all the source code can be found in the `neural_lam` directory.
Model classes, including abstract base classes, are located in `neural_lam/models`.
+Notebooks for visualization and analysis are located in `docs`.
+
## Format of data directory
-It is possible to store multiple datasets in the `data` directory.
-Each dataset contains a set of files with static features and a set of samples.
-The samples are split into different sub-directories for training, validation and testing.
-The directory structure is shown with examples below.
-Script names within parenthesis denote the script used to generate the file.
+The new workflow uses YAML configuration files to specify dataset properties and locations.
+Below is an example of how to structure your data directory and a condensed version of the YAML configuration file. The community decided for now, that a zarr-based approach is the most flexible and efficient way to store the data. Please make sure that your dataset is stored as zarr, contains the necessary dimensions, and is structured as described below. For optimal performance chunking the dataset along the time dimension only is recommended.
```
-data
-├── dataset1
-│ ├── samples - Directory with data samples
-│ │ ├── train - Training data
-│ │ │ ├── nwp_2022040100_mbr000.npy - A time series sample
-│ │ │ ├── nwp_2022040100_mbr001.npy
-│ │ │ ├── ...
-│ │ │ ├── nwp_2022043012_mbr001.npy
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy - Solar flux forcing
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy
-│ │ │ ├── ...
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022043012.npy
-│ │ │ ├── wtr_2022040100.npy - Open water features for one sample
-│ │ │ ├── wtr_2022040112.npy
-│ │ │ ├── ...
-│ │ │ └── wtr_202204012.npy
-│ │ ├── val - Validation data
-│ │ └── test - Test data
-│ └── static - Directory with graph information and static features
-│ ├── nwp_xy.npy - Coordinates of grid nodes (part of dataset)
-│ ├── surface_geopotential.npy - Geopotential at surface of grid nodes (part of dataset)
-│ ├── border_mask.npy - Mask with True for grid nodes that are part of border (part of dataset)
-│ ├── grid_features.pt - Static features of grid nodes (create_grid_features.py)
-│ ├── parameter_mean.pt - Means of state parameters (create_parameter_weights.py)
-│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py)
-│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py)
-│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py)
-│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py)
-│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py)
-├── dataset2
-├── ...
-└── datasetN
+name: danra
+state: # State variables vary in time and are predicted by the model
+ zarrs:
+ - path: # Path to the zarr file
+ dims: # Only the following dimensions will be mapped: time, level, x, y, grid
+ time: time # Required
+ level: null # Optional
+ x: x # Either x and y or grid must be specified
+ y: y
+ grid: null # Grid has precedence over x and y
+ lat_lon_names: # Required to map grid- projection to lat/lon
+ lon: lon
+ lat: lat
+ - path:
+ ... # Additional zarr files are allowed
+ surface_vars: # Single level variables to include in the state (in this order)
+ - var1
+ - var2
+ surface_units: # Units for the surface variables
+ - unit1
+ - unit2
+ atmosphere_vars: # Multi-level variables to include in the state (in this order)
+ - var1
+ ...
+ atmosphere_units: # Units for the atmosphere variables
+ - unit1
+ ...
+ levels: # Selection of vertical levels to include in the state (pressure/height/model level)
+ - 100
+ - 200
+ ...
+forcing: # Forcing variables vary in time but are not predicted by the model
+ ... # Same structure as state, multiple zarr files allowed
+ window: 3 # Number of time steps to use for forcing (odd number)
+static: # Static variables are not predicted by the model and do not vary in time
+ zarrs:
+ ...
+ dims: # Same structure as state but no "time" dimension
+ level: null
+ x: x
+ y: y
+ grid: null
+ ...
+boundary: # Boundary variables are not predicted by the model and do not vary in time
+ ... # They are used to inform the model about the surrounding weather conditions
+ ... # The boundaries are often used from a separate model, specified identically to the state
+ mask: # Boundary mask to indicate where the model should not make predictions
+ path: "boundary_mask.zarr"
+ dims:
+ x: x
+ y: y
+ window: 3 # Windowing of the boundary variables (odd number), may differ from forcing window
+utilities: # Additional utilities to be used in the model
+ normalization: # Normalization statistics for the state, forcing, and one-step differences
+ zarrs: # Zarr files containing the normalization statistics, multiple allowed
+ - path: "normalization.zarr" # Path to the zarr file, default locaton of `calculate_statistics.py`
+ stats_vars: # The variables to use for normalization, predefined and required
+ state_mean: name_in_dataset1
+ state_std: name_in_dataset2
+ forcing_mean: name_in_dataset3
+ forcing_std: name_in_dataset4
+ diff_mean: name_in_dataset5
+ diff_std: name_in_dataset6
+ combined_stats: # For some variables the statistics can be retrieved jointly
+ - vars: # List of variables that should end of with the same statistics
+ - vars1
+ - vars2
+ - vars:
+ ...
+grid_shape_state: # Shape of the state grid, used for reshaping the model output
+ y: 589 # Number of grid points in the y-direction (lat)
+ x: 789 # Number of grid points in the x-direction (lon)
+splits: # Train, validation, and test splits based on time-sampling
+ train:
+ start: 1990-09-01T00
+ end: 1990-09-11T00
+ val:
+ start: 1990-09-11T03
+ end: 1990-09-13T09
+ test:
+ start: 1990-09-11T03
+ end: 1990-09-13T09
+projection: # Projection of the grid (only used for plotting)
+ class: LambertConformal # Name of class in cartopy.crs
+ kwargs:
+ central_longitude: 6.22
+ central_latitude: 56.0
+ standard_parallels: [47.6, 64.4]
+
```
## Format of graph directory
From 6d384f018ca78f9ccfd45e15154ff62653510ec4 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 4 Jun 2024 17:42:36 +0200
Subject: [PATCH 083/273] samller bugfixes and improvements
---
calculate_statistics.py | 2 --
neural_lam/config.py | 27 ++++++---------------------
neural_lam/data_config.yaml | 11 +++++------
neural_lam/models/ar_model.py | 23 ++++++++++++++++++++++-
train_model.py | 8 +-------
5 files changed, 34 insertions(+), 37 deletions(-)
diff --git a/calculate_statistics.py b/calculate_statistics.py
index 90d3dbc0..b2469838 100644
--- a/calculate_statistics.py
+++ b/calculate_statistics.py
@@ -98,8 +98,6 @@ def main():
)
ds = xr.merge([ds, dsf])
- print(ds)
-
ds = ds.chunk({"variable": -1, "forcing_variable": -1})
print("Saving dataset as Zarr...")
ds.to_zarr(args.zarr_path, mode="w")
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 21a97018..f71d7d8f 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -103,7 +103,7 @@ def open_zarrs(self, category):
datasets = []
for config in zarr_configs:
dataset_path = config["path"]
- dataset = xr.open_zarrs(dataset_path, consolidated=True)
+ dataset = xr.open_zarr(dataset_path, consolidated=True)
datasets.append(dataset)
merged_dataset = xr.merge(datasets)
merged_dataset.attrs["category"] = category
@@ -175,7 +175,7 @@ def filter_dimensions(self, dataset, transpose_array=True):
"-- from dataset.\033[0m",
)
print(
- "\033[91mAny data vars still dependent "
+ "\033[91mAny data vars dependent "
"on these variables were dropped!\033[0m"
)
@@ -244,7 +244,7 @@ def load_normalization_stats(self, category, datatype="torch"):
f"{stats_path}"
)
return None
- stats = xr.open_zarrs(stats_path, consolidated=True)
+ stats = xr.open_zarr(stats_path, consolidated=True)
if i == 0:
combined_stats = stats
else:
@@ -291,23 +291,6 @@ def load_normalization_stats(self, category, datatype="torch"):
return stats
- # def assign_lat_lon_coords(self, category, dataset=None):
- # """Process the latitude and longitude names of the dataset."""
- # if dataset is None:
- # dataset = self.open_zarrs(category)
- # lat_lon_names = {}
- # for zarr_config in self.values[category]["zarrs"]:
- # lat_lon_names.update(zarr_config["lat_lon_names"])
- # lat_name, lon_name = (lat_lon_names["lat"], lat_lon_names["lon"])
-
- # if "x" not in dataset.dims or "y" in dataset.dims:
- # dataset = self.reshape_grid_to_2d(dataset)
- # if not set(lat_lon_names).issubset(dataset.to_array().dims):
- # dataset = dataset.assign_coords(
- # x=dataset[lon_name], y=dataset[lat_name]
- # )
- # return dataset
-
def extract_vars(self, category, dataset=None):
"""Extract the variables from the dataset."""
if dataset is None:
@@ -396,6 +379,8 @@ def process_dataset(self, category, split="train", apply_windowing=True):
dataset = self.convert_dataset_to_dataarray(dataset)
if "window" in self.values[category] and apply_windowing:
dataset = self.apply_window(category, dataset)
+ if category == "static" and "time" in dataset.dims:
+ dataset = dataset.isel(time=0, drop=True)
return dataset
@@ -417,4 +402,4 @@ def apply_window(self, category, dataset=None):
def load_boundary_mask(self):
"""Load the boundary mask for the dataset."""
boundary_mask = xr.open_zarr(self.values["boundary"]["mask"]["path"])
- return boundary_mask.to_array().values
+ return torch.tensor(boundary_mask.to_array().values)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index bffedb77..8e1e9c12 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -72,6 +72,10 @@ forcing:
- kg/m^2 # just as a technical test :)
- m
- m
+ - ""
+ - ""
+ - ""
+ - ""
atmosphere_vars: null
atmosphere_units: null
levels: null
@@ -129,14 +133,9 @@ utilities:
- vars:
- cape_column
- xhail0m
- boundary_mask:
- zarrs:
- - path: "boundary.zarr"
- boundary_vars:
- boundary_mask: boundary_mask
grid_shape_state:
- x: 789
y: 589
+ x: 789
splits:
train:
start: 1990-09-01T00
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 391f6b02..1dec1d50 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -34,18 +34,39 @@ def __init__(self, args):
persistent=False,
)
+ state_stats = self.config_loader.load_normalization_stats(
+ "state", datatype="torch"
+ )
+ for key, val in state_stats.items():
+ self.register_buffer(key, val, persistent=False)
+
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
self.grid_output_dim = self.config_loader.num_data_vars("state")
if self.output_std:
# Pred. dim. in grid cell
- self.grid_output_dim = 2 * self.grid_output_dim
+ self.grid_output_dim = 2 * self.config_loader.num_data_vars(
+ "state"
+ )
+ else:
+ # Pred. dim. in grid cell
+ self.grid_output_dim = self.config_loader.num_data_vars("state")
+ # Store constant per-variable std.-dev. weighting
+ # Note that this is the inverse of the multiplicative weighting
+ # in wMSE/wMAE
+ # TODO: Do we need param_weights for this?
+ self.register_buffer(
+ "per_var_std",
+ self.diff_std,
+ persistent=False,
+ )
# grid_dim from data + static
(
self.num_grid_nodes,
grid_static_dim,
) = self.grid_static_features.shape
+
self.grid_dim = (
2 * self.grid_output_dim
+ grid_static_dim
diff --git a/train_model.py b/train_model.py
index eda536b6..11b386d0 100644
--- a/train_model.py
+++ b/train_model.py
@@ -43,12 +43,6 @@ def main(input_args=None):
default="graph_lam",
help="Model architecture to train/evaluate (default: graph_lam)",
)
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
parser.add_argument(
"--seed", type=int, default=42, help="random seed (default: 42)"
)
@@ -281,7 +275,7 @@ def main(input_args=None):
# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics(
- logger, val_steps=args.val_steps_log
+ logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
wandb.save(args.data_config)
if args.eval:
From 12ff4f25e66bb52333fc2dc6d8a1536207604b87 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 4 Jun 2024 17:42:58 +0200
Subject: [PATCH 084/273] Added fixed data config file for testing on Danra
---
tests/data_config.yaml | 154 +++++++++++++++++++++++++++++++++++++++++
1 file changed, 154 insertions(+)
create mode 100644 tests/data_config.yaml
diff --git a/tests/data_config.yaml b/tests/data_config.yaml
new file mode 100644
index 00000000..8e1e9c12
--- /dev/null
+++ b/tests/data_config.yaml
@@ -0,0 +1,154 @@
+name: danra
+state:
+ zarrs:
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ dims:
+ time: time
+ level: null
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr"
+ dims:
+ time: time
+ level: altitude
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
+ surface_vars:
+ - u10m
+ - v10m
+ - t2m
+ surface_units:
+ - m/s
+ - m/s
+ - K
+ atmosphere_vars:
+ - u
+ - v
+ - t
+ atmosphere_units:
+ - m/s
+ - m/s
+ - K
+ levels:
+ - 100
+forcing:
+ zarrs:
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ dims:
+ time: time
+ level: null
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
+ - path: "forcings.zarr"
+ dims:
+ time: time
+ level: null
+ x: x
+ y: y
+ grid: null
+ surface_vars:
+ - cape_column # just as a technical test
+ - icei0m
+ - vis0m
+ - xhail0m
+ - hour_cos
+ - hour_sin
+ - year_cos
+ - year_sin
+ surface_units:
+ - J/kg
+ - kg/m^2 # just as a technical test :)
+ - m
+ - m
+ - ""
+ - ""
+ - ""
+ - ""
+ atmosphere_vars: null
+ atmosphere_units: null
+ levels: null
+ window: 3 # Number of time steps to use for forcing (odd)
+static:
+ zarrs:
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ dims:
+ level: null
+ x: x
+ y: y
+ grid: null
+ lat_lon_names:
+ lon: lon
+ lat: lat
+ surface_vars:
+ - pres0m # just as a technical test
+ surface_units:
+ - Pa
+ atmosphere_vars: null
+ atmosphere_units: null
+ levels: null
+boundary:
+ zarrs: # This is not used currently, but soon ERA% boundaries will be used
+ - path: "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
+ dims:
+ time: time
+ level: level
+ x: longitude
+ y: latitude
+ lat_lon_names:
+ lon: longitude
+ lat: latitude
+ mask:
+ path: "boundary_mask.zarr"
+ dims:
+ x: x
+ y: y
+ window: 3
+utilities:
+ normalization:
+ zarrs:
+ - path: "normalization.zarr"
+ stats_vars:
+ state_mean: state_mean
+ state_std: state_std
+ forcing_mean: forcing_mean
+ forcing_std: forcing_std
+ diff_mean: diff_mean
+ diff_std: diff_std
+ combined_stats:
+ - vars:
+ - icei0m
+ - vis0m
+ - vars:
+ - cape_column
+ - xhail0m
+grid_shape_state:
+ y: 589
+ x: 789
+splits:
+ train:
+ start: 1990-09-01T00
+ end: 1990-09-11T00
+ val:
+ start: 1990-09-11T03
+ end: 1990-09-13T09
+ test:
+ start: 1990-09-11T03
+ end: 1990-09-13T09
+projection:
+ class: LambertConformal # Name of class in cartopy.crs
+ kwargs:
+ central_longitude: 6.22
+ central_latitude: 56.0
+ standard_parallels: [47.6, 64.4]
From 03f77699ee1daea54014e3c11eff9e16ba7091d6 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Tue, 4 Jun 2024 17:45:24 +0200
Subject: [PATCH 085/273] reducing runtime of tests with smaller sample
---
tests/data_config.yaml | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/tests/data_config.yaml b/tests/data_config.yaml
index 8e1e9c12..224c3f4e 100644
--- a/tests/data_config.yaml
+++ b/tests/data_config.yaml
@@ -139,13 +139,13 @@ grid_shape_state:
splits:
train:
start: 1990-09-01T00
- end: 1990-09-11T00
+ end: 1990-09-01T02
val:
- start: 1990-09-11T03
- end: 1990-09-13T09
+ start: 1990-09-11T00
+ end: 1990-09-11T02
test:
- start: 1990-09-11T03
- end: 1990-09-13T09
+ start: 1990-09-11T00
+ end: 1990-09-11T02
projection:
class: LambertConformal # Name of class in cartopy.crs
kwargs:
From 26f069c2581026558bd293a445b47812e19d8327 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 6 Jun 2024 07:41:03 +0200
Subject: [PATCH 086/273] download danra data for test and example (streaming
not possible)
---
docs/download_danra.py | 25 +++++++++++++++++++++++++
1 file changed, 25 insertions(+)
create mode 100644 docs/download_danra.py
diff --git a/docs/download_danra.py b/docs/download_danra.py
new file mode 100644
index 00000000..8d7542a2
--- /dev/null
+++ b/docs/download_danra.py
@@ -0,0 +1,25 @@
+import xarray as xr
+
+data_urls = [
+ "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr",
+ "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr",
+]
+
+local_paths = [
+ "data/danra/single_levels.zarr",
+ "data/danra/height_levels.zarr",
+]
+
+for url, path in zip(data_urls, local_paths):
+ print(f"Downloading {url} to {path}")
+ ds = xr.open_zarr(url)
+ chunk_dict = {dim: -1 for dim in ds.dims if dim != "time"}
+ chunk_dict["time"] = 20
+ ds = ds.chunk(chunk_dict)
+
+ for var in ds.variables:
+ if 'chunks' in ds[var].encoding:
+ del ds[var].encoding['chunks']
+
+ ds.to_zarr(path, mode="w")
+ print("DONE")
From 1f1cbcc01bfbad814d2fbac8fb6dbfe896f2bb79 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 6 Jun 2024 13:14:00 +0200
Subject: [PATCH 087/273] bugfixes after real-life testcase
---
calculate_statistics.py | 12 +++---
create_boundary_mask.py | 6 +--
create_forcings.py | 4 +-
create_mesh.py | 4 +-
docs/download_danra.py | 5 ++-
neural_lam/config.py | 63 +++++++++++++++++----------
neural_lam/data_config.yaml | 16 +++++--
neural_lam/models/ar_model.py | 78 +++++++++++++++++-----------------
neural_lam/vis.py | 22 +++++++---
neural_lam/weather_dataset.py | 17 ++++----
plot_graph.py | 4 +-
tests/data_config.yaml | 16 +++++--
tests/test_analysis_dataset.py | 38 +++++++++--------
train_model.py | 9 ++--
14 files changed, 173 insertions(+), 121 deletions(-)
diff --git a/calculate_statistics.py b/calculate_statistics.py
index b2469838..e142ddfc 100644
--- a/calculate_statistics.py
+++ b/calculate_statistics.py
@@ -30,9 +30,9 @@ def main():
)
args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
- state_data = config_loader.process_dataset("state", split="train")
- forcing_data = config_loader.process_dataset(
+ data_config = config.Config.from_file(args.data_config)
+ state_data = data_config.process_dataset("state", split="train")
+ forcing_data = data_config.process_dataset(
"forcing", split="train", apply_windowing=False
)
@@ -41,7 +41,7 @@ def main():
if forcing_data is not None:
forcing_mean, forcing_std = compute_stats(forcing_data)
- combined_stats = config_loader["utilities"]["normalization"][
+ combined_stats = data_config["utilities"]["normalization"][
"combined_stats"
]
@@ -58,7 +58,7 @@ def main():
dict(variable=vars_to_combine)
] = combined_mean
forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
- window = config_loader["forcing"]["window"]
+ window = data_config["forcing"]["window"]
forcing_mean = xr.concat([forcing_mean] * window, dim="window").stack(
forcing_variable=("variable", "window")
)
@@ -66,7 +66,7 @@ def main():
forcing_variable=("variable", "window")
)
vars = forcing_data["variable"].values.tolist()
- window = config_loader["forcing"]["window"]
+ window = data_config["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
print(
diff --git a/create_boundary_mask.py b/create_boundary_mask.py
index 78443df0..1933cfef 100644
--- a/create_boundary_mask.py
+++ b/create_boundary_mask.py
@@ -31,8 +31,8 @@ def main():
help="Number of grid-cells to set to True along each boundary",
)
args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
- mask = np.zeros(list(config_loader.grid_shape_state.values.values()))
+ data_config = config.Config.from_file(args.data_config)
+ mask = np.zeros(list(data_config.grid_shape_state.values.values()))
# Set the args.boundaries grid-cells closest to each boundary to True
mask[: args.boundaries, :] = True # top boundary
@@ -40,7 +40,7 @@ def main():
mask[:, : args.boundaries] = True # left boundary
mask[:, -args.boundaries :] = True # noqa right boundary
- mask = xr.Dataset({"mask": (["x", "y"], mask)})
+ mask = xr.Dataset({"mask": (["y", "x"], mask)})
print(f"Saving mask to {args.zarr_path}...")
mask.to_zarr(args.zarr_path, mode="w")
diff --git a/create_forcings.py b/create_forcings.py
index 459a3982..10dc3c8e 100644
--- a/create_forcings.py
+++ b/create_forcings.py
@@ -59,8 +59,8 @@ def main():
parser.add_argument("--zarr_path", type=str, default="forcings.zarr")
args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
- dataset = config_loader.open_zarrs("state")
+ data_config = config.Config.from_file(args.data_config)
+ dataset = data_config.open_zarrs("state")
datetime_forcing = calculate_datetime_forcing(timesteps=dataset.time)
# Expand dimensions to match the target dataset
diff --git a/create_mesh.py b/create_mesh.py
index 42e23358..238d075b 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -193,11 +193,11 @@ def main(input_args=None):
args = parser.parse_args(input_args)
# Load grid positions
- config_loader = config.Config.from_file(args.data_config)
+ data_config = config.Config.from_file(args.data_config)
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)
- xy = config_loader.get_xy("static") # (2, N_y, N_x)
+ xy = data_config.get_xy("static") # (2, N_y, N_x)
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
diff --git a/docs/download_danra.py b/docs/download_danra.py
index 8d7542a2..fb70754f 100644
--- a/docs/download_danra.py
+++ b/docs/download_danra.py
@@ -1,3 +1,4 @@
+# Third-party
import xarray as xr
data_urls = [
@@ -18,8 +19,8 @@
ds = ds.chunk(chunk_dict)
for var in ds.variables:
- if 'chunks' in ds[var].encoding:
- del ds[var].encoding['chunks']
+ if "chunks" in ds[var].encoding:
+ del ds[var].encoding["chunks"]
ds.to_zarr(path, mode="w")
print("DONE")
diff --git a/neural_lam/config.py b/neural_lam/config.py
index f71d7d8f..480aaddf 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -56,6 +56,15 @@ def coords_projection(self):
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)
+ @functools.cached_property
+ def step_length(self):
+ """Return the step length of the dataset in hours."""
+ dataset = self.open_zarrs("state")
+ time = dataset.time.isel(time=slice(0, 2)).values
+ step_length_ns = time[1] - time[0]
+ step_length_hours = step_length_ns / np.timedelta64(1, "h")
+ return int(step_length_hours)
+
@functools.lru_cache()
def vars_names(self, category):
"""Return the names of the variables in the dataset."""
@@ -191,10 +200,10 @@ def filter_dimensions(self, dataset, transpose_array=True):
if isinstance(dataset, xr.Dataset)
else dataset["variable"].values.tolist()
)
- print(
- "\033[94mYour Dataarray has the following variables: ",
- dataset_vars,
- "\033[0m",
+
+ print( # noqa
+ f"\033[94mYour {dataset.attrs['category']} xr.Dataarray has the "
+ f"following variables: {dataset_vars} \033[0m",
)
return dataset
@@ -366,29 +375,19 @@ def filter_dataset_by_time(self, dataset, split="train"):
self.values["splits"][split]["start"],
self.values["splits"][split]["end"],
)
- return dataset.sel(time=slice(start, end))
-
- def process_dataset(self, category, split="train", apply_windowing=True):
- """Process the dataset for the given category."""
- dataset = self.open_zarrs(category)
- dataset = self.extract_vars(category, dataset)
- dataset = self.filter_dataset_by_time(dataset, split)
- dataset = self.stack_grid(dataset)
- dataset = self.rename_dataset_dims_and_vars(category, dataset)
- dataset = self.filter_dimensions(dataset)
- dataset = self.convert_dataset_to_dataarray(dataset)
- if "window" in self.values[category] and apply_windowing:
- dataset = self.apply_window(category, dataset)
- if category == "static" and "time" in dataset.dims:
- dataset = dataset.isel(time=0, drop=True)
-
+ dataset = dataset.sel(time=slice(start, end))
+ dataset.attrs["split"] = split
return dataset
def apply_window(self, category, dataset=None):
"""Apply the forcing window to the forcing dataset."""
if dataset is None:
dataset = self.open_zarrs(category)
- state_time = self.open_zarrs("state").time.values
+ if isinstance(dataset, xr.Dataset):
+ dataset = self.convert_dataset_to_dataarray(dataset)
+ state = self.open_zarrs("state")
+ state = self.filter_dataset_by_time(state, dataset.attrs["split"])
+ state_time = state.time.values
window = self[category].window
dataset = (
dataset.sel(time=state_time, method="nearest")
@@ -397,9 +396,29 @@ def apply_window(self, category, dataset=None):
.construct("window")
.stack(variable_window=("variable", "window"))
)
+ dataset = dataset.isel(time=slice(window // 2, -window // 2 + 1))
return dataset
def load_boundary_mask(self):
"""Load the boundary mask for the dataset."""
boundary_mask = xr.open_zarr(self.values["boundary"]["mask"]["path"])
- return torch.tensor(boundary_mask.to_array().values)
+ return torch.tensor(
+ boundary_mask.mask.stack(grid=("y", "x")).values,
+ dtype=torch.float32,
+ ).unsqueeze(1)
+
+ def process_dataset(self, category, split="train", apply_windowing=True):
+ """Process the dataset for the given category."""
+ dataset = self.open_zarrs(category)
+ dataset = self.extract_vars(category, dataset)
+ dataset = self.filter_dataset_by_time(dataset, split)
+ dataset = self.stack_grid(dataset)
+ dataset = self.rename_dataset_dims_and_vars(category, dataset)
+ dataset = self.filter_dimensions(dataset)
+ dataset = self.convert_dataset_to_dataarray(dataset)
+ if "window" in self.values[category] and apply_windowing:
+ dataset = self.apply_window(category, dataset)
+ if category == "static" and "time" in dataset.dims:
+ dataset = dataset.isel(time=0, drop=True)
+
+ return dataset
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 8e1e9c12..87c3a354 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -1,7 +1,7 @@
name: danra
state:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "data/danra/single_levels.zarr"
dims:
time: time
level: null
@@ -11,7 +11,7 @@ state:
lat_lon_names:
lon: lon
lat: lat
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr"
+ - path: "data/danra/height_levels.zarr"
dims:
time: time
level: altitude
@@ -41,7 +41,7 @@ state:
- 100
forcing:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "data/danra/single_levels.zarr"
dims:
time: time
level: null
@@ -82,7 +82,7 @@ forcing:
window: 3 # Number of time steps to use for forcing (odd)
static:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "data/danra/single_levels.zarr"
dims:
level: null
x: x
@@ -106,6 +106,7 @@ boundary:
level: level
x: longitude
y: latitude
+ grid: null
lat_lon_names:
lon: longitude
lat: latitude
@@ -114,6 +115,13 @@ boundary:
dims:
x: x
y: y
+ surface_vars:
+ - t2m
+ surface_units:
+ - K
+ atmosphere_vars: null
+ atmosphere_units: null
+ levels: null
window: 3
utilities:
normalization:
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 1dec1d50..5b57fb4b 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -24,17 +24,17 @@ def __init__(self, args):
super().__init__()
self.save_hyperparameters()
self.args = args
- self.config_loader = config.Config.from_file(args.data_config)
+ self.data_config = config.Config.from_file(args.data_config)
# Load static features for grid/data
- static = self.config_loader.process_dataset("static")
+ static = self.data_config.process_dataset("static")
self.register_buffer(
"grid_static_features",
- torch.tensor(static.values),
+ torch.tensor(static.values, dtype=torch.float32),
persistent=False,
)
- state_stats = self.config_loader.load_normalization_stats(
+ state_stats = self.data_config.load_normalization_stats(
"state", datatype="torch"
)
for key, val in state_stats.items():
@@ -42,15 +42,13 @@ def __init__(self, args):
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
- self.grid_output_dim = self.config_loader.num_data_vars("state")
+ self.grid_output_dim = self.data_config.num_data_vars("state")
if self.output_std:
# Pred. dim. in grid cell
- self.grid_output_dim = 2 * self.config_loader.num_data_vars(
- "state"
- )
+ self.grid_output_dim = 2 * self.data_config.num_data_vars("state")
else:
# Pred. dim. in grid cell
- self.grid_output_dim = self.config_loader.num_data_vars("state")
+ self.grid_output_dim = self.data_config.num_data_vars("state")
# Store constant per-variable std.-dev. weighting
# Note that this is the inverse of the multiplicative weighting
# in wMSE/wMAE
@@ -70,14 +68,14 @@ def __init__(self, args):
self.grid_dim = (
2 * self.grid_output_dim
+ grid_static_dim
- + self.config_loader.num_data_vars("forcing")
- * self.config_loader.forcing.window
+ + self.data_config.num_data_vars("forcing")
+ * self.data_config.forcing.window
)
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
- boundary_mask = self.config_loader.load_boundary_mask()
+ boundary_mask = self.data_config.load_boundary_mask()
self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
@@ -85,7 +83,7 @@ def __init__(self, args):
) # (num_grid_nodes, 1), 1 for non-border
# Number of hours per pred. step
- self.step_length = self.config_loader.step_length
+ self.step_length = self.data_config.step_length
self.val_metrics = {
"mse": [],
}
@@ -192,11 +190,7 @@ def common_step(self, batch):
num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
"""
- (
- init_states,
- target_states,
- forcing_features,
- ) = batch
+ (init_states, target_states, forcing_features, batch_times) = batch
prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
@@ -204,13 +198,13 @@ def common_step(self, batch):
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
- return prediction, target_states, pred_std
+ return prediction, target_states, pred_std, batch_times
def training_step(self, batch):
"""
Train on single batch
"""
- prediction, target, pred_std = self.common_step(batch)
+ prediction, target, pred_std, _ = self.common_step(batch)
# Compute loss
batch_loss = torch.mean(
@@ -226,6 +220,7 @@ def training_step(self, batch):
on_step=True,
on_epoch=True,
sync_dist=True,
+ batch_size=batch[0].shape[0],
)
return batch_loss
@@ -246,7 +241,7 @@ def validation_step(self, batch, batch_idx):
"""
Run validation on single batch
"""
- prediction, target, pred_std = self.common_step(batch)
+ prediction, target, pred_std, _ = self.common_step(batch)
time_step_loss = torch.mean(
self.loss(
@@ -263,7 +258,11 @@ def validation_step(self, batch, batch_idx):
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
- val_log_dict, on_step=False, on_epoch=True, sync_dist=True
+ val_log_dict,
+ on_step=False,
+ on_epoch=True,
+ sync_dist=True,
+ batch_size=batch[0].shape[0],
)
# Store MSEs
@@ -292,7 +291,8 @@ def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
- prediction, target, pred_std = self.common_step(batch)
+ # NOTE Here batch_times can be used for plotting routines
+ prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
@@ -312,7 +312,11 @@ def test_step(self, batch, batch_idx):
test_log_dict["test_mean_loss"] = mean_loss
self.log_dict(
- test_log_dict, on_step=False, on_epoch=True, sync_dist=True
+ test_log_dict,
+ on_step=False,
+ on_epoch=True,
+ sync_dist=True,
+ batch_size=batch[0].shape[0],
)
# Compute all evaluation metrics for error maps Note: explicitly list
@@ -371,13 +375,13 @@ def plot_examples(self, batch, n_examples, prediction=None):
Generate if None.
"""
if prediction is None:
- prediction, target = self.common_step(batch)
+ prediction, target, _, _ = self.common_step(batch)
target = batch[1]
# Rescale to original data scale
- prediction_rescaled = prediction * self.std + self.mean
- target_rescaled = target * self.std + self.mean
+ prediction_rescaled = prediction * self.state_std + self.state_mean
+ target_rescaled = target * self.state_std + self.state_mean
# Iterate over the examples
for pred_slice, target_slice in zip(
@@ -414,17 +418,15 @@ def plot_examples(self, batch, n_examples, prediction=None):
pred_t[:, var_i],
target_t[:, var_i],
self.interior_mask[:, 0],
- self.config_loader,
+ self.data_config,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self.step_length * t_i} h)",
vrange=var_vrange,
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
- self.config_loader.dataset.var_names,
- self.config_loader.dataset.var_units,
- self.config_loader.param_names(),
- self.config_loader.param_units(),
+ self.data_config.vars_names("state"),
+ self.data_config.vars_units("state"),
var_vranges,
)
)
@@ -435,7 +437,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
- self.config_loader.param_names(), var_figs
+ self.data_config.vars_names("state"), var_figs
)
}
)
@@ -470,7 +472,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
log_dict = {}
metric_fig = vis.plot_error_map(
- metric_tensor, self.config_loader, step_length=self.step_length
+ metric_tensor, self.data_config, step_length=self.step_length
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
@@ -490,7 +492,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
# Check if metrics are watched, log exact values for specific vars
if full_log_name in self.args.metrics_watch:
for var_i, timesteps in self.args.var_leads_metrics_watch.items():
- var = self.config_loader.param_names()[var_i]
+ var = self.data_config.vars_names("state")[var_i]
log_dict.update(
{
f"{full_log_name}_{var}_step_{step}": metric_tensor[
@@ -526,7 +528,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
metric_name = metric_name.replace("mse", "rmse")
# Note: we here assume rescaling for all metrics is linear
- metric_rescaled = metric_tensor_averaged * self.std
+ metric_rescaled = metric_tensor_averaged * self.state_std
# (pred_steps, d_f)
log_dict.update(
self.create_metric_log_dict(
@@ -559,7 +561,7 @@ def on_test_epoch_end(self):
vis.plot_spatial_error(
loss_map,
self.interior_mask[:, 0],
- self.config_loader,
+ self.data_config,
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
for t_i, loss_map in zip(
@@ -574,7 +576,7 @@ def on_test_epoch_end(self):
# also make without title and save as pdf
pdf_loss_map_figs = [
vis.plot_spatial_error(
- loss_map, self.interior_mask[:, 0], self.config_loader
+ loss_map, self.interior_mask[:, 0], self.data_config
)
for loss_map in mean_spatial_loss
]
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index ca77e24e..c92739f9 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -51,7 +51,7 @@ def plot_error_map(errors, data_config, title=None, step_length=1):
y_ticklabels = [
f"{name} ({unit})"
for name, unit in zip(
- data_config.dataset.var_names, data_config.dataset.var_units
+ data_config.vars_names("state"), data_config.vars_units("state")
)
]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)
@@ -78,7 +78,9 @@ def plot_prediction(
vmin, vmax = vrange
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state)
+ mask_reshaped = obs_mask.reshape(
+ list(data_config.grid_shape_state.values.values())
+ )
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
@@ -93,7 +95,11 @@ def plot_prediction(
# Plot pred and target
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
- data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy()
+ data_grid = (
+ data.reshape(list(data_config.grid_shape_state.values.values()))
+ .cpu()
+ .numpy()
+ )
im = ax.imshow(
data_grid,
origin="lower",
@@ -129,7 +135,9 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
vmin, vmax = vrange
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state)
+ mask_reshaped = obs_mask.reshape(
+ list(data_config.grid_shape_state.values.values())
+ )
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
@@ -140,7 +148,11 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
)
ax.coastlines() # Add coastline outlines
- error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy()
+ error_grid = (
+ error.reshape(list(data_config.grid_shape_state.values.values()))
+ .cpu()
+ .numpy()
+ )
im = ax.imshow(
error_grid,
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index c25b0452..5eda343f 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -36,20 +36,18 @@ def __init__(
self.batch_size = batch_size
self.ar_steps = ar_steps
self.control_only = control_only
- self.config_loader = config.Config.from_file(data_config)
+ self.data_config = config.Config.from_file(data_config)
- self.state = self.config_loader.process_dataset("state", self.split)
+ self.state = self.data_config.process_dataset("state", self.split)
assert self.state is not None, "State dataset not found"
- self.forcing = self.config_loader.process_dataset(
- "forcing", self.split
- )
+ self.forcing = self.data_config.process_dataset("forcing", self.split)
self.state_times = self.state.time.values
# Set up for standardization
# NOTE: This will become part of ar_model.py soon!
self.standardize = standardize
if standardize:
- state_stats = self.config_loader.load_normalization_stats(
+ state_stats = self.data_config.load_normalization_stats(
"state", datatype="torch"
)
self.state_mean, self.state_std = (
@@ -58,7 +56,7 @@ def __init__(
)
if self.forcing is not None:
- forcing_stats = self.config_loader.load_normalization_stats(
+ forcing_stats = self.data_config.load_normalization_stats(
"forcing", datatype="torch"
)
self.forcing_mean, self.forcing_std = (
@@ -80,10 +78,11 @@ def __getitem__(self, idx):
torch.tensor(
self.forcing.isel(
time=slice(idx + 2, idx + self.ar_steps)
- ).values
+ ).values,
+ dtype=torch.float32,
)
if self.forcing is not None
- else torch.tensor([])
+ else torch.tensor([], dtype=torch.float32)
)
init_states = sample[:2]
diff --git a/plot_graph.py b/plot_graph.py
index dc3682ff..73acc801 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -44,8 +44,8 @@ def main():
)
args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
- xy = config_loader.get_xy("state") # (2, N_y, N_x)
+ data_config = config.Config.from_file(args.data_config)
+ xy = data_config.get_xy("state") # (2, N_y, N_x)
xy = xy.reshape(2, -1).T # (N_grid, 2)
pos_max = np.max(np.abs(xy))
grid_pos = xy / pos_max # Divide by maximum coordinate
diff --git a/tests/data_config.yaml b/tests/data_config.yaml
index 224c3f4e..9fb6d2d9 100644
--- a/tests/data_config.yaml
+++ b/tests/data_config.yaml
@@ -1,7 +1,7 @@
name: danra
state:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "data/danra/single_levels.zarr"
dims:
time: time
level: null
@@ -11,7 +11,7 @@ state:
lat_lon_names:
lon: lon
lat: lat
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr"
+ - path: "data/danra/height_levels.zarr"
dims:
time: time
level: altitude
@@ -41,7 +41,7 @@ state:
- 100
forcing:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "data/danra/single_levels.zarr"
dims:
time: time
level: null
@@ -82,7 +82,7 @@ forcing:
window: 3 # Number of time steps to use for forcing (odd)
static:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "data/danra/single_levels.zarr"
dims:
level: null
x: x
@@ -106,6 +106,7 @@ boundary:
level: level
x: longitude
y: latitude
+ grid: null
lat_lon_names:
lon: longitude
lat: latitude
@@ -114,6 +115,13 @@ boundary:
dims:
x: x
y: y
+ surface_vars:
+ - t2m
+ surface_units:
+ - K
+ atmosphere_vars: null
+ atmosphere_units: null
+ levels: null
window: 3
utilities:
normalization:
diff --git a/tests/test_analysis_dataset.py b/tests/test_analysis_dataset.py
index 546921aa..f5ceb678 100644
--- a/tests/test_analysis_dataset.py
+++ b/tests/test_analysis_dataset.py
@@ -5,7 +5,6 @@
from create_mesh import main as create_mesh
from neural_lam.config import Config
from neural_lam.weather_dataset import WeatherDataset
-from train_model import main as train_model
# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
@@ -13,8 +12,10 @@
def test_load_analysis_dataset():
- # The data_config.yaml file is downloaded and extracted in
- # test_retrieve_data_ewc together with the dataset itself
+ # NOTE: Access rights should be fixed for pooch to work
+ if not os.path.exists("data/danra"):
+ print("Please download test data first: python docs/download_danra.py")
+ return
data_config_file = "tests/data_config.yaml"
config = Config.from_file(data_config_file)
@@ -67,18 +68,19 @@ def test_create_graph_analysis_dataset():
create_mesh(args)
-def test_train_model_analysis_dataset():
- args = [
- "--model=hi_lam",
- "--data_config=tests/data_config.yaml",
- "--num_workers=4",
- "--epochs=1",
- "--graph=hierarchical",
- "--hidden_dim=16",
- "--hidden_layers=1",
- "--processor_layers=1",
- "--ar_steps_eval=1",
- "--eval=val",
- "--n_example_pred=0",
- ]
- train_model(args)
+# def test_train_model_analysis_dataset():
+# args = [
+# "--model=hi_lam",
+# "--data_config=tests/data_config.yaml",
+# "--num_workers=4",
+# "--epochs=1",
+# "--graph=hierarchical",
+# "--hidden_dim=16",
+# "--hidden_layers=1",
+# "--processor_layers=1",
+# "--ar_steps_eval=1",
+# "--eval=val",
+# "--n_example_pred=0",
+# "--val_steps_to_log=1",
+# ]
+# train_model(args)
diff --git a/train_model.py b/train_model.py
index 11b386d0..49f0a4c5 100644
--- a/train_model.py
+++ b/train_model.py
@@ -164,9 +164,9 @@ def main(input_args=None):
parser.add_argument(
"--ar_steps_eval",
type=int,
- default=25,
+ default=10,
help="Number of steps to unroll prediction for in loss function "
- "(default: 25)",
+ "(default: 10)",
)
parser.add_argument(
"--n_example_pred",
@@ -185,9 +185,10 @@ def main(input_args=None):
)
parser.add_argument(
"--val_steps_to_log",
- type=list,
+ nargs="+",
+ type=int,
default=[1, 2, 3, 5, 10, 15, 19],
- help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])",
+ help="Steps to log val loss for (default: 1 2 3 5 10 15 19)",
)
parser.add_argument(
"--metrics_watch",
From 0cdc3618acad77d71d86081067a9ac44881e122c Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 6 Jun 2024 14:48:02 +0200
Subject: [PATCH 088/273] organize .zarr in /data
---
README.md | 4 ++--
calculate_statistics.py | 2 +-
create_boundary_mask.py | 4 ++--
create_forcings.py | 2 +-
neural_lam/data_config.yaml | 6 +++---
tests/data_config.yaml | 6 +++---
6 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/README.md b/README.md
index 272cd8a9..16e3257d 100644
--- a/README.md
+++ b/README.md
@@ -228,7 +228,7 @@ boundary: # Boundary variables are not predicted
... # They are used to inform the model about the surrounding weather conditions
... # The boundaries are often used from a separate model, specified identically to the state
mask: # Boundary mask to indicate where the model should not make predictions
- path: "boundary_mask.zarr"
+ path: "data/boundary_mask.zarr"
dims:
x: x
y: y
@@ -236,7 +236,7 @@ boundary: # Boundary variables are not predicted
utilities: # Additional utilities to be used in the model
normalization: # Normalization statistics for the state, forcing, and one-step differences
zarrs: # Zarr files containing the normalization statistics, multiple allowed
- - path: "normalization.zarr" # Path to the zarr file, default locaton of `calculate_statistics.py`
+ - path: "data/normalization.zarr" # Path to the zarr file, default locaton of `calculate_statistics.py`
stats_vars: # The variables to use for normalization, predefined and required
state_mean: name_in_dataset1
state_std: name_in_dataset2
diff --git a/calculate_statistics.py b/calculate_statistics.py
index e142ddfc..b62dbc1a 100644
--- a/calculate_statistics.py
+++ b/calculate_statistics.py
@@ -25,7 +25,7 @@ def main():
parser.add_argument(
"--zarr_path",
type=str,
- default="normalization.zarr",
+ default="data/normalization.zarr",
help="Directory where data is stored",
)
args = parser.parse_args()
diff --git a/create_boundary_mask.py b/create_boundary_mask.py
index 1933cfef..5c0c115f 100644
--- a/create_boundary_mask.py
+++ b/create_boundary_mask.py
@@ -20,9 +20,9 @@ def main():
parser.add_argument(
"--zarr_path",
type=str,
- default="boundary_mask.zarr",
+ default="data/boundary_mask.zarr",
help="Path to save the Zarr archive "
- "(default: same directory as border_mask.npy)",
+ "(default: same directory as data/boundary_mask.zarr)",
)
parser.add_argument(
"--boundaries",
diff --git a/create_forcings.py b/create_forcings.py
index 10dc3c8e..f1df2312 100644
--- a/create_forcings.py
+++ b/create_forcings.py
@@ -56,7 +56,7 @@ def main():
parser.add_argument(
"--data_config", type=str, default="neural_lam/data_config.yaml"
)
- parser.add_argument("--zarr_path", type=str, default="forcings.zarr")
+ parser.add_argument("--zarr_path", type=str, default="data/forcings.zarr")
args = parser.parse_args()
data_config = config.Config.from_file(args.data_config)
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 87c3a354..0b9ef1bf 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -51,7 +51,7 @@ forcing:
lat_lon_names:
lon: lon
lat: lat
- - path: "forcings.zarr"
+ - path: "data/forcings.zarr"
dims:
time: time
level: null
@@ -111,7 +111,7 @@ boundary:
lon: longitude
lat: latitude
mask:
- path: "boundary_mask.zarr"
+ path: "data/boundary_mask.zarr"
dims:
x: x
y: y
@@ -126,7 +126,7 @@ boundary:
utilities:
normalization:
zarrs:
- - path: "normalization.zarr"
+ - path: "data/normalization.zarr"
stats_vars:
state_mean: state_mean
state_std: state_std
diff --git a/tests/data_config.yaml b/tests/data_config.yaml
index 9fb6d2d9..b36098e2 100644
--- a/tests/data_config.yaml
+++ b/tests/data_config.yaml
@@ -51,7 +51,7 @@ forcing:
lat_lon_names:
lon: lon
lat: lat
- - path: "forcings.zarr"
+ - path: "data/forcings.zarr"
dims:
time: time
level: null
@@ -111,7 +111,7 @@ boundary:
lon: longitude
lat: latitude
mask:
- path: "boundary_mask.zarr"
+ path: "data/boundary_mask.zarr"
dims:
x: x
y: y
@@ -126,7 +126,7 @@ boundary:
utilities:
normalization:
zarrs:
- - path: "normalization.zarr"
+ - path: "data/normalization.zarr"
stats_vars:
state_mean: state_mean
state_std: state_std
From 23ca7b35455a1c5f675606331e3c045e90342bf7 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 6 Jun 2024 14:59:57 +0200
Subject: [PATCH 089/273] cleanup
---
.flake8 | 3 +++
calculate_statistics.py | 4 +---
create_mesh.py | 10 ++--------
neural_lam/config.py | 8 ++------
neural_lam/models/ar_model.py | 8 ++------
neural_lam/models/base_graph_model.py | 9 +++++----
neural_lam/models/graph_lam.py | 4 +---
neural_lam/models/hi_lam.py | 19 ++++++++-----------
neural_lam/models/hi_lam_parallel.py | 11 ++++-------
neural_lam/utils.py | 4 +---
pyproject.toml | 25 ++++++++++---------------
11 files changed, 39 insertions(+), 66 deletions(-)
create mode 100644 .flake8
diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..b02dd545
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 88
+ignore = E203, F811, I002, W503
diff --git a/calculate_statistics.py b/calculate_statistics.py
index b62dbc1a..daaf8767 100644
--- a/calculate_statistics.py
+++ b/calculate_statistics.py
@@ -54,9 +54,7 @@ def main():
combined_mean = means.mean(dim="variable")
combined_std = (stds**2).mean(dim="variable") ** 0.5
- forcing_mean.loc[
- dict(variable=vars_to_combine)
- ] = combined_mean
+ forcing_mean.loc[dict(variable=vars_to_combine)] = combined_mean
forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
window = data_config["forcing"]["window"]
forcing_mean = xr.concat([forcing_mean] * window, dim="window").stack(
diff --git a/create_mesh.py b/create_mesh.py
index 238d075b..f827ee56 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -125,11 +125,7 @@ def mk_2d_graph(xy, nx, ny):
# add diagonal edges
g.add_edges_from(
- [
- ((x, y), (x + 1, y + 1))
- for x in range(nx - 1)
- for y in range(ny - 1)
- ]
+ [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
+ [
((x + 1, y), (x, y + 1))
for x in range(nx - 1)
@@ -347,9 +343,7 @@ def main(input_args=None):
.reshape(int(n / nx) ** 2, 2)
)
ij = [tuple(x) for x in ij]
- G[lev] = networkx.relabel_nodes(
- G[lev], dict(zip(G[lev].nodes, ij))
- )
+ G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij)))
G_tot = networkx.compose(G_tot, G[lev])
# Relabel mesh nodes to start with 0
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 480aaddf..10653a86 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -275,16 +275,12 @@ def load_normalization_stats(self, category, datatype="torch"):
)
if category == "state":
- stats = combined_stats.loc[
- dict(variable=self.vars_names(category))
- ]
+ stats = combined_stats.loc[dict(variable=self.vars_names(category))]
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
elif category == "forcing":
vars = self.vars_names(category)
window = self["forcing"]["window"]
- forcing_vars = [
- f"{var}_{i}" for var in vars for i in range(window)
- ]
+ forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
stats = combined_stats.loc[dict(forcing_variable=forcing_vars)]
stats = stats[["forcing_mean", "forcing_std"]]
else:
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 5b57fb4b..9aa4b4e5 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -580,14 +580,10 @@ def on_test_epoch_end(self):
)
for loss_map in mean_spatial_loss
]
- pdf_loss_maps_dir = os.path.join(
- wandb.run.dir, "spatial_loss_maps"
- )
+ pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs):
- fig.savefig(
- os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")
- )
+ fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
torch.save(
mean_spatial_loss.cpu(),
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 723a3f3c..fb5df62d 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -118,8 +118,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
dim=-1,
)
- # Embed all features # (B, num_grid_nodes, d_h)
- grid_emb = self.grid_embedder(grid_features)
+ # Embed all features
+ grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h)
g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h)
m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h)
mesh_emb = self.embedd_mesh_nodes()
@@ -149,8 +149,9 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
) # (B, num_grid_nodes, d_h)
# Map to output dimension, only for grid
- # (B, num_grid_nodes, d_grid_out)
- net_output = self.output_map(grid_rep)
+ net_output = self.output_map(
+ grid_rep
+ ) # (B, num_grid_nodes, d_grid_out)
if self.output_std:
pred_delta_mean, pred_std_raw = net_output.chunk(
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index e4dc74ac..f767fba0 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -32,9 +32,7 @@ def __init__(self, args):
# Define sub-models
# Feature embedders for mesh
- self.mesh_embedder = utils.make_mlp(
- [mesh_dim] + self.mlp_blueprint_end
- )
+ self.mesh_embedder = utils.make_mlp([mesh_dim] + self.mlp_blueprint_end)
self.m2m_embedder = utils.make_mlp([m2m_dim] + self.mlp_blueprint_end)
# GNNs
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index 335ea8c7..4d7eb94c 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -101,8 +101,9 @@ def mesh_down_step(
reversed(same_gnns[:-1]),
):
# Extract representations
- # (B, N_mesh[l+1], d_h)
- send_node_rep = mesh_rep_levels[level_l + 1]
+ send_node_rep = mesh_rep_levels[
+ level_l + 1
+ ] # (B, N_mesh[l+1], d_h)
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
down_edge_rep = mesh_down_rep[level_l]
same_edge_rep = mesh_same_rep[level_l]
@@ -138,8 +139,9 @@ def mesh_up_step(
zip(up_gnns, same_gnns[1:]), start=1
):
# Extract representations
- # (B, N_mesh[l-1], d_h)
- send_node_rep = mesh_rep_levels[level_l - 1]
+ send_node_rep = mesh_rep_levels[
+ level_l - 1
+ ] # (B, N_mesh[l-1], d_h)
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
up_edge_rep = mesh_up_rep[level_l - 1]
same_edge_rep = mesh_same_rep[level_l]
@@ -181,11 +183,7 @@ def hi_processor_step(
self.mesh_up_same_gnns,
):
# Down
- (
- mesh_rep_levels,
- mesh_same_rep,
- mesh_down_rep,
- ) = self.mesh_down_step(
+ mesh_rep_levels, mesh_same_rep, mesh_down_rep = self.mesh_down_step(
mesh_rep_levels,
mesh_same_rep,
mesh_down_rep,
@@ -202,6 +200,5 @@ def hi_processor_step(
up_same_gnns,
)
- # Note: We return all, even though only down edges really are used
- # later
+ # Note: We return all, even though only down edges really are used later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index b6f619d1..740824e1 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -27,9 +27,7 @@ def __init__(self, args):
+ list(self.mesh_down_edge_index)
)
total_edge_index = torch.cat(total_edge_index_list, dim=1)
- self.edge_split_sections = [
- ei.shape[1] for ei in total_edge_index_list
- ]
+ self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list]
if args.processor_layers == 0:
self.processor = lambda x, edge_attr: (x, edge_attr)
@@ -88,12 +86,11 @@ def hi_processor_step(
mesh_same_rep = mesh_edge_rep_sections[: self.num_levels]
mesh_up_rep = mesh_edge_rep_sections[
- self.num_levels : self.num_levels + (self.num_levels - 1) # noqa
+ self.num_levels : self.num_levels + (self.num_levels - 1)
]
mesh_down_rep = mesh_edge_rep_sections[
- self.num_levels + (self.num_levels - 1) : # noqa
+ self.num_levels + (self.num_levels - 1) :
] # Last are down edges
- # Note: We return all, even though only down edges really are used
- # later
+ # Note: We return all, even though only down edges really are used later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 824038eb..682aa2e3 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -41,9 +41,7 @@ def load_graph(graph_name, device="cpu"):
graph_dir_path = os.path.join("graphs", graph_name)
def loads_file(fn):
- return torch.load(
- os.path.join(graph_dir_path, fn), map_location=device
- )
+ return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
# Load edges (edge_index)
m2m_edge_index = BufferList(
diff --git a/pyproject.toml b/pyproject.toml
index 192afbc7..b513a258 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,15 +1,8 @@
-[project]
-name = "neural_lam"
-version = "0.1.0"
-
-[tool.setuptools]
-packages = ["neural_lam"]
-
[tool.black]
-line-length = 79
+line-length = 80
[tool.isort]
-default_section = "THIRDPARTY" #codespell:ignore
+default_section = "THIRDPARTY"
profile = "black"
# Headings
import_heading_stdlib = "Standard library"
@@ -49,9 +42,12 @@ ignore = [
"create_mesh.py", # Disable linting for now, as major rework is planned/expected
]
# Temporary fix for import neural_lam statements until set up as proper package
-init-hook = 'import sys; sys.path.append(".")'
+init-hook='import sys; sys.path.append(".")'
[tool.pylint.TYPECHECK]
-generated-members = ["numpy.*", "torch.*"]
+generated-members = [
+ "numpy.*",
+ "torch.*",
+]
[tool.pylint.'MESSAGES CONTROL']
disable = [
"C0114", # 'missing-module-docstring', Do not require module docstrings
@@ -60,11 +56,10 @@ disable = [
"R0913", # 'too-many-arguments', Allow many function arguments
"R0914", # 'too-many-locals', Allow many local variables
"W0223", # 'abstract-method', Subclasses do not have to override all abstract methods
- "C0411", # 'wrong-import-order', Allow for isort to handle import order
]
[tool.pylint.DESIGN]
-max-statements = 100 # Allow for some more involved functions
+max-statements=100 # Allow for some more involved functions
[tool.pylint.IMPORTS]
-allow-any-import-level = "neural_lam"
+allow-any-import-level="neural_lam"
[tool.pylint.SIMILARITIES]
-min-similarity-lines = 10
+min-similarity-lines=10
From 81422f1ba19b145372a93f3fae12355a41f1df30 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 6 Jun 2024 15:56:33 +0200
Subject: [PATCH 090/273] linter
---
neural_lam/models/ar_model.py | 13 +++++++------
neural_lam/models/base_graph_model.py | 2 +-
neural_lam/models/hi_lam.py | 3 ++-
neural_lam/models/hi_lam_parallel.py | 3 ++-
neural_lam/weather_dataset.py | 2 +-
tests/test_analysis_dataset.py | 2 +-
6 files changed, 14 insertions(+), 11 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 9aa4b4e5..0c8422f3 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -14,11 +14,12 @@
class ARModel(pl.LightningModule):
"""
- Generic auto-regressive weather model. Abstract class that can be extended.
+ Generic auto-regressive weather model.
+ Abstract class that can be extended.
"""
- # pylint: disable=arguments-differ Disable to override args/kwargs from
- # superclass
+ # pylint: disable=arguments-differ
+ # Disable to override args/kwargs from superclass
def __init__(self, args):
super().__init__()
@@ -50,7 +51,7 @@ def __init__(self, args):
# Pred. dim. in grid cell
self.grid_output_dim = self.data_config.num_data_vars("state")
# Store constant per-variable std.-dev. weighting
- # Note that this is the inverse of the multiplicative weighting
+ # NOTE that this is the inverse of the multiplicative weighting
# in wMSE/wMAE
# TODO: Do we need param_weights for this?
self.register_buffer(
@@ -291,7 +292,7 @@ def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
- # NOTE Here batch_times can be used for plotting routines
+ # TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
@@ -527,7 +528,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
metric_tensor_averaged = torch.sqrt(metric_tensor_averaged)
metric_name = metric_name.replace("mse", "rmse")
- # Note: we here assume rescaling for all metrics is linear
+ # NOTE: we here assume rescaling for all metrics is linear
metric_rescaled = metric_tensor_averaged * self.state_std
# (pred_steps, d_f)
log_dict.update(
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index fb5df62d..f055b782 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -157,7 +157,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
pred_delta_mean, pred_std_raw = net_output.chunk(
2, dim=-1
) # both (B, num_grid_nodes, d_f)
- # Note: The predicted std. is not scaled in any way here
+ # NOTE: The predicted std. is not scaled in any way here
# linter for some reason does not think softplus is callable
# pylint: disable-next=not-callable
pred_std = torch.nn.functional.softplus(pred_std_raw)
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index 4d7eb94c..3d6905c7 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -200,5 +200,6 @@ def hi_processor_step(
up_same_gnns,
)
- # Note: We return all, even though only down edges really are used later
+ # NOTE: We return all, even though only down edges really are used
+ # later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index 740824e1..ee25b0e9 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -92,5 +92,6 @@ def hi_processor_step(
self.num_levels + (self.num_levels - 1) :
] # Last are down edges
- # Note: We return all, even though only down edges really are used later
+ # TODO: We return all, even though only down edges really are used
+ # later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 5eda343f..d14b2fd8 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -44,7 +44,7 @@ def __init__(
self.state_times = self.state.time.values
# Set up for standardization
- # NOTE: This will become part of ar_model.py soon!
+ # TODO: This will become part of ar_model.py soon!
self.standardize = standardize
if standardize:
state_stats = self.data_config.load_normalization_stats(
diff --git a/tests/test_analysis_dataset.py b/tests/test_analysis_dataset.py
index f5ceb678..d7191a01 100644
--- a/tests/test_analysis_dataset.py
+++ b/tests/test_analysis_dataset.py
@@ -12,7 +12,7 @@
def test_load_analysis_dataset():
- # NOTE: Access rights should be fixed for pooch to work
+ # TODO: Access rights should be fixed for pooch to work
if not os.path.exists("data/danra"):
print("Please download test data first: python docs/download_danra.py")
return
From 124541b7a5ef280bc855f9f51906c760b75ec4b4 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Fri, 7 Jun 2024 13:13:30 +0200
Subject: [PATCH 091/273] static dataset doesn't have time dim
---
neural_lam/config.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 10653a86..f6e625e1 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -407,7 +407,8 @@ def process_dataset(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category."""
dataset = self.open_zarrs(category)
dataset = self.extract_vars(category, dataset)
- dataset = self.filter_dataset_by_time(dataset, split)
+ if category != "static":
+ dataset = self.filter_dataset_by_time(dataset, split)
dataset = self.stack_grid(dataset)
dataset = self.rename_dataset_dims_and_vars(category, dataset)
dataset = self.filter_dimensions(dataset)
From 6140fdb6b1eb5c1cc22a300796fc3ec446bcb57d Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Fri, 7 Jun 2024 15:05:57 +0200
Subject: [PATCH 092/273] making two complex functions more modular
---
neural_lam/config.py | 62 ++++++++++++++++++++++++++++++--------------
1 file changed, 42 insertions(+), 20 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index f6e625e1..885600f1 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -243,30 +243,45 @@ def get_xy(self, category):
@functools.lru_cache()
def load_normalization_stats(self, category, datatype="torch"):
"""Load the normalization statistics for the dataset."""
+ combined_stats = self._load_and_merge_stats()
+ if combined_stats is None:
+ return None
+
+ combined_stats = self._rename_data_vars(combined_stats)
+
+ stats = self._select_stats_by_category(combined_stats, category)
+ if stats is None:
+ return None
+
+ if datatype == "torch":
+ return self._convert_stats_to_torch(stats)
+
+ return stats
+
+ def _load_and_merge_stats(self):
+ combined_stats = None
for i, zarr_config in enumerate(
self.values["utilities"]["normalization"]["zarrs"]
):
stats_path = zarr_config["path"]
if not os.path.exists(stats_path):
print(
- f"Normalization statistics not found at path: "
- f"{stats_path}"
+ f"Normalization statistics not found at path: {stats_path}"
)
return None
stats = xr.open_zarr(stats_path, consolidated=True)
if i == 0:
combined_stats = stats
else:
- stats = xr.merge([stats, combined_stats])
- combined_stats = stats
+ combined_stats = xr.merge([stats, combined_stats])
+ return combined_stats
- # Rename data variables
+ def _rename_data_vars(self, combined_stats):
vars_mapping = {}
- zarr_configs = self.values["utilities"]["normalization"]["zarrs"]
- for zarr_config in zarr_configs:
+ for zarr_config in self.values["utilities"]["normalization"]["zarrs"]:
vars_mapping.update(zarr_config["stats_vars"])
- combined_stats = combined_stats.rename_vars(
+ return combined_stats.rename_vars(
{
v: k
for k, v in vars_mapping.items()
@@ -274,38 +289,44 @@ def load_normalization_stats(self, category, datatype="torch"):
}
)
+ def _select_stats_by_category(self, combined_stats, category):
if category == "state":
stats = combined_stats.loc[dict(variable=self.vars_names(category))]
- stats = stats.drop_vars(["forcing_mean", "forcing_std"])
+ return stats.drop_vars(["forcing_mean", "forcing_std"])
elif category == "forcing":
vars = self.vars_names(category)
window = self["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
stats = combined_stats.loc[dict(forcing_variable=forcing_vars)]
- stats = stats[["forcing_mean", "forcing_std"]]
+ return stats[["forcing_mean", "forcing_std"]]
else:
print(f"Invalid category: {category}")
return None
- if datatype == "torch":
- stats_dict = {
- var: torch.tensor(stats[var].values, dtype=torch.float32)
- for var in stats.data_vars
- }
- return stats_dict
-
- return stats
+ def _convert_stats_to_torch(self, stats):
+ return {
+ var: torch.tensor(stats[var].values, dtype=torch.float32)
+ for var in stats.data_vars
+ }
def extract_vars(self, category, dataset=None):
"""Extract the variables from the dataset."""
if dataset is None:
dataset = self.open_zarrs(category)
- surface_vars = (
+
+ surface_vars = self._extract_surface_vars(category, dataset)
+ atmosphere_vars = self._extract_atmosphere_vars(category, dataset)
+
+ return self._merge_vars(surface_vars, atmosphere_vars, category)
+
+ def _extract_surface_vars(self, category, dataset):
+ return (
dataset[self[category].surface_vars]
if self[category].surface_vars
else []
)
+ def _extract_atmosphere_vars(self, category, dataset):
if (
"level" not in dataset.to_array().dims
and self[category].atmosphere_vars
@@ -314,7 +335,7 @@ def extract_vars(self, category, dataset=None):
dataset.attrs["category"], dataset=dataset
)
- atmosphere_vars = (
+ return (
xr.merge(
[
dataset[var]
@@ -328,6 +349,7 @@ def extract_vars(self, category, dataset=None):
else []
)
+ def _merge_vars(self, surface_vars, atmosphere_vars, category):
if surface_vars and atmosphere_vars:
return xr.merge([surface_vars, atmosphere_vars])
elif surface_vars:
From db6a912653319f571d7b3bec9fe09f99130a1a22 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 8 Jun 2024 10:30:44 +0200
Subject: [PATCH 093/273] chunk dataset by time
---
create_forcings.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/create_forcings.py b/create_forcings.py
index f1df2312..5f069d02 100644
--- a/create_forcings.py
+++ b/create_forcings.py
@@ -68,6 +68,10 @@ def main():
{"y": dataset.y, "x": dataset.x}
)
+ datetime_forcing_expanded = datetime_forcing_expanded.chunk(
+ {"time": 1, "y": -1, "x": -1}
+ )
+
datetime_forcing_expanded.to_zarr(args.zarr_path, mode="w")
print(f"Datetime forcing saved to {args.zarr_path}")
From 1aaa8dcbb9ff67bd2162be620829527bbc97cbe8 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 8 Jun 2024 10:31:01 +0200
Subject: [PATCH 094/273] create list first for performance
---
neural_lam/config.py | 23 ++++++++++-------------
1 file changed, 10 insertions(+), 13 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 885600f1..c9a27634 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -335,19 +335,16 @@ def _extract_atmosphere_vars(self, category, dataset):
dataset.attrs["category"], dataset=dataset
)
- return (
- xr.merge(
- [
- dataset[var]
- .sel(level=level, drop=True)
- .rename(f"{var}_{level}")
- for var in self[category].atmosphere_vars
- for level in self[category].levels
- ]
- )
- if self[category].atmosphere_vars
- else []
- )
+ data_arrays = [
+ dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}")
+ for var in self[category].atmosphere_vars
+ for level in self[category].levels
+ ]
+
+ if self[category].atmosphere_vars:
+ return xr.merge(data_arrays)
+ else:
+ return xr.Dataset()
def _merge_vars(self, surface_vars, atmosphere_vars, category):
if surface_vars and atmosphere_vars:
From 81856b274dec9c8c025c2fd1c6ae13fc5c74e2ee Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 8 Jun 2024 11:03:29 +0200
Subject: [PATCH 095/273] converting to_array is very slow the behavior of
xr.Dataset.dims will change soon
---
neural_lam/config.py | 13 +++++--------
1 file changed, 5 insertions(+), 8 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index c9a27634..ee1476f4 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -124,7 +124,7 @@ def open_zarrs(self, category):
def stack_grid(self, dataset):
if dataset is None:
return None
- dims = dataset.to_array().dims
+ dims = list(dataset.dims)
if "grid" in dims:
print("\033[94mGrid dimensions already stacked.\033[0m")
@@ -144,7 +144,7 @@ def convert_dataset_to_dataarray(self, dataset):
def filter_dimensions(self, dataset, transpose_array=True):
"""Filter the dimensions of the dataset."""
dims_to_keep = self.DIMS_TO_KEEP
- dataset_dims = set(dataset.to_array().dims)
+ dataset_dims = set(list(dataset.dims) + ["variable"])
min_req_dims = dims_to_keep.copy()
min_req_dims.discard("time")
if not min_req_dims.issubset(dataset_dims):
@@ -161,7 +161,7 @@ def filter_dimensions(self, dataset, transpose_array=True):
dataset.attrs["category"], dataset=dataset
)
dataset = self.stack_grid(dataset)
- dataset_dims = set(dataset.to_array().dims)
+ dataset_dims = set(list(dataset.dims) + ["variable"])
if min_req_dims.issubset(dataset_dims):
print(
"\033[92mSuccessfully updated dims and "
@@ -174,7 +174,7 @@ def filter_dimensions(self, dataset, transpose_array=True):
)
return None
- dataset_dims = set(dataset.to_array().dims)
+ dataset_dims = set(list(dataset.dims) + ["variable"])
dims_to_drop = dataset_dims - dims_to_keep
dataset = dataset.drop_dims(dims_to_drop)
if dims_to_drop:
@@ -327,10 +327,7 @@ def _extract_surface_vars(self, category, dataset):
)
def _extract_atmosphere_vars(self, category, dataset):
- if (
- "level" not in dataset.to_array().dims
- and self[category].atmosphere_vars
- ):
+ if "level" not in list(dataset.dims) and self[category].atmosphere_vars:
dataset = self.rename_dataset_dims_and_vars(
dataset.attrs["category"], dataset=dataset
)
From b3da818b09e45d1c2319daa929fede81bbe4fa10 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 8 Jun 2024 11:41:06 +0200
Subject: [PATCH 096/273] allow for forcings to not be normalized
---
neural_lam/config.py | 59 ++++++++++++++++++++++++-------------
neural_lam/data_config.yaml | 5 ++++
2 files changed, 44 insertions(+), 20 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index ee1476f4..ec0e0a38 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -292,13 +292,34 @@ def _rename_data_vars(self, combined_stats):
def _select_stats_by_category(self, combined_stats, category):
if category == "state":
stats = combined_stats.loc[dict(variable=self.vars_names(category))]
- return stats.drop_vars(["forcing_mean", "forcing_std"])
+ stats = stats.drop_vars(["forcing_mean", "forcing_std"])
+ return stats
elif category == "forcing":
+ non_normalized_vars = (
+ self.utilities.normalization.non_normalized_vars
+ )
vars = self.vars_names(category)
window = self["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
- stats = combined_stats.loc[dict(forcing_variable=forcing_vars)]
- return stats[["forcing_mean", "forcing_std"]]
+ normalized_vars = [
+ var for var in forcing_vars if var not in non_normalized_vars
+ ]
+ non_normalized_vars = [
+ var for var in forcing_vars if var in non_normalized_vars
+ ]
+ stats_normalized = combined_stats.loc[
+ dict(forcing_variable=normalized_vars)
+ ]
+ if non_normalized_vars:
+ stats_non_normalized = combined_stats.loc[
+ dict(forcing_variable=non_normalized_vars)
+ ]
+ stats = xr.merge([stats_normalized, stats_non_normalized])
+ else:
+ stats = stats_normalized
+ stats_normalized = stats_normalized[["forcing_mean", "forcing_std"]]
+
+ return stats
else:
print(f"Invalid category: {category}")
return None
@@ -310,14 +331,23 @@ def _convert_stats_to_torch(self, stats):
}
def extract_vars(self, category, dataset=None):
- """Extract the variables from the dataset."""
if dataset is None:
dataset = self.open_zarrs(category)
-
- surface_vars = self._extract_surface_vars(category, dataset)
- atmosphere_vars = self._extract_atmosphere_vars(category, dataset)
-
- return self._merge_vars(surface_vars, atmosphere_vars, category)
+ surface_vars = None
+ atmosphere_vars = None
+ if self[category].surface_vars:
+ surface_vars = self._extract_surface_vars(category, dataset)
+ if self[category].atmosphere_vars:
+ atmosphere_vars = self._extract_atmosphere_vars(category, dataset)
+ if surface_vars and atmosphere_vars:
+ return xr.merge([surface_vars, atmosphere_vars])
+ elif surface_vars:
+ return surface_vars
+ elif atmosphere_vars:
+ return atmosphere_vars
+ else:
+ print(f"No variables found in dataset {category}")
+ return None
def _extract_surface_vars(self, category, dataset):
return (
@@ -343,17 +373,6 @@ def _extract_atmosphere_vars(self, category, dataset):
else:
return xr.Dataset()
- def _merge_vars(self, surface_vars, atmosphere_vars, category):
- if surface_vars and atmosphere_vars:
- return xr.merge([surface_vars, atmosphere_vars])
- elif surface_vars:
- return surface_vars
- elif atmosphere_vars:
- return atmosphere_vars
- else:
- print(f"No variables found in dataset {category}")
- return None
-
def rename_dataset_dims_and_vars(self, category, dataset=None):
"""Rename the dimensions and variables of the dataset."""
convert = False
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 0b9ef1bf..54816637 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -141,6 +141,11 @@ utilities:
- vars:
- cape_column
- xhail0m
+ non_normalized_vars:
+ - hour_cos
+ - hour_sin
+ - year_cos
+ - year_sin
grid_shape_state:
y: 589
x: 789
From 7ee5398bd08889b002795cb736f5ef8825df0f0d Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 8 Jun 2024 15:53:26 +0200
Subject: [PATCH 097/273] allow non_normalized_vars to be null
---
neural_lam/config.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index ec0e0a38..a081ecae 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -298,6 +298,8 @@ def _select_stats_by_category(self, combined_stats, category):
non_normalized_vars = (
self.utilities.normalization.non_normalized_vars
)
+ if non_normalized_vars is None:
+ non_normalized_vars = []
vars = self.vars_names(category)
window = self["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
From 47821034b76fe5b582367c7784a9d76cc6d0afbd Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sat, 8 Jun 2024 22:34:15 +0200
Subject: [PATCH 098/273] fixed coastlines using new xy_extent function
---
neural_lam/config.py | 15 +++++++++++----
neural_lam/data_config.yaml | 6 +++---
neural_lam/vis.py | 6 ++++++
3 files changed, 20 insertions(+), 7 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index a081ecae..813659e1 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -230,15 +230,22 @@ def reshape_grid_to_2d(self, dataset, grid_shape=None):
return reshaped_data
@functools.lru_cache()
- def get_xy(self, category):
+ def get_xy(self, category, stacked=True):
"""Return the x, y coordinates of the dataset."""
dataset = self.open_zarrs(category)
x, y = dataset.x.values, dataset.y.values
if x.ndim == 1:
x, y = np.meshgrid(x, y)
- xy = np.stack((x, y), axis=0) # (2, N_y, N_x)
-
- return xy
+ if stacked:
+ xy = np.stack((x, y), axis=0) # (2, N_y, N_x)
+ return xy
+ return x, y
+
+ def get_xy_extent(self, category):
+ """Return the extent of the x, y coordinates."""
+ x, y = self.get_xy(category, stacked=False)
+ extent = [x.min(), x.max(), y.min(), y.max()]
+ return extent
@functools.lru_cache()
def load_normalization_stats(self, category, datatype="torch"):
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 54816637..4290ed53 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -162,6 +162,6 @@ splits:
projection:
class: LambertConformal # Name of class in cartopy.crs
kwargs:
- central_longitude: 6.22
- central_latitude: 56.0
- standard_parallels: [47.6, 64.4]
+ central_longitude: 25
+ central_latitude: 56.4
+ standard_parallels: [50.4, 61.6]
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index c92739f9..8551e74b 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -77,6 +77,8 @@ def plot_prediction(
else:
vmin, vmax = vrange
+ extent = data_config.get_xy_extent("state")
+
# Set up masking of border region
mask_reshaped = obs_mask.reshape(
list(data_config.grid_shape_state.values.values())
@@ -103,6 +105,7 @@ def plot_prediction(
im = ax.imshow(
data_grid,
origin="lower",
+ extent=extent,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
@@ -134,6 +137,8 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
else:
vmin, vmax = vrange
+ extent = data_config.get_xy_extent("state")
+
# Set up masking of border region
mask_reshaped = obs_mask.reshape(
list(data_config.grid_shape_state.values.values())
@@ -157,6 +162,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
im = ax.imshow(
error_grid,
origin="lower",
+ extent=extent,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
From e0ffc5bdfcb81a55e2651ac89bb1189303f50a5b Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Sun, 9 Jun 2024 11:19:10 +0200
Subject: [PATCH 099/273] Some projections return inverted axes (rotatedPole)
---
neural_lam/config.py | 6 +++++-
neural_lam/data_config.yaml | 1 +
2 files changed, 6 insertions(+), 1 deletion(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 813659e1..d77aab16 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -244,7 +244,11 @@ def get_xy(self, category, stacked=True):
def get_xy_extent(self, category):
"""Return the extent of the x, y coordinates."""
x, y = self.get_xy(category, stacked=False)
- extent = [x.min(), x.max(), y.min(), y.max()]
+ if self.projection.inverted:
+ extent = [x.max(), x.min(), y.max(), y.min()]
+ else:
+ extent = [x.min(), x.max(), y.min(), y.max()]
+
return extent
@functools.lru_cache()
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
index 4290ed53..63756002 100644
--- a/neural_lam/data_config.yaml
+++ b/neural_lam/data_config.yaml
@@ -165,3 +165,4 @@ projection:
central_longitude: 25
central_latitude: 56.4
standard_parallels: [50.4, 61.6]
+ inverted: false # Whether the projection is inverted
From c1f43b72a533bc0e4aaed18a2be501ce2021f188 Mon Sep 17 00:00:00 2001
From: Simon Adamov
Date: Thu, 13 Jun 2024 11:42:47 +0200
Subject: [PATCH 100/273] Docstrings added
---
neural_lam/config.py | 213 +++++++++++++++++++++++++++++++++++++++----
1 file changed, 196 insertions(+), 17 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index d77aab16..6ab5868d 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -13,6 +13,9 @@
class Config:
+ """Class to load and access the configuration file.
+ The class also preprocesses the dataset based on the config."""
+
DIMS_TO_KEEP = {"time", "grid", "variable"}
def __init__(self, values):
@@ -20,6 +23,7 @@ def __init__(self, values):
@classmethod
def from_file(cls, filepath):
+ """Load the configuration file from the given path."""
if filepath.endswith(".yaml"):
with open(filepath, encoding="utf-8", mode="r") as file:
return cls(values=yaml.safe_load(file))
@@ -27,6 +31,7 @@ def from_file(cls, filepath):
raise NotImplementedError(Path(filepath).suffix)
def __getattr__(self, name):
+ """Recursively access the values in the configuration."""
keys = name.split(".")
value = self.values
for key in keys:
@@ -49,7 +54,12 @@ def __contains__(self, key):
@functools.cached_property
def coords_projection(self):
- """Return the projection object for the coordinates."""
+ """Return the projection object for the coordinates.
+
+ The projection object is used to plot the coordinates on a map.
+
+ Returns:
+ cartopy.crs.Projection: The projection object."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
@@ -58,7 +68,10 @@ def coords_projection(self):
@functools.cached_property
def step_length(self):
- """Return the step length of the dataset in hours."""
+ """Return the step length of the dataset in hours.
+
+ Returns:
+ int: The step length in hours."""
dataset = self.open_zarrs("state")
time = dataset.time.isel(time=slice(0, 2)).values
step_length_ns = time[1] - time[0]
@@ -67,7 +80,13 @@ def step_length(self):
@functools.lru_cache()
def vars_names(self, category):
- """Return the names of the variables in the dataset."""
+ """Return the names of the variables in the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+
+ Returns:
+ list: The names of the variables in the dataset."""
surface_vars_names = self.values[category].get("surface_vars") or []
atmosphere_vars_names = [
f"{var}_{level}"
@@ -78,7 +97,13 @@ def vars_names(self, category):
@functools.lru_cache()
def vars_units(self, category):
- """Return the units of the variables in the dataset."""
+ """Return the units of the variables in the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+
+ Returns:
+ list: The units of the variables in the dataset."""
surface_vars_units = self.values[category].get("surface_units") or []
atmosphere_vars_units = [
unit
@@ -89,7 +114,13 @@ def vars_units(self, category):
@functools.lru_cache()
def num_data_vars(self, category):
- """Return the number of data variables in the dataset."""
+ """Return the number of data variables in the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+
+ Returns:
+ int: The number of data variables in the dataset."""
surface_vars = self.values[category].get("surface_vars", [])
atmosphere_vars = self.values[category].get("atmosphere_vars", [])
levels = self.values[category].get("levels", [])
@@ -105,7 +136,13 @@ def num_data_vars(self, category):
return surface_vars_count + atmosphere_vars_count * levels_count
def open_zarrs(self, category):
- """Open the zarr dataset for the given category."""
+ """Open the zarr dataset for the given category.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+
+ Returns:
+ xr.Dataset: The xarray Dataset object."""
zarr_configs = self.values[category]["zarrs"]
try:
@@ -122,6 +159,13 @@ def open_zarrs(self, category):
return None
def stack_grid(self, dataset):
+ """Stack the grid dimensions of the dataset.
+
+ Args:
+ dataset (xr.Dataset): The xarray Dataset object.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with stacked grid dimensions."""
if dataset is None:
return None
dims = list(dataset.dims)
@@ -136,13 +180,27 @@ def stack_grid(self, dataset):
return dataset
def convert_dataset_to_dataarray(self, dataset):
- """Convert the Dataset to a Dataarray."""
+ """Convert the Dataset to a Dataarray.
+
+ Args:
+ dataset (xr.Dataset): The xarray Dataset object.
+
+ Returns:
+ xr.DataArray: The xarray DataArray object."""
if isinstance(dataset, xr.Dataset):
dataset = dataset.to_array()
return dataset
def filter_dimensions(self, dataset, transpose_array=True):
- """Filter the dimensions of the dataset."""
+ """Drop the dimensions and filter the data_vars of the dataset.
+
+ Args:
+ dataset (xr.Dataset): The xarray Dataset object.
+ transpose_array (bool): Whether to transpose the array.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with filtered dimensions.
+ OR xr.DataArray: The xarray DataArray object with filtered dimensions."""
dims_to_keep = self.DIMS_TO_KEEP
dataset_dims = set(list(dataset.dims) + ["variable"])
min_req_dims = dims_to_keep.copy()
@@ -209,7 +267,14 @@ def filter_dimensions(self, dataset, transpose_array=True):
return dataset
def reshape_grid_to_2d(self, dataset, grid_shape=None):
- """Reshape the grid to 2D for stacked data without multi-index."""
+ """Reshape the grid to 2D for stacked data without multi-index.
+
+ Args:
+ dataset (xr.Dataset): The xarray Dataset object.
+ grid_shape (dict): The shape of the grid.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with reshaped grid dimensions."""
if grid_shape is None:
grid_shape = dict(self.grid_shape_state.values.items())
x_dim, y_dim = (grid_shape["x"], grid_shape["y"])
@@ -231,7 +296,17 @@ def reshape_grid_to_2d(self, dataset, grid_shape=None):
@functools.lru_cache()
def get_xy(self, category, stacked=True):
- """Return the x, y coordinates of the dataset."""
+ """Return the x, y coordinates of the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ stacked (bool): Whether to stack the x, y coordinates.
+
+ Returns:
+ np.ndarray: The x, y coordinates of the dataset (if stacked) (2, N_y, N_x)
+
+ OR tuple(np.ndarray, np.ndarray): The x, y coordinates of the dataset
+ (if not stacked) ((N_y, N_x), (N_y, N_x))"""
dataset = self.open_zarrs(category)
x, y = dataset.x.values, dataset.y.values
if x.ndim == 1:
@@ -242,7 +317,13 @@ def get_xy(self, category, stacked=True):
return x, y
def get_xy_extent(self, category):
- """Return the extent of the x, y coordinates."""
+ """Return the extent of the x, y coordinates.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+
+ Returns:
+ list(float): The extent of the x, y coordinates."""
x, y = self.get_xy(category, stacked=False)
if self.projection.inverted:
extent = [x.max(), x.min(), y.max(), y.min()]
@@ -253,7 +334,17 @@ def get_xy_extent(self, category):
@functools.lru_cache()
def load_normalization_stats(self, category, datatype="torch"):
- """Load the normalization statistics for the dataset."""
+ """Load the normalization statistics for the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ datatype (str): The datatype of the statistics (torch/"").
+
+ Returns:
+ tensor: The normalization statistics for the dataset.
+ (if datatype="torch")
+ OR xr.Dataset: The normalization statistics for the dataset.
+ (otherwise)"""
combined_stats = self._load_and_merge_stats()
if combined_stats is None:
return None
@@ -270,6 +361,10 @@ def load_normalization_stats(self, category, datatype="torch"):
return stats
def _load_and_merge_stats(self):
+ """Load and merge the normalization statistics for the dataset.
+
+ Returns:
+ xr.Dataset: The merged normalization statistics for the dataset."""
combined_stats = None
for i, zarr_config in enumerate(
self.values["utilities"]["normalization"]["zarrs"]
@@ -288,6 +383,14 @@ def _load_and_merge_stats(self):
return combined_stats
def _rename_data_vars(self, combined_stats):
+ """Rename the data variables of the normalization statistics.
+
+ Args:
+ combined_stats (xr.Dataset): The combined normalization statistics.
+
+ Returns:
+ xr.Dataset: The combined normalization statistics with renamed data
+ variables."""
vars_mapping = {}
for zarr_config in self.values["utilities"]["normalization"]["zarrs"]:
vars_mapping.update(zarr_config["stats_vars"])
@@ -301,6 +404,14 @@ def _rename_data_vars(self, combined_stats):
)
def _select_stats_by_category(self, combined_stats, category):
+ """Select the normalization statistics for the given category.
+
+ Args:
+ combined_stats (xr.Dataset): The combined normalization statistics.
+ category (str): The category of the dataset (state/forcing/static).
+
+ Returns:
+ xr.Dataset: The normalization statistics for the dataset."""
if category == "state":
stats = combined_stats.loc[dict(variable=self.vars_names(category))]
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
@@ -338,12 +449,27 @@ def _select_stats_by_category(self, combined_stats, category):
return None
def _convert_stats_to_torch(self, stats):
+ """Convert the normalization statistics to torch tensors.
+
+ Args:
+ stats (xr.Dataset): The normalization statistics.
+
+ Returns:
+ dict(tensor): The normalization statistics as torch tensors."""
return {
var: torch.tensor(stats[var].values, dtype=torch.float32)
for var in stats.data_vars
}
def extract_vars(self, category, dataset=None):
+ """Extract (select) the data variables from the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ dataset (xr.Dataset): The xarray Dataset object.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with extracted variables."""
if dataset is None:
dataset = self.open_zarrs(category)
surface_vars = None
@@ -363,6 +489,15 @@ def extract_vars(self, category, dataset=None):
return None
def _extract_surface_vars(self, category, dataset):
+ """Extract the surface variables from the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ dataset (xr.Dataset): The xarray Dataset object.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with surface variables.
+ """
return (
dataset[self[category].surface_vars]
if self[category].surface_vars
@@ -370,6 +505,14 @@ def _extract_surface_vars(self, category, dataset):
)
def _extract_atmosphere_vars(self, category, dataset):
+ """Extract the atmosphere variables from the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ dataset (xr.Dataset): The xarray Dataset object.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with atmosphere variables."""
if "level" not in list(dataset.dims) and self[category].atmosphere_vars:
dataset = self.rename_dataset_dims_and_vars(
dataset.attrs["category"], dataset=dataset
@@ -387,7 +530,18 @@ def _extract_atmosphere_vars(self, category, dataset):
return xr.Dataset()
def rename_dataset_dims_and_vars(self, category, dataset=None):
- """Rename the dimensions and variables of the dataset."""
+ """Rename the dimensions and variables of the dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ dataset (xr.Dataset): The xarray Dataset object. OR xr.DataArray:
+ The xarray DataArray object.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with renamed dimensions and
+ variables.
+ OR xr.DataArray: The xarray DataArray object with renamed
+ dimensions and variables."""
convert = False
if dataset is None:
dataset = self.open_zarrs(category)
@@ -414,7 +568,14 @@ def rename_dataset_dims_and_vars(self, category, dataset=None):
return dataset
def filter_dataset_by_time(self, dataset, split="train"):
- """Filter the dataset by the time split."""
+ """Filter the dataset by the time split.
+
+ Args:
+ dataset (xr.Dataset): The xarray Dataset object.
+ split (str): The time split to filter the dataset.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object filtered by the time split."""
start, end = (
self.values["splits"][split]["start"],
self.values["splits"][split]["end"],
@@ -424,7 +585,14 @@ def filter_dataset_by_time(self, dataset, split="train"):
return dataset
def apply_window(self, category, dataset=None):
- """Apply the forcing window to the forcing dataset."""
+ """Apply the forcing window to the forcing dataset.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ dataset (xr.Dataset): The xarray Dataset object.
+
+ Returns:
+ xr.Dataset: The xarray Dataset object with the window applied."""
if dataset is None:
dataset = self.open_zarrs(category)
if isinstance(dataset, xr.Dataset):
@@ -444,7 +612,10 @@ def apply_window(self, category, dataset=None):
return dataset
def load_boundary_mask(self):
- """Load the boundary mask for the dataset."""
+ """Load the boundary mask for the dataset.
+
+ Returns:
+ tensor: The boundary mask for the dataset."""
boundary_mask = xr.open_zarr(self.values["boundary"]["mask"]["path"])
return torch.tensor(
boundary_mask.mask.stack(grid=("y", "x")).values,
@@ -452,7 +623,15 @@ def load_boundary_mask(self):
).unsqueeze(1)
def process_dataset(self, category, split="train", apply_windowing=True):
- """Process the dataset for the given category."""
+ """Process the dataset for the given category.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
+ split (str): The time split to filter the dataset (train/val/test).
+ apply_windowing (bool): Whether to apply windowing to the forcing dataset.
+
+ Returns:
+ xr.DataArray: The xarray DataArray object with processed dataset."""
dataset = self.open_zarrs(category)
dataset = self.extract_vars(category, dataset)
if category != "static":
From 21fd929e6f4854dd2708829fe7123b306e516024 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 26 Jun 2024 16:32:04 +0200
Subject: [PATCH 101/273] wip
---
create_forcings.py | 82 ----
create_mesh.py | 2 +-
docs/download_danra.py | 26 --
neural_lam/.DS_Store | Bin 0 -> 6148 bytes
neural_lam/datasets/.DS_Store | Bin 0 -> 6148 bytes
neural_lam/datasets/__init__.py | 0
neural_lam/datastore/__init__.py | 2 +
neural_lam/datastore/base.py | 216 ++++++++++
neural_lam/datastore/mllam.py | 80 ++++
neural_lam/datastore/multizarr/__init__.py | 1 +
neural_lam/datastore/multizarr/config.py | 43 ++
.../multizarr/create_auxiliary_forcings.py | 115 ++++++
.../multizarr/create_boundary_mask.py | 2 +-
.../multizarr/create_normalization_stats.py | 67 ++--
.../multizarr/store.py} | 338 +++++++---------
neural_lam/datastore/npyfiles/__init__.py | 1 +
neural_lam/datastore/npyfiles/config.py | 62 +++
neural_lam/datastore/npyfiles/store.py | 373 ++++++++++++++++++
neural_lam/models/ar_model.py | 14 +-
neural_lam/weather_dataset.py | 65 +--
plot_graph.py | 3 +-
requirements.txt | 2 +-
test_ewc.py | 17 +
tests/data_config.yaml | 162 --------
.../mllam.example.danra.yaml | 74 ++++
.../datastore_configs/multizarr.danra.yaml | 12 +-
tests/test_mllam_dataset.py | 8 +
...s_dataset.py => test_multizarr_dataset.py} | 30 +-
...taset.py_ => test_npy_forecast_dataset.py} | 35 +-
29 files changed, 1272 insertions(+), 560 deletions(-)
delete mode 100644 create_forcings.py
delete mode 100644 docs/download_danra.py
create mode 100644 neural_lam/.DS_Store
create mode 100644 neural_lam/datasets/.DS_Store
create mode 100644 neural_lam/datasets/__init__.py
create mode 100644 neural_lam/datastore/__init__.py
create mode 100644 neural_lam/datastore/base.py
create mode 100644 neural_lam/datastore/mllam.py
create mode 100644 neural_lam/datastore/multizarr/__init__.py
create mode 100644 neural_lam/datastore/multizarr/config.py
create mode 100644 neural_lam/datastore/multizarr/create_auxiliary_forcings.py
rename create_boundary_mask.py => neural_lam/datastore/multizarr/create_boundary_mask.py (96%)
rename calculate_statistics.py => neural_lam/datastore/multizarr/create_normalization_stats.py (50%)
rename neural_lam/{config.py => datastore/multizarr/store.py} (67%)
create mode 100644 neural_lam/datastore/npyfiles/__init__.py
create mode 100644 neural_lam/datastore/npyfiles/config.py
create mode 100644 neural_lam/datastore/npyfiles/store.py
create mode 100644 test_ewc.py
delete mode 100644 tests/data_config.yaml
create mode 100644 tests/datastore_configs/mllam.example.danra.yaml
rename neural_lam/data_config.yaml => tests/datastore_configs/multizarr.danra.yaml (86%)
create mode 100644 tests/test_mllam_dataset.py
rename tests/{test_analysis_dataset.py => test_multizarr_dataset.py} (71%)
rename tests/{test_forecast_dataset.py_ => test_npy_forecast_dataset.py} (79%)
diff --git a/create_forcings.py b/create_forcings.py
deleted file mode 100644
index 5f069d02..00000000
--- a/create_forcings.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# Standard library
-import argparse
-
-# Third-party
-import numpy as np
-import pandas as pd
-import xarray as xr
-
-# First-party
-from neural_lam import config
-
-
-def get_seconds_in_year(year):
- start_of_year = pd.Timestamp(f"{year}-01-01")
- start_of_next_year = pd.Timestamp(f"{year + 1}-01-01")
- return (start_of_next_year - start_of_year).total_seconds()
-
-
-def calculate_datetime_forcing(timesteps):
- hours_of_day = xr.DataArray(timesteps.dt.hour, dims=["time"])
- seconds_into_year = xr.DataArray(
- [
- (
- pd.Timestamp(dt_obj)
- - pd.Timestamp(f"{pd.Timestamp(dt_obj).year}-01-01")
- ).total_seconds()
- for dt_obj in timesteps.values
- ],
- dims=["time"],
- )
- year_seconds = xr.DataArray(
- [
- get_seconds_in_year(pd.Timestamp(dt_obj).year)
- for dt_obj in timesteps.values
- ],
- dims=["time"],
- )
- hour_angle = (hours_of_day / 12) * np.pi
- year_angle = (seconds_into_year / year_seconds) * 2 * np.pi
- datetime_forcing = xr.Dataset(
- {
- "hour_sin": np.sin(hour_angle),
- "hour_cos": np.cos(hour_angle),
- "year_sin": np.sin(year_angle),
- "year_cos": np.cos(year_angle),
- },
- coords={"time": timesteps},
- )
- datetime_forcing = (datetime_forcing + 1) / 2
- return datetime_forcing
-
-
-def main():
- """Main function for creating the datetime forcing and boundary mask."""
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--data_config", type=str, default="neural_lam/data_config.yaml"
- )
- parser.add_argument("--zarr_path", type=str, default="data/forcings.zarr")
- args = parser.parse_args()
-
- data_config = config.Config.from_file(args.data_config)
- dataset = data_config.open_zarrs("state")
- datetime_forcing = calculate_datetime_forcing(timesteps=dataset.time)
-
- # Expand dimensions to match the target dataset
- datetime_forcing_expanded = datetime_forcing.expand_dims(
- {"y": dataset.y, "x": dataset.x}
- )
-
- datetime_forcing_expanded = datetime_forcing_expanded.chunk(
- {"time": 1, "y": -1, "x": -1}
- )
-
- datetime_forcing_expanded.to_zarr(args.zarr_path, mode="w")
- print(f"Datetime forcing saved to {args.zarr_path}")
-
- dataset
-
-
-if __name__ == "__main__":
- main()
diff --git a/create_mesh.py b/create_mesh.py
index f827ee56..c7c1e95c 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -13,7 +13,7 @@
from torch_geometric.utils.convert import from_networkx
# First-party
-from neural_lam import config
+from neural_lam.datastore.multizarr import config
def plot_graph(graph, title=None):
diff --git a/docs/download_danra.py b/docs/download_danra.py
deleted file mode 100644
index fb70754f..00000000
--- a/docs/download_danra.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Third-party
-import xarray as xr
-
-data_urls = [
- "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr",
- "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr",
-]
-
-local_paths = [
- "data/danra/single_levels.zarr",
- "data/danra/height_levels.zarr",
-]
-
-for url, path in zip(data_urls, local_paths):
- print(f"Downloading {url} to {path}")
- ds = xr.open_zarr(url)
- chunk_dict = {dim: -1 for dim in ds.dims if dim != "time"}
- chunk_dict["time"] = 20
- ds = ds.chunk(chunk_dict)
-
- for var in ds.variables:
- if "chunks" in ds[var].encoding:
- del ds[var].encoding["chunks"]
-
- ds.to_zarr(path, mode="w")
- print("DONE")
diff --git a/neural_lam/.DS_Store b/neural_lam/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..d0f319116c039464f072e86a8e9d64de0b1a5a9a
GIT binary patch
literal 6148
zcmeHKO>5gg5S?}0+9H(tkU)+Ly&B@WA7FYBPI~AyD$OA&suYPhP1qA3#e?3
zm~#3~37ykqBU&5}kpbSj9qd<*9joc){*~~>G=vF>l|ZX;qr)Gji7T)^!Ah0bxKGxDN*W
zS=8OWkHuw8VL%vo&DW3`Dt9pi5Q0
z#ZWFCe(mEzkBvi@PRchQ%CD?^hobE2sIP4}snDTVVL%wT%>eg)kRIRvZ-3wa+aNK*
zfH3fXGN9Ur(cu{H=I_>xx8%E4Lq9=TIIeMcoq~bAiV@3K@gCF){F(>A&|~8e9*F!1
MSQ^9#1OJqPUyKWBpa1{>
literal 0
HcmV?d00001
diff --git a/neural_lam/datasets/.DS_Store b/neural_lam/datasets/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..f172ab58d31f03adddb2b8b1d35371f1d00616de
GIT binary patch
literal 6148
zcmeHKJ8nWj3>*gvqBN8#_X@ee3Xv0VfgpisB9O?ZepSwuqhj~%HV}muB>vR5b
z|9aT(^5*MQWu<@=kOERb3P^z)6!6|ln>{2dN&zV#1x^b1_o2}pd*P56pALo?0f int:
+ """
+ The step length of the dataset in hours.
+
+ Returns:
+ int: The step length in hours.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_vars_units(self, category: str) -> List[str]:
+ """
+ Get the units of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables.
+
+ Returns
+ -------
+ List[str]
+ The units of the variables.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_vars_names(self, category: str) -> List[str]:
+ """
+ Get the names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables.
+
+ Returns
+ -------
+ List[str]
+ The names of the variables.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_num_data_vars(self, category: str) -> int:
+ """
+ Get the number of data variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables.
+
+ Returns
+ -------
+ int
+ The number of data variables.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
+ """
+ Return the processed dataset for the given category and test/train/val-split that covers
+ the entire timeline of the dataset.
+ The returned dataarray is expected to at minimum have dimensions of `(time, grid_index, feature)` so
+ that any spatial dimensions have been stacked into a single dimension and all variables
+ and levels have been stacked into a single feature dimension.
+ Any additional dimensions (for example `ensemble_member` or `analysis_time`) should be kept as separate
+ dimensions in the dataarray, and `WeatherDataset` will handle the sampling of the data.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ split : str
+ The time split to filter the dataset (train/val/test).
+
+ Returns
+ -------
+ xr.DataArray
+ The xarray DataArray object with processed dataset.
+ """
+ pass
+
+ @property
+ @abc.abstractmethod
+ def boundary_mask(self):
+ """
+ Return the boundary mask for the dataset, with spatial dimensions stacked.
+ Where the value is 1, the grid point is a boundary point, and where the value is 0,
+ the grid point is not a boundary point.
+
+ Returns
+ -------
+ xr.DataArray
+ The boundary mask for the dataset, with dimensions `('grid_index',)`.
+ """
+ pass
+
+
+@dataclasses.dataclass
+class CartesianGridShape:
+ """
+ Dataclass to store the shape of a grid.
+ """
+ x: int
+ y: int
+
+
+class BaseCartesianDatastore(BaseDatastore):
+ """
+ Base class for weather data stored on a Cartesian grid. In addition
+ to the methods and attributes required for weather data in general
+ (see `BaseDatastore`) for Cartesian gridded source data each `grid_index`
+ coordinate value is assume to have an associated `x` and `y`-value so
+ that the processed data-arrays can be reshaped back into into 2D xy-gridded arrays.
+
+ In addition the following attributes and methods are required:
+ - `coords_projection` (property): Projection object for the coordinates.
+ - `grid_shape_state` (property): Shape of the grid for the state variables.
+ - `get_xy_extent` (method): Return the extent of the x, y coordinates for a given category of data.
+ - `get_xy` (method): Return the x, y coordinates of the dataset.
+ """
+
+ @property
+ @abc.abstractmethod
+ def coords_projection(self) -> ccrs.Projection:
+ """Return the projection object for the coordinates.
+
+ The projection object is used to plot the coordinates on a map.
+
+ Returns
+ -------
+ cartopy.crs.Projection:
+ The projection object.
+ """
+ pass
+
+ @property
+ @abc.abstractmethod
+ def grid_shape_state(self) -> CartesianGridShape:
+ """
+ The shape of the grid for the state variables.
+
+ Returns
+ -------
+ CartesianGridShape:
+ The shape of the grid for the state variables, which has `x` and `y` attributes.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
+ """
+ Return the x, y coordinates of the dataset.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates.
+
+ Returns
+ -------
+ np.ndarray or tuple(np.ndarray, np.ndarray)
+ The x, y coordinates of the dataset with shape `(2, N_y, N_x)` if `stacked=True` or
+ a tuple of two arrays with shape `((N_y, N_x), (N_y, N_x))` if `stacked=False`.
+ """
+ pass
+
+ def get_xy_extent(self, category: str) -> List[float]:
+ """
+ Return the extent of the x, y coordinates for a given category of data.
+ The extent should be returned as a list of 4 floats with `[xmin, xmax, ymin, ymax]`
+ which can then be used to set the extent of a plot.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[float]
+ The extent of the x, y coordinates.
+ """
+ xy = self.get_xy(category, stacked=False)
+ return [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
\ No newline at end of file
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
new file mode 100644
index 00000000..ee16fc18
--- /dev/null
+++ b/neural_lam/datastore/mllam.py
@@ -0,0 +1,80 @@
+from typing import List
+
+from numpy import ndarray
+
+from .base import BaseCartesianDatastore, CartesianGridShape
+
+import mllam_data_prep as mdp
+import xarray as xr
+import cartopy.crs as ccrs
+
+
+class MLLAMDatastore(BaseCartesianDatastore):
+ """
+ Datastore class for the MLLAM dataset.
+ """
+
+ def __init__(self, config_path, n_boundary_points=30):
+ self._config_path = config_path
+ self._config = mdp.Config.from_yaml_file(config_path)
+ self._ds = mdp.create_dataset(config=self._config)
+ self._n_boundary_points = n_boundary_points
+
+ def step_length(self) -> int:
+ da_dt = self._ds["time"].diff("time")
+ return da_dt.dt.seconds[0] // 3600
+
+ def get_vars_units(self, category: str) -> List[str]:
+ return self._ds[f"{category}_unit"].values.tolist()
+
+ def get_vars_names(self, category: str) -> List[str]:
+ return self._ds[f"{category}_longname"].values.tolist()
+
+ def get_num_data_vars(self, category: str) -> int:
+ return len(self._ds[category].data_vars)
+
+ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
+ # TODO: Implement split handling in mllam-data-prep, for now we hardcode that
+ # train will be the first 80%, then validation 10% and test 10%
+ da_category = self._ds[category]
+ n_samples = len(da_category.time)
+ # compute the split indices
+ if split == "train":
+ i_start, i_end = 0, int(0.8 * n_samples)
+ elif split == "val":
+ i_start, i_end = int(0.8 * n_samples), int(0.9 * n_samples)
+ elif split == "test":
+ i_start, i_end = int(0.9 * n_samples), n_samples
+ else:
+ raise ValueError(f"Unknown split {split}")
+
+ da_split = da_category.isel(time=slice(i_start, i_end))
+ return da_split
+
+ @property
+ def boundary_mask(self) -> xr.DataArray:
+ da_mask = xr.ones_like(self._ds["state"].isel(time=0).isel(variable=0))
+ da_mask.isel(x=slice(0, self._n_boundary_points), y=slice(0, self._n_boundary_points)).values = 0
+ return da_mask
+
+ @property
+ def coords_projection(self) -> ccrs.Projection:
+ # TODO: danra doesn't contain projection information yet, but the next version wil
+ # for now we hardcode the projection
+ # XXX: this is wrong
+ return ccrs.PlateCarree()
+
+ @property
+ def grid_shape_state(self):
+ return CartesianGridShape(
+ x=self._ds["state"].x.size, y=self._ds["state"].y.size
+ )
+
+ def get_xy(self, category: str, stacked: bool) -> ndarray:
+ da_x = self._ds[category].x
+ da_y = self._ds[category].y
+ if stacked:
+ x, y = xr.broadcast(da_x, da_y)
+ return xr.concat([x, y], dim="xy").values
+ else:
+ return da_x.values, da_y.values
\ No newline at end of file
diff --git a/neural_lam/datastore/multizarr/__init__.py b/neural_lam/datastore/multizarr/__init__.py
new file mode 100644
index 00000000..491d4a18
--- /dev/null
+++ b/neural_lam/datastore/multizarr/__init__.py
@@ -0,0 +1 @@
+from .store import MultiZarrDatastore
\ No newline at end of file
diff --git a/neural_lam/datastore/multizarr/config.py b/neural_lam/datastore/multizarr/config.py
new file mode 100644
index 00000000..0d93ab70
--- /dev/null
+++ b/neural_lam/datastore/multizarr/config.py
@@ -0,0 +1,43 @@
+# Standard library
+from pathlib import Path
+
+# Third-party
+import yaml
+
+
+class Config:
+ """Class to load and access the configuration file."""
+
+ def __init__(self, values):
+ self.values = values
+
+ @classmethod
+ def from_file(cls, filepath):
+ """Load the configuration file from the given path."""
+ if filepath.endswith(".yaml"):
+ with open(filepath, encoding="utf-8", mode="r") as file:
+ return cls(values=yaml.safe_load(file))
+ else:
+ raise NotImplementedError(Path(filepath).suffix)
+
+ def __getattr__(self, name):
+ """Recursively access the values in the configuration."""
+ keys = name.split(".")
+ value = self.values
+ for key in keys:
+ if key in value:
+ value = value[key]
+ else:
+ return None
+ if isinstance(value, dict):
+ return Config(values=value)
+ return value
+
+ def __getitem__(self, key):
+ value = self.values[key]
+ if isinstance(value, dict):
+ return Config(values=value)
+ return value
+
+ def __contains__(self, key):
+ return key in self.values
\ No newline at end of file
diff --git a/neural_lam/datastore/multizarr/create_auxiliary_forcings.py b/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
new file mode 100644
index 00000000..9ce15a2a
--- /dev/null
+++ b/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
@@ -0,0 +1,115 @@
+# Standard library
+import argparse
+from pathlib import Path
+
+# Third-party
+import numpy as np
+import pandas as pd
+import xarray as xr
+
+# First-party
+from neural_lam.datastore.multizarr import MultiZarrDatastore
+
+
+def get_seconds_in_year(year):
+ start_of_year = pd.Timestamp(f"{year}-01-01")
+ start_of_next_year = pd.Timestamp(f"{year + 1}-01-01")
+ return (start_of_next_year - start_of_year).total_seconds()
+
+
+def calculate_datetime_forcing(da_time: xr.DataArray):
+ """
+ Compute the datetime forcing for a given set of timesteps, assuming
+ that timesteps is a DataArray with a type of `np.datetime64`.
+
+ Parameters
+ ----------
+ timesteps : xr.DataArray
+ The timesteps for which to compute the datetime forcing.
+
+ Returns
+ -------
+ xr.Dataset
+ The datetime forcing, with the following variables:
+ - hour_sin: The sine of the hour of the day, normalized to [0, 1].
+ - hour_cos: The cosine of the hour of the day, normalized to [0, 1].
+ - year_sin: The sine of the time of year, normalized to [0, 1].
+ - year_cos: The cosine of the time of year, normalized to [0, 1].
+ """
+ hours_of_day = xr.DataArray(da_time.dt.hour, dims=["time"])
+ seconds_into_year = xr.DataArray(
+ [
+ (
+ pd.Timestamp(dt_obj)
+ - pd.Timestamp(f"{pd.Timestamp(dt_obj).year}-01-01")
+ ).total_seconds()
+ for dt_obj in da_time.values
+ ],
+ dims=["time"],
+ )
+ year_seconds = xr.DataArray(
+ [
+ get_seconds_in_year(pd.Timestamp(dt_obj).year)
+ for dt_obj in da_time.values
+ ],
+ dims=["time"],
+ )
+ hour_angle = (hours_of_day / 12) * np.pi
+ year_angle = (seconds_into_year / year_seconds) * 2 * np.pi
+ datetime_forcing = xr.Dataset(
+ {
+ "hour_sin": np.sin(hour_angle),
+ "hour_cos": np.cos(hour_angle),
+ "year_sin": np.sin(year_angle),
+ "year_cos": np.cos(year_angle),
+ },
+ coords={"time": da_time},
+ )
+ datetime_forcing = (datetime_forcing + 1) / 2
+ return datetime_forcing
+
+
+def main():
+ """Main function for creating the datetime forcing and boundary mask."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-config", type=str, default="tests/datastore_configs/multizarr.danra.yaml")
+ parser.add_argument(
+ "--zarr_path",
+ type=str,
+ default=None,
+ help="Path to save the Zarr archive "
+ "(default: same directory as the data-config)",
+ )
+ args = parser.parse_args()
+
+ zarr_path = args.zarr_path
+ if zarr_path is None:
+ zarr_path = Path(args.data_config).parent / "datetime_forcings.zarr"
+
+ datastore = MultiZarrDatastore(config_path=args.data_config)
+ da_state = datastore.get_dataarray(category="state", split="train")
+
+ da_datetime_forcing = calculate_datetime_forcing(da_time=da_state.time).expand_dims({"grid_index": da_state.grid_index})
+
+ chunking = {"time": 1}
+
+ if "x" in da_state.coords and "y" in da_state.coords:
+ # copy the x and y coordinates to the datetime forcing
+ for aux_coord in ["x", "y"]:
+ da_datetime_forcing.coords[aux_coord] = da_state[aux_coord]
+
+ da_datetime_forcing = da_datetime_forcing.set_index(grid_index=("y", "x")).unstack("grid_index")
+ chunking["x"] = -1
+ chunking["y"] = -1
+ else:
+ chunking["grid_index"] = -1
+
+ da_datetime_forcing = da_datetime_forcing.chunk(chunking)
+
+ da_datetime_forcing.to_zarr(zarr_path, mode="w")
+ print(da_datetime_forcing)
+ print(f"Datetime forcing saved to {zarr_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/create_boundary_mask.py b/neural_lam/datastore/multizarr/create_boundary_mask.py
similarity index 96%
rename from create_boundary_mask.py
rename to neural_lam/datastore/multizarr/create_boundary_mask.py
index 5c0c115f..038d88be 100644
--- a/create_boundary_mask.py
+++ b/neural_lam/datastore/multizarr/create_boundary_mask.py
@@ -6,7 +6,7 @@
import xarray as xr
# First-party
-from neural_lam import config
+from neural_lam.datastore.multizarr import config
def main():
diff --git a/calculate_statistics.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
similarity index 50%
rename from calculate_statistics.py
rename to neural_lam/datastore/multizarr/create_normalization_stats.py
index daaf8767..a258fb6d 100644
--- a/calculate_statistics.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -5,12 +5,15 @@
import xarray as xr
# First-party
-from neural_lam import config
+from neural_lam.datastore.multizarr import MultiZarrDatastore
-def compute_stats(data_array):
- mean = data_array.mean(dim=("time", "grid"))
- std = data_array.std(dim=("time", "grid"))
+DEFAULT_PATH = "tests/datastore_configs/multizarr.danra.yaml"
+
+
+def compute_stats(da):
+ mean = da.mean(dim=("time", "grid_index"))
+ std = da.std(dim=("time", "grid_index"))
return mean, std
@@ -19,8 +22,8 @@ def main():
parser.add_argument(
"--data_config",
type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
+ default=DEFAULT_PATH,
+ help=f"Path to data config file (default: {DEFAULT_PATH})",
)
parser.add_argument(
"--zarr_path",
@@ -29,65 +32,65 @@ def main():
help="Directory where data is stored",
)
args = parser.parse_args()
+
+ datastore = MultiZarrDatastore(config_path=args.data_config)
- data_config = config.Config.from_file(args.data_config)
- state_data = data_config.process_dataset("state", split="train")
- forcing_data = data_config.process_dataset(
- "forcing", split="train", apply_windowing=False
- )
+ da_state = datastore.get_dataarray(category="state", split="train")
+ da_forcing = datastore.get_dataarray(category="forcing", split="train")
print("Computing mean and std.-dev. for parameters...", flush=True)
- state_mean, state_std = compute_stats(state_data)
+ da_state_mean, da_state_std = compute_stats(da_state)
- if forcing_data is not None:
- forcing_mean, forcing_std = compute_stats(forcing_data)
- combined_stats = data_config["utilities"]["normalization"][
- "combined_stats"
- ]
+ if da_forcing is not None:
+ da_forcing_mean, da_forcing_std = compute_stats(da_forcing)
+ combined_stats = datastore._config["utilities"]["normalization"]["combined_stats"]
if combined_stats is not None:
for group in combined_stats:
vars_to_combine = group["vars"]
- means = forcing_mean.sel(variable=vars_to_combine)
- stds = forcing_std.sel(variable=vars_to_combine)
+ import ipdb; ipdb.set_trace()
+ means = da_forcing_mean.sel(variable=vars_to_combine)
+ stds = da_forcing_std.sel(variable=vars_to_combine)
combined_mean = means.mean(dim="variable")
combined_std = (stds**2).mean(dim="variable") ** 0.5
- forcing_mean.loc[dict(variable=vars_to_combine)] = combined_mean
- forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
- window = data_config["forcing"]["window"]
- forcing_mean = xr.concat([forcing_mean] * window, dim="window").stack(
+ da_forcing_mean.loc[dict(variable=vars_to_combine)] = combined_mean
+ da_forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
+
+ window = datastore._config["forcing"]["window"]
+
+ da_forcing_mean = xr.concat([da_forcing_mean] * window, dim="window").stack(
forcing_variable=("variable", "window")
)
- forcing_std = xr.concat([forcing_std] * window, dim="window").stack(
+ da_forcing_std = xr.concat([da_forcing_std] * window, dim="window").stack(
forcing_variable=("variable", "window")
)
- vars = forcing_data["variable"].values.tolist()
- window = data_config["forcing"]["window"]
+ vars = da_forcing["variable"].values.tolist()
+ window = datastore._config["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
print(
"Computing mean and std.-dev. for one-step differences...", flush=True
)
- state_data_normalized = (state_data - state_mean) / state_std
+ state_data_normalized = (da_state - da_state_mean) / da_state_std
state_data_diff_normalized = state_data_normalized.diff(dim="time")
diff_mean, diff_std = compute_stats(state_data_diff_normalized)
ds = xr.Dataset(
{
- "state_mean": state_mean,
- "state_std": state_std,
+ "state_mean": da_state_mean,
+ "state_std": da_state_std,
"diff_mean": diff_mean,
"diff_std": diff_std,
}
)
- if forcing_data is not None:
+ if da_forcing is not None:
dsf = (
xr.Dataset(
{
- "forcing_mean": forcing_mean,
- "forcing_std": forcing_std,
+ "forcing_mean": da_forcing_mean,
+ "forcing_std": da_forcing_std,
}
)
.reset_index(["forcing_variable"])
diff --git a/neural_lam/config.py b/neural_lam/datastore/multizarr/store.py
similarity index 67%
rename from neural_lam/config.py
rename to neural_lam/datastore/multizarr/store.py
index 6ab5868d..38617984 100644
--- a/neural_lam/config.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -1,56 +1,57 @@
-# Standard library
-import functools
-import os
-from pathlib import Path
-
-# Third-party
import cartopy.crs as ccrs
import numpy as np
import pandas as pd
-import torch
import xarray as xr
import yaml
+import functools
+import os
-class Config:
- """Class to load and access the configuration file.
- The class also preprocesses the dataset based on the config."""
+from .config import Config
+from ..base import BaseDatastore
- DIMS_TO_KEEP = {"time", "grid", "variable"}
- def __init__(self, values):
- self.values = values
+def convert_stats_to_torch(stats):
+ """Convert the normalization statistics to torch tensors.
- @classmethod
- def from_file(cls, filepath):
- """Load the configuration file from the given path."""
- if filepath.endswith(".yaml"):
- with open(filepath, encoding="utf-8", mode="r") as file:
- return cls(values=yaml.safe_load(file))
- else:
- raise NotImplementedError(Path(filepath).suffix)
-
- def __getattr__(self, name):
- """Recursively access the values in the configuration."""
- keys = name.split(".")
- value = self.values
- for key in keys:
- if key in value:
- value = value[key]
- else:
- return None
- if isinstance(value, dict):
- return Config(values=value)
- return value
+ Args:
+ stats (xr.Dataset): The normalization statistics.
+
+ Returns:
+ dict(tensor): The normalization statistics as torch tensors."""
+ return {
+ var: torch.tensor(stats[var].values, dtype=torch.float32)
+ for var in stats.data_vars
+ }
+
+class MultiZarrDatastore(BaseDatastore):
+ DIMS_TO_KEEP = {"time", "grid_index", "variable"}
+
+ def __init__(self, config_path):
+ with open(config_path, encoding="utf-8", mode="r") as file:
+ self._config = yaml.safe_load(file)
+
+ def open_zarrs(self, category):
+ """Open the zarr dataset for the given category.
+
+ Args:
+ category (str): The category of the dataset (state/forcing/static).
- def __getitem__(self, key):
- value = self.values[key]
- if isinstance(value, dict):
- return Config(values=value)
- return value
+ Returns:
+ xr.Dataset: The xarray Dataset object."""
+ zarr_configs = self._config[category]["zarrs"]
- def __contains__(self, key):
- return key in self.values
+ datasets = []
+ for config in zarr_configs:
+ dataset_path = config["path"]
+ try:
+ dataset = xr.open_zarr(dataset_path, consolidated=True)
+ except Exception as e:
+ raise Exception("Error opening dataset:", dataset_path) from e
+ datasets.append(dataset)
+ merged_dataset = xr.merge(datasets)
+ merged_dataset.attrs["category"] = category
+ return merged_dataset
@functools.cached_property
def coords_projection(self):
@@ -60,7 +61,7 @@ def coords_projection(self):
Returns:
cartopy.crs.Projection: The projection object."""
- proj_config = self.values["projection"]
+ proj_config = self._config["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
proj_params = proj_config.get("kwargs", {})
@@ -79,7 +80,7 @@ def step_length(self):
return int(step_length_hours)
@functools.lru_cache()
- def vars_names(self, category):
+ def get_vars_names(self, category):
"""Return the names of the variables in the dataset.
Args:
@@ -87,16 +88,16 @@ def vars_names(self, category):
Returns:
list: The names of the variables in the dataset."""
- surface_vars_names = self.values[category].get("surface_vars") or []
+ surface_vars_names = self._config[category].get("surface_vars") or []
atmosphere_vars_names = [
f"{var}_{level}"
- for var in (self.values[category].get("atmosphere_vars") or [])
- for level in (self.values[category].get("levels") or [])
+ for var in (self._config[category].get("atmosphere_vars") or [])
+ for level in (self._config[category].get("levels") or [])
]
return surface_vars_names + atmosphere_vars_names
@functools.lru_cache()
- def vars_units(self, category):
+ def get_vars_units(self, category):
"""Return the units of the variables in the dataset.
Args:
@@ -104,16 +105,16 @@ def vars_units(self, category):
Returns:
list: The units of the variables in the dataset."""
- surface_vars_units = self.values[category].get("surface_units") or []
+ surface_vars_units = self._config[category].get("surface_units") or []
atmosphere_vars_units = [
unit
- for unit in (self.values[category].get("atmosphere_units") or [])
- for _ in (self.values[category].get("levels") or [])
+ for unit in (self._config[category].get("atmosphere_units") or [])
+ for _ in (self._config[category].get("levels") or [])
]
return surface_vars_units + atmosphere_vars_units
@functools.lru_cache()
- def num_data_vars(self, category):
+ def get_num_data_vars(self, category):
"""Return the number of data variables in the dataset.
Args:
@@ -121,9 +122,9 @@ def num_data_vars(self, category):
Returns:
int: The number of data variables in the dataset."""
- surface_vars = self.values[category].get("surface_vars", [])
- atmosphere_vars = self.values[category].get("atmosphere_vars", [])
- levels = self.values[category].get("levels", [])
+ surface_vars = self._config[category].get("surface_vars", [])
+ atmosphere_vars = self._config[category].get("atmosphere_vars", [])
+ levels = self._config[category].get("levels", [])
surface_vars_count = (
len(surface_vars) if surface_vars is not None else 0
@@ -135,51 +136,26 @@ def num_data_vars(self, category):
return surface_vars_count + atmosphere_vars_count * levels_count
- def open_zarrs(self, category):
- """Open the zarr dataset for the given category.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
-
- Returns:
- xr.Dataset: The xarray Dataset object."""
- zarr_configs = self.values[category]["zarrs"]
-
- try:
- datasets = []
- for config in zarr_configs:
- dataset_path = config["path"]
- dataset = xr.open_zarr(dataset_path, consolidated=True)
- datasets.append(dataset)
- merged_dataset = xr.merge(datasets)
- merged_dataset.attrs["category"] = category
- return merged_dataset
- except Exception:
- print(f"Invalid zarr configuration for category: {category}")
- return None
-
- def stack_grid(self, dataset):
+ def _stack_grid(self, ds):
"""Stack the grid dimensions of the dataset.
Args:
- dataset (xr.Dataset): The xarray Dataset object.
+ ds (xr.Dataset): The xarray Dataset object.
Returns:
xr.Dataset: The xarray Dataset object with stacked grid dimensions."""
- if dataset is None:
- return None
- dims = list(dataset.dims)
-
- if "grid" in dims:
- print("\033[94mGrid dimensions already stacked.\033[0m")
- return dataset.squeeze()
+ if "grid_index" in ds.dims:
+ raise ValueError("Grid dimensions already stacked.")
else:
- if "x" not in dims or "y" not in dims:
- self.rename_dataset_dims_and_vars(dataset=dataset)
- dataset = dataset.squeeze().stack(grid=("y", "x"))
- return dataset
-
- def convert_dataset_to_dataarray(self, dataset):
+ if "x" not in ds.dims or "y" not in ds.dims:
+ self._rename_dataset_dims_and_vars(dataset=ds)
+ ds = ds.stack(grid_index=("y", "x")).reset_index("grid_index")
+ # reset the grid_index coordinates to have integer values, otherwise
+ # the serialisation to zarr will fail
+ ds["grid_index"] = np.arange(len(ds["grid_index"]))
+ return ds
+
+ def _convert_dataset_to_dataarray(self, dataset):
"""Convert the Dataset to a Dataarray.
Args:
@@ -191,7 +167,7 @@ def convert_dataset_to_dataarray(self, dataset):
dataset = dataset.to_array()
return dataset
- def filter_dimensions(self, dataset, transpose_array=True):
+ def _filter_dimensions(self, dataset, transpose_array=True):
"""Drop the dimensions and filter the data_vars of the dataset.
Args:
@@ -215,10 +191,10 @@ def filter_dimensions(self, dataset, transpose_array=True):
"\033[91mAttempting to update dims and "
"vars based on zarr config...\033[0m"
)
- dataset = self.rename_dataset_dims_and_vars(
+ dataset = self._rename_dataset_dims_and_vars(
dataset.attrs["category"], dataset=dataset
)
- dataset = self.stack_grid(dataset)
+ dataset = self._stack_grid(dataset)
dataset_dims = set(list(dataset.dims) + ["variable"])
if min_req_dims.issubset(dataset_dims):
print(
@@ -247,12 +223,12 @@ def filter_dimensions(self, dataset, transpose_array=True):
)
if transpose_array:
- dataset = self.convert_dataset_to_dataarray(dataset)
+ dataset = self._convert_dataset_to_dataarray(dataset)
if "time" in dataset.dims:
- dataset = dataset.transpose("time", "grid", "variable")
+ dataset = dataset.transpose("time", "grid_index", "variable")
else:
- dataset = dataset.transpose("grid", "variable")
+ dataset = dataset.transpose("grid_index", "variable")
dataset_vars = (
list(dataset.data_vars)
if isinstance(dataset, xr.Dataset)
@@ -266,7 +242,7 @@ def filter_dimensions(self, dataset, transpose_array=True):
return dataset
- def reshape_grid_to_2d(self, dataset, grid_shape=None):
+ def _reshape_grid_to_2d(self, dataset, grid_shape=None):
"""Reshape the grid to 2D for stacked data without multi-index.
Args:
@@ -317,7 +293,8 @@ def get_xy(self, category, stacked=True):
return x, y
def get_xy_extent(self, category):
- """Return the extent of the x, y coordinates.
+ """Return the extent of the x, y coordinates. This should be a list
+ of 4 floats with `[xmin, xmax, ymin, ymax]`
Args:
category (str): The category of the dataset (state/forcing/static).
@@ -333,18 +310,15 @@ def get_xy_extent(self, category):
return extent
@functools.lru_cache()
- def load_normalization_stats(self, category, datatype="torch"):
+ def get_normalization_stats(self, category):
"""Load the normalization statistics for the dataset.
Args:
category (str): The category of the dataset (state/forcing/static).
- datatype (str): The datatype of the statistics (torch/"").
Returns:
- tensor: The normalization statistics for the dataset.
- (if datatype="torch")
OR xr.Dataset: The normalization statistics for the dataset.
- (otherwise)"""
+ """
combined_stats = self._load_and_merge_stats()
if combined_stats is None:
return None
@@ -355,9 +329,6 @@ def load_normalization_stats(self, category, datatype="torch"):
if stats is None:
return None
- if datatype == "torch":
- return self._convert_stats_to_torch(stats)
-
return stats
def _load_and_merge_stats(self):
@@ -367,14 +338,13 @@ def _load_and_merge_stats(self):
xr.Dataset: The merged normalization statistics for the dataset."""
combined_stats = None
for i, zarr_config in enumerate(
- self.values["utilities"]["normalization"]["zarrs"]
+ self._config["utilities"]["normalization"]["zarrs"]
):
stats_path = zarr_config["path"]
if not os.path.exists(stats_path):
- print(
+ raise FileNotFoundError(
f"Normalization statistics not found at path: {stats_path}"
)
- return None
stats = xr.open_zarr(stats_path, consolidated=True)
if i == 0:
combined_stats = stats
@@ -392,7 +362,7 @@ def _rename_data_vars(self, combined_stats):
xr.Dataset: The combined normalization statistics with renamed data
variables."""
vars_mapping = {}
- for zarr_config in self.values["utilities"]["normalization"]["zarrs"]:
+ for zarr_config in self._config["utilities"]["normalization"]["zarrs"]:
vars_mapping.update(zarr_config["stats_vars"])
return combined_stats.rename_vars(
@@ -448,20 +418,7 @@ def _select_stats_by_category(self, combined_stats, category):
print(f"Invalid category: {category}")
return None
- def _convert_stats_to_torch(self, stats):
- """Convert the normalization statistics to torch tensors.
-
- Args:
- stats (xr.Dataset): The normalization statistics.
-
- Returns:
- dict(tensor): The normalization statistics as torch tensors."""
- return {
- var: torch.tensor(stats[var].values, dtype=torch.float32)
- for var in stats.data_vars
- }
-
- def extract_vars(self, category, dataset=None):
+ def _extract_vars(self, category, ds=None):
"""Extract (select) the data variables from the dataset.
Args:
@@ -469,67 +426,57 @@ def extract_vars(self, category, dataset=None):
dataset (xr.Dataset): The xarray Dataset object.
Returns:
- xr.Dataset: The xarray Dataset object with extracted variables."""
- if dataset is None:
- dataset = self.open_zarrs(category)
- surface_vars = None
- atmosphere_vars = None
- if self[category].surface_vars:
- surface_vars = self._extract_surface_vars(category, dataset)
- if self[category].atmosphere_vars:
- atmosphere_vars = self._extract_atmosphere_vars(category, dataset)
- if surface_vars and atmosphere_vars:
- return xr.merge([surface_vars, atmosphere_vars])
- elif surface_vars:
- return surface_vars
- elif atmosphere_vars:
- return atmosphere_vars
- else:
- print(f"No variables found in dataset {category}")
- return None
-
- def _extract_surface_vars(self, category, dataset):
- """Extract the surface variables from the dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
- dataset (xr.Dataset): The xarray Dataset object.
-
- Returns:
- xr.Dataset: The xarray Dataset object with surface variables.
+ xr.Dataset: The xarray Dataset object with extracted variables.
"""
- return (
- dataset[self[category].surface_vars]
- if self[category].surface_vars
- else []
- )
+ if ds is None:
+ ds = self.open_zarrs(category)
+ surface_vars = self._config[category].get("surface_vars")
+ atmoshere_vars = self._config[category].get("atmosphere_vars")
+
+ ds_surface = None
+ if surface_vars is not None:
+ ds_surface = ds[surface_vars]
+
+ ds_atmosphere = None
+ if atmoshere_vars is not None:
+ ds_atmosphere = self._extract_atmosphere_vars(category=category, ds=ds)
+
+ if ds_surface and ds_atmosphere:
+ return xr.merge([ds_surface, ds_atmosphere])
+ elif ds_surface:
+ return ds_surface
+ elif ds_atmosphere:
+ return ds_atmosphere
+ else:
+ raise ValueError(f"No variables found in dataset {category}")
- def _extract_atmosphere_vars(self, category, dataset):
+ def _extract_atmosphere_vars(self, category, ds):
"""Extract the atmosphere variables from the dataset.
Args:
category (str): The category of the dataset (state/forcing/static).
- dataset (xr.Dataset): The xarray Dataset object.
+ ds (xr.Dataset): The xarray Dataset object.
Returns:
xr.Dataset: The xarray Dataset object with atmosphere variables."""
- if "level" not in list(dataset.dims) and self[category].atmosphere_vars:
- dataset = self.rename_dataset_dims_and_vars(
- dataset.attrs["category"], dataset=dataset
+
+ if "level" not in list(ds.dims) and self._config[category]["atmosphere_vars"]:
+ ds = self._rename_dataset_dims_and_vars(
+ ds.attrs["category"], dataset=ds
)
data_arrays = [
- dataset[var].sel(level=level, drop=True).rename(f"{var}_{level}")
- for var in self[category].atmosphere_vars
- for level in self[category].levels
+ ds[var].sel(level=level, drop=True).rename(f"{var}_{level}")
+ for var in self._config[category]["atmosphere_vars"]
+ for level in self._config[category]["levels"]
]
- if self[category].atmosphere_vars:
+ if self._config[category]["atmosphere_vars"]:
return xr.merge(data_arrays)
else:
return xr.Dataset()
- def rename_dataset_dims_and_vars(self, category, dataset=None):
+ def _rename_dataset_dims_and_vars(self, category, dataset=None):
"""Rename the dimensions and variables of the dataset.
Args:
@@ -549,7 +496,7 @@ def rename_dataset_dims_and_vars(self, category, dataset=None):
convert = True
dataset = dataset.to_dataset("variable")
dims_mapping = {}
- zarr_configs = self.values[category]["zarrs"]
+ zarr_configs = self._config[category]["zarrs"]
for zarr_config in zarr_configs:
dims_mapping.update(zarr_config["dims"])
@@ -567,7 +514,7 @@ def rename_dataset_dims_and_vars(self, category, dataset=None):
dataset = dataset.to_array()
return dataset
- def filter_dataset_by_time(self, dataset, split="train"):
+ def _apply_time_split(self, dataset, split="train"):
"""Filter the dataset by the time split.
Args:
@@ -577,8 +524,8 @@ def filter_dataset_by_time(self, dataset, split="train"):
Returns:
xr.Dataset: The xarray Dataset object filtered by the time split."""
start, end = (
- self.values["splits"][split]["start"],
- self.values["splits"][split]["end"],
+ self._config["splits"][split]["start"],
+ self._config["splits"][split]["end"],
)
dataset = dataset.sel(time=slice(start, end))
dataset.attrs["split"] = split
@@ -596,11 +543,11 @@ def apply_window(self, category, dataset=None):
if dataset is None:
dataset = self.open_zarrs(category)
if isinstance(dataset, xr.Dataset):
- dataset = self.convert_dataset_to_dataarray(dataset)
+ dataset = self._convert_dataset_to_dataarray(dataset)
state = self.open_zarrs("state")
- state = self.filter_dataset_by_time(state, dataset.attrs["split"])
+ state = self._apply_time_split(state, dataset.attrs["split"])
state_time = state.time.values
- window = self[category].window
+ window = self._config[category]["window"]
dataset = (
dataset.sel(time=state_time, method="nearest")
.pad(time=(window // 2, window // 2), mode="edge")
@@ -611,18 +558,21 @@ def apply_window(self, category, dataset=None):
dataset = dataset.isel(time=slice(window // 2, -window // 2 + 1))
return dataset
- def load_boundary_mask(self):
- """Load the boundary mask for the dataset.
+ @property
+ def boundary_mask(self):
+ """
+ Load the boundary mask for the dataset, with spatial dimensions stacked.
+
+ Returns
+ -------
+ xr.DataArray
+ The boundary mask for the dataset, with dimensions `('grid_index',)`.
+ """
+ ds_boundary_mask = xr.open_zarr(self._config["boundary"]["mask"]["path"])
+ return ds_boundary_mask.mask.stack(grid_index=("y", "x")).reset_index("grid_index")
+
- Returns:
- tensor: The boundary mask for the dataset."""
- boundary_mask = xr.open_zarr(self.values["boundary"]["mask"]["path"])
- return torch.tensor(
- boundary_mask.mask.stack(grid=("y", "x")).values,
- dtype=torch.float32,
- ).unsqueeze(1)
-
- def process_dataset(self, category, split="train", apply_windowing=True):
+ def get_dataarray(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category.
Args:
@@ -633,14 +583,14 @@ def process_dataset(self, category, split="train", apply_windowing=True):
Returns:
xr.DataArray: The xarray DataArray object with processed dataset."""
dataset = self.open_zarrs(category)
- dataset = self.extract_vars(category, dataset)
+ dataset = self._extract_vars(category, dataset)
if category != "static":
- dataset = self.filter_dataset_by_time(dataset, split)
- dataset = self.stack_grid(dataset)
- dataset = self.rename_dataset_dims_and_vars(category, dataset)
- dataset = self.filter_dimensions(dataset)
- dataset = self.convert_dataset_to_dataarray(dataset)
- if "window" in self.values[category] and apply_windowing:
+ dataset = self._apply_time_split(dataset, split)
+ dataset = self._stack_grid(dataset)
+ dataset = self._rename_dataset_dims_and_vars(category, dataset)
+ dataset = self._filter_dimensions(dataset)
+ dataset = self._convert_dataset_to_dataarray(dataset)
+ if "window" in self._config[category] and apply_windowing:
dataset = self.apply_window(category, dataset)
if category == "static" and "time" in dataset.dims:
dataset = dataset.isel(time=0, drop=True)
diff --git a/neural_lam/datastore/npyfiles/__init__.py b/neural_lam/datastore/npyfiles/__init__.py
new file mode 100644
index 00000000..57b47049
--- /dev/null
+++ b/neural_lam/datastore/npyfiles/__init__.py
@@ -0,0 +1 @@
+from .store import NumpyFilesDatastore
\ No newline at end of file
diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py
new file mode 100644
index 00000000..842c4b83
--- /dev/null
+++ b/neural_lam/datastore/npyfiles/config.py
@@ -0,0 +1,62 @@
+# Standard library
+import functools
+from pathlib import Path
+
+# Third-party
+import cartopy.crs as ccrs
+import yaml
+
+
+class NpyConfig:
+ """
+ Class for loading configuration files.
+
+ This class loads a configuration file and provides a way to access its
+ values as attributes.
+ """
+
+ def __init__(self, values):
+ self.values = values
+
+ @classmethod
+ def from_file(cls, filepath):
+ """Load a configuration file."""
+ if str(filepath).endswith(".yaml"):
+ with open(filepath, encoding="utf-8", mode="r") as file:
+ return cls(values=yaml.safe_load(file))
+ else:
+ raise NotImplementedError(Path(filepath).suffix)
+
+ def __getattr__(self, name):
+ child, *children = name.split(".")
+
+ value = self.values[child]
+ if len(children) > 0:
+ return self.__class__(values=value).get(".".join(children))
+ else:
+ if isinstance(value, dict):
+ return self.__class__(values=value)
+ else:
+ return value
+
+ def __getitem__(self, key):
+ value = self.values[key]
+ if isinstance(value, dict):
+ return self.__class__(values=value)
+ return value
+
+ def __contains__(self, key):
+ return key in self.values
+
+ def num_data_vars(self):
+ """Return the number of data variables for a given key."""
+ return len(self.dataset.var_names)
+
+ @functools.cached_property
+ def coords_projection(self):
+ """Return the projection."""
+ proj_config = self.values["projection"]
+ proj_class_name = proj_config["class"]
+ proj_class = getattr(ccrs, proj_class_name)
+ proj_params = proj_config.get("kwargs", {})
+ return proj_class(**proj_params)
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
new file mode 100644
index 00000000..35e53004
--- /dev/null
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -0,0 +1,373 @@
+# Standard library
+import datetime as dt
+import glob
+import os
+import re
+from pathlib import Path
+
+# Third-party
+import dask.delayed
+import numpy as np
+import torch
+from xarray.core.dataarray import DataArray
+import parse
+import dask
+import dask.array
+import xarray as xr
+
+# First-party
+
+from ..base import BaseCartesianDatastore
+from .config import NpyConfig
+
+STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy"
+TOA_SW_DOWN_FLUX_FILENAME_FORMAT = "nwp_toa_downwelling_shortwave_flux_{analysis_time:%Y%m%d%H}.npy"
+COLUMN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"
+
+
+
+class NumpyFilesDatastore(BaseCartesianDatastore):
+ __doc__ = f"""
+ Represents a dataset stored as numpy files on disk. The dataset is assumed
+ to be stored in a directory structure where each sample is stored in a
+ separate file. The file-name format is assumed to be '{STATE_FILENAME_FORMAT}'
+
+ The MEPS dataset is organised into three splits: train, val, and test. Each
+ split has a set of files which are:
+
+ - `{STATE_FILENAME_FORMAT}`:
+ The state variables for a forecast started at `analysis_time` with
+ member id `member_id`. The dimensions of the array are
+ `[forecast_timestep, y, x, feature]`.
+
+ - `{TOA_SW_DOWN_FLUX_FILENAME_FORMAT}`:
+ The top-of-atmosphere downwelling shortwave flux at `time`. The
+ dimensions of the array are `[forecast_timestep, y, x]`.
+
+ - `{COLUMN_WATER_FILENAME_FORMAT}`:
+ The column water at `time`. The dimensions of the array are
+ `[y, x]`.
+
+
+ Folder structure:
+
+ meps_example_reduced
+ ├── data_config.yaml
+ ├── samples
+ │ ├── test
+ │ │ ├── nwp_2022090100_mbr000.npy
+ │ │ ├── nwp_2022090100_mbr001.npy
+ │ │ ├── nwp_2022090112_mbr000.npy
+ │ │ ├── nwp_2022090112_mbr001.npy
+ │ │ ├── ...
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022090100.npy
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022090112.npy
+ │ │ ├── ...
+ │ │ ├── wtr_2022090100.npy
+ │ │ ├── wtr_2022090112.npy
+ │ │ └── ...
+ │ ├── train
+ │ │ ├── nwp_2022040100_mbr000.npy
+ │ │ ├── nwp_2022040100_mbr001.npy
+ │ │ ├── ...
+ │ │ ├── nwp_2022040112_mbr000.npy
+ │ │ ├── nwp_2022040112_mbr001.npy
+ │ │ ├── ...
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy
+ │ │ ├── ...
+ │ │ ├── wtr_2022040100.npy
+ │ │ ├── wtr_2022040112.npy
+ │ │ └── ...
+ │ └── val
+ │ ├── nwp_2022060500_mbr000.npy
+ │ ├── nwp_2022060500_mbr001.npy
+ │ ├── ...
+ │ ├── nwp_2022060512_mbr000.npy
+ │ ├── nwp_2022060512_mbr001.npy
+ │ ├── ...
+ │ ├── nwp_toa_downwelling_shortwave_flux_2022060500.npy
+ │ ├── nwp_toa_downwelling_shortwave_flux_2022060512.npy
+ │ ├── ...
+ │ ├── wtr_2022060500.npy
+ │ ├── wtr_2022060512.npy
+ │ └── ...
+ └── static
+ ├── border_mask.npy
+ ├── diff_mean.pt
+ ├── diff_std.pt
+ ├── flux_stats.pt
+ ├── grid_features.pt
+ ├── nwp_xy.npy
+ ├── parameter_mean.pt
+ ├── parameter_std.pt
+ ├── parameter_weights.npy
+ └── surface_geopotential.npy
+
+ For the MEPS dataset:
+ N_t' = 65
+ N_t = 65//subsample_step (= 21 for 3h steps)
+ dim_y = 268
+ dim_x = 238
+ N_grid = 268x238 = 63784
+ d_features = 17 (d_features' = 18)
+ d_forcing = 5
+
+ For the MEPS reduced dataset:
+ N_t' = 65
+ N_t = 65//subsample_step (= 21 for 3h steps)
+ dim_y = 134
+ dim_x = 119
+ N_grid = 134x119 = 15946
+ d_features = 8
+ d_forcing = 1
+ """
+ is_ensemble = True
+
+ def __init__(
+ self,
+ root_path,
+ ):
+ # XXX: This should really be in the config file, not hard-coded in this class
+ self._num_timesteps = 65
+ self._step_length = 3 # 3 hours
+ self._num_ensemble_members = 2
+
+ self.root_path = Path(root_path)
+ self._config = NpyConfig.from_file(self.root_path / "data_config.yaml")
+ pass
+
+ def get_dataarray(self, category: str, split: str) -> DataArray:
+ """
+ Get the data array for the given category and split of data. If the category
+ is 'state', the data array will be a concatenation of the data arrays for all
+ ensemble members. The data will be loaded as a dask array, so that the data
+ isn't actually loaded until it's needed.
+
+ Parameters
+ ----------
+ category : str
+ The category of the data to load. One of 'state', 'forcing', or 'static'.
+ split : str
+ The dataset split to load the data for. One of 'train', 'val', or 'test'.
+
+ Returns
+ -------
+ xr.DataArray
+ The data array for the given category and split, with dimensions per category:
+ state: `[time, analysis_time, grid_index, feature, ensemble_member]`
+ forcing & static: `[time, analysis_time, grid_index, feature]`
+ """
+ if category == "state":
+ # for the state category, we need to load all ensemble members
+ da = xr.concat(
+ [
+ self._get_single_timeseries_dataarray(category=category, split=split, member=member)
+ for member in range(self._num_ensemble_members)
+ ],
+ dim="ensemble_member"
+ )
+ else:
+ da = self._get_single_timeseries_dataarray(category=category, split=split)
+ return da
+
+ def _get_single_timeseries_dataarray(self, category: str, split: str, member: int = None) -> DataArray:
+ """
+ Get the data array spanning the complete time series for a given category and split
+ of data. If the category is 'state', the member argument should be specified to select
+ the ensemble member to load. The data will be loaded as a dask array, so that the data
+ isn't actually loaded until it's needed.
+
+ Parameters
+ ----------
+ category : str
+ The category of the data to load. One of 'state', 'forcing', or 'static'.
+ split : str
+ The dataset split to load the data for. One of 'train', 'val', or 'test'.
+ member : int, optional
+ The ensemble member to load. Only applicable for the 'state' category.
+
+ Returns
+ -------
+ xr.DataArray
+ The data array for the given category and split, with dimensions
+ `[time, analysis_time, grid_index, feature]` for all categories of data
+ """
+ assert split in ("train", "val", "test"), "Unknown dataset split"
+
+ if member is not None and category != "state":
+ raise ValueError("Member can only be specified for the 'state' category")
+
+ # XXX: we here assume that the grid shape is the same for all categories
+ grid_shape = self.grid_shape_state
+
+ analysis_times = self._get_analysis_times(split=split)
+ fp_split = self.root_path / "samples" / split
+
+ file_dims = ["time", "y", "x", "feature"]
+ elapsed_time = self.step_length * np.arange(self._num_timesteps) * np.timedelta64(1, "h")
+ arr_shape = [len(elapsed_time)] + grid_shape
+ coords = dict(
+ analysis_time=analysis_times,
+ time=elapsed_time,
+ y=np.arange(grid_shape[0]),
+ x=np.arange(grid_shape[1]),
+ )
+
+ extra_kwargs = {}
+ add_feature_dim = False
+ if category == "state":
+ filename_format = STATE_FILENAME_FORMAT
+ # only select one member for now
+ extra_kwargs["member_id"] = member
+ # state has multiple features
+ num_state_variables = self.get_num_data_vars
+ arr_shape += [num_state_variables]
+ coords["feature"] = self.get_vars_names(category="state")
+ elif category == "forcing":
+ filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
+ arr_shape += [1]
+ # XXX: this should really be saved in the data-config
+ coords["feature"] = ["toa_downwelling_shortwave_flux"]
+ add_feature_dim = True
+ elif category == "static":
+ filename_format = COLUMN_WATER_FILENAME_FORMAT
+ arr_shape += [1]
+ # XXX: this should really be saved in the data-config
+ coords["feature"] = ["column_water"]
+ add_feature_dim = True
+ else:
+ raise NotImplementedError(f"Category {category} not supported")
+
+ filepaths = [
+ fp_split / filename_format.format(analysis_time=analysis_time, **extra_kwargs)
+ for analysis_time in analysis_times
+ ]
+
+ # use dask.delayed to load the numpy files, so that loading isn't
+ # done until the data is actually needed
+ @dask.delayed
+ def _load_np(fp):
+ arr = np.load(fp)
+ if add_feature_dim:
+ arr = arr[..., np.newaxis]
+ return arr
+
+ arrays = [
+ dask.array.from_delayed(
+ _load_np(fp), shape=arr_shape, dtype=np.float32
+ ) for fp in filepaths
+ ]
+
+ arr_all = dask.array.stack(arrays, axis=0)
+
+ da = xr.DataArray(
+ arr_all,
+ dims=["analysis_time"] + file_dims,
+ coords=coords,
+ name=category
+ )
+
+ # stack the [x, y] dimensions into a `grid_index` dimension
+ da = da.stack(grid_index=["y", "x"])
+
+ if category == "forcing":
+ # add datetime forcing as a feature
+ # to do this we create a forecast time variable which has the dimensions of
+ # (analysis_time, time) with values that are the actual forecast time of each
+ # time step. But calling .chunk({"time": 1}) this time variable is turned into
+ # a dask array and so execution of the calculation is delayed until the feature
+ # values are actually used.
+ da_forecast_time = (da.time + da.analysis_time).chunk({"time": 1})
+ da_datetime_forcing_features = self._calc_datetime_forcing_features(da_time=da_forecast_time)
+ da = xr.concat([da, da_datetime_forcing_features], dim="feature")
+
+ return da
+
+ def _get_analysis_times(self, split):
+ """
+ Get the analysis times for the given split by parsing the filenames
+ of all the files found for the given split.
+
+ Parameters
+ ----------
+ split : str
+ The dataset split to get the analysis times for.
+
+ Returns
+ -------
+ List[dt.datetime]
+ The analysis times for the given split.
+ """
+ pattern = re.sub(r'{analysis_time:[^}]*}', '*', STATE_FILENAME_FORMAT)
+ pattern = re.sub(r'{member_id:[^}]*}', '*', pattern)
+
+ sample_dir = self.root_path / "samples" / split
+ sample_files = sample_dir.glob(pattern)
+ times = []
+ for fp in sample_files:
+ name_parts = parse.parse(STATE_FILENAME_FORMAT, fp.name)
+ times.append(name_parts["analysis_time"])
+
+ return times
+
+ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
+ da_hour_angle = da_time.dt.hour / 12 * np.pi
+ da_year_angle = da_time.dt.dayofyear / 365 * 2 * np.pi
+
+ da_datetime_forcing = xr.concat(
+ (
+ np.sin(da_hour_angle),
+ np.cos(da_hour_angle),
+ np.sin(da_year_angle),
+ np.cos(da_year_angle),
+ ),
+ dim="feature",
+ )
+ da_datetime_forcing = (da_datetime_forcing + 1) / 2 # Rescale to [0,1]
+ da_datetime_forcing["feature"] = ["sin_hour", "cos_hour", "sin_year", "cos_year"]
+
+ return da_datetime_forcing
+
+ def get_vars_units(self, category: str) -> torch.List[str]:
+ if category == "state":
+ return self._config["dataset"]["var_units"]
+ else:
+ raise NotImplementedError(f"Category {category} not supported")
+
+ def get_vars_names(self, category: str) -> torch.List[str]:
+ if category == "state":
+ return self._config["dataset"]["var_names"]
+ else:
+ raise NotImplementedError(f"Category {category} not supported")
+
+ @property
+ def get_num_data_vars(self) -> int:
+ return len(self.get_vars_names(category="state"))
+
+ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
+ arr = np.load(self.root_path / "static" / "nwp_xy.npy")
+
+ assert arr.shape[0] == 2, "Expected 2D array"
+ assert arr.shape[1:] == tuple(self.grid_shape_state), "Unexpected shape"
+
+ if stacked:
+ return arr
+ else:
+ return arr[0], arr[1]
+
+ @property
+ def step_length(self):
+ return self._step_length
+
+ @property
+ def coords_projection(self):
+ return self._config.coords_projection
+
+ @property
+ def grid_shape_state(self):
+ return self._config.grid_shape_state
+
+ @property
+ def boundary_mask(self):
+ return np.load(self.root_path / "static" / "border_mask.npy")
\ No newline at end of file
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 0c8422f3..7c206028 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -9,7 +9,8 @@
import wandb
# First-party
-from neural_lam import config, metrics, vis
+from neural_lam import metrics, vis
+from neural_lam.datastore.multizarr import config
class ARModel(pl.LightningModule):
@@ -26,9 +27,13 @@ def __init__(self, args):
self.save_hyperparameters()
self.args = args
self.data_config = config.Config.from_file(args.data_config)
+
+ num_state_vars = self.data_config.num_data_vars("state")
+ num_forcing_vars = self.data_config.num_data_vars("forcing")
+ da_static_features = self.data_config.process_dataset("static")
# Load static features for grid/data
- static = self.data_config.process_dataset("static")
+ static = da_static_features.values
self.register_buffer(
"grid_static_features",
torch.tensor(static.values, dtype=torch.float32),
@@ -43,13 +48,12 @@ def __init__(self, args):
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
- self.grid_output_dim = self.data_config.num_data_vars("state")
if self.output_std:
# Pred. dim. in grid cell
- self.grid_output_dim = 2 * self.data_config.num_data_vars("state")
+ self.grid_output_dim = 2 * num_state_vars
else:
# Pred. dim. in grid cell
- self.grid_output_dim = self.data_config.num_data_vars("state")
+ self.grid_output_dim = num_state_vars
# Store constant per-variable std.-dev. weighting
# NOTE that this is the inverse of the multiplicative weighting
# in wMSE/wMAE
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index d14b2fd8..988c5c9a 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -3,50 +3,45 @@
import torch
# First-party
-from neural_lam import config
+from neural_lam.datastore.multizarr import config
+from neural_lam.datastore.base import BaseDatastore
class WeatherDataset(torch.utils.data.Dataset):
"""
Dataset class for weather data.
- This class loads and processes weather data from zarr files based on the
- provided configuration. It supports splitting the data into train,
- validation, and test sets.
+ This class loads and processes weather data from a given datastore.
"""
def __init__(
self,
+ datastore: BaseDatastore,
split="train",
ar_steps=3,
batch_size=4,
standardize=True,
control_only=False,
- data_config="neural_lam/data_config.yaml",
):
super().__init__()
- assert split in (
- "train",
- "val",
- "test",
- ), "Unknown dataset split"
-
self.split = split
self.batch_size = batch_size
self.ar_steps = ar_steps
self.control_only = control_only
- self.data_config = config.Config.from_file(data_config)
+ self.datastore = datastore
- self.state = self.data_config.process_dataset("state", self.split)
- assert self.state is not None, "State dataset not found"
- self.forcing = self.data_config.process_dataset("forcing", self.split)
- self.state_times = self.state.time.values
+ self.da_state = self.datastore.get_dataarray(category="state", split=self.split)
+ self.da_forcing = self.datastore.get_dataarray(category="forcing", split=self.split)
+ self.state_times = self.da_state.time.values
# Set up for standardization
# TODO: This will become part of ar_model.py soon!
self.standardize = standardize
if standardize:
+ self.da_state_mean, self.da_state_std = self.datastore.get_normalization_dataarray(category="state")
+ self.da_forcing_mean, self.da_forcing_std = self.datastore.get_normalization_dataarray(category="forcing")
+
state_stats = self.data_config.load_normalization_stats(
"state", datatype="torch"
)
@@ -55,7 +50,7 @@ def __init__(
state_stats["state_std"],
)
- if self.forcing is not None:
+ if self.da_forcing is not None:
forcing_stats = self.data_config.load_normalization_stats(
"forcing", datatype="torch"
)
@@ -66,22 +61,33 @@ def __init__(
def __len__(self):
# Skip first and last time step
- return len(self.state.time) - self.ar_steps
+ return len(self.da_state.time) - self.ar_steps
def __getitem__(self, idx):
+ """
+ Return a single training sample, which consists of the initial states,
+ target states, forcing and batch times.
+ """
+ # TODO: could use xr.DataArrays instead of torch.tensor when normalizing
+ # so that we can make use of xrray's broadcasting capabilities. This would
+ # allow us to easily normalize only with global, grid-wise or some other
+ # normalization statistics. Currently, the implementation below assumes
+ # the normalization statistics are global with one scalar value (mean, std)
+ # for each feature.
+
sample = torch.tensor(
- self.state.isel(time=slice(idx, idx + self.ar_steps)).values,
+ self.da_state.isel(time=slice(idx, idx + self.ar_steps)).values,
dtype=torch.float32,
)
forcing = (
torch.tensor(
- self.forcing.isel(
+ self.da_forcing.isel(
time=slice(idx + 2, idx + self.ar_steps)
).values,
dtype=torch.float32,
)
- if self.forcing is not None
+ if self.da_forcing is not None
else torch.tensor([], dtype=torch.float32)
)
@@ -89,17 +95,22 @@ def __getitem__(self, idx):
target_states = sample[2:]
batch_times = (
- self.state.isel(time=slice(idx + 2, idx + self.ar_steps))
+ self.da_state.isel(time=slice(idx + 2, idx + self.ar_steps))
.time.values.astype(str)
.tolist()
)
if self.standardize:
- init_states = (init_states - self.state_mean) / self.state_std
- target_states = (target_states - self.state_mean) / self.state_std
-
- if self.forcing is not None:
- forcing = (forcing - self.forcing_mean) / self.forcing_std
+ state_mean = self.da_state_mean.values
+ state_std = self.da_state_std.values
+ forcing_mean = self.da_forcing_mean.values
+ forcing_std = self.da_forcing_std.values
+
+ init_states = (init_states - state_mean) / state_std
+ target_states = (target_states - state_mean) / state_std
+
+ if self.da_forcing is not None:
+ forcing = (forcing - forcing_mean) / forcing_std
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps-2, N_grid, d_features)
diff --git a/plot_graph.py b/plot_graph.py
index 73acc801..db4dc536 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -7,7 +7,8 @@
import torch_geometric as pyg
# First-party
-from neural_lam import config, utils
+from neural_lam import utils
+from neural_lam.datastore.multizarr import config
MESH_HEIGHT = 0.1
MESH_LEVEL_DIST = 0.2
diff --git a/requirements.txt b/requirements.txt
index 9e506785..7c851c97 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,7 +12,7 @@ tueplots>=0.0.8
plotly>=5.15.0
xarray>=0.20.1
zarr>=2.10.0
-dask>=2022.0.0
+dask
pandas>=1.4.0
# for dev
pre-commit>=2.15.0
diff --git a/test_ewc.py b/test_ewc.py
new file mode 100644
index 00000000..dc4ad6fb
--- /dev/null
+++ b/test_ewc.py
@@ -0,0 +1,17 @@
+import xarray as xr
+
+
+credentials_key = "546V9NGV07UQCBM80Y47"
+credentials_secret = "8n61wiWFojIkxJM4MC5luoZNBDoitIqvHLXkXs9i"
+credentials_endpoint_url = "https://object-store.os-api.cci1.ecmwf.int"
+
+ds = xr.open_zarr(
+ "s3://danra/v0.4.0/single_levels.zarr/",
+ consolidated=True,
+ storage_options={
+ "key": credentials_key,
+ "secret": credentials_secret,
+ "client_kwargs": {"endpoint_url": credentials_endpoint_url},
+ },
+)
+print(ds)
\ No newline at end of file
diff --git a/tests/data_config.yaml b/tests/data_config.yaml
deleted file mode 100644
index b36098e2..00000000
--- a/tests/data_config.yaml
+++ /dev/null
@@ -1,162 +0,0 @@
-name: danra
-state:
- zarrs:
- - path: "data/danra/single_levels.zarr"
- dims:
- time: time
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- - path: "data/danra/height_levels.zarr"
- dims:
- time: time
- level: altitude
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- surface_vars:
- - u10m
- - v10m
- - t2m
- surface_units:
- - m/s
- - m/s
- - K
- atmosphere_vars:
- - u
- - v
- - t
- atmosphere_units:
- - m/s
- - m/s
- - K
- levels:
- - 100
-forcing:
- zarrs:
- - path: "data/danra/single_levels.zarr"
- dims:
- time: time
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- - path: "data/forcings.zarr"
- dims:
- time: time
- level: null
- x: x
- y: y
- grid: null
- surface_vars:
- - cape_column # just as a technical test
- - icei0m
- - vis0m
- - xhail0m
- - hour_cos
- - hour_sin
- - year_cos
- - year_sin
- surface_units:
- - J/kg
- - kg/m^2 # just as a technical test :)
- - m
- - m
- - ""
- - ""
- - ""
- - ""
- atmosphere_vars: null
- atmosphere_units: null
- levels: null
- window: 3 # Number of time steps to use for forcing (odd)
-static:
- zarrs:
- - path: "data/danra/single_levels.zarr"
- dims:
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- surface_vars:
- - pres0m # just as a technical test
- surface_units:
- - Pa
- atmosphere_vars: null
- atmosphere_units: null
- levels: null
-boundary:
- zarrs: # This is not used currently, but soon ERA% boundaries will be used
- - path: "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
- dims:
- time: time
- level: level
- x: longitude
- y: latitude
- grid: null
- lat_lon_names:
- lon: longitude
- lat: latitude
- mask:
- path: "data/boundary_mask.zarr"
- dims:
- x: x
- y: y
- surface_vars:
- - t2m
- surface_units:
- - K
- atmosphere_vars: null
- atmosphere_units: null
- levels: null
- window: 3
-utilities:
- normalization:
- zarrs:
- - path: "data/normalization.zarr"
- stats_vars:
- state_mean: state_mean
- state_std: state_std
- forcing_mean: forcing_mean
- forcing_std: forcing_std
- diff_mean: diff_mean
- diff_std: diff_std
- combined_stats:
- - vars:
- - icei0m
- - vis0m
- - vars:
- - cape_column
- - xhail0m
-grid_shape_state:
- y: 589
- x: 789
-splits:
- train:
- start: 1990-09-01T00
- end: 1990-09-01T02
- val:
- start: 1990-09-11T00
- end: 1990-09-11T02
- test:
- start: 1990-09-11T00
- end: 1990-09-11T02
-projection:
- class: LambertConformal # Name of class in cartopy.crs
- kwargs:
- central_longitude: 6.22
- central_latitude: 56.0
- standard_parallels: [47.6, 64.4]
diff --git a/tests/datastore_configs/mllam.example.danra.yaml b/tests/datastore_configs/mllam.example.danra.yaml
new file mode 100644
index 00000000..2f1cfddf
--- /dev/null
+++ b/tests/datastore_configs/mllam.example.danra.yaml
@@ -0,0 +1,74 @@
+schema_version: v0.2.0
+dataset_version: v0.1.0
+
+architecture:
+ input_variables:
+ static: [grid_index, static_feature]
+ state: [time, grid_index, state_feature]
+ forcing: [time, grid_index, forcing_feature]
+ input_coord_ranges:
+ time:
+ start: 1990-09-03T00:00
+ end: 1990-09-09T00:00
+ step: PT3H
+ chunking:
+ time: 6
+
+inputs:
+ danra_height_levels:
+ path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr
+ dims: [time, x, y, altitude]
+ variables:
+ u:
+ altitude:
+ values: [100,]
+ units: m
+ v:
+ altitude:
+ values: [100, ]
+ units: m
+ dim_mapping:
+ time:
+ method: rename
+ dim: time
+ state_feature:
+ method: stack_variables_by_var_name
+ dims: [altitude]
+ name_format: f"{var_name}{altitude}m"
+ grid_index:
+ method: stack
+ dims: [x, y]
+ target_architecture_variable: state
+
+ danra_surface:
+ path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr
+ dims: [time, x, y]
+ variables:
+ # shouldn't really be using sea-surface pressure as "forcing", but don't
+ # have radiation varibles in danra yet
+ - pres_seasurface
+ dim_mapping:
+ time:
+ method: rename
+ dim: time
+ grid_index:
+ method: stack
+ dims: [x, y]
+ forcing_feature:
+ method: stack_variables_by_var_name
+ name_format: f"{var_name}"
+ target_architecture_variable: forcing
+
+ danra_lsm:
+ path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr
+ dims: [x, y]
+ variables:
+ - lsm
+ dim_mapping:
+ grid_index:
+ method: stack
+ dims: [x, y]
+ static_feature:
+ method: stack_variables_by_var_name
+ name_format: f"{var_name}"
+ target_architecture_variable: static
diff --git a/neural_lam/data_config.yaml b/tests/datastore_configs/multizarr.danra.yaml
similarity index 86%
rename from neural_lam/data_config.yaml
rename to tests/datastore_configs/multizarr.danra.yaml
index 63756002..27d84f33 100644
--- a/neural_lam/data_config.yaml
+++ b/tests/datastore_configs/multizarr.danra.yaml
@@ -1,7 +1,7 @@
name: danra
state:
zarrs:
- - path: "data/danra/single_levels.zarr"
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
time: time
level: null
@@ -11,7 +11,7 @@ state:
lat_lon_names:
lon: lon
lat: lat
- - path: "data/danra/height_levels.zarr"
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr"
dims:
time: time
level: altitude
@@ -41,7 +41,7 @@ state:
- 100
forcing:
zarrs:
- - path: "data/danra/single_levels.zarr"
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
time: time
level: null
@@ -51,7 +51,7 @@ forcing:
lat_lon_names:
lon: lon
lat: lat
- - path: "data/forcings.zarr"
+ - path: "data/danra_multizarr/datetime_forcings.zarr"
dims:
time: time
level: null
@@ -82,7 +82,7 @@ forcing:
window: 3 # Number of time steps to use for forcing (odd)
static:
zarrs:
- - path: "data/danra/single_levels.zarr"
+ - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
level: null
x: x
@@ -126,7 +126,7 @@ boundary:
utilities:
normalization:
zarrs:
- - path: "data/normalization.zarr"
+ - path: "data/danra_multizarr/normalization.zarr"
stats_vars:
state_mean: state_mean
state_std: state_std
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
new file mode 100644
index 00000000..6d11b7f9
--- /dev/null
+++ b/tests/test_mllam_dataset.py
@@ -0,0 +1,8 @@
+from neural_lam.datastore import MLLAMDatastore
+from neural_lam.weather_dataset import WeatherDataset
+
+
+def test_mllam():
+ config_path = "tests/datastore_configs/mllam.example.danra.yaml"
+ datastore = MLLAMDatastore(config_path=config_path)
+ dataset = WeatherDataset(datastore=datastore)
\ No newline at end of file
diff --git a/tests/test_analysis_dataset.py b/tests/test_multizarr_dataset.py
similarity index 71%
rename from tests/test_analysis_dataset.py
rename to tests/test_multizarr_dataset.py
index d7191a01..064abf3f 100644
--- a/tests/test_analysis_dataset.py
+++ b/tests/test_multizarr_dataset.py
@@ -3,7 +3,8 @@
# First-party
from create_mesh import main as create_mesh
-from neural_lam.config import Config
+from neural_lam.datastore.multizarr import MultiZarrDatastore
+# from neural_lam.datasets.config import Config
from neural_lam.weather_dataset import WeatherDataset
# Disable weights and biases to avoid unnecessary logging
@@ -13,28 +14,29 @@
def test_load_analysis_dataset():
# TODO: Access rights should be fixed for pooch to work
- if not os.path.exists("data/danra"):
- print("Please download test data first: python docs/download_danra.py")
- return
- data_config_file = "tests/data_config.yaml"
- config = Config.from_file(data_config_file)
+ datastore = MultiZarrDatastore(
+ config_path="tests/datastore_configs/multizarr.danra.yaml"
+ )
- var_state_names = config.vars_names("state")
- var_state_units = config.vars_units("state")
- num_state_vars = config.num_data_vars("state")
+ var_state_names = datastore.get_vars_names(category="state")
+ var_state_units = datastore.get_vars_units(category="state")
+ num_state_vars = datastore.get_num_data_vars(category="state")
assert len(var_state_names) == len(var_state_units) == num_state_vars
- var_forcing_names = config.vars_names("forcing")
- var_forcing_units = config.vars_units("forcing")
- num_forcing_vars = config.num_data_vars("forcing")
+ var_forcing_names = datastore.get_vars_names(category="forcing")
+ var_forcing_units = datastore.get_vars_units(category="forcing")
+ num_forcing_vars = datastore.get_num_data_vars(category="forcing")
assert len(var_forcing_names) == len(var_forcing_units) == num_forcing_vars
+
+ stats = datastore.get_normalization_stats(category="state")
+
# Assert dataset can be loaded
- ds = config.open_zarrs("state")
+ ds = datastore.get_dataarray(category="state")
grid = ds.sizes["y"] * ds.sizes["x"]
- dataset = WeatherDataset(split="train", ar_steps=3, standardize=False)
+ dataset = WeatherDataset(datastore=datastore, split="train", ar_steps=3, standardize=True)
batch = dataset[0]
# return init_states, target_states, forcing, batch_times
# init_states: (2, N_grid, d_features)
diff --git a/tests/test_forecast_dataset.py_ b/tests/test_npy_forecast_dataset.py
similarity index 79%
rename from tests/test_forecast_dataset.py_
rename to tests/test_npy_forecast_dataset.py
index f91170c9..230485ec 100644
--- a/tests/test_forecast_dataset.py_
+++ b/tests/test_npy_forecast_dataset.py
@@ -3,12 +3,13 @@
# Third-party
import pooch
+import pytest
# First-party
from create_mesh import main as create_mesh
-from neural_lam.config import Config
-from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
+from neural_lam.datastore.npyfiles import NumpyFilesDatastore
+from neural_lam.datastore.multizarr import MultiZarrDatastore
from train_model import main as train_model
# Disable weights and biases to avoid unnecessary logging
@@ -25,7 +26,8 @@
)
-def test_retrieve_data_ewc():
+@pytest.fixture(scope="session", autouse=True)
+def ewc_testdata_path():
# Download and unzip test data into data/meps_example_reduced
pooch.retrieve(
url=S3_FULL_PATH,
@@ -34,13 +36,30 @@ def test_retrieve_data_ewc():
path="data",
fname="meps_example_reduced.zip",
)
+
+ return "data/meps_example_reduced"
-def test_load_reduced_meps_dataset():
- # The data_config.yaml file is downloaded and extracted in
- # test_retrieve_data_ewc together with the dataset itself
- data_config_file = "data/meps_example_reduced/data_config.yaml"
- dataset_name = "meps_example_reduced"
+def test_load_reduced_meps_dataset(ewc_testdata_path):
+ datastore = NumpyFilesDatastore(
+ root_path=ewc_testdata_path
+ )
+ datastore = MultiZarrDatastore(
+ config_path="tests/data_config.yaml"
+ )
+
+ datastore.get_xy(category="state", stacked=True)
+
+ import matplotlib.pyplot as plt
+ da = datastore.get_dataarray(category="forcing", split="train").unstack("grid_index")
+ da.isel(analysis_time=0, feature=-1, time=slice(0, 4)).plot(col="time", col_wrap=4)
+ plt.show()
+
+ da = datastore.get_dataarray(category="state", split="train").unstack("grid_index")
+ da.isel(analysis_time=0, feature=0, time=slice(0, 4)).plot(col="time", row="ensemble_member")
+ plt.show()
+
+ import ipdb; ipdb.set_trace()
dataset = WeatherDataset(dataset_name="meps_example_reduced")
config = Config.from_file(data_config_file)
From c52f98ee204ba599699b237e2a610d8e9bd9c4d6 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sat, 6 Jul 2024 13:43:58 +0200
Subject: [PATCH 102/273] npy mllam nearly done
---
.gitignore | 3 +
neural_lam/datastore/base.py | 75 ++++-
neural_lam/datastore/mllam.py | 129 ++++++--
neural_lam/datastore/multizarr/store.py | 4 +-
neural_lam/datastore/npyfiles/store.py | 289 +++++++++++++-----
neural_lam/models/ar_model.py | 40 +--
neural_lam/models/base_graph_model.py | 4 +-
neural_lam/models/graph_lam.py | 4 +-
neural_lam/weather_dataset.py | 245 +++++++++++----
test_ewc.py | 17 --
.../example.danra.yaml} | 28 +-
.../data_config.yaml} | 4 +-
tests/test_mllam_dataset.py | 34 ++-
tests/test_multizarr_dataset.py | 5 +-
tests/test_npy_forecast_dataset.py | 49 ++-
15 files changed, 682 insertions(+), 248 deletions(-)
delete mode 100644 test_ewc.py
rename tests/datastore_configs/{mllam.example.danra.yaml => mllam/example.danra.yaml} (77%)
rename tests/datastore_configs/{multizarr.danra.yaml => multizarr/data_config.yaml} (96%)
diff --git a/.gitignore b/.gitignore
index 43968c74..05de13e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -76,3 +76,6 @@ tags
# Coc configuration directory
.vim
.vscode
+
+# macos
+.DS_Store
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index abcaff95..2c7470fc 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -1,10 +1,8 @@
import cartopy.crs as ccrs
import numpy as np
-import pandas as pd
-import torch
import xarray as xr
-from typing import List, Dict
+from typing import List, Dict, Union
import abc
import dataclasses
@@ -83,15 +81,40 @@ def get_num_data_vars(self, category: str) -> int:
The number of data variables.
"""
pass
+
+
+ @abc.abstractmethod
+ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
+ """
+ Return the normalization dataarray for the given category. This should contain
+ a `{category}_mean` and `{category}_std` variable for each variable in the category.
+ For `category=="state"`, the dataarray should also contain a `state_diff_mean` and
+ `state_diff_std` variable for the one-step differences of the state variables. The
+ return dataarray should at least have dimensions of `({category}_feature)`, but can
+ also include for example `grid_index` (if the normalisation is done per grid point for
+ example).
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ xr.Dataset
+ The normalization dataarray for the given category, with variables for the mean
+ and standard deviation of the variables (and differences for state variables).
+ """
+ pass
@abc.abstractmethod
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
"""
- Return the processed dataset for the given category and test/train/val-split that covers
- the entire timeline of the dataset.
- The returned dataarray is expected to at minimum have dimensions of `(time, grid_index, feature)` so
+ Return the processed data (as a single `xr.DataArray`) for the given category and
+ test/train/val-split that covers the entire timeline of the dataset.
+ The returned dataarray is expected to at minimum have dimensions of `(time, grid_index, {category}_feature)` so
that any spatial dimensions have been stacked into a single dimension and all variables
- and levels have been stacked into a single feature dimension.
+ and levels have been stacked into a single feature dimension named by the `category` of data being loaded.
Any additional dimensions (for example `ensemble_member` or `analysis_time`) should be kept as separate
dimensions in the dataarray, and `WeatherDataset` will handle the sampling of the data.
@@ -148,6 +171,8 @@ class BaseCartesianDatastore(BaseDatastore):
- `get_xy_extent` (method): Return the extent of the x, y coordinates for a given category of data.
- `get_xy` (method): Return the x, y coordinates of the dataset.
"""
+
+ CARTESIAN_COORDS = ["y", "x"]
@property
@abc.abstractmethod
@@ -213,4 +238,38 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.
"""
xy = self.get_xy(category, stacked=False)
- return [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
\ No newline at end of file
+ return [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
+
+ def unstack_grid_coords(self, da_or_ds: Union[xr.DataArray, xr.Dataset]) -> Union[xr.DataArray, xr.Dataset]:
+ """
+ Stack the spatial grid coordinates into separate `x` and `y` dimensions (the names
+ can be set by the `CARTESIAN_COORDS` attribute) to create a 2D grid.
+
+ Parameters
+ ----------
+ da_or_ds : xr.DataArray or xr.Dataset
+ The dataarray or dataset to unstack the grid coordinates of.
+
+ Returns
+ -------
+ xr.DataArray or xr.Dataset
+ The dataarray or dataset with the grid coordinates unstacked.
+ """
+ return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack("grid_index")
+
+ def stack_grid_coords(self, da_or_ds: Union[xr.DataArray, xr.Dataset]) -> Union[xr.DataArray, xr.Dataset]:
+ """
+ Stack the spatial grid coordinated (by default `x` and `y`, but this can be set by the
+ `CARTESIAN_COORDS` attribute) into a single `grid_index` dimension.
+
+ Parameters
+ ----------
+ da_or_ds : xr.DataArray or xr.Dataset
+ The dataarray or dataset to stack the grid coordinates of.
+
+ Returns
+ -------
+ xr.DataArray or xr.Dataset
+ The dataarray or dataset with the grid coordinates stacked.
+ """
+ return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
\ No newline at end of file
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index ee16fc18..f822dd03 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -1,4 +1,5 @@
-from typing import List
+from typing import List, Union
+from pathlib import Path
from numpy import ndarray
@@ -14,10 +15,31 @@ class MLLAMDatastore(BaseCartesianDatastore):
Datastore class for the MLLAM dataset.
"""
- def __init__(self, config_path, n_boundary_points=30):
- self._config_path = config_path
+ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
+ """
+ Construct a new MLLAMDatastore from the configuration file at `config_path`. A boundary mask
+ is created with `n_boundary_points` boundary points. If `reuse_existing` is True, the dataset
+ is loaded from a zarr file if it exists, otherwise it is created from the configuration file.
+
+ Parameters
+ ----------
+ config_path : str
+ The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file`
+ method to then call `mllam_data_prep.create_dataset` to create the dataset.
+ n_boundary_points : int
+ The number of boundary points to use in the boundary mask.
+ reuse_existing : bool
+ Whether to reuse an existing dataset zarr file if it exists.
+ """
+ self._config_path = Path(config_path)
self._config = mdp.Config.from_yaml_file(config_path)
- self._ds = mdp.create_dataset(config=self._config)
+ fp_ds = self._config_path.parent / self._config_path.name.replace(".yaml", ".zarr")
+ if reuse_existing and fp_ds.exists():
+ self._ds = xr.open_zarr(fp_ds, consolidated=True)
+ else:
+ self._ds = mdp.create_dataset(config=self._config)
+ if reuse_existing:
+ self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points
def step_length(self) -> int:
@@ -28,34 +50,75 @@ def get_vars_units(self, category: str) -> List[str]:
return self._ds[f"{category}_unit"].values.tolist()
def get_vars_names(self, category: str) -> List[str]:
+ import ipdb; ipdb.set_trace()
return self._ds[f"{category}_longname"].values.tolist()
def get_num_data_vars(self, category: str) -> int:
- return len(self._ds[category].data_vars)
+ return self._ds[f"{category}_feature"].count().item()
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
- # TODO: Implement split handling in mllam-data-prep, for now we hardcode that
- # train will be the first 80%, then validation 10% and test 10%
da_category = self._ds[category]
- n_samples = len(da_category.time)
- # compute the split indices
- if split == "train":
- i_start, i_end = 0, int(0.8 * n_samples)
- elif split == "val":
- i_start, i_end = int(0.8 * n_samples), int(0.9 * n_samples)
- elif split == "test":
- i_start, i_end = int(0.9 * n_samples), n_samples
+
+ if "time" not in da_category.dims:
+ return da_category
else:
- raise ValueError(f"Unknown split {split}")
+ t_start = self._ds.splits.sel(split_name=split, split_part="start").load().item()
+ t_end = self._ds.splits.sel(split_name=split, split_part="end").load().item()
+ return da_category.sel(time=slice(t_start, t_end))
+
+ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
+ """
+ Return the normalization dataarray for the given category. This should contain
+ a `{category}_mean` and `{category}_std` variable for each variable in the category.
+ For `category=="state"`, the dataarray should also contain a `state_diff_mean` and
+ `state_diff_std` variable for the one-step differences of the state variables.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ xr.Dataset
+ The normalization dataarray for the given category, with variables for the mean
+ and standard deviation of the variables (and differences for state variables).
+ """
+ ops = ["mean", "std"]
+ split = "train"
+ stats_variables = {
+ f"{category}__{split}__{op}": f"{category}_{op}"
+ for op in ops
+ }
+ if category == "state":
+ stats_variables.update({
+ f"state__{split}__diff_{op}": f"state_diff_{op}"
+ for op in ops
+ })
+
+ ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)
+ return ds_stats
- da_split = da_category.isel(time=slice(i_start, i_end))
- return da_split
@property
def boundary_mask(self) -> xr.DataArray:
- da_mask = xr.ones_like(self._ds["state"].isel(time=0).isel(variable=0))
- da_mask.isel(x=slice(0, self._n_boundary_points), y=slice(0, self._n_boundary_points)).values = 0
- return da_mask
+ """
+ Produce a 0/1 mask for the boundary points of the dataset, these will sit at the edges of the
+ domain (in x/y extent) and will be used to mask out the boundary points from the loss function
+ and to overwrite the boundary points from the prediction. For now this is created when the mask
+ is requested, but in the future this could be saved to the zarr file.
+
+ Returns
+ -------
+ xr.DataArray
+ A 0/1 mask for the boundary points of the dataset, where 1 is a boundary point and 0 is not.
+ """
+ ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
+ da_state_variable = ds_unstacked["state"].isel(time=0).isel(state_feature=0)
+ da_domain_allzero = xr.zeros_like(da_state_variable)
+ ds_unstacked["boundary_mask"] = da_domain_allzero.isel(x=slice(self._n_boundary_points, -self._n_boundary_points), y=slice(self._n_boundary_points, -self._n_boundary_points))
+ ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(1)
+ return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)
@property
def coords_projection(self) -> ccrs.Projection:
@@ -66,11 +129,35 @@ def coords_projection(self) -> ccrs.Projection:
@property
def grid_shape_state(self):
+ """
+ The shape of the cartesian grid for the state variables.
+
+ Returns
+ -------
+ CartesianGridShape
+ The shape of the cartesian grid for the state variables.
+ """
return CartesianGridShape(
x=self._ds["state"].x.size, y=self._ds["state"].y.size
)
def get_xy(self, category: str, stacked: bool) -> ndarray:
+ """
+ Return the x, y coordinates of the dataset.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates.
+
+ Returns
+ -------
+ np.ndarray or tuple(np.ndarray, np.ndarray)
+ The x, y coordinates of the dataset with shape `(2, N_y, N_x)` if `stacked=True` or
+ a tuple of two arrays with shape `((N_y, N_x), (N_y, N_x))` if `stacked=False`.
+ """
da_x = self._ds[category].x
da_y = self._ds[category].y
if stacked:
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 38617984..1abd11af 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -383,7 +383,7 @@ def _select_stats_by_category(self, combined_stats, category):
Returns:
xr.Dataset: The normalization statistics for the dataset."""
if category == "state":
- stats = combined_stats.loc[dict(variable=self.vars_names(category))]
+ stats = combined_stats.loc[dict(variable=self.get_vars_names(category=category))]
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
return stats
elif category == "forcing":
@@ -521,7 +521,7 @@ def _apply_time_split(self, dataset, split="train"):
dataset (xr.Dataset): The xarray Dataset object.
split (str): The time split to filter the dataset.
- Returns:
+ Returns:["window"]
xr.Dataset: The xarray Dataset object filtered by the time split."""
start, end = (
self._config["splits"][split]["start"],
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 35e53004..88a7cb83 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -123,6 +123,7 @@ class NumpyFilesDatastore(BaseCartesianDatastore):
d_forcing = 1
"""
is_ensemble = True
+ is_forecast = True
def __init__(
self,
@@ -134,8 +135,7 @@ def __init__(
self._num_ensemble_members = 2
self.root_path = Path(root_path)
- self._config = NpyConfig.from_file(self.root_path / "data_config.yaml")
- pass
+ self.config = NpyConfig.from_file(self.root_path / "data_config.yaml")
def get_dataarray(self, category: str, split: str) -> DataArray:
"""
@@ -155,27 +155,66 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
-------
xr.DataArray
The data array for the given category and split, with dimensions per category:
- state: `[time, analysis_time, grid_index, feature, ensemble_member]`
- forcing & static: `[time, analysis_time, grid_index, feature]`
+ state: `[time, analysis_time, grid_index, feature, ensemble_member]`
+ forcing: `[time, analysis_time, grid_index, feature]`
+ static: `[grid_index, feature]`
"""
if category == "state":
+ das = []
# for the state category, we need to load all ensemble members
- da = xr.concat(
- [
- self._get_single_timeseries_dataarray(category=category, split=split, member=member)
- for member in range(self._num_ensemble_members)
- ],
- dim="ensemble_member"
- )
+ for member in range(self._num_ensemble_members):
+ da_member = self._get_single_timeseries_dataarray(features=self.get_vars_names(category="state"), split=split, member=member)
+ das.append(da_member)
+ da = xr.concat(das, dim="ensemble_member")
+
+ elif category == "forcing":
+ # the forcing features are in separate files, so we need to load them separately
+ features = ["toa_downwelling_shortwave_flux", "column_water"]
+ das = [self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features]
+ da = xr.concat(das, dim="feature")
+
+ elif category == "static":
+ # the static features are collected in three files:
+ # - surface_geopotential
+ # - border_mask
+ # - x, y
+ das = []
+ for features in [["surface_geopotential"], ["border_mask"], ["x", "y"]]:
+ da = self._get_single_timeseries_dataarray(features=features, split=split)
+ das.append(da)
+ da = xr.concat(das, dim="feature").transpose("grid_index", "feature")
+
else:
- da = self._get_single_timeseries_dataarray(category=category, split=split)
+ raise NotImplementedError(category)
+
+ da = da.rename(dict(feature=f"{category}_feature"))
+
+ if category == "forcing":
+ # add datetime forcing as a feature
+ # to do this we create a forecast time variable which has the dimensions of
+ # (analysis_time, elapsed_forecast_time) with values that are the actual forecast time of each
+ # time step. By calling .chunk({"elapsed_forecast_time": 1}) this time variable is turned into
+ # a dask array and so execution of the calculation is delayed until the feature
+ # values are actually used.
+ da_forecast_time = (da.analysis_time + da.elapsed_forecast_time).chunk({"elapsed_forecast_time": 1})
+ da_datetime_forcing_features = self._calc_datetime_forcing_features(da_time=da_forecast_time)
+ da = xr.concat([da, da_datetime_forcing_features], dim=f"{category}_feature")
+
+ da.name = category
+
+ # check that we have the right features
+ actual_features = list(da[f"{category}_feature"].values)
+ expected_features = self.get_vars_names(category=category)
+ if actual_features != expected_features:
+ raise ValueError(f"Expected features {expected_features}, got {actual_features}")
+
return da
- def _get_single_timeseries_dataarray(self, category: str, split: str, member: int = None) -> DataArray:
+ def _get_single_timeseries_dataarray(self, features: str, split: str, member: int = None) -> DataArray:
"""
- Get the data array spanning the complete time series for a given category and split
+ Get the data array spanning the complete time series for a given set of features and split
of data. If the category is 'state', the member argument should be specified to select
- the ensemble member to load. The data will be loaded as a dask array, so that the data
+ the ensemble member to load. The data will be loaded using dask.delayed, so that the data
isn't actually loaded until it's needed.
Parameters
@@ -191,58 +230,95 @@ def _get_single_timeseries_dataarray(self, category: str, split: str, member: in
-------
xr.DataArray
The data array for the given category and split, with dimensions
- `[time, analysis_time, grid_index, feature]` for all categories of data
+ `[elapsed_forecast_time, analysis_time, grid_index, feature]` for all categories of data
"""
assert split in ("train", "val", "test"), "Unknown dataset split"
- if member is not None and category != "state":
+ if member is not None and features != self.get_vars_names(category="state"):
raise ValueError("Member can only be specified for the 'state' category")
# XXX: we here assume that the grid shape is the same for all categories
grid_shape = self.grid_shape_state
- analysis_times = self._get_analysis_times(split=split)
- fp_split = self.root_path / "samples" / split
+ fp_samples = self.root_path / "samples" / split
- file_dims = ["time", "y", "x", "feature"]
- elapsed_time = self.step_length * np.arange(self._num_timesteps) * np.timedelta64(1, "h")
- arr_shape = [len(elapsed_time)] + grid_shape
- coords = dict(
- analysis_time=analysis_times,
- time=elapsed_time,
- y=np.arange(grid_shape[0]),
- x=np.arange(grid_shape[1]),
- )
-
- extra_kwargs = {}
+ file_params = {}
add_feature_dim = False
- if category == "state":
+ features_vary_with_analysis_time = True
+ if features == self.get_vars_names(category="state"):
filename_format = STATE_FILENAME_FORMAT
+ file_dims = ["elapsed_forecast_time", "y", "x", "feature"]
# only select one member for now
- extra_kwargs["member_id"] = member
- # state has multiple features
- num_state_variables = self.get_num_data_vars
- arr_shape += [num_state_variables]
- coords["feature"] = self.get_vars_names(category="state")
- elif category == "forcing":
+ file_params["member_id"] = member
+ elif features == ["toa_downwelling_shortwave_flux"]:
filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
- arr_shape += [1]
- # XXX: this should really be saved in the data-config
- coords["feature"] = ["toa_downwelling_shortwave_flux"]
+ file_dims = ["elapsed_forecast_time", "y", "x", "feature"]
add_feature_dim = True
- elif category == "static":
+ elif features == ["column_water"]:
filename_format = COLUMN_WATER_FILENAME_FORMAT
- arr_shape += [1]
- # XXX: this should really be saved in the data-config
- coords["feature"] = ["column_water"]
+ file_dims = ["y", "x", "feature"]
+ add_feature_dim = True
+ elif features == ["surface_geopotential"]:
+ filename_format = "surface_geopotential.npy"
+ file_dims = ["y", "x", "feature"]
add_feature_dim = True
+ features_vary_with_analysis_time = False
+ # XXX: surface_geopotential is the same for all splits, and so saved in static/
+ fp_samples = self.root_path / "static"
+ import ipdb; ipdb.set_trace()
+ elif features == ["border_mask"]:
+ filename_format = "border_mask.npy"
+ file_dims = ["y", "x", "feature"]
+ add_feature_dim = True
+ features_vary_with_analysis_time = False
+ # XXX: border_mask is the same for all splits, and so saved in static/
+ fp_samples = self.root_path / "static"
+ elif features == ["x", "y"]:
+ filename_format = "nwp_xy.npy"
+ file_dims = ["y", "x", "feature"]
+ features_vary_with_analysis_time = False
+ # XXX: x, y are the same for all splits, and so saved in static/
+ fp_samples = self.root_path / "static"
else:
- raise NotImplementedError(f"Category {category} not supported")
+ raise NotImplementedError(f"Reading of variables set `{features}` not supported")
+
+ if features_vary_with_analysis_time:
+ dims = ["analysis_time"] + file_dims
+ else:
+ dims = file_dims
+
+ coords = {}
+ arr_shape = []
+ for d in dims:
+ if d == "elapsed_forecast_time":
+ coord_values = self.step_length * np.arange(self._num_timesteps) * np.timedelta64(1, "h")
+ elif d == "analysis_time":
+ coord_values = self._get_analysis_times(split=split)
+ elif d == "y":
+ coord_values = np.arange(grid_shape[0])
+ elif d == "x":
+ coord_values = np.arange(grid_shape[1])
+ elif d == "feature":
+ coord_values = features
+ else:
+ raise NotImplementedError(f"Dimension {d} not supported")
- filepaths = [
- fp_split / filename_format.format(analysis_time=analysis_time, **extra_kwargs)
- for analysis_time in analysis_times
- ]
+ print(f"{d}: {len(coord_values)}")
+
+ coords[d] = coord_values
+ if d != "analysis_time":
+ # analysis_time varies across the different files, but not within a single file
+ arr_shape.append(len(coord_values))
+
+ print(f"{features}: {dims=} {file_dims=} {arr_shape=}")
+
+ if features_vary_with_analysis_time:
+ filepaths = [
+ fp_samples / filename_format.format(analysis_time=analysis_time, **file_params)
+ for analysis_time in coords["analysis_time"]
+ ]
+ else:
+ filepaths = [fp_samples / filename_format.format(**file_params)]
# use dask.delayed to load the numpy files, so that loading isn't
# done until the data is actually needed
@@ -259,28 +335,21 @@ def _load_np(fp):
) for fp in filepaths
]
- arr_all = dask.array.stack(arrays, axis=0)
+ if features_vary_with_analysis_time:
+ arr_all = dask.array.stack(arrays, axis=0)
+ else:
+ arr_all = arrays[0]
- da = xr.DataArray(
- arr_all,
- dims=["analysis_time"] + file_dims,
- coords=coords,
- name=category
- )
+ # if features == ["column_water"]:
+ # # for column water, we need to repeat the array for each forecast time
+ # # first insert a new axis for the forecast time
+ # arr_all = np.expand_dims(arr_all, 1)
+ # # and then repeat
+ # arr_all = dask.array.repeat(arr_all, self._num_timesteps, axis=1)
+ da = xr.DataArray(arr_all, dims=dims, coords=coords)
# stack the [x, y] dimensions into a `grid_index` dimension
- da = da.stack(grid_index=["y", "x"])
-
- if category == "forcing":
- # add datetime forcing as a feature
- # to do this we create a forecast time variable which has the dimensions of
- # (analysis_time, time) with values that are the actual forecast time of each
- # time step. But calling .chunk({"time": 1}) this time variable is turned into
- # a dask array and so execution of the calculation is delayed until the feature
- # values are actually used.
- da_forecast_time = (da.time + da.analysis_time).chunk({"time": 1})
- da_datetime_forcing_features = self._calc_datetime_forcing_features(da_time=da_forecast_time)
- da = xr.concat([da, da_datetime_forcing_features], dim="feature")
+ da = self.stack_grid_coords(da)
return da
@@ -322,22 +391,27 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
np.sin(da_year_angle),
np.cos(da_year_angle),
),
- dim="feature",
+ dim="forcing_feature",
)
da_datetime_forcing = (da_datetime_forcing + 1) / 2 # Rescale to [0,1]
- da_datetime_forcing["feature"] = ["sin_hour", "cos_hour", "sin_year", "cos_year"]
+ da_datetime_forcing["forcing_feature"] = ["sin_hour", "cos_hour", "sin_year", "cos_year"]
return da_datetime_forcing
def get_vars_units(self, category: str) -> torch.List[str]:
if category == "state":
- return self._config["dataset"]["var_units"]
+ return self.config["dataset"]["var_units"]
else:
raise NotImplementedError(f"Category {category} not supported")
def get_vars_names(self, category: str) -> torch.List[str]:
if category == "state":
- return self._config["dataset"]["var_names"]
+ return self.config["dataset"]["var_names"]
+ elif category == "forcing":
+ # XXX: this really shouldn't be hard-coded here, this should be in the config
+ return ["toa_downwelling_shortwave_flux", "column_water", "sin_hour", "cos_hour", "sin_year", "cos_year"]
+ elif category == "static":
+ return ["surface_geopotential", "border_mask", "x", "y"]
else:
raise NotImplementedError(f"Category {category} not supported")
@@ -362,12 +436,77 @@ def step_length(self):
@property
def coords_projection(self):
- return self._config.coords_projection
+ return self.config.coords_projection
@property
def grid_shape_state(self):
- return self._config.grid_shape_state
+ return self.config.grid_shape_state
@property
def boundary_mask(self):
- return np.load(self.root_path / "static" / "border_mask.npy")
\ No newline at end of file
+ xs, ys = self.get_xy(category="state", stacked=False)
+ assert np.all(xs[0,:] == xs[-1,:])
+ assert np.all(ys[:,0] == ys[:,-1])
+ x = xs[0,:]
+ y = ys[:,0]
+ values = np.load(self.root_path / "static" / "border_mask.npy")
+ da_mask = xr.DataArray(values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask")
+ da_mask_stacked_xy = self.stack_grid_coords(da_mask)
+ return da_mask_stacked_xy
+
+
+ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
+ """
+ Return the normalization dataarray for the given category. This should contain
+ a `{category}_mean` and `{category}_std` variable for each variable in the category.
+ For `category=="state"`, the dataarray should also contain a `state_diff_mean` and
+ `state_diff_std` variable for the one-step differences of the state variables.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ xr.Dataset
+ The normalization dataarray for the given category, with variables for the mean
+ and standard deviation of the variables (and differences for state variables).
+ """
+ def load_pickled_tensor(fn):
+ return torch.load(self.root_path / "static" / fn).numpy()
+
+ mean_diff_values = None
+ std_diff_values = None
+ if category == "state":
+ mean_values = load_pickled_tensor("parameter_mean.pt")
+ std_values = load_pickled_tensor("parameter_std.pt")
+ mean_diff_values = load_pickled_tensor("diff_mean.pt")
+ std_diff_values = load_pickled_tensor("diff_std.pt")
+ elif category == "forcing":
+ flux_stats = load_pickled_tensor("flux_stats.pt") # (2,)
+ flux_mean, flux_std = flux_stats
+ # manually add hour sin/cos and day-of-year sin/cos stats for now
+ # the mean/std for column_water is hardcoded for now
+ mean_values = np.array([flux_mean, 0.34033957, 0.0, 0.0, 0.0, 0.0])
+ std_values = np.array([flux_std, 0.4661307, 1.0, 1.0, 1.0, 1.0])
+
+ else:
+ raise NotImplementedError(f"Category {category} not supported")
+
+ feature_dim_name = f"{category}_feature"
+ variables = {
+ f"{category}_mean": (feature_dim_name, mean_values),
+ f"{category}_std": (feature_dim_name, std_values),
+ }
+
+ if mean_diff_values is not None and std_diff_values is not None:
+ variables["state_diff_mean"] = (feature_dim_name, mean_diff_values)
+ variables["state_diff_std"] = (feature_dim_name, std_diff_values)
+
+ ds_norm = xr.Dataset(
+ variables,
+ coords={ feature_dim_name: self.get_vars_names(category=category) }
+ )
+
+ return ds_norm
\ No newline at end of file
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 7c206028..6bc595d0 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -10,7 +10,7 @@
# First-party
from neural_lam import metrics, vis
-from neural_lam.datastore.multizarr import config
+from neural_lam.datastore.base import BaseDatastore
class ARModel(pl.LightningModule):
@@ -22,27 +22,33 @@ class ARModel(pl.LightningModule):
# pylint: disable=arguments-differ
# Disable to override args/kwargs from superclass
- def __init__(self, args):
+ def __init__(self, args, datastore: BaseDatastore, forcing_window_size: int):
super().__init__()
self.save_hyperparameters()
self.args = args
- self.data_config = config.Config.from_file(args.data_config)
-
- num_state_vars = self.data_config.num_data_vars("state")
- num_forcing_vars = self.data_config.num_data_vars("forcing")
- da_static_features = self.data_config.process_dataset("static")
+ # XXX: should be this be somewhere else?
+ split = "train"
+ num_state_vars = datastore.get_num_data_vars(category="state")
+ num_forcing_vars = datastore.get_num_data_vars(category="forcing")
+ da_static_features = datastore.get_dataarray(category="static", split=split)
+ da_state_stats = datastore.get_normalization_dataarray(category="state")
+ da_boundary_mask = datastore.boundary_mask
+
# Load static features for grid/data
- static = da_static_features.values
self.register_buffer(
"grid_static_features",
- torch.tensor(static.values, dtype=torch.float32),
+ torch.tensor(da_static_features.values, dtype=torch.float32),
persistent=False,
)
- state_stats = self.data_config.load_normalization_stats(
- "state", datatype="torch"
- )
+ state_stats = {
+ "state_mean": torch.tensor(da_state_stats.state_mean.values, dtype=torch.float32),
+ "state_std": torch.tensor(da_state_stats.state_std.values, dtype=torch.float32),
+ "diff_mean": torch.tensor(da_state_stats.state_diff_mean.values, dtype=torch.float32),
+ "diff_std": torch.tensor(da_state_stats.state_diff_std.values, dtype=torch.float32),
+ }
+
for key, val in state_stats.items():
self.register_buffer(key, val, persistent=False)
@@ -73,14 +79,14 @@ def __init__(self, args):
self.grid_dim = (
2 * self.grid_output_dim
+ grid_static_dim
- + self.data_config.num_data_vars("forcing")
- * self.data_config.forcing.window
+ + num_forcing_vars
+ * forcing_window_size
)
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
- boundary_mask = self.data_config.load_boundary_mask()
+ boundary_mask = torch.tensor(da_boundary_mask.values, dtype=torch.float32)
self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
@@ -88,7 +94,7 @@ def __init__(self, args):
) # (num_grid_nodes, 1), 1 for non-border
# Number of hours per pred. step
- self.step_length = self.data_config.step_length
+ self.step_length = datastore.step_length
self.val_metrics = {
"mse": [],
}
@@ -423,7 +429,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
pred_t[:, var_i],
target_t[:, var_i],
self.interior_mask[:, 0],
- self.data_config,
+ self.datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self.step_length * t_i} h)",
vrange=var_vrange,
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index f055b782..9f517101 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -13,8 +13,8 @@ class BaseGraphModel(ARModel):
the encode-process-decode idea.
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, datastore, forcing_window_size):
+ super().__init__(args, datastore=datastore, forcing_window_size=forcing_window_size)
# Load graph with static features
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index f767fba0..45e89dff 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -15,8 +15,8 @@ class GraphLAM(BaseGraphModel):
Oskarsson et al. (2023).
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, datastore, forcing_window_size):
+ super().__init__(args, datastore, forcing_window_size)
assert (
not self.hierarchical
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 988c5c9a..d9ffb06e 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,12 +1,66 @@
+import dataclasses
+import warnings
+
# Third-party
import pytorch_lightning as pl
import torch
+import xarray as xr
+import numpy as np
# First-party
from neural_lam.datastore.multizarr import config
from neural_lam.datastore.base import BaseDatastore
+@dataclasses.dataclass
+class TrainingSample:
+ """
+ A dataclass to hold a single training sample of `ar_steps` autoregressive steps,
+ which consists of the initial states, target states, forcing and batch times. The
+ inititial and target states should have `d_features` features, and the forcing should
+ have `d_windowed_forcing` features.
+
+ Parameters
+ ----------
+ init_states : torch.Tensor
+ The initial states of the training sample, shape (2, N_grid, d_features).
+ target_states : torch.Tensor
+ The target states of the training sample, shape (ar_steps, N_grid, d_features).
+ forcing : torch.Tensor
+ The forcing of the training sample, shape (ar_steps, N_grid, d_windowed_forcing).
+ batch_times : np.ndarray
+ The times of the batch, shape (ar_steps,).
+ """
+ init_states: torch.Tensor
+ target_states: torch.Tensor
+ forcing: torch.Tensor
+ batch_times: np.ndarray
+
+ def __post_init__(self):
+ """
+ Validate the shapes of the tensors match between the different components of the training sample.
+
+ # init_states: (2, N_grid, d_features)
+ # target_states: (ar_steps, N_grid, d_features)
+ # forcing: (ar_steps, N_grid, d_windowed_forcing)
+ # batch_times: (ar_steps,)
+ """
+ assert self.init_states.shape[0] == 2
+ _, N_grid, d_features = self.init_states.shape
+ N_pred_steps = self.target_states.shape[0]
+
+ # check number of grid points
+ if not (self.target_states.shape[1] == self.target_states.shape[1] == N_grid):
+ raise Exception(f"Number of grid points do not match, got {self.target_states.shape[1]=} and {self.target_states.shape[2]=}, expected {N_grid=}")
+
+ # check number of features for init and target states
+ assert self.target_states.shape[2] == d_features
+
+ # check that target, forcing and batch times have the same number of prediction steps
+ if not (self.target_states.shape[0] == self.forcing.shape[0] == self.batch_times.shape[0] == N_pred_steps):
+ raise Exception(f"Number of prediction steps do not match, got {self.target_states.shape[0]=}, {self.forcing.shape[0]=} and {self.batch_times.shape[0]=}, expected {N_pred_steps=}")
+
+
class WeatherDataset(torch.utils.data.Dataset):
"""
Dataset class for weather data.
@@ -19,104 +73,169 @@ def __init__(
datastore: BaseDatastore,
split="train",
ar_steps=3,
+ forcing_window_size=3,
batch_size=4,
standardize=True,
- control_only=False,
):
super().__init__()
self.split = split
self.batch_size = batch_size
self.ar_steps = ar_steps
- self.control_only = control_only
self.datastore = datastore
self.da_state = self.datastore.get_dataarray(category="state", split=self.split)
self.da_forcing = self.datastore.get_dataarray(category="forcing", split=self.split)
- self.state_times = self.da_state.time.values
+ self.forcing_window_size = forcing_window_size
# Set up for standardization
# TODO: This will become part of ar_model.py soon!
self.standardize = standardize
if standardize:
- self.da_state_mean, self.da_state_std = self.datastore.get_normalization_dataarray(category="state")
- self.da_forcing_mean, self.da_forcing_std = self.datastore.get_normalization_dataarray(category="forcing")
+ self.ds_state_stats = self.datastore.get_normalization_dataarray(category="state")
- state_stats = self.data_config.load_normalization_stats(
- "state", datatype="torch"
- )
- self.state_mean, self.state_std = (
- state_stats["state_mean"],
- state_stats["state_std"],
- )
+ self.da_state_mean = self.ds_state_stats.state_mean
+ self.da_state_std = self.ds_state_stats.state_std
if self.da_forcing is not None:
- forcing_stats = self.data_config.load_normalization_stats(
- "forcing", datatype="torch"
- )
- self.forcing_mean, self.forcing_std = (
- forcing_stats["forcing_mean"],
- forcing_stats["forcing_std"],
- )
+ self.ds_forcing_stats = self.datastore.get_normalization_dataarray(category="forcing")
+ self.da_forcing_mean = self.ds_forcing_stats.forcing_mean
+ self.da_forcing_std = self.ds_forcing_stats.forcing_std
def __len__(self):
- # Skip first and last time step
- return len(self.da_state.time) - self.ar_steps
+ if self.datastore.is_forecast:
+ # for now we simply create a single sample for each analysis time
+ # and then the next ar_steps forecast times
+ if self.datastore.is_ensemble:
+ warnings.warn(
+ "only using first ensemble member, so dataset size is effectively"
+ f" reduced by the number of ensemble members ({self.da_state.ensemble_member.size})", UserWarning
+ )
+ return self.da_state.analysis_time.size * self.da_state.ensemble_member.size
+ return self.da_state.analysis_time.size
+ else:
+ # Skip first and last time step
+ return len(self.da_state.time) - self.ar_steps
+
+ def _sample_time(self, da, idx, n_steps:int, n_timesteps_offset:int=0):
+ """
+ Produce a time slice of the given dataarray `da` (state or forcing) starting at `idx` and
+ with `n_steps` steps. The `n_timesteps_offset` parameter is used to offset the start of the
+ sample, for example to exclude the first two steps when sampling the forcing data (and to
+ produce the windowing samples of forcing data by increasing the offset for each window).
+
+ Parameters
+ ----------
+ da : xr.DataArray
+ The dataarray to sample from. This is expected to have a `time` dimension if the datastore
+ is providing analysis only data, and a `analysis_time` and `elapsed_forecast_time` dimensions
+ if the datastore is providing forecast data.
+ idx : int
+ The index of the time step to start the sample from.
+ n_steps : int
+ The number of time steps to include in the sample.
+
+ """
+ # selecting the time slice
+ if self.datastore.is_forecast:
+ # this implies that the data will have both `analysis_time` and `elapsed_forecast_time` dimensions
+ # for forecasts we for now simply select a analysis time and then
+ # the next ar_steps forecast times
+ da = da.isel(analysis_time=idx, elapsed_forecast_time=slice(n_timesteps_offset, n_steps + n_timesteps_offset))
+ # create a new time dimension so that the produced sample has a `time` dimension, similarly
+ # to the analysis only data
+ da["time"] = da.analysis_time + da.elapsed_forecast_time
+ da = da.swap_dims({"elapsed_forecast_time": "time"})
+ else:
+ # only `time` dimension for analysis only data
+ da = da.isel(time=slice(idx + n_timesteps_offset, idx + n_steps + n_timesteps_offset))
+ return da
def __getitem__(self, idx):
"""
Return a single training sample, which consists of the initial states,
- target states, forcing and batch times.
- """
- # TODO: could use xr.DataArrays instead of torch.tensor when normalizing
- # so that we can make use of xrray's broadcasting capabilities. This would
- # allow us to easily normalize only with global, grid-wise or some other
- # normalization statistics. Currently, the implementation below assumes
- # the normalization statistics are global with one scalar value (mean, std)
- # for each feature.
-
- sample = torch.tensor(
- self.da_state.isel(time=slice(idx, idx + self.ar_steps)).values,
- dtype=torch.float32,
- )
+ target states, forcing and batch times.
+
+ The implementation currently uses xarray.DataArray objects for the normalisation
+ so that we can make us of xarray's broadcasting capabilities. This makes it possible
+ to normalise with both global means, but also for example where a grid-point mean
+ has been computed. This code will have to be replace if normalisation is to be done
+ on the GPU to handle different shapes of the normalisation.
+
+ Parameters
+ ----------
+ idx : int
+ The index of the sample to return, this will refer to the time of the initial state.
- forcing = (
- torch.tensor(
- self.da_forcing.isel(
- time=slice(idx + 2, idx + self.ar_steps)
- ).values,
- dtype=torch.float32,
- )
- if self.da_forcing is not None
- else torch.tensor([], dtype=torch.float32)
- )
+ Returns
+ -------
+ init_states : TrainingSample
+ A training sample object containing the initial states, target states, forcing and batch times.
+ The batch times are the times of the target steps.
+ """
+ # handling ensemble data
+ if self.datastore.is_ensemble:
+ # for the now the strategy is to simply select a random ensemble member
+ # XXX: this could be changed to include all ensemble members by splitting `idx` into
+ # two parts, one for the analysis time and one for the ensemble member and then increasing
+ # self.__len__ to include all ensemble members
+ i_ensemble = np.random.randint(self.da_state.ensemble_member.size)
+ da_state = self.da_state.isel(ensemble_member=i_ensemble)
+ else:
+ da_state = self.da_state
+
+ if self.da_forcing is not None:
+ if "ensemble_member" in self.da_forcing.dims:
+ raise NotImplementedError("Ensemble member not yet supported for forcing data")
+ da_forcing = self.da_forcing
+ else:
+ da_forcing = xr.DataArray()
+
+ # handle time sampling in a way that is compatible with both analysis and forecast data
+ da_state = self._sample_time(da=da_state, idx=idx, n_steps=2+self.ar_steps)
- init_states = sample[:2]
- target_states = sample[2:]
+ das_forcing = []
+ for n in range(self.forcing_window_size):
+ da_ = self._sample_time(da=da_forcing, idx=idx, n_steps=self.ar_steps, n_timesteps_offset=2+n)
+ if n > 0:
+ da_ = da_.drop_vars("time")
+ das_forcing.append(da_)
+ da_forcing_windowed = xr.concat(das_forcing, dim="window_sample")
+
+ # ensure the dimensions are in the correct order
+ da_state = da_state.transpose("time", "grid_index", "state_feature")
+ da_forcing_windowed = da_forcing_windowed.transpose("time", "grid_index", "forcing_feature", "window_sample")
- batch_times = (
- self.da_state.isel(time=slice(idx + 2, idx + self.ar_steps))
- .time.values.astype(str)
- .tolist()
- )
+ da_init_states = da_state.isel(time=slice(None, 2))
+ da_target_states = da_state.isel(time=slice(2, None))
+
+ batch_times = da_forcing_windowed.time
if self.standardize:
- state_mean = self.da_state_mean.values
- state_std = self.da_state_std.values
- forcing_mean = self.da_forcing_mean.values
- forcing_std = self.da_forcing_std.values
-
- init_states = (init_states - state_mean) / state_std
- target_states = (target_states - state_mean) / state_std
+ da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std
+ da_target_states = (da_target_states - self.da_state_mean) / self.da_state_std
if self.da_forcing is not None:
- forcing = (forcing - forcing_mean) / forcing_std
+ da_forcing_windowed = (da_forcing_windowed - self.da_forcing_mean) / self.da_forcing_std
+
+ # stack the `forcing_feature` and `window_sample` dimensions into a single `forcing_feature` dimension
+ da_forcing_windowed = da_forcing_windowed.stack(forcing_feature_windowed=("forcing_feature", "window_sample"))
+
+ init_states = torch.tensor(da_init_states.values, dtype=torch.float32)
+ target_states = torch.tensor(da_target_states.values, dtype=torch.float32)
+ forcing = torch.tensor(da_forcing_windowed.values, dtype=torch.float32)
# init_states: (2, N_grid, d_features)
- # target_states: (ar_steps-2, N_grid, d_features)
- # forcing: (ar_steps-2, N_grid, d_windowed_forcing)
- # batch_times: (ar_steps-2,)
- return init_states, target_states, forcing, batch_times
+ # target_states: (ar_steps, N_grid, d_features)
+ # forcing: (ar_steps, N_grid, d_windowed_forcing)
+ # batch_times: (ar_steps,)
+
+ return TrainingSample(
+ init_states=init_states,
+ target_states=target_states,
+ forcing=forcing,
+ batch_times=batch_times,
+ )
class WeatherDataModule(pl.LightningDataModule):
diff --git a/test_ewc.py b/test_ewc.py
deleted file mode 100644
index dc4ad6fb..00000000
--- a/test_ewc.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import xarray as xr
-
-
-credentials_key = "546V9NGV07UQCBM80Y47"
-credentials_secret = "8n61wiWFojIkxJM4MC5luoZNBDoitIqvHLXkXs9i"
-credentials_endpoint_url = "https://object-store.os-api.cci1.ecmwf.int"
-
-ds = xr.open_zarr(
- "s3://danra/v0.4.0/single_levels.zarr/",
- consolidated=True,
- storage_options={
- "key": credentials_key,
- "secret": credentials_secret,
- "client_kwargs": {"endpoint_url": credentials_endpoint_url},
- },
-)
-print(ds)
\ No newline at end of file
diff --git a/tests/datastore_configs/mllam.example.danra.yaml b/tests/datastore_configs/mllam/example.danra.yaml
similarity index 77%
rename from tests/datastore_configs/mllam.example.danra.yaml
rename to tests/datastore_configs/mllam/example.danra.yaml
index 2f1cfddf..3be8debb 100644
--- a/tests/datastore_configs/mllam.example.danra.yaml
+++ b/tests/datastore_configs/mllam/example.danra.yaml
@@ -1,18 +1,32 @@
schema_version: v0.2.0
dataset_version: v0.1.0
-architecture:
- input_variables:
+output:
+ variables:
static: [grid_index, static_feature]
state: [time, grid_index, state_feature]
forcing: [time, grid_index, forcing_feature]
- input_coord_ranges:
+ coord_ranges:
time:
start: 1990-09-03T00:00
end: 1990-09-09T00:00
step: PT3H
chunking:
- time: 6
+ time: 1
+ splitting_dim: time
+ splits:
+ train:
+ start: 1990-09-03T00:00
+ end: 1990-09-06T00:00
+ compute_statistics:
+ ops: [mean, std]
+ dims: [grid_index, time]
+ validation:
+ start: 1990-09-06T00:00
+ end: 1990-09-07T00:00
+ test:
+ start: 1990-09-07T00:00
+ end: 1990-09-09T00:00
inputs:
danra_height_levels:
@@ -38,7 +52,7 @@ inputs:
grid_index:
method: stack
dims: [x, y]
- target_architecture_variable: state
+ target_output_variable: state
danra_surface:
path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr
@@ -57,7 +71,7 @@ inputs:
forcing_feature:
method: stack_variables_by_var_name
name_format: f"{var_name}"
- target_architecture_variable: forcing
+ target_output_variable: forcing
danra_lsm:
path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr
@@ -71,4 +85,4 @@ inputs:
static_feature:
method: stack_variables_by_var_name
name_format: f"{var_name}"
- target_architecture_variable: static
+ target_output_variable: static
diff --git a/tests/datastore_configs/multizarr.danra.yaml b/tests/datastore_configs/multizarr/data_config.yaml
similarity index 96%
rename from tests/datastore_configs/multizarr.danra.yaml
rename to tests/datastore_configs/multizarr/data_config.yaml
index 27d84f33..d46afa53 100644
--- a/tests/datastore_configs/multizarr.danra.yaml
+++ b/tests/datastore_configs/multizarr/data_config.yaml
@@ -51,7 +51,7 @@ forcing:
lat_lon_names:
lon: lon
lat: lat
- - path: "data/danra_multizarr/datetime_forcings.zarr"
+ - path: "tests/config_examples/multizarr/datetime_forcings.zarr"
dims:
time: time
level: null
@@ -126,7 +126,7 @@ boundary:
utilities:
normalization:
zarrs:
- - path: "data/danra_multizarr/normalization.zarr"
+ - path: "tests/datastore_configs/multizarr/normalization.zarr"
stats_vars:
state_mean: state_mean
state_std: state_std
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
index 6d11b7f9..7f3c5b27 100644
--- a/tests/test_mllam_dataset.py
+++ b/tests/test_mllam_dataset.py
@@ -1,8 +1,36 @@
+import torch
+
from neural_lam.datastore import MLLAMDatastore
-from neural_lam.weather_dataset import WeatherDataset
+from neural_lam.weather_dataset import WeatherDataset, WeatherDataModule
+from neural_lam.models.graph_lam import GraphLAM
+
+class ModelArgs:
+ output_std = True
+ loss = "mse"
+ restore_opt = False
+ n_example_pred = 1
+ graph = "multiscale" # XXX: this should be superflous when we have already defined the model object
def test_mllam():
- config_path = "tests/datastore_configs/mllam.example.danra.yaml"
+ config_path = "tests/datastore_configs/mllam/example.danra.yaml"
datastore = MLLAMDatastore(config_path=config_path)
- dataset = WeatherDataset(datastore=datastore)
\ No newline at end of file
+ dataset = WeatherDataset(datastore=datastore)
+
+ item = dataset[0]
+
+ data_module = WeatherDataModule(
+ ar_steps_train=3,
+ ar_steps_eval=3,
+ standardize=True,
+ batch_size=2,
+ )
+
+ import ipdb
+ ipdb.set_trace()
+
+ device_name = torch.device("cuda") if torch.cuda.is_available() else "cpu"
+
+ args = ModelArgs()
+
+ model = GraphLAM(args=args, forcing_window_size=dataset.forcing_window_size, datastore=datastore)
\ No newline at end of file
diff --git a/tests/test_multizarr_dataset.py b/tests/test_multizarr_dataset.py
index 064abf3f..05d2e969 100644
--- a/tests/test_multizarr_dataset.py
+++ b/tests/test_multizarr_dataset.py
@@ -15,7 +15,7 @@
def test_load_analysis_dataset():
# TODO: Access rights should be fixed for pooch to work
datastore = MultiZarrDatastore(
- config_path="tests/datastore_configs/multizarr.danra.yaml"
+ config_path="tests/datastore_configs/multizarr/data_config.yaml"
)
var_state_names = datastore.get_vars_names(category="state")
@@ -31,6 +31,9 @@ def test_load_analysis_dataset():
assert len(var_forcing_names) == len(var_forcing_units) == num_forcing_vars
stats = datastore.get_normalization_stats(category="state")
+
+ import ipdb
+ ipdb.set_trace()
# Assert dataset can be loaded
diff --git a/tests/test_npy_forecast_dataset.py b/tests/test_npy_forecast_dataset.py
index 230485ec..6e5a4dc3 100644
--- a/tests/test_npy_forecast_dataset.py
+++ b/tests/test_npy_forecast_dataset.py
@@ -26,7 +26,7 @@
)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="session")
def ewc_testdata_path():
# Download and unzip test data into data/meps_example_reduced
pooch.retrieve(
@@ -44,29 +44,16 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
datastore = NumpyFilesDatastore(
root_path=ewc_testdata_path
)
- datastore = MultiZarrDatastore(
- config_path="tests/data_config.yaml"
- )
-
datastore.get_xy(category="state", stacked=True)
- import matplotlib.pyplot as plt
- da = datastore.get_dataarray(category="forcing", split="train").unstack("grid_index")
- da.isel(analysis_time=0, feature=-1, time=slice(0, 4)).plot(col="time", col_wrap=4)
- plt.show()
+ datastore.get_dataarray(category="forcing", split="train").unstack("grid_index")
+ datastore.get_dataarray(category="state", split="train").unstack("grid_index")
- da = datastore.get_dataarray(category="state", split="train").unstack("grid_index")
- da.isel(analysis_time=0, feature=0, time=slice(0, 4)).plot(col="time", row="ensemble_member")
- plt.show()
-
- import ipdb; ipdb.set_trace()
+ dataset = WeatherDataset(datastore=datastore)
- dataset = WeatherDataset(dataset_name="meps_example_reduced")
- config = Config.from_file(data_config_file)
-
- var_names = config.values["dataset"]["var_names"]
- var_units = config.values["dataset"]["var_units"]
- var_longnames = config.values["dataset"]["var_longnames"]
+ var_names = datastore.config.values["dataset"]["var_names"]
+ var_units = datastore.config.values["dataset"]["var_units"]
+ var_longnames = datastore.config.values["dataset"]["var_longnames"]
assert len(var_names) == len(var_longnames)
assert len(var_names) == len(var_units)
@@ -77,19 +64,22 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
# Hardcoded in model
n_input_steps = 2
- n_forcing_features = config.values["dataset"]["num_forcing_features"]
+ n_forcing_features = datastore.config.values["dataset"]["num_forcing_features"]
n_state_features = len(var_names)
- n_prediction_timesteps = dataset.sample_length - n_input_steps
+ n_prediction_timesteps = dataset.ar_steps
- nx, ny = config.values["grid_shape_state"]
+ nx, ny = datastore.config.values["grid_shape_state"]
n_grid = nx * ny
# check that the dataset is not empty
assert len(dataset) > 0
# get the first item
- init_states, target_states, forcing = dataset[0]
-
+ item = dataset[0]
+ init_states = item.init_states
+ target_states = item.target_states
+ forcing = item.forcing
+
# check that the shapes of the tensors are correct
assert init_states.shape == (n_input_steps, n_grid, n_state_features)
assert target_states.shape == (
@@ -102,8 +92,11 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
n_grid,
n_forcing_features,
)
-
- static_data = load_static_data(dataset_name=dataset_name)
+
+ static_data = {
+ "border_mask": datastore.boundary_mask.values,
+ "grid_static_features": datastore.get_dataarray(category="static", split="train").values
+ }
required_props = {
"border_mask",
@@ -116,7 +109,7 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
}
# check the sizes of the props
- assert static_data["border_mask"].shape == (n_grid, 1)
+ assert static_data["border_mask"].shape == (n_grid, )
assert static_data["grid_static_features"].shape == (
n_grid,
n_grid_static_features,
From 80f3639c99845e20a35fcd5d82ebbc234c10c935 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 7 Jul 2024 21:46:38 +0200
Subject: [PATCH 103/273] minor adjustment
---
neural_lam/datastore/npyfiles/store.py | 42 +++++++++++++-------------
tests/test_npy_forecast_dataset.py | 8 ++++-
2 files changed, 28 insertions(+), 22 deletions(-)
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 88a7cb83..f60cc83e 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -4,6 +4,7 @@
import os
import re
from pathlib import Path
+from typing import List
# Third-party
import dask.delayed
@@ -172,6 +173,16 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
features = ["toa_downwelling_shortwave_flux", "column_water"]
das = [self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features]
da = xr.concat(das, dim="feature")
+
+ # add datetime forcing as a feature
+ # to do this we create a forecast time variable which has the dimensions of
+ # (analysis_time, elapsed_forecast_time) with values that are the actual forecast time of each
+ # time step. By calling .chunk({"elapsed_forecast_time": 1}) this time variable is turned into
+ # a dask array and so execution of the calculation is delayed until the feature
+ # values are actually used.
+ da_forecast_time = (da.analysis_time + da.elapsed_forecast_time).chunk({"elapsed_forecast_time": 1})
+ da_datetime_forcing_features = self._calc_datetime_forcing_features(da_time=da_forecast_time)
+ da = xr.concat([da, da_datetime_forcing_features], dim="feature")
elif category == "static":
# the static features are collected in three files:
@@ -189,38 +200,28 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
da = da.rename(dict(feature=f"{category}_feature"))
- if category == "forcing":
- # add datetime forcing as a feature
- # to do this we create a forecast time variable which has the dimensions of
- # (analysis_time, elapsed_forecast_time) with values that are the actual forecast time of each
- # time step. By calling .chunk({"elapsed_forecast_time": 1}) this time variable is turned into
- # a dask array and so execution of the calculation is delayed until the feature
- # values are actually used.
- da_forecast_time = (da.analysis_time + da.elapsed_forecast_time).chunk({"elapsed_forecast_time": 1})
- da_datetime_forcing_features = self._calc_datetime_forcing_features(da_time=da_forecast_time)
- da = xr.concat([da, da_datetime_forcing_features], dim=f"{category}_feature")
-
- da.name = category
-
# check that we have the right features
- actual_features = list(da[f"{category}_feature"].values)
+ actual_features = da[f"{category}_feature"].values.tolist()
expected_features = self.get_vars_names(category=category)
if actual_features != expected_features:
raise ValueError(f"Expected features {expected_features}, got {actual_features}")
return da
- def _get_single_timeseries_dataarray(self, features: str, split: str, member: int = None) -> DataArray:
+ def _get_single_timeseries_dataarray(self, features: List[str], split: str, member: int = None) -> DataArray:
"""
Get the data array spanning the complete time series for a given set of features and split
- of data. If the category is 'state', the member argument should be specified to select
+ of data. For state features the `member` argument should be specified to select
the ensemble member to load. The data will be loaded using dask.delayed, so that the data
isn't actually loaded until it's needed.
Parameters
----------
- category : str
- The category of the data to load. One of 'state', 'forcing', or 'static'.
+ features : List[str]
+ The list of features to load the data for. For the 'state' category, this should be
+ the result of `self.get_vars_names(category="state")`, for the 'forcing' category this
+ should be the list of forcing features to load, and for the 'static' category this should
+ be the list of static features to load.
split : str
The dataset split to load the data for. One of 'train', 'val', or 'test'.
member : int, optional
@@ -265,7 +266,6 @@ def _get_single_timeseries_dataarray(self, features: str, split: str, member: in
features_vary_with_analysis_time = False
# XXX: surface_geopotential is the same for all splits, and so saved in static/
fp_samples = self.root_path / "static"
- import ipdb; ipdb.set_trace()
elif features == ["border_mask"]:
filename_format = "border_mask.npy"
file_dims = ["y", "x", "feature"]
@@ -391,10 +391,10 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
np.sin(da_year_angle),
np.cos(da_year_angle),
),
- dim="forcing_feature",
+ dim="feature",
)
da_datetime_forcing = (da_datetime_forcing + 1) / 2 # Rescale to [0,1]
- da_datetime_forcing["forcing_feature"] = ["sin_hour", "cos_hour", "sin_year", "cos_year"]
+ da_datetime_forcing["feature"] = ["sin_hour", "cos_hour", "sin_year", "cos_year"]
return da_datetime_forcing
diff --git a/tests/test_npy_forecast_dataset.py b/tests/test_npy_forecast_dataset.py
index 6e5a4dc3..67c128ed 100644
--- a/tests/test_npy_forecast_dataset.py
+++ b/tests/test_npy_forecast_dataset.py
@@ -93,9 +93,15 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
n_forcing_features,
)
+ ds_state_norm = datastore.get_normalization_dataarray(category="state")
+
static_data = {
"border_mask": datastore.boundary_mask.values,
- "grid_static_features": datastore.get_dataarray(category="static", split="train").values
+ "grid_static_features": datastore.get_dataarray(category="static", split="train").values,
+ "data_mean": ds_state_norm.state_mean.values,
+ "data_std": ds_state_norm.state_std.values,
+ "step_diff_mean": ds_state_norm.state_diff_mean.values,
+ "step_diff_std": ds_state_norm.state_diff_std.values,
}
required_props = {
From 5aaa2393f2a5aec3cb761ac6e889f4f251e15373 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 09:50:39 +0200
Subject: [PATCH 104/273] add pooch and tweak pip cicd testing
---
.github/workflows/ci-pip-install-and-test.yml | 5 ++++-
pyproject.toml | 1 +
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/.github/workflows/ci-pip-install-and-test.yml b/.github/workflows/ci-pip-install-and-test.yml
index 307f1829..f23ad258 100644
--- a/.github/workflows/ci-pip-install-and-test.yml
+++ b/.github/workflows/ci-pip-install-and-test.yml
@@ -20,7 +20,10 @@ jobs:
- name: Install package (including dev dependencies)
run: |
python -m pip install .
- python -m pip install pytest
+ # pip can't install from "dev" pdm group in pyproject.toml, should we put these requirements
+ # for running tests in a seperate group? Using "dev" ensures that the these requirements aren't
+ # included in build packages
+ python -m pip install pytest pooch
- name: Print and check torch version
run: |
diff --git a/pyproject.toml b/pyproject.toml
index 50ddca04..9cb78770 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,6 +30,7 @@ requires-python = ">=3.9"
dev = [
"pre-commit>=2.15.0",
"pytest>=8.2.1",
+ "pooch>=1.8.1",
]
[tool.black]
From 66c3b03ec79e5a47d9fba33c1066f3e81e795814 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 09:54:48 +0200
Subject: [PATCH 105/273] combine cicd tests with caching
---
.github/workflows/ci-pdm-install-and-test.yml | 14 ++++++
.../workflows/ci-pip-install-and-test-gpu.yml | 14 ++++++
.github/workflows/ci-pip-install-and-test.yml | 14 ++++++
.github/workflows/run_tests.yml | 43 -------------------
4 files changed, 42 insertions(+), 43 deletions(-)
delete mode 100644 .github/workflows/run_tests.yml
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index 69fc29d3..0517a5f7 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -37,6 +37,20 @@ jobs:
pdm run python -c "import torch; print(torch.__version__)"
pdm run python -c "import torch; assert torch.__version__.endswith('+cpu')"
+ - name: Load cache data
+ uses: actions/cache/restore@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ restore-keys: |
+ ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+
- name: Run tests
run: |
pdm run pytest
+
+ - name: Save cache data
+ uses: actions/cache/save@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
\ No newline at end of file
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index dab7b060..b1264d38 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -32,6 +32,20 @@ jobs:
python -c "import torch; print(torch.__version__)"
python -c "import torch; assert not torch.__version__.endswith('+cpu')"
+ - name: Load cache data
+ uses: actions/cache/restore@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ restore-keys: |
+ ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+
- name: Run tests
run: |
python -m pytest
+
+ - name: Save cache data
+ uses: actions/cache/save@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
\ No newline at end of file
diff --git a/.github/workflows/ci-pip-install-and-test.yml b/.github/workflows/ci-pip-install-and-test.yml
index f23ad258..b6527a7c 100644
--- a/.github/workflows/ci-pip-install-and-test.yml
+++ b/.github/workflows/ci-pip-install-and-test.yml
@@ -30,6 +30,20 @@ jobs:
python -c "import torch; print(torch.__version__)"
python -c "import torch; assert torch.__version__.endswith('+cpu')"
+ - name: Load cache data
+ uses: actions/cache/restore@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ restore-keys: |
+ ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+
- name: Run tests
run: |
python -m pytest
+
+ - name: Save cache data
+ uses: actions/cache/save@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
\ No newline at end of file
diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml
deleted file mode 100644
index 4c677908..00000000
--- a/.github/workflows/run_tests.yml
+++ /dev/null
@@ -1,43 +0,0 @@
-name: Unit Tests
-
-on:
- # trigger on pushes to any branch
- push:
- # and also on PRs to main
- pull_request:
- branches:
- - main
-
-jobs:
- build:
- runs-on: ubuntu-latest
- strategy:
- matrix:
- python-version: ["3.9", "3.10", "3.11", "3.12"]
-
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- pip install torch-geometric>=2.5.2
- - name: Load cache data
- uses: actions/cache/restore@v4
- with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
- restore-keys: |
- ${{ runner.os }}-meps-reduced-example-data-v0.1.0
- - name: Test with pytest
- run: |
- pytest -v -s
- - name: Save cache data
- uses: actions/cache/save@v4
- with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
From 8566b8f2ab3e732942fc96760d7ba2d6d010d11c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 10:00:42 +0200
Subject: [PATCH 106/273] linting
---
.github/workflows/ci-pdm-install-and-test.yml | 2 +-
.github/workflows/ci-pip-install-and-test-gpu.yml | 2 +-
.github/workflows/ci-pip-install-and-test.yml | 4 ++--
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test.yml
index 0517a5f7..7d31f867 100644
--- a/.github/workflows/ci-pdm-install-and-test.yml
+++ b/.github/workflows/ci-pdm-install-and-test.yml
@@ -53,4 +53,4 @@ jobs:
uses: actions/cache/save@v4
with:
path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
\ No newline at end of file
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index b1264d38..c00e92ac 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -48,4 +48,4 @@ jobs:
uses: actions/cache/save@v4
with:
path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
\ No newline at end of file
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
diff --git a/.github/workflows/ci-pip-install-and-test.yml b/.github/workflows/ci-pip-install-and-test.yml
index b6527a7c..c94e70c2 100644
--- a/.github/workflows/ci-pip-install-and-test.yml
+++ b/.github/workflows/ci-pip-install-and-test.yml
@@ -21,7 +21,7 @@ jobs:
run: |
python -m pip install .
# pip can't install from "dev" pdm group in pyproject.toml, should we put these requirements
- # for running tests in a seperate group? Using "dev" ensures that the these requirements aren't
+ # for running tests in a separate group? Using "dev" ensures that the these requirements aren't
# included in build packages
python -m pip install pytest pooch
@@ -46,4 +46,4 @@ jobs:
uses: actions/cache/save@v4
with:
path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
\ No newline at end of file
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
From 29bd9e585d8e73c8d07f3200dd1e751968602773 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 10:12:12 +0200
Subject: [PATCH 107/273] add pyg dep
---
pyproject.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/pyproject.toml b/pyproject.toml
index 9cb78770..f86cf653 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,6 +23,7 @@ dependencies = [
"matplotlib>=3.7.0",
"plotly>=5.15.0",
"torch>=2.3.0",
+ "torch-geometric==2.3.1",
]
requires-python = ">=3.9"
From bc7f0286e0a06f6ed064d746ba142c8444088747 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 11:16:21 +0200
Subject: [PATCH 108/273] set cirun aws region to frankfurt
---
.cirun.yml | 2 ++
1 file changed, 2 insertions(+)
diff --git a/.cirun.yml b/.cirun.yml
index b188d6dc..50d01f21 100644
--- a/.cirun.yml
+++ b/.cirun.yml
@@ -7,6 +7,8 @@ runners:
instance_type: "g4ad.xlarge"
# Ubuntu-20.4, ami image
machine_image: "ami-06fd8a495a537da8b"
+ # use Frankfurt region
+ region: "eu-central-1"
preemptible: false
# Add this label in the "runs-on" param in .github/workflows/.yml
# So that this runner is created for running the workflow
From 2070166e2d66c7cdd75ae61dc730a1d1aceef3fe Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 11:19:59 +0200
Subject: [PATCH 109/273] adapt image
---
.cirun.yml | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/.cirun.yml b/.cirun.yml
index 50d01f21..1c436e55 100644
--- a/.cirun.yml
+++ b/.cirun.yml
@@ -5,10 +5,9 @@ runners:
cloud: "aws"
# https://aws.amazon.com/ec2/instance-types/g4/
instance_type: "g4ad.xlarge"
- # Ubuntu-20.4, ami image
- machine_image: "ami-06fd8a495a537da8b"
# use Frankfurt region
region: "eu-central-1"
+ # use ubuntu 20.04
preemptible: false
# Add this label in the "runs-on" param in .github/workflows/.yml
# So that this runner is created for running the workflow
From e4e86e56d2c68aabcc197c898bf10f5e1cb54633 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 11:23:11 +0200
Subject: [PATCH 110/273] set image
---
.cirun.yml | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/.cirun.yml b/.cirun.yml
index 1c436e55..4a0f4f38 100644
--- a/.cirun.yml
+++ b/.cirun.yml
@@ -5,9 +5,10 @@ runners:
cloud: "aws"
# https://aws.amazon.com/ec2/instance-types/g4/
instance_type: "g4ad.xlarge"
+ # Ubuntu-24.04 LTS, ami image, Frankfurt
+ machine_image: "ami-0e872aee57663ae2d"
# use Frankfurt region
region: "eu-central-1"
- # use ubuntu 20.04
preemptible: false
# Add this label in the "runs-on" param in .github/workflows/.yml
# So that this runner is created for running the workflow
From 1fba8fe578d33335f833915896fa4572f7b1f3a5 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 11:27:20 +0200
Subject: [PATCH 111/273] try different image
---
.cirun.yml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/.cirun.yml b/.cirun.yml
index 4a0f4f38..21b03ab4 100644
--- a/.cirun.yml
+++ b/.cirun.yml
@@ -5,8 +5,8 @@ runners:
cloud: "aws"
# https://aws.amazon.com/ec2/instance-types/g4/
instance_type: "g4ad.xlarge"
- # Ubuntu-24.04 LTS, ami image, Frankfurt
- machine_image: "ami-0e872aee57663ae2d"
+ # Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 22.04), Frankfurt region
+ machine_image: "ami-0ba41b554b28d24a4"
# use Frankfurt region
region: "eu-central-1"
preemptible: false
From 02b77cf97e76423756729c01e5acc7f683cf6a21 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 11 Jul 2024 12:47:39 +0200
Subject: [PATCH 112/273] add pooch to cicd
---
.github/workflows/ci-pip-install-and-test-gpu.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index c00e92ac..4dfc98c8 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -25,7 +25,7 @@ jobs:
- name: Install package (including dev dependencies)
run: |
python -m pip install .
- python -m pip install pytest
+ python -m pip install pytest pooch
- name: Print and check torch version
run: |
From b481929f7823f7c2034c2e5a5b0f4c4f2043f0cc Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 22:47:42 +0200
Subject: [PATCH 113/273] add pdm gpu test
---
...st.yml => ci-pdm-install-and-test-cpu.yml} | 0
.../workflows/ci-pdm-install-and-test-gpu.yml | 60 +++++++++++++++++++
...st.yml => ci-pip-install-and-test-cpu.yml} | 0
3 files changed, 60 insertions(+)
rename .github/workflows/{ci-pdm-install-and-test.yml => ci-pdm-install-and-test-cpu.yml} (100%)
create mode 100644 .github/workflows/ci-pdm-install-and-test-gpu.yml
rename .github/workflows/{ci-pip-install-and-test.yml => ci-pip-install-and-test-cpu.yml} (100%)
diff --git a/.github/workflows/ci-pdm-install-and-test.yml b/.github/workflows/ci-pdm-install-and-test-cpu.yml
similarity index 100%
rename from .github/workflows/ci-pdm-install-and-test.yml
rename to .github/workflows/ci-pdm-install-and-test-cpu.yml
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
new file mode 100644
index 00000000..f9060361
--- /dev/null
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -0,0 +1,60 @@
+# cicd workflow for running tests with pytest
+# needs to first install pdm, then install torch cpu manually and then install the package
+# then run the tests
+
+name: test (pdm install, gpu)
+
+on: [push, pull_request]
+
+jobs:
+ tests:
+ runs-on: "cirun-aws-runner--${{ github.run_id }}"
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v2
+
+ - name: Set up Python 3.9
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Install pdm
+ run: |
+ python -m pip install pdm
+
+ - name: Create venv
+ run: |
+ pdm venv create --with-pip
+ pdm use --venv in-project
+
+ - name: Install torch (GPU CUDA 12.1)
+ run: |
+ python -m pip install torch --index-url https://download.pytorch.org/whl/cu121
+
+ - name: Print and check torch version
+ run: |
+ python -c "import torch; print(torch.__version__)"
+ python -c "import torch; assert not torch.__version__.endswith('+cpu')"
+
+ - name: Install package (including dev dependencies)
+ run: |
+ pdm install
+ pdm install --dev
+
+ - name: Load cache data
+ uses: actions/cache/restore@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ restore-keys: |
+ ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+
+ - name: Run tests
+ run: |
+ pdm run pytest
+
+ - name: Save cache data
+ uses: actions/cache/save@v4
+ with:
+ path: data
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
diff --git a/.github/workflows/ci-pip-install-and-test.yml b/.github/workflows/ci-pip-install-and-test-cpu.yml
similarity index 100%
rename from .github/workflows/ci-pip-install-and-test.yml
rename to .github/workflows/ci-pip-install-and-test-cpu.yml
From bcec4727877b2ffefc80daa1a60419910e757dd0 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 22:51:43 +0200
Subject: [PATCH 114/273] start work on readme
---
README.md | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/README.md b/README.md
index 26d844f7..8d5d2c5b 100644
--- a/README.md
+++ b/README.md
@@ -57,6 +57,13 @@ See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://git
Below follows instructions on how to use Neural-LAM to train and evaluate models.
## Installation
+
+The dependencies in `neural-lam` is handled with [pdm](https://pdm.fming.dev/), but you can still install `neural-lam` directly with pip if you prefer. The benefits of using `pdm` are that [pyproject.toml](pyproject.toml) is automatically updated when you add/remove dependencies (with `pdm add ` or `pdm remove
Date: Tue, 16 Jul 2024 23:08:32 +0200
Subject: [PATCH 115/273] turn meps testdata download into pytest fixture
---
tests/test_mllam_dataset.py | 23 ++++++++++++++---------
1 file changed, 14 insertions(+), 9 deletions(-)
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
index f91170c9..cb992db5 100644
--- a/tests/test_mllam_dataset.py
+++ b/tests/test_mllam_dataset.py
@@ -1,15 +1,18 @@
# Standard library
import os
+from pathlib import Path
# Third-party
import pooch
+import pytest
-# First-party
-from create_mesh import main as create_mesh
from neural_lam.config import Config
+
+# First-party
+from neural_lam.create_mesh import main as create_mesh
+from neural_lam.train_model import main as train_model
from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
-from train_model import main as train_model
# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
@@ -25,7 +28,8 @@
)
-def test_retrieve_data_ewc():
+@pytest.fixture
+def meps_example_reduced_filepath():
# Download and unzip test data into data/meps_example_reduced
pooch.retrieve(
url=S3_FULL_PATH,
@@ -34,16 +38,17 @@ def test_retrieve_data_ewc():
path="data",
fname="meps_example_reduced.zip",
)
+ return Path("data/meps_example_reduced")
-def test_load_reduced_meps_dataset():
+def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
# The data_config.yaml file is downloaded and extracted in
# test_retrieve_data_ewc together with the dataset itself
- data_config_file = "data/meps_example_reduced/data_config.yaml"
- dataset_name = "meps_example_reduced"
+ data_config_file = meps_example_reduced_filepath / "data_config.yaml"
+ dataset_name = meps_example_reduced_filepath.name
- dataset = WeatherDataset(dataset_name="meps_example_reduced")
- config = Config.from_file(data_config_file)
+ dataset = WeatherDataset(dataset_name=dataset_name)
+ config = Config.from_file(str(data_config_file))
var_names = config.values["dataset"]["var_names"]
var_units = config.values["dataset"]["var_units"]
From 49e9bfef051a0cb04a1bf4ff6371a5a117a74af8 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:13:31 +0200
Subject: [PATCH 116/273] adapt README for package
---
README.md | 54 +++++++++++++++++++++++++++---------------------------
1 file changed, 27 insertions(+), 27 deletions(-)
diff --git a/README.md b/README.md
index 50bb0af0..00562f22 100644
--- a/README.md
+++ b/README.md
@@ -48,7 +48,7 @@ Still, some restrictions are inevitable:
## A note on the limited area setting
Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
There are still some parts of the code that is quite specific for the MEPS area use case.
-This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ).
+This is in particular true for the mesh graph creation (`python -m neural_lam.create_mesh`) and some of the constants set in a `data_config.yaml` file (path specified in `python -m neural_lam.train_model --data_config ` ).
If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
We would be happy to support such enhancements.
See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
@@ -81,15 +81,15 @@ The full MEPS dataset can be shared with other researchers on request, contact u
A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX).
Download the file and unzip in the neural-lam directory.
All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `python -m neural_lam.create_mesh`).
-Note that this is far too little data to train any useful models, but all scripts can be ran with it.
+Note that this is far too little data to train any useful models, but all pre-processing and training steps can be run with it.
It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues.
## Pre-processing
-An overview of how the different scripts and files depend on each other is given in this figure:
+An overview of how the different pre-processing steps, training and files depend on each other is given in this figure:
-In order to start training models at least three pre-processing scripts have to be run:
+In order to start training models at least three pre-processing steps have to be run:
* `python -m neural_lam.create_mesh`
* `python -m neural_lam.create_grid_features`
@@ -106,13 +106,13 @@ The graphs used for the different models in the [paper](https://arxiv.org/abs/23
The graph-related files are stored in a directory called `graphs`.
### Create remaining static features
-To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`.
+To create the remaining static files run `python -m neural_lam.create_grid_features` and `python -m neural_lam.create_parameter_weights`.
## Weights & Biases Integration
The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it.
When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface.
If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`.
-The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse).
+The W&B project name is set to `neural-lam`, but this can be changed in the flags of `python -m neural_lam.train_model` (using argsparse).
See the [W&B documentation](https://docs.wandb.ai/) for details.
If you would like to login and use W&B, run:
@@ -216,13 +216,13 @@ data
│ ├── nwp_xy.npy - Coordinates of grid nodes (part of dataset)
│ ├── surface_geopotential.npy - Geopotential at surface of grid nodes (part of dataset)
│ ├── border_mask.npy - Mask with True for grid nodes that are part of border (part of dataset)
-│ ├── grid_features.pt - Static features of grid nodes (create_grid_features.py)
-│ ├── parameter_mean.pt - Means of state parameters (create_parameter_weights.py)
-│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py)
-│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py)
-│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py)
-│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py)
-│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py)
+│ ├── grid_features.pt - Static features of grid nodes (neural_lam.create_grid_features)
+│ ├── parameter_mean.pt - Means of state parameters (neural_lam.create_parameter_weights)
+│ ├── parameter_std.pt - Std.-dev. of state parameters (neural_lam.create_parameter_weights)
+│ ├── diff_mean.pt - Means of one-step differences (neural_lam.create_parameter_weights)
+│ ├── diff_std.pt - Std.-dev. of one-step differences (neural_lam.create_parameter_weights)
+│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (neural_lam.create_parameter_weights)
+│ └── parameter_weights.npy - Loss weights for different state parameters (neural_lam.create_parameter_weights)
├── dataset2
├── ...
└── datasetN
@@ -234,13 +234,13 @@ The structure is shown with examples below:
```
graphs
├── graph1 - Directory with a graph definition
-│ ├── m2m_edge_index.pt - Edges in mesh graph (create_mesh.py)
-│ ├── g2m_edge_index.pt - Edges from grid to mesh (create_mesh.py)
-│ ├── m2g_edge_index.pt - Edges from mesh to grid (create_mesh.py)
-│ ├── m2m_features.pt - Static features of mesh edges (create_mesh.py)
-│ ├── g2m_features.pt - Static features of grid to mesh edges (create_mesh.py)
-│ ├── m2g_features.pt - Static features of mesh to grid edges (create_mesh.py)
-│ └── mesh_features.pt - Static features of mesh nodes (create_mesh.py)
+│ ├── m2m_edge_index.pt - Edges in mesh graph (neural_lam.create_mesh)
+│ ├── g2m_edge_index.pt - Edges from grid to mesh (neural_lam.create_mesh)
+│ ├── m2g_edge_index.pt - Edges from mesh to grid (neural_lam.create_mesh)
+│ ├── m2m_features.pt - Static features of mesh edges (neural_lam.create_mesh)
+│ ├── g2m_features.pt - Static features of grid to mesh edges (neural_lam.create_mesh)
+│ ├── m2g_features.pt - Static features of mesh to grid edges (neural_lam.create_mesh)
+│ └── mesh_features.pt - Static features of mesh nodes (neural_lam.create_mesh)
├── graph2
├── ...
└── graphN
@@ -250,9 +250,9 @@ graphs
To keep track of levels in the mesh graph, a list format is used for the files with mesh graph information.
In particular, the files
```
-│ ├── m2m_edge_index.pt - Edges in mesh graph (create_mesh.py)
-│ ├── m2m_features.pt - Static features of mesh edges (create_mesh.py)
-│ ├── mesh_features.pt - Static features of mesh nodes (create_mesh.py)
+│ ├── m2m_edge_index.pt - Edges in mesh graph (neural_lam.create_mesh)
+│ ├── m2m_features.pt - Static features of mesh edges (neural_lam.create_mesh)
+│ ├── mesh_features.pt - Static features of mesh nodes (neural_lam.create_mesh)
```
all contain lists of length `L`, for a hierarchical mesh graph with `L` layers.
For non-hierarchical graphs `L == 1` and these are all just singly-entry lists.
@@ -263,10 +263,10 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w
```
├── graph1
│ ├── ...
-│ ├── mesh_down_edge_index.pt - Downward edges in mesh graph (create_mesh.py)
-│ ├── mesh_up_edge_index.pt - Upward edges in mesh graph (create_mesh.py)
-│ ├── mesh_down_features.pt - Static features of downward mesh edges (create_mesh.py)
-│ ├── mesh_up_features.pt - Static features of upward mesh edges (create_mesh.py)
+│ ├── mesh_down_edge_index.pt - Downward edges in mesh graph (neural_lam.create_mesh)
+│ ├── mesh_up_edge_index.pt - Upward edges in mesh graph (neural_lam.create_mesh)
+│ ├── mesh_down_features.pt - Static features of downward mesh edges (neural_lam.create_mesh)
+│ ├── mesh_up_features.pt - Static features of upward mesh edges (neural_lam.create_mesh)
│ ├── ...
```
These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels).
From 12cc02b9a1559c60cd1be495900f2cfbeddbe569 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:14:33 +0200
Subject: [PATCH 117/273] remove pdm cicd test (will be in separate PR)
---
.github/workflows/ci-tests.yml | 33 ---------------------------------
1 file changed, 33 deletions(-)
delete mode 100644 .github/workflows/ci-tests.yml
diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml
deleted file mode 100644
index 9b73f298..00000000
--- a/.github/workflows/ci-tests.yml
+++ /dev/null
@@ -1,33 +0,0 @@
-# cicd workflow for running tests with pytest
-# needs to first install pdm, then install torch cpu manually and then install the package
-# then run the tests
-
-name: tests (cpu)
-
-on: [push, pull_request]
-
-jobs:
- tests:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Install pdm
- uses: pdm-project/setup-pdm@v4
- with:
- python-version: "3.10"
- cache: true
-
- - name: Install torch (CPU)
- run: |
- python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
-
- - name: Install package (including dev dependencies)
- run: |
- pdm install
- pdm install --dev
-
- - name: Run tests
- run: |
- pdm run pytest
From b47f50bdb977bf1b63e8f0f5df9ba5a32813cbb2 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:15:02 +0200
Subject: [PATCH 118/273] remove pdm in gitignore
---
.gitignore | 3 ---
1 file changed, 3 deletions(-)
diff --git a/.gitignore b/.gitignore
index ef497608..65e9f6f8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -75,6 +75,3 @@ tags
# Coc configuration directory
.vim
-
-# pdm
-.pdm-python
From 90d99ca7d5d816a7ba13e2373c8bf54ab6203b3c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:17:04 +0200
Subject: [PATCH 119/273] remove pdm and pyproject files (will be sep PR)
---
pdm.lock | 1880 ------------------------------------------------
pyproject.toml | 89 ---
2 files changed, 1969 deletions(-)
delete mode 100644 pdm.lock
delete mode 100644 pyproject.toml
diff --git a/pdm.lock b/pdm.lock
deleted file mode 100644
index 21467c0d..00000000
--- a/pdm.lock
+++ /dev/null
@@ -1,1880 +0,0 @@
-# This file is @generated by PDM.
-# It is not intended for manual editing.
-
-[metadata]
-groups = ["default", "dev"]
-strategy = ["cross_platform", "inherit_metadata"]
-lock_version = "4.4.1"
-content_hash = "sha256:c6c346f14a001266b5cc8a2eafb2081b9bcba755c41eb0f44525436548a09fde"
-
-[[package]]
-name = "aiohttp"
-version = "3.9.5"
-requires_python = ">=3.8"
-summary = "Async http client/server framework (asyncio)"
-groups = ["default"]
-dependencies = [
- "aiosignal>=1.1.2",
- "async-timeout<5.0,>=4.0; python_version < \"3.11\"",
- "attrs>=17.3.0",
- "frozenlist>=1.1.1",
- "multidict<7.0,>=4.5",
- "yarl<2.0,>=1.0",
-]
-files = [
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"},
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"},
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"},
- {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"},
- {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"},
- {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"},
- {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"},
- {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"},
- {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"},
- {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"},
-]
-
-[[package]]
-name = "aiosignal"
-version = "1.3.1"
-requires_python = ">=3.7"
-summary = "aiosignal: a list of registered asynchronous callbacks"
-groups = ["default"]
-dependencies = [
- "frozenlist>=1.1.0",
-]
-files = [
- {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
- {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
-]
-
-[[package]]
-name = "async-timeout"
-version = "4.0.3"
-requires_python = ">=3.7"
-summary = "Timeout context manager for asyncio programs"
-groups = ["default"]
-marker = "python_version < \"3.11\""
-files = [
- {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
- {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
-]
-
-[[package]]
-name = "attrs"
-version = "23.2.0"
-requires_python = ">=3.7"
-summary = "Classes Without Boilerplate"
-groups = ["default"]
-files = [
- {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
- {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
-]
-
-[[package]]
-name = "cartopy"
-version = "0.23.0"
-requires_python = ">=3.9"
-summary = "A Python library for cartographic visualizations with Matplotlib"
-groups = ["default"]
-dependencies = [
- "matplotlib>=3.5",
- "numpy>=1.21",
- "packaging>=20",
- "pyproj>=3.3.1",
- "pyshp>=2.3",
- "shapely>=1.7",
-]
-files = [
- {file = "Cartopy-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:374e66f816c3bafa48ffdbf6abaefa67063b405fac5f425f9be241cdf3498352"},
- {file = "Cartopy-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2bae450c4c913796cad0b7ce05aa2fa78d1788de47989f0a03183397648e24be"},
- {file = "Cartopy-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40437596e8ac5e74575eab822c661f4e725bd995cfd9e445069695fe9086b42"},
- {file = "Cartopy-0.23.0-cp310-cp310-win_amd64.whl", hash = "sha256:3292d6d403137eed80d32014c2f28de6282bed8824213f4b4c2170f388b24a1b"},
- {file = "Cartopy-0.23.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86b07b6794b616674e4e485b8574e9197bca54a4467d28dd01ae0bf178f8dc2b"},
- {file = "Cartopy-0.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8dece2aa8d5ff7bf989ded6b5f07c980fb5bb772952bc7cdeab469738abdecee"},
- {file = "Cartopy-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9dfd28352dc83d6b4e4cf85d84cb50fc4886d4c1510d61f4c7cf22477d1156f"},
- {file = "Cartopy-0.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:b2671b5354e43220f8e1074e7fe30a8b9f71cb38407c78e51db9c97772f0320b"},
- {file = "Cartopy-0.23.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:80b9fd666fd47f6370d29f7ad4e352828d54aaf688a03d0b83b51e141cfd77fa"},
- {file = "Cartopy-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:43e36b8b7e7e373a5698757458fd28fafbbbf5f3ebbe2d378f6a5ec3993d6dc0"},
- {file = "Cartopy-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550173b91155d4d81cd14b4892cb6cabe3dd32bd34feacaa1ec78c0e56287832"},
- {file = "Cartopy-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:55219ee0fb069cc3254426e87382cde03546e86c3f7c6759f076823b1e3a44d9"},
- {file = "Cartopy-0.23.0.tar.gz", hash = "sha256:231f37b35701f2ba31d94959cca75e6da04c2eea3a7f14ce1c75ee3b0eae7676"},
-]
-
-[[package]]
-name = "certifi"
-version = "2024.2.2"
-requires_python = ">=3.6"
-summary = "Python package for providing Mozilla's CA Bundle."
-groups = ["default"]
-files = [
- {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"},
- {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"},
-]
-
-[[package]]
-name = "cfgv"
-version = "3.4.0"
-requires_python = ">=3.8"
-summary = "Validate configuration and produce human readable error messages."
-groups = ["default"]
-files = [
- {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"},
- {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"},
-]
-
-[[package]]
-name = "charset-normalizer"
-version = "3.3.2"
-requires_python = ">=3.7.0"
-summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
-groups = ["default"]
-files = [
- {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"},
- {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"},
- {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"},
- {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"},
- {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
-]
-
-[[package]]
-name = "click"
-version = "8.1.7"
-requires_python = ">=3.7"
-summary = "Composable command line interface toolkit"
-groups = ["default"]
-dependencies = [
- "colorama; platform_system == \"Windows\"",
-]
-files = [
- {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
- {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
-]
-
-[[package]]
-name = "colorama"
-version = "0.4.6"
-requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
-summary = "Cross-platform colored terminal text."
-groups = ["default", "dev"]
-files = [
- {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
- {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
-]
-
-[[package]]
-name = "contourpy"
-version = "1.2.1"
-requires_python = ">=3.9"
-summary = "Python library for calculating contours of 2D quadrilateral grids"
-groups = ["default"]
-dependencies = [
- "numpy>=1.20",
-]
-files = [
- {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"},
- {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"},
- {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"},
- {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"},
- {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"},
- {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"},
- {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"},
- {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"},
- {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"},
- {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"},
- {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"},
- {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"},
- {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"},
- {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"},
- {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"},
- {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"},
- {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"},
- {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"},
- {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"},
- {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"},
- {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"},
- {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"},
- {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"},
- {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"},
- {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"},
-]
-
-[[package]]
-name = "cycler"
-version = "0.12.1"
-requires_python = ">=3.8"
-summary = "Composable style cycles"
-groups = ["default"]
-files = [
- {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"},
- {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"},
-]
-
-[[package]]
-name = "distlib"
-version = "0.3.8"
-summary = "Distribution utilities"
-groups = ["default"]
-files = [
- {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"},
- {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
-]
-
-[[package]]
-name = "docker-pycreds"
-version = "0.4.0"
-summary = "Python bindings for the docker credentials store API"
-groups = ["default"]
-dependencies = [
- "six>=1.4.0",
-]
-files = [
- {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"},
- {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"},
-]
-
-[[package]]
-name = "exceptiongroup"
-version = "1.2.1"
-requires_python = ">=3.7"
-summary = "Backport of PEP 654 (exception groups)"
-groups = ["dev"]
-marker = "python_version < \"3.11\""
-files = [
- {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"},
- {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"},
-]
-
-[[package]]
-name = "filelock"
-version = "3.14.0"
-requires_python = ">=3.8"
-summary = "A platform independent file lock."
-groups = ["default"]
-files = [
- {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
- {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
-]
-
-[[package]]
-name = "fonttools"
-version = "4.51.0"
-requires_python = ">=3.8"
-summary = "Tools to manipulate font files"
-groups = ["default"]
-files = [
- {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74"},
- {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308"},
- {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037"},
- {file = "fonttools-4.51.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716"},
- {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438"},
- {file = "fonttools-4.51.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039"},
- {file = "fonttools-4.51.0-cp310-cp310-win32.whl", hash = "sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77"},
- {file = "fonttools-4.51.0-cp310-cp310-win_amd64.whl", hash = "sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b"},
- {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74"},
- {file = "fonttools-4.51.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2"},
- {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f"},
- {file = "fonttools-4.51.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097"},
- {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0"},
- {file = "fonttools-4.51.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1"},
- {file = "fonttools-4.51.0-cp311-cp311-win32.whl", hash = "sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034"},
- {file = "fonttools-4.51.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1"},
- {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba"},
- {file = "fonttools-4.51.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc"},
- {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a"},
- {file = "fonttools-4.51.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2"},
- {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671"},
- {file = "fonttools-4.51.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5"},
- {file = "fonttools-4.51.0-cp312-cp312-win32.whl", hash = "sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15"},
- {file = "fonttools-4.51.0-cp312-cp312-win_amd64.whl", hash = "sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e"},
- {file = "fonttools-4.51.0-py3-none-any.whl", hash = "sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f"},
- {file = "fonttools-4.51.0.tar.gz", hash = "sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68"},
-]
-
-[[package]]
-name = "frozenlist"
-version = "1.4.1"
-requires_python = ">=3.8"
-summary = "A list-like structure which implements collections.abc.MutableSequence"
-groups = ["default"]
-files = [
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"},
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"},
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"},
- {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"},
- {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"},
- {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"},
- {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"},
- {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"},
- {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"},
- {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"},
- {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
-]
-
-[[package]]
-name = "fsspec"
-version = "2024.3.1"
-requires_python = ">=3.8"
-summary = "File-system specification"
-groups = ["default"]
-files = [
- {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
- {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
-]
-
-[[package]]
-name = "fsspec"
-version = "2024.3.1"
-extras = ["http"]
-requires_python = ">=3.8"
-summary = "File-system specification"
-groups = ["default"]
-dependencies = [
- "aiohttp!=4.0.0a0,!=4.0.0a1",
- "fsspec==2024.3.1",
-]
-files = [
- {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
- {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
-]
-
-[[package]]
-name = "gitdb"
-version = "4.0.11"
-requires_python = ">=3.7"
-summary = "Git Object Database"
-groups = ["default"]
-dependencies = [
- "smmap<6,>=3.0.1",
-]
-files = [
- {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"},
- {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"},
-]
-
-[[package]]
-name = "gitpython"
-version = "3.1.43"
-requires_python = ">=3.7"
-summary = "GitPython is a Python library used to interact with Git repositories"
-groups = ["default"]
-dependencies = [
- "gitdb<5,>=4.0.1",
-]
-files = [
- {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"},
- {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"},
-]
-
-[[package]]
-name = "identify"
-version = "2.5.36"
-requires_python = ">=3.8"
-summary = "File identification library for Python"
-groups = ["default"]
-files = [
- {file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"},
- {file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"},
-]
-
-[[package]]
-name = "idna"
-version = "3.7"
-requires_python = ">=3.5"
-summary = "Internationalized Domain Names in Applications (IDNA)"
-groups = ["default"]
-files = [
- {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"},
- {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
-]
-
-[[package]]
-name = "iniconfig"
-version = "2.0.0"
-requires_python = ">=3.7"
-summary = "brain-dead simple config-ini parsing"
-groups = ["dev"]
-files = [
- {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
- {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
-]
-
-[[package]]
-name = "intel-openmp"
-version = "2021.4.0"
-summary = "Intel® OpenMP* Runtime Library"
-groups = ["default"]
-marker = "platform_system == \"Windows\""
-files = [
- {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
-]
-
-[[package]]
-name = "jinja2"
-version = "3.1.4"
-requires_python = ">=3.7"
-summary = "A very fast and expressive template engine."
-groups = ["default"]
-dependencies = [
- "MarkupSafe>=2.0",
-]
-files = [
- {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"},
- {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"},
-]
-
-[[package]]
-name = "joblib"
-version = "1.4.2"
-requires_python = ">=3.8"
-summary = "Lightweight pipelining with Python functions"
-groups = ["default"]
-files = [
- {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
- {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
-]
-
-[[package]]
-name = "kiwisolver"
-version = "1.4.5"
-requires_python = ">=3.7"
-summary = "A fast implementation of the Cassowary constraint solver"
-groups = ["default"]
-files = [
- {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"},
- {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"},
- {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"},
- {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"},
- {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"},
- {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"},
- {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"},
- {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"},
- {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"},
- {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"},
- {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"},
- {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"},
- {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"},
- {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"},
- {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"},
- {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"},
- {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"},
- {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"},
- {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"},
- {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"},
- {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"},
- {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"},
- {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"},
- {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"},
- {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"},
-]
-
-[[package]]
-name = "lightning-utilities"
-version = "0.11.2"
-requires_python = ">=3.8"
-summary = "Lightning toolbox for across the our ecosystem."
-groups = ["default"]
-dependencies = [
- "packaging>=17.1",
- "setuptools",
- "typing-extensions",
-]
-files = [
- {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"},
- {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"},
-]
-
-[[package]]
-name = "markupsafe"
-version = "2.1.5"
-requires_python = ">=3.7"
-summary = "Safely add untrusted strings to HTML/XML markup."
-groups = ["default"]
-files = [
- {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"},
- {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"},
- {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"},
- {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"},
- {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
-]
-
-[[package]]
-name = "matplotlib"
-version = "3.8.4"
-requires_python = ">=3.9"
-summary = "Python plotting package"
-groups = ["default"]
-dependencies = [
- "contourpy>=1.0.1",
- "cycler>=0.10",
- "fonttools>=4.22.0",
- "kiwisolver>=1.3.1",
- "numpy>=1.21",
- "packaging>=20.0",
- "pillow>=8",
- "pyparsing>=2.3.1",
- "python-dateutil>=2.7",
-]
-files = [
- {file = "matplotlib-3.8.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014"},
- {file = "matplotlib-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106"},
- {file = "matplotlib-3.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10"},
- {file = "matplotlib-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0"},
- {file = "matplotlib-3.8.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef"},
- {file = "matplotlib-3.8.4-cp310-cp310-win_amd64.whl", hash = "sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338"},
- {file = "matplotlib-3.8.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661"},
- {file = "matplotlib-3.8.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c"},
- {file = "matplotlib-3.8.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa"},
- {file = "matplotlib-3.8.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71"},
- {file = "matplotlib-3.8.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b"},
- {file = "matplotlib-3.8.4-cp311-cp311-win_amd64.whl", hash = "sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae"},
- {file = "matplotlib-3.8.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616"},
- {file = "matplotlib-3.8.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732"},
- {file = "matplotlib-3.8.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb"},
- {file = "matplotlib-3.8.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30"},
- {file = "matplotlib-3.8.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25"},
- {file = "matplotlib-3.8.4-cp312-cp312-win_amd64.whl", hash = "sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a"},
- {file = "matplotlib-3.8.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35"},
- {file = "matplotlib-3.8.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f"},
- {file = "matplotlib-3.8.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94"},
- {file = "matplotlib-3.8.4.tar.gz", hash = "sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea"},
-]
-
-[[package]]
-name = "mkl"
-version = "2021.4.0"
-summary = "Intel® oneAPI Math Kernel Library"
-groups = ["default"]
-marker = "platform_system == \"Windows\""
-dependencies = [
- "intel-openmp==2021.*",
- "tbb==2021.*",
-]
-files = [
- {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"},
- {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"},
- {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"},
- {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"},
- {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"},
-]
-
-[[package]]
-name = "mpmath"
-version = "1.3.0"
-summary = "Python library for arbitrary-precision floating-point arithmetic"
-groups = ["default"]
-files = [
- {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
- {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
-]
-
-[[package]]
-name = "multidict"
-version = "6.0.5"
-requires_python = ">=3.7"
-summary = "multidict implementation"
-groups = ["default"]
-files = [
- {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"},
- {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"},
- {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"},
- {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"},
- {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"},
- {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"},
- {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"},
- {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"},
- {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"},
- {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"},
- {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
-]
-
-[[package]]
-name = "networkx"
-version = "3.3"
-requires_python = ">=3.10"
-summary = "Python package for creating and manipulating graphs and networks"
-groups = ["default"]
-files = [
- {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"},
- {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"},
-]
-
-[[package]]
-name = "nodeenv"
-version = "1.8.0"
-requires_python = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
-summary = "Node.js virtual environment builder"
-groups = ["default"]
-dependencies = [
- "setuptools",
-]
-files = [
- {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
- {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
-]
-
-[[package]]
-name = "numpy"
-version = "1.26.4"
-requires_python = ">=3.9"
-summary = "Fundamental package for array computing in Python"
-groups = ["default"]
-files = [
- {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
- {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
- {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
- {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
- {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
- {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
- {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
- {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
- {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
- {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
- {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
- {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
- {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
- {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
- {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
- {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
- {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
- {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
- {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
- {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
- {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
- {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
- {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
- {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
- {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
- {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
- {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
- {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
-]
-
-[[package]]
-name = "nvidia-cublas-cu12"
-version = "12.1.3.1"
-requires_python = ">=3"
-summary = "CUBLAS native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"},
- {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"},
-]
-
-[[package]]
-name = "nvidia-cuda-cupti-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "CUDA profiling tools runtime libs."
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"},
- {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"},
-]
-
-[[package]]
-name = "nvidia-cuda-nvrtc-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "NVRTC native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"},
- {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"},
-]
-
-[[package]]
-name = "nvidia-cuda-runtime-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "CUDA Runtime native Libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"},
- {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"},
-]
-
-[[package]]
-name = "nvidia-cudnn-cu12"
-version = "8.9.2.26"
-requires_python = ">=3"
-summary = "cuDNN runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-dependencies = [
- "nvidia-cublas-cu12",
-]
-files = [
- {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"},
-]
-
-[[package]]
-name = "nvidia-cufft-cu12"
-version = "11.0.2.54"
-requires_python = ">=3"
-summary = "CUFFT native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"},
- {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"},
-]
-
-[[package]]
-name = "nvidia-curand-cu12"
-version = "10.3.2.106"
-requires_python = ">=3"
-summary = "CURAND native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"},
- {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"},
-]
-
-[[package]]
-name = "nvidia-cusolver-cu12"
-version = "11.4.5.107"
-requires_python = ">=3"
-summary = "CUDA solver native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-dependencies = [
- "nvidia-cublas-cu12",
- "nvidia-cusparse-cu12",
- "nvidia-nvjitlink-cu12",
-]
-files = [
- {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"},
- {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"},
-]
-
-[[package]]
-name = "nvidia-cusparse-cu12"
-version = "12.1.0.106"
-requires_python = ">=3"
-summary = "CUSPARSE native runtime libraries"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-dependencies = [
- "nvidia-nvjitlink-cu12",
-]
-files = [
- {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"},
- {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"},
-]
-
-[[package]]
-name = "nvidia-nccl-cu12"
-version = "2.20.5"
-requires_python = ">=3"
-summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"},
- {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"},
-]
-
-[[package]]
-name = "nvidia-nvjitlink-cu12"
-version = "12.4.127"
-requires_python = ">=3"
-summary = "Nvidia JIT LTO Library"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
- {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
-]
-
-[[package]]
-name = "nvidia-nvtx-cu12"
-version = "12.1.105"
-requires_python = ">=3"
-summary = "NVIDIA Tools Extension"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
-files = [
- {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"},
- {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
-]
-
-[[package]]
-name = "packaging"
-version = "24.0"
-requires_python = ">=3.7"
-summary = "Core utilities for Python packages"
-groups = ["default", "dev"]
-files = [
- {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
- {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
-]
-
-[[package]]
-name = "pillow"
-version = "10.3.0"
-requires_python = ">=3.8"
-summary = "Python Imaging Library (Fork)"
-groups = ["default"]
-files = [
- {file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"},
- {file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"},
- {file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"},
- {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"},
- {file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"},
- {file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"},
- {file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"},
- {file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"},
- {file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"},
- {file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"},
- {file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"},
- {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"},
- {file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"},
- {file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"},
- {file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"},
- {file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"},
- {file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"},
- {file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"},
- {file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"},
- {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"},
- {file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"},
- {file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"},
- {file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"},
- {file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"},
- {file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"},
- {file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"},
- {file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"},
-]
-
-[[package]]
-name = "platformdirs"
-version = "4.2.1"
-requires_python = ">=3.8"
-summary = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
-groups = ["default"]
-files = [
- {file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"},
- {file = "platformdirs-4.2.1.tar.gz", hash = "sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf"},
-]
-
-[[package]]
-name = "plotly"
-version = "5.22.0"
-requires_python = ">=3.8"
-summary = "An open-source, interactive data visualization library for Python"
-groups = ["default"]
-dependencies = [
- "packaging",
- "tenacity>=6.2.0",
-]
-files = [
- {file = "plotly-5.22.0-py3-none-any.whl", hash = "sha256:68fc1901f098daeb233cc3dd44ec9dc31fb3ca4f4e53189344199c43496ed006"},
- {file = "plotly-5.22.0.tar.gz", hash = "sha256:859fdadbd86b5770ae2466e542b761b247d1c6b49daed765b95bb8c7063e7469"},
-]
-
-[[package]]
-name = "pluggy"
-version = "1.5.0"
-requires_python = ">=3.8"
-summary = "plugin and hook calling mechanisms for python"
-groups = ["dev"]
-files = [
- {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
- {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
-]
-
-[[package]]
-name = "pre-commit"
-version = "3.7.1"
-requires_python = ">=3.9"
-summary = "A framework for managing and maintaining multi-language pre-commit hooks."
-groups = ["default"]
-dependencies = [
- "cfgv>=2.0.0",
- "identify>=1.0.0",
- "nodeenv>=0.11.1",
- "pyyaml>=5.1",
- "virtualenv>=20.10.0",
-]
-files = [
- {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"},
- {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"},
-]
-
-[[package]]
-name = "pretty-errors"
-version = "1.2.25"
-summary = "Prettifies Python exception output to make it legible."
-groups = ["default"]
-dependencies = [
- "colorama",
-]
-files = [
- {file = "pretty_errors-1.2.25-py3-none-any.whl", hash = "sha256:8ce68ccd99e0f2a099265c8c1f1c23b7c60a15d69bb08816cb336e237d5dc983"},
- {file = "pretty_errors-1.2.25.tar.gz", hash = "sha256:a16ba5c752c87c263bf92f8b4b58624e3b1e29271a9391f564f12b86e93c6755"},
-]
-
-[[package]]
-name = "protobuf"
-version = "4.25.3"
-requires_python = ">=3.8"
-summary = ""
-groups = ["default"]
-marker = "python_version > \"3.9\" or sys_platform != \"linux\""
-files = [
- {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"},
- {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"},
- {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"},
- {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"},
- {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"},
- {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"},
- {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
-]
-
-[[package]]
-name = "psutil"
-version = "5.9.8"
-requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
-summary = "Cross-platform lib for process and system monitoring in Python."
-groups = ["default"]
-files = [
- {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"},
- {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"},
- {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"},
- {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"},
- {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"},
- {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"},
- {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"},
-]
-
-[[package]]
-name = "pyparsing"
-version = "3.1.2"
-requires_python = ">=3.6.8"
-summary = "pyparsing module - Classes and methods to define and execute parsing grammars"
-groups = ["default"]
-files = [
- {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
- {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
-]
-
-[[package]]
-name = "pyproj"
-version = "3.6.1"
-requires_python = ">=3.9"
-summary = "Python interface to PROJ (cartographic projections and coordinate transformations library)"
-groups = ["default"]
-dependencies = [
- "certifi",
-]
-files = [
- {file = "pyproj-3.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ab7aa4d9ff3c3acf60d4b285ccec134167a948df02347585fdd934ebad8811b4"},
- {file = "pyproj-3.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bc0472302919e59114aa140fd7213c2370d848a7249d09704f10f5b062031fe"},
- {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5279586013b8d6582e22b6f9e30c49796966770389a9d5b85e25a4223286cd3f"},
- {file = "pyproj-3.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80fafd1f3eb421694857f254a9bdbacd1eb22fc6c24ca74b136679f376f97d35"},
- {file = "pyproj-3.6.1-cp310-cp310-win32.whl", hash = "sha256:c41e80ddee130450dcb8829af7118f1ab69eaf8169c4bf0ee8d52b72f098dc2f"},
- {file = "pyproj-3.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:db3aedd458e7f7f21d8176f0a1d924f1ae06d725228302b872885a1c34f3119e"},
- {file = "pyproj-3.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ebfbdbd0936e178091309f6cd4fcb4decd9eab12aa513cdd9add89efa3ec2882"},
- {file = "pyproj-3.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:447db19c7efad70ff161e5e46a54ab9cc2399acebb656b6ccf63e4bc4a04b97a"},
- {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7e13c40183884ec7f94eb8e0f622f08f1d5716150b8d7a134de48c6110fee85"},
- {file = "pyproj-3.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65ad699e0c830e2b8565afe42bd58cc972b47d829b2e0e48ad9638386d994915"},
- {file = "pyproj-3.6.1-cp311-cp311-win32.whl", hash = "sha256:8b8acc31fb8702c54625f4d5a2a6543557bec3c28a0ef638778b7ab1d1772132"},
- {file = "pyproj-3.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:38a3361941eb72b82bd9a18f60c78b0df8408416f9340521df442cebfc4306e2"},
- {file = "pyproj-3.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1e9fbaf920f0f9b4ee62aab832be3ae3968f33f24e2e3f7fbb8c6728ef1d9746"},
- {file = "pyproj-3.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d227a865356f225591b6732430b1d1781e946893789a609bb34f59d09b8b0f8"},
- {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83039e5ae04e5afc974f7d25ee0870a80a6bd6b7957c3aca5613ccbe0d3e72bf"},
- {file = "pyproj-3.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb059ba3bced6f6725961ba758649261d85ed6ce670d3e3b0a26e81cf1aa8d"},
- {file = "pyproj-3.6.1-cp312-cp312-win32.whl", hash = "sha256:2d6ff73cc6dbbce3766b6c0bce70ce070193105d8de17aa2470009463682a8eb"},
- {file = "pyproj-3.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:7a27151ddad8e1439ba70c9b4b2b617b290c39395fa9ddb7411ebb0eb86d6fb0"},
- {file = "pyproj-3.6.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd93c1a0c6c4aedc77c0fe275a9f2aba4d59b8acf88cebfc19fe3c430cfabf4f"},
- {file = "pyproj-3.6.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6420ea8e7d2a88cb148b124429fba8cd2e0fae700a2d96eab7083c0928a85110"},
- {file = "pyproj-3.6.1.tar.gz", hash = "sha256:44aa7c704c2b7d8fb3d483bbf75af6cb2350d30a63b144279a09b75fead501bf"},
-]
-
-[[package]]
-name = "pyshp"
-version = "2.3.1"
-requires_python = ">=2.7"
-summary = "Pure Python read/write support for ESRI Shapefile format"
-groups = ["default"]
-files = [
- {file = "pyshp-2.3.1-py2.py3-none-any.whl", hash = "sha256:67024c0ccdc352ba5db777c4e968483782dfa78f8e200672a90d2d30fd8b7b49"},
- {file = "pyshp-2.3.1.tar.gz", hash = "sha256:4caec82fd8dd096feba8217858068bacb2a3b5950f43c048c6dc32a3489d5af1"},
-]
-
-[[package]]
-name = "pytest"
-version = "8.2.0"
-requires_python = ">=3.8"
-summary = "pytest: simple powerful testing with Python"
-groups = ["dev"]
-dependencies = [
- "colorama; sys_platform == \"win32\"",
- "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"",
- "iniconfig",
- "packaging",
- "pluggy<2.0,>=1.5",
- "tomli>=1; python_version < \"3.11\"",
-]
-files = [
- {file = "pytest-8.2.0-py3-none-any.whl", hash = "sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233"},
- {file = "pytest-8.2.0.tar.gz", hash = "sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f"},
-]
-
-[[package]]
-name = "python-dateutil"
-version = "2.9.0.post0"
-requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
-summary = "Extensions to the standard Python datetime module"
-groups = ["default"]
-dependencies = [
- "six>=1.5",
-]
-files = [
- {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
- {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
-]
-
-[[package]]
-name = "pytorch-lightning"
-version = "2.2.4"
-requires_python = ">=3.8"
-summary = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
-groups = ["default"]
-dependencies = [
- "PyYAML>=5.4",
- "fsspec[http]>=2022.5.0",
- "lightning-utilities>=0.8.0",
- "numpy>=1.17.2",
- "packaging>=20.0",
- "torch>=1.13.0",
- "torchmetrics>=0.7.0",
- "tqdm>=4.57.0",
- "typing-extensions>=4.4.0",
-]
-files = [
- {file = "pytorch-lightning-2.2.4.tar.gz", hash = "sha256:525b04ebad9900c3e3c2a12b3b462fe4f61ebe11fdb694716c3209f05b9b0fa8"},
- {file = "pytorch_lightning-2.2.4-py3-none-any.whl", hash = "sha256:fd91d47e983a2cd743c5c8c3c3795bbd0f3b69d24be2172a2f9012d930701ff2"},
-]
-
-[[package]]
-name = "pyyaml"
-version = "6.0.1"
-requires_python = ">=3.6"
-summary = "YAML parser and emitter for Python"
-groups = ["default"]
-files = [
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
- {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
- {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
- {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
- {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
- {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
- {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
- {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
- {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
- {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
- {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
-]
-
-[[package]]
-name = "requests"
-version = "2.31.0"
-requires_python = ">=3.7"
-summary = "Python HTTP for Humans."
-groups = ["default"]
-dependencies = [
- "certifi>=2017.4.17",
- "charset-normalizer<4,>=2",
- "idna<4,>=2.5",
- "urllib3<3,>=1.21.1",
-]
-files = [
- {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
- {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
-]
-
-[[package]]
-name = "scikit-learn"
-version = "1.4.2"
-requires_python = ">=3.9"
-summary = "A set of python modules for machine learning and data mining"
-groups = ["default"]
-dependencies = [
- "joblib>=1.2.0",
- "numpy>=1.19.5",
- "scipy>=1.6.0",
- "threadpoolctl>=2.0.0",
-]
-files = [
- {file = "scikit-learn-1.4.2.tar.gz", hash = "sha256:daa1c471d95bad080c6e44b4946c9390a4842adc3082572c20e4f8884e39e959"},
- {file = "scikit_learn-1.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8539a41b3d6d1af82eb629f9c57f37428ff1481c1e34dddb3b9d7af8ede67ac5"},
- {file = "scikit_learn-1.4.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:68b8404841f944a4a1459b07198fa2edd41a82f189b44f3e1d55c104dbc2e40c"},
- {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81bf5d8bbe87643103334032dd82f7419bc8c8d02a763643a6b9a5c7288c5054"},
- {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f0ea5d0f693cb247a073d21a4123bdf4172e470e6d163c12b74cbb1536cf38"},
- {file = "scikit_learn-1.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:87440e2e188c87db80ea4023440923dccbd56fbc2d557b18ced00fef79da0727"},
- {file = "scikit_learn-1.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:45dee87ac5309bb82e3ea633955030df9bbcb8d2cdb30383c6cd483691c546cc"},
- {file = "scikit_learn-1.4.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1d0b25d9c651fd050555aadd57431b53d4cf664e749069da77f3d52c5ad14b3b"},
- {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0203c368058ab92efc6168a1507d388d41469c873e96ec220ca8e74079bf62e"},
- {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44c62f2b124848a28fd695db5bc4da019287abf390bfce602ddc8aa1ec186aae"},
- {file = "scikit_learn-1.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:5cd7b524115499b18b63f0c96f4224eb885564937a0b3477531b2b63ce331904"},
- {file = "scikit_learn-1.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90378e1747949f90c8f385898fff35d73193dfcaec3dd75d6b542f90c4e89755"},
- {file = "scikit_learn-1.4.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ff4effe5a1d4e8fed260a83a163f7dbf4f6087b54528d8880bab1d1377bd78be"},
- {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:671e2f0c3f2c15409dae4f282a3a619601fa824d2c820e5b608d9d775f91780c"},
- {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d36d0bc983336bbc1be22f9b686b50c964f593c8a9a913a792442af9bf4f5e68"},
- {file = "scikit_learn-1.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:d762070980c17ba3e9a4a1e043ba0518ce4c55152032f1af0ca6f39b376b5928"},
-]
-
-[[package]]
-name = "scipy"
-version = "1.13.0"
-requires_python = ">=3.9"
-summary = "Fundamental algorithms for scientific computing in Python"
-groups = ["default"]
-dependencies = [
- "numpy<2.3,>=1.22.4",
-]
-files = [
- {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
- {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
- {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
- {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
- {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
- {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
- {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
- {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
- {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
- {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
- {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
- {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
- {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
- {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
- {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
- {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
- {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
- {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
- {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
-]
-
-[[package]]
-name = "sentry-sdk"
-version = "2.1.1"
-requires_python = ">=3.6"
-summary = "Python client for Sentry (https://sentry.io)"
-groups = ["default"]
-dependencies = [
- "certifi",
- "urllib3>=1.26.11",
-]
-files = [
- {file = "sentry_sdk-2.1.1-py2.py3-none-any.whl", hash = "sha256:99aeb78fb76771513bd3b2829d12613130152620768d00cd3e45ac00cb17950f"},
- {file = "sentry_sdk-2.1.1.tar.gz", hash = "sha256:95d8c0bb41c8b0bc37ab202c2c4a295bb84398ee05f4cdce55051cd75b926ec1"},
-]
-
-[[package]]
-name = "setproctitle"
-version = "1.3.3"
-requires_python = ">=3.7"
-summary = "A Python module to customize the process title"
-groups = ["default"]
-files = [
- {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"},
- {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbbd6c7de0771c84b4aa30e70b409565eb1fc13627a723ca6be774ed6b9d9fa3"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c05ac48ef16ee013b8a326c63e4610e2430dbec037ec5c5b58fcced550382b74"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1342f4fdb37f89d3e3c1c0a59d6ddbedbde838fff5c51178a7982993d238fe4f"},
- {file = "setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc74e84fdfa96821580fb5e9c0b0777c1c4779434ce16d3d62a9c4d8c710df39"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9617b676b95adb412bb69645d5b077d664b6882bb0d37bfdafbbb1b999568d85"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6a249415f5bb88b5e9e8c4db47f609e0bf0e20a75e8d744ea787f3092ba1f2d0"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:38da436a0aaace9add67b999eb6abe4b84397edf4a78ec28f264e5b4c9d53cd5"},
- {file = "setproctitle-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:da0d57edd4c95bf221b2ebbaa061e65b1788f1544977288bdf95831b6e44e44d"},
- {file = "setproctitle-1.3.3-cp310-cp310-win32.whl", hash = "sha256:a1fcac43918b836ace25f69b1dca8c9395253ad8152b625064415b1d2f9be4fb"},
- {file = "setproctitle-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:200620c3b15388d7f3f97e0ae26599c0c378fdf07ae9ac5a13616e933cbd2086"},
- {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:334f7ed39895d692f753a443102dd5fed180c571eb6a48b2a5b7f5b3564908c8"},
- {file = "setproctitle-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:950f6476d56ff7817a8fed4ab207727fc5260af83481b2a4b125f32844df513a"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:195c961f54a09eb2acabbfc90c413955cf16c6e2f8caa2adbf2237d1019c7dd8"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f05e66746bf9fe6a3397ec246fe481096664a9c97eb3fea6004735a4daf867fd"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5901a31012a40ec913265b64e48c2a4059278d9f4e6be628441482dd13fb8b5"},
- {file = "setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64286f8a995f2cd934082b398fc63fca7d5ffe31f0e27e75b3ca6b4efda4e353"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:184239903bbc6b813b1a8fc86394dc6ca7d20e2ebe6f69f716bec301e4b0199d"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:664698ae0013f986118064b6676d7dcd28fefd0d7d5a5ae9497cbc10cba48fa5"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e5119a211c2e98ff18b9908ba62a3bd0e3fabb02a29277a7232a6fb4b2560aa0"},
- {file = "setproctitle-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:417de6b2e214e837827067048f61841f5d7fc27926f2e43954567094051aff18"},
- {file = "setproctitle-1.3.3-cp311-cp311-win32.whl", hash = "sha256:6a143b31d758296dc2f440175f6c8e0b5301ced3b0f477b84ca43cdcf7f2f476"},
- {file = "setproctitle-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a680d62c399fa4b44899094027ec9a1bdaf6f31c650e44183b50d4c4d0ccc085"},
- {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4460795a8a7a391e3567b902ec5bdf6c60a47d791c3b1d27080fc203d11c9dc"},
- {file = "setproctitle-1.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bdfd7254745bb737ca1384dee57e6523651892f0ea2a7344490e9caefcc35e64"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477d3da48e216d7fc04bddab67b0dcde633e19f484a146fd2a34bb0e9dbb4a1e"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ab2900d111e93aff5df9fddc64cf51ca4ef2c9f98702ce26524f1acc5a786ae7"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:088b9efc62d5aa5d6edf6cba1cf0c81f4488b5ce1c0342a8b67ae39d64001120"},
- {file = "setproctitle-1.3.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6d50252377db62d6a0bb82cc898089916457f2db2041e1d03ce7fadd4a07381"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:87e668f9561fd3a457ba189edfc9e37709261287b52293c115ae3487a24b92f6"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:287490eb90e7a0ddd22e74c89a92cc922389daa95babc833c08cf80c84c4df0a"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:4fe1c49486109f72d502f8be569972e27f385fe632bd8895f4730df3c87d5ac8"},
- {file = "setproctitle-1.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4a6ba2494a6449b1f477bd3e67935c2b7b0274f2f6dcd0f7c6aceae10c6c6ba3"},
- {file = "setproctitle-1.3.3-cp312-cp312-win32.whl", hash = "sha256:2df2b67e4b1d7498632e18c56722851ba4db5d6a0c91aaf0fd395111e51cdcf4"},
- {file = "setproctitle-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:f38d48abc121263f3b62943f84cbaede05749047e428409c2c199664feb6abc7"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6b9e62ddb3db4b5205c0321dd69a406d8af9ee1693529d144e86bd43bcb4b6c0"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e3b99b338598de0bd6b2643bf8c343cf5ff70db3627af3ca427a5e1a1a90dd9"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ae9a02766dad331deb06855fb7a6ca15daea333b3967e214de12cfae8f0ef5"},
- {file = "setproctitle-1.3.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:200ede6fd11233085ba9b764eb055a2a191fb4ffb950c68675ac53c874c22e20"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0d3a953c50776751e80fe755a380a64cb14d61e8762bd43041ab3f8cc436092f"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5e08e232b78ba3ac6bc0d23ce9e2bee8fad2be391b7e2da834fc9a45129eb87"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1da82c3e11284da4fcbf54957dafbf0655d2389cd3d54e4eaba636faf6d117a"},
- {file = "setproctitle-1.3.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:aeaa71fb9568ebe9b911ddb490c644fbd2006e8c940f21cb9a1e9425bd709574"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:59335d000c6250c35989394661eb6287187854e94ac79ea22315469ee4f4c244"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3ba57029c9c50ecaf0c92bb127224cc2ea9fda057b5d99d3f348c9ec2855ad3"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d876d355c53d975c2ef9c4f2487c8f83dad6aeaaee1b6571453cb0ee992f55f6"},
- {file = "setproctitle-1.3.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:224602f0939e6fb9d5dd881be1229d485f3257b540f8a900d4271a2c2aa4e5f4"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d7f27e0268af2d7503386e0e6be87fb9b6657afd96f5726b733837121146750d"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5e7266498cd31a4572378c61920af9f6b4676a73c299fce8ba93afd694f8ae7"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33c5609ad51cd99d388e55651b19148ea99727516132fb44680e1f28dd0d1de9"},
- {file = "setproctitle-1.3.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:eae8988e78192fd1a3245a6f4f382390b61bce6cfcc93f3809726e4c885fa68d"},
- {file = "setproctitle-1.3.3.tar.gz", hash = "sha256:c913e151e7ea01567837ff037a23ca8740192880198b7fbb90b16d181607caae"},
-]
-
-[[package]]
-name = "setuptools"
-version = "69.5.1"
-requires_python = ">=3.8"
-summary = "Easily download, build, install, upgrade, and uninstall Python packages"
-groups = ["default"]
-files = [
- {file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"},
- {file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"},
-]
-
-[[package]]
-name = "shapely"
-version = "2.0.4"
-requires_python = ">=3.7"
-summary = "Manipulation and analysis of geometric objects"
-groups = ["default"]
-dependencies = [
- "numpy<3,>=1.14",
-]
-files = [
- {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:011b77153906030b795791f2fdfa2d68f1a8d7e40bce78b029782ade3afe4f2f"},
- {file = "shapely-2.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9831816a5d34d5170aa9ed32a64982c3d6f4332e7ecfe62dc97767e163cb0b17"},
- {file = "shapely-2.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5c4849916f71dc44e19ed370421518c0d86cf73b26e8656192fcfcda08218fbd"},
- {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:841f93a0e31e4c64d62ea570d81c35de0f6cea224568b2430d832967536308e6"},
- {file = "shapely-2.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b4431f522b277c79c34b65da128029a9955e4481462cbf7ebec23aab61fc58"},
- {file = "shapely-2.0.4-cp310-cp310-win32.whl", hash = "sha256:92a41d936f7d6743f343be265ace93b7c57f5b231e21b9605716f5a47c2879e7"},
- {file = "shapely-2.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:30982f79f21bb0ff7d7d4a4e531e3fcaa39b778584c2ce81a147f95be1cd58c9"},
- {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de0205cb21ad5ddaef607cda9a3191eadd1e7a62a756ea3a356369675230ac35"},
- {file = "shapely-2.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7d56ce3e2a6a556b59a288771cf9d091470116867e578bebced8bfc4147fbfd7"},
- {file = "shapely-2.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:58b0ecc505bbe49a99551eea3f2e8a9b3b24b3edd2a4de1ac0dc17bc75c9ec07"},
- {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:790a168a808bd00ee42786b8ba883307c0e3684ebb292e0e20009588c426da47"},
- {file = "shapely-2.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4310b5494271e18580d61022c0857eb85d30510d88606fa3b8314790df7f367d"},
- {file = "shapely-2.0.4-cp311-cp311-win32.whl", hash = "sha256:63f3a80daf4f867bd80f5c97fbe03314348ac1b3b70fb1c0ad255a69e3749879"},
- {file = "shapely-2.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:c52ed79f683f721b69a10fb9e3d940a468203f5054927215586c5d49a072de8d"},
- {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5bbd974193e2cc274312da16b189b38f5f128410f3377721cadb76b1e8ca5328"},
- {file = "shapely-2.0.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:41388321a73ba1a84edd90d86ecc8bfed55e6a1e51882eafb019f45895ec0f65"},
- {file = "shapely-2.0.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0776c92d584f72f1e584d2e43cfc5542c2f3dd19d53f70df0900fda643f4bae6"},
- {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c75c98380b1ede1cae9a252c6dc247e6279403fae38c77060a5e6186c95073ac"},
- {file = "shapely-2.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3e700abf4a37b7b8b90532fa6ed5c38a9bfc777098bc9fbae5ec8e618ac8f30"},
- {file = "shapely-2.0.4-cp312-cp312-win32.whl", hash = "sha256:4f2ab0faf8188b9f99e6a273b24b97662194160cc8ca17cf9d1fb6f18d7fb93f"},
- {file = "shapely-2.0.4-cp312-cp312-win_amd64.whl", hash = "sha256:03152442d311a5e85ac73b39680dd64a9892fa42bb08fd83b3bab4fe6999bfa0"},
- {file = "shapely-2.0.4.tar.gz", hash = "sha256:5dc736127fac70009b8d309a0eeb74f3e08979e530cf7017f2f507ef62e6cfb8"},
-]
-
-[[package]]
-name = "six"
-version = "1.16.0"
-requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
-summary = "Python 2 and 3 compatibility utilities"
-groups = ["default"]
-files = [
- {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
- {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
-]
-
-[[package]]
-name = "smmap"
-version = "5.0.1"
-requires_python = ">=3.7"
-summary = "A pure Python implementation of a sliding window memory map manager"
-groups = ["default"]
-files = [
- {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"},
- {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"},
-]
-
-[[package]]
-name = "sympy"
-version = "1.12"
-requires_python = ">=3.8"
-summary = "Computer algebra system (CAS) in Python"
-groups = ["default"]
-dependencies = [
- "mpmath>=0.19",
-]
-files = [
- {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
- {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
-]
-
-[[package]]
-name = "tbb"
-version = "2021.12.0"
-summary = "Intel® oneAPI Threading Building Blocks (oneTBB)"
-groups = ["default"]
-marker = "platform_system == \"Windows\""
-files = [
- {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"},
- {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"},
- {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"},
- {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
-]
-
-[[package]]
-name = "tenacity"
-version = "8.3.0"
-requires_python = ">=3.8"
-summary = "Retry code until it succeeds"
-groups = ["default"]
-files = [
- {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"},
- {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"},
-]
-
-[[package]]
-name = "threadpoolctl"
-version = "3.5.0"
-requires_python = ">=3.8"
-summary = "threadpoolctl"
-groups = ["default"]
-files = [
- {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"},
- {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
-]
-
-[[package]]
-name = "tomli"
-version = "2.0.1"
-requires_python = ">=3.7"
-summary = "A lil' TOML parser"
-groups = ["dev"]
-marker = "python_version < \"3.11\""
-files = [
- {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
- {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
-]
-
-[[package]]
-name = "torch"
-version = "2.3.0"
-requires_python = ">=3.8.0"
-summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
-groups = ["default"]
-dependencies = [
- "filelock",
- "fsspec",
- "jinja2",
- "mkl<=2021.4.0,>=2021.1.1; platform_system == \"Windows\"",
- "networkx",
- "nvidia-cublas-cu12==12.1.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cuda-cupti-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cuda-runtime-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cudnn-cu12==8.9.2.26; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cufft-cu12==11.0.2.54; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-curand-cu12==10.3.2.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cusolver-cu12==11.4.5.107; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-cusparse-cu12==12.1.0.106; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-nccl-cu12==2.20.5; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "nvidia-nvtx-cu12==12.1.105; platform_system == \"Linux\" and platform_machine == \"x86_64\"",
- "sympy",
- "triton==2.3.0; platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\"",
- "typing-extensions>=4.8.0",
-]
-files = [
- {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"},
- {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"},
- {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"},
- {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"},
- {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"},
- {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"},
- {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"},
- {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"},
- {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"},
- {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"},
- {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"},
- {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"},
-]
-
-[[package]]
-name = "torch-geometric"
-version = "2.5.3"
-requires_python = ">=3.8"
-summary = "Graph Neural Network Library for PyTorch"
-groups = ["default"]
-dependencies = [
- "aiohttp",
- "fsspec",
- "jinja2",
- "numpy",
- "psutil>=5.8.0",
- "pyparsing",
- "requests",
- "scikit-learn",
- "scipy",
- "tqdm",
-]
-files = [
- {file = "torch_geometric-2.5.3-py3-none-any.whl", hash = "sha256:8277abfc12600b0e8047e0c3ea2d55cc43f08c1448e73e924de827c15d0b5f85"},
- {file = "torch_geometric-2.5.3.tar.gz", hash = "sha256:ad0761650c8fa56cdc46ee61c564fd4995f07f079965fe732b3a76d109fd3edc"},
-]
-
-[[package]]
-name = "torchmetrics"
-version = "1.4.0"
-requires_python = ">=3.8"
-summary = "PyTorch native Metrics"
-groups = ["default"]
-dependencies = [
- "lightning-utilities>=0.8.0",
- "numpy>1.20.0",
- "packaging>17.1",
- "pretty-errors==1.2.25",
- "torch>=1.10.0",
-]
-files = [
- {file = "torchmetrics-1.4.0-py3-none-any.whl", hash = "sha256:18599929a0fff7d4b840a3f9a7700054121850c378caaf7206f4161c0a5dc93c"},
- {file = "torchmetrics-1.4.0.tar.gz", hash = "sha256:0b1e5acdcc9beb05bfe369d3d56cfa5b143f060ebfd6079d19ccc59ba46465b3"},
-]
-
-[[package]]
-name = "tqdm"
-version = "4.66.4"
-requires_python = ">=3.7"
-summary = "Fast, Extensible Progress Meter"
-groups = ["default"]
-dependencies = [
- "colorama; platform_system == \"Windows\"",
-]
-files = [
- {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"},
- {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"},
-]
-
-[[package]]
-name = "triton"
-version = "2.3.0"
-summary = "A language and compiler for custom Deep Learning operations"
-groups = ["default"]
-marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""
-dependencies = [
- "filelock",
-]
-files = [
- {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"},
- {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"},
- {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"},
-]
-
-[[package]]
-name = "tueplots"
-version = "0.0.15"
-requires_python = ">=3.9"
-summary = "Scientific plotting made easy."
-groups = ["default"]
-dependencies = [
- "matplotlib",
- "numpy",
-]
-files = [
- {file = "tueplots-0.0.15-py3-none-any.whl", hash = "sha256:f63e020af88328c78618f3d912612c75c3c91d21004a88fd12cf79dbd9b6d78a"},
-]
-
-[[package]]
-name = "typing-extensions"
-version = "4.11.0"
-requires_python = ">=3.8"
-summary = "Backported and Experimental Type Hints for Python 3.8+"
-groups = ["default"]
-files = [
- {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
- {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
-]
-
-[[package]]
-name = "urllib3"
-version = "2.2.1"
-requires_python = ">=3.8"
-summary = "HTTP library with thread-safe connection pooling, file post, and more."
-groups = ["default"]
-files = [
- {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
- {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},
-]
-
-[[package]]
-name = "virtualenv"
-version = "20.26.1"
-requires_python = ">=3.7"
-summary = "Virtual Python Environment builder"
-groups = ["default"]
-dependencies = [
- "distlib<1,>=0.3.7",
- "filelock<4,>=3.12.2",
- "platformdirs<5,>=3.9.1",
-]
-files = [
- {file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"},
- {file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"},
-]
-
-[[package]]
-name = "wandb"
-version = "0.17.0"
-requires_python = ">=3.7"
-summary = "A CLI and library for interacting with the Weights & Biases API."
-groups = ["default"]
-dependencies = [
- "click!=8.0.0,>=7.1",
- "docker-pycreds>=0.4.0",
- "gitpython!=3.1.29,>=1.0.0",
- "platformdirs",
- "protobuf!=4.21.0,<5,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"",
- "protobuf!=4.21.0,<5,>=3.19.0; sys_platform != \"linux\"",
- "psutil>=5.0.0",
- "pyyaml",
- "requests<3,>=2.0.0",
- "sentry-sdk>=1.0.0",
- "setproctitle",
- "setuptools",
-]
-files = [
- {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"},
- {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"},
- {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"},
- {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"},
- {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"},
- {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"},
- {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"},
-]
-
-[[package]]
-name = "yarl"
-version = "1.9.4"
-requires_python = ">=3.7"
-summary = "Yet another URL library"
-groups = ["default"]
-dependencies = [
- "idna>=2.0",
- "multidict>=4.0",
-]
-files = [
- {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"},
- {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"},
- {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"},
- {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"},
- {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"},
- {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"},
- {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"},
- {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"},
- {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"},
- {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"},
- {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"},
-]
diff --git a/pyproject.toml b/pyproject.toml
deleted file mode 100644
index b2461c42..00000000
--- a/pyproject.toml
+++ /dev/null
@@ -1,89 +0,0 @@
-[project]
-# PEP 621 project metadata
-# See https://www.python.org/dev/peps/pep-0621/
-dependencies = [
- "numpy>=1.24.2",
- "wandb>=0.13.10",
- "matplotlib>=3.7.0",
- "scipy>=1.10.0",
- "pytorch-lightning>=2.0.3",
- "shapely>=2.0.1",
- "networkx>=3.0",
- "Cartopy>=0.22.0",
- "pyproj>=3.4.1",
- "tueplots>=0.0.8",
- "plotly>=5.15.0",
- "pre-commit>=2.15.0",
- "torch-geometric>=2.5.3",
-]
-requires-python = ">=3.10"
-name = "neural-lam"
-version = "0.1.0"
-description = "Neural Weather Prediction for Limited Area Modeling"
-authors = [
- {name = "Joel Oskarsson", email = "joel.oskarsson@liu.se"},
- {name = "Simon Adamov", email = "simon.adamov@meteoswiss.ch"},
- {name = "Leif Denby", email = "lcd@dmi.dk"},
-]
-readme = "README.md"
-license = {text = "MIT"}
-
-[tool.black]
-line-length = 80
-
-[tool.isort]
-profile = "black"
-
-[tool.flake8]
-max-line-length = 80
-ignore = [
- "E203", # Allow whitespace before ':' (https://github.com/PyCQA/pycodestyle/issues/373)
- "I002", # Don't check for isort configuration
- "W503", # Allow line break before binary operator (PEP 8-compatible)
-]
-per-file-ignores = [
- "__init__.py: F401", # Allow unused imports
-]
-
-[tool.codespell]
-skip = "requirements/*"
-
-# Pylint config
-[tool.pylint]
-ignore = [
- "create_mesh.py", # Disable linting for now, as major rework is planned/expected
-]
-
-[tool.pylint.TYPECHECK]
-generated-members = [
- "numpy.*",
- "torch.*",
-]
-
-[tool.pylint.'MESSAGES CONTROL']
-disable = [
- "C0114", # 'missing-module-docstring', Do not require module docstrings
- "R0901", # 'too-many-ancestors', Allow many layers of sub-classing
- "R0902", # 'too-many-instance-attribtes', Allow many attributes
- "R0913", # 'too-many-arguments', Allow many function arguments
- "R0914", # 'too-many-locals', Allow many local variables
- "W0223", # 'abstract-method', Subclasses do not have to override all abstract methods
-]
-[tool.pylint.DESIGN]
-max-statements=100 # Allow for some more involved functions
-[tool.pylint.IMPORTS]
-allow-any-import-level="neural_lam"
-[tool.pylint.SIMILARITIES]
-min-similarity-lines=10
-
-
-[tool.pdm]
-distribution = true
-
-[tool.pdm.dev-dependencies]
-dev = [
- "pytest>=8.2.0",
-]
-[build-system]
-requires = ["pdm-backend"]
-build-backend = "pdm.backend"
From a91eaaa32c54cf6d753f6d4011837b141b931ee3 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:17:52 +0200
Subject: [PATCH 120/273] add pyproject.toml from main
---
pyproject.toml | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 65 insertions(+)
create mode 100644 pyproject.toml
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..b513a258
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,65 @@
+[tool.black]
+line-length = 80
+
+[tool.isort]
+default_section = "THIRDPARTY"
+profile = "black"
+# Headings
+import_heading_stdlib = "Standard library"
+import_heading_thirdparty = "Third-party"
+import_heading_firstparty = "First-party"
+import_heading_localfolder = "Local"
+# Known modules to avoid misclassification
+known_standard_library = [
+ # Add standard library modules that may be misclassified by isort
+]
+known_third_party = [
+ # Add third-party modules that may be misclassified by isort
+ "wandb",
+]
+known_first_party = [
+ # Add first-party modules that may be misclassified by isort
+ "neural_lam",
+]
+
+[tool.flake8]
+max-line-length = 80
+ignore = [
+ "E203", # Allow whitespace before ':' (https://github.com/PyCQA/pycodestyle/issues/373)
+ "I002", # Don't check for isort configuration
+ "W503", # Allow line break before binary operator (PEP 8-compatible)
+]
+per-file-ignores = [
+ "__init__.py: F401", # Allow unused imports
+]
+
+[tool.codespell]
+skip = "requirements/*"
+
+# Pylint config
+[tool.pylint]
+ignore = [
+ "create_mesh.py", # Disable linting for now, as major rework is planned/expected
+]
+# Temporary fix for import neural_lam statements until set up as proper package
+init-hook='import sys; sys.path.append(".")'
+[tool.pylint.TYPECHECK]
+generated-members = [
+ "numpy.*",
+ "torch.*",
+]
+[tool.pylint.'MESSAGES CONTROL']
+disable = [
+ "C0114", # 'missing-module-docstring', Do not require module docstrings
+ "R0901", # 'too-many-ancestors', Allow many layers of sub-classing
+ "R0902", # 'too-many-instance-attribtes', Allow many attributes
+ "R0913", # 'too-many-arguments', Allow many function arguments
+ "R0914", # 'too-many-locals', Allow many local variables
+ "W0223", # 'abstract-method', Subclasses do not have to override all abstract methods
+]
+[tool.pylint.DESIGN]
+max-statements=100 # Allow for some more involved functions
+[tool.pylint.IMPORTS]
+allow-any-import-level="neural_lam"
+[tool.pylint.SIMILARITIES]
+min-similarity-lines=10
From 5508ceabf0bb6c8c86bd51aebee80c30d09f51c4 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:19:03 +0200
Subject: [PATCH 121/273] clean out tests
---
tests/__init__.py | 5 -----
tests/test_base.py | 13 -------------
2 files changed, 18 deletions(-)
delete mode 100644 tests/__init__.py
delete mode 100644 tests/test_base.py
diff --git a/tests/__init__.py b/tests/__init__.py
deleted file mode 100644
index 2f88fa16..00000000
--- a/tests/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import neural_lam
-
-
-def test_import():
- assert neural_lam is not None
diff --git a/tests/test_base.py b/tests/test_base.py
deleted file mode 100644
index 27228cfb..00000000
--- a/tests/test_base.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import neural_lam
-import neural_lam.create_grid_features
-import neural_lam.create_mesh
-import neural_lam.create_parameter_weights
-import neural_lam.train_model
-
-
-def test_import():
- assert neural_lam is not None
- assert neural_lam.create_mesh is not None
- assert neural_lam.create_grid_features is not None
- assert neural_lam.create_parameter_weights is not None
- assert neural_lam.train_model is not None
From 5c623c315724cd14a60d7f7d11ef640d5ba21c5e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:20:00 +0200
Subject: [PATCH 122/273] fix linting
---
neural_lam/create_grid_features.py | 2 +-
neural_lam/create_mesh.py | 2 +-
neural_lam/create_parameter_weights.py | 2 +-
neural_lam/interaction_net.py | 2 +-
neural_lam/models/ar_model.py | 2 +-
neural_lam/models/base_graph_model.py | 2 +-
neural_lam/models/base_hi_graph_model.py | 2 +-
neural_lam/models/graph_lam.py | 2 +-
neural_lam/models/hi_lam.py | 2 +-
neural_lam/models/hi_lam_parallel.py | 2 +-
neural_lam/train_model.py | 2 +-
neural_lam/vis.py | 2 +-
neural_lam/weather_dataset.py | 2 +-
tests/test_mllam_dataset.py | 3 +--
14 files changed, 14 insertions(+), 15 deletions(-)
diff --git a/neural_lam/create_grid_features.py b/neural_lam/create_grid_features.py
index 4d62fab2..adabd9dc 100644
--- a/neural_lam/create_grid_features.py
+++ b/neural_lam/create_grid_features.py
@@ -6,7 +6,7 @@
import numpy as np
import torch
-# First-party
+# Local
from . import config
diff --git a/neural_lam/create_mesh.py b/neural_lam/create_mesh.py
index 1dbbf90f..40f7ba0e 100644
--- a/neural_lam/create_mesh.py
+++ b/neural_lam/create_mesh.py
@@ -12,7 +12,7 @@
import torch_geometric as pyg
from torch_geometric.utils.convert import from_networkx
-# First-party
+# Local
from . import config
diff --git a/neural_lam/create_parameter_weights.py b/neural_lam/create_parameter_weights.py
index c109c5b2..a33b56b2 100644
--- a/neural_lam/create_parameter_weights.py
+++ b/neural_lam/create_parameter_weights.py
@@ -10,7 +10,7 @@
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
-# First-party
+# Local
from . import config
from .weather_dataset import WeatherDataset
diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py
index a81a3ab4..2f45b03f 100644
--- a/neural_lam/interaction_net.py
+++ b/neural_lam/interaction_net.py
@@ -3,7 +3,7 @@
import torch_geometric as pyg
from torch import nn
-# First-party
+# Local
from . import utils
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 53949fb6..e94de8c6 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -8,7 +8,7 @@
import torch
import wandb
-# First-party
+# Local
from .. import config, metrics, utils, vis
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 77be82eb..99629073 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -1,7 +1,7 @@
# Third-party
import torch
-# First-party
+# Local
from .. import utils
from ..interaction_net import InteractionNet
from .ar_model import ARModel
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index 5529ed4b..a2ebcc1b 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -1,7 +1,7 @@
# Third-party
from torch import nn
-# First-party
+# Local
from .. import utils
from ..interaction_net import InteractionNet
from .base_graph_model import BaseGraphModel
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index ff641c20..d73f7ad8 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -1,7 +1,7 @@
# Third-party
import torch_geometric as pyg
-# First-party
+# Local
from .. import utils
from ..interaction_net import InteractionNet
from .base_graph_model import BaseGraphModel
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index df9d3cbb..4f3aec05 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -1,7 +1,7 @@
# Third-party
from torch import nn
-# First-party
+# Local
from ..interaction_net import InteractionNet
from .base_hi_graph_model import BaseHiGraphModel
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index d6dc27ee..b40a9424 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -2,7 +2,7 @@
import torch
import torch_geometric as pyg
-# First-party
+# Local
from ..interaction_net import InteractionNet
from .base_hi_graph_model import BaseHiGraphModel
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index eaf675ed..dd1ad313 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -9,7 +9,7 @@
import torch
from lightning_fabric.utilities import seed
-# First-party
+# Local
from . import config, utils
from .models.graph_lam import GraphLAM
from .models.hi_lam import HiLAM
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index fb4bba96..2f22bef1 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -3,7 +3,7 @@
import matplotlib.pyplot as plt
import numpy as np
-# First-party
+# Local
from . import utils
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 8f30407e..29977789 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -7,7 +7,7 @@
import numpy as np
import torch
-# First-party
+# Local
from . import utils
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
index cb992db5..3edf4198 100644
--- a/tests/test_mllam_dataset.py
+++ b/tests/test_mllam_dataset.py
@@ -6,9 +6,8 @@
import pooch
import pytest
-from neural_lam.config import Config
-
# First-party
+from neural_lam.config import Config
from neural_lam.create_mesh import main as create_mesh
from neural_lam.train_model import main as train_model
from neural_lam.utils import load_static_data
From 08ec168c52d6a62774ef84797db70c018f7642f0 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:22:03 +0200
Subject: [PATCH 123/273] add cli entrypoints import test
---
tests/test_cli.py | 18 ++++++++++++++++++
1 file changed, 18 insertions(+)
create mode 100644 tests/test_cli.py
diff --git a/tests/test_cli.py b/tests/test_cli.py
new file mode 100644
index 00000000..e90daa04
--- /dev/null
+++ b/tests/test_cli.py
@@ -0,0 +1,18 @@
+# First-party
+import neural_lam
+import neural_lam.create_grid_features
+import neural_lam.create_mesh
+import neural_lam.create_parameter_weights
+import neural_lam.train_model
+
+
+def test_import():
+ """
+ This test just ensures that each cli entry-point can be imported for now,
+ eventually we should test their execution too
+ """
+ assert neural_lam is not None
+ assert neural_lam.create_mesh is not None
+ assert neural_lam.create_grid_features is not None
+ assert neural_lam.create_parameter_weights is not None
+ assert neural_lam.train_model is not None
From 3954f04f4c734fe324bca387a68f49f25d4f919a Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 16 Jul 2024 23:51:07 +0200
Subject: [PATCH 124/273] tweak cicd pytest execution
---
.github/workflows/run_tests.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml
index 4c677908..810f2b2c 100644
--- a/.github/workflows/run_tests.yml
+++ b/.github/workflows/run_tests.yml
@@ -35,7 +35,7 @@ jobs:
${{ runner.os }}-meps-reduced-example-data-v0.1.0
- name: Test with pytest
run: |
- pytest -v -s
+ python -m pytest -v -s tests/
- name: Save cache data
uses: actions/cache/save@v4
with:
From db9d96f8379b8c1c053fbf808c522b17e64b5fef Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 17 Jul 2024 10:37:42 +0200
Subject: [PATCH 125/273] Update tests/test_mllam_dataset.py
Co-authored-by: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com>
---
tests/test_mllam_dataset.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
index 3edf4198..e12a57ae 100644
--- a/tests/test_mllam_dataset.py
+++ b/tests/test_mllam_dataset.py
@@ -27,7 +27,7 @@
)
-@pytest.fixture
+@pytest.fixture(scope="module")
def meps_example_reduced_filepath():
# Download and unzip test data into data/meps_example_reduced
pooch.retrieve(
From 3c864b282881f3add6467171841293389c29f080 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 17 Jul 2024 12:06:21 +0200
Subject: [PATCH 126/273] grid-shape ok
---
.../{create_mesh.py => create_graph.py} | 149 +++++---
neural_lam/datastore/__init__.py | 4 +-
neural_lam/datastore/base.py | 207 +++++-----
neural_lam/datastore/mllam.py | 176 +++++----
neural_lam/datastore/multizarr/__init__.py | 3 +-
neural_lam/datastore/multizarr/config.py | 2 +-
.../multizarr/create_auxiliary_forcings.py | 29 +-
.../multizarr}/create_grid_features.py | 0
.../multizarr/create_normalization_stats.py | 28 +-
neural_lam/datastore/multizarr/store.py | 176 +++++----
neural_lam/datastore/npyfiles/__init__.py | 3 +-
neural_lam/datastore/npyfiles/config.py | 9 +-
neural_lam/datastore/npyfiles/store.py | 352 +++++++++++-------
neural_lam/interaction_net.py | 27 +-
neural_lam/metrics.py | 26 +-
neural_lam/models/base_hi_graph_model.py | 29 +-
neural_lam/train_model.py | 2 +-
neural_lam/utils.py | 17 +-
plot_graph.py | 4 +-
tests/datastore_configs/npy/.gitignore | 2 +
tests/datastore_configs/npy/data_config.yaml | 40 ++
tests/test_cli.py | 16 +-
tests/test_datastores.py | 42 +++
tests/test_mllam_dataset.py | 6 +
tests/test_multizarr_dataset.py | 49 ++-
tests/test_npy_forecast_dataset.py | 39 +-
26 files changed, 876 insertions(+), 561 deletions(-)
rename neural_lam/{create_mesh.py => create_graph.py} (89%)
rename neural_lam/{ => datastore/multizarr}/create_grid_features.py (100%)
create mode 100644 tests/datastore_configs/npy/.gitignore
create mode 100644 tests/datastore_configs/npy/data_config.yaml
create mode 100644 tests/test_datastores.py
diff --git a/neural_lam/create_mesh.py b/neural_lam/create_graph.py
similarity index 89%
rename from neural_lam/create_mesh.py
rename to neural_lam/create_graph.py
index 40f7ba0e..de73a9c8 100644
--- a/neural_lam/create_mesh.py
+++ b/neural_lam/create_graph.py
@@ -13,7 +13,11 @@
from torch_geometric.utils.convert import from_networkx
# Local
-from . import config
+# from . import config
+from .datastore.base import BaseCartesianDatastore
+from .datastore.mllam import MLLAMDatastore
+from .datastore.multizarr import MultiZarrDatastore
+from .datastore.npyfiles import NumpyFilesDatastore
def plot_graph(graph, title=None):
@@ -153,49 +157,15 @@ def prepend_node_index(graph, new_index):
return networkx.relabel_nodes(graph, to_mapping, copy=True)
-def main(input_args=None):
- parser = ArgumentParser(description="Graph generation arguments")
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
- parser.add_argument(
- "--graph",
- type=str,
- default="multiscale",
- help="Name to save graph as (default: multiscale)",
- )
- parser.add_argument(
- "--plot",
- type=int,
- default=0,
- help="If graphs should be plotted during generation "
- "(default: 0 (false))",
- )
- parser.add_argument(
- "--levels",
- type=int,
- help="Limit multi-scale mesh to given number of levels, "
- "from bottom up (default: None (no limit))",
- )
- parser.add_argument(
- "--hierarchical",
- type=int,
- default=0,
- help="Generate hierarchical mesh graph (default: 0, no)",
- )
- args = parser.parse_args(input_args)
-
- # Load grid positions
- config_loader = config.Config.from_file(args.data_config)
- static_dir_path = os.path.join("data", config_loader.dataset.name, "static")
- graph_dir_path = os.path.join("graphs", args.graph)
+def create_graph(
+ graph_dir_path: str,
+ xy: np.ndarray,
+ n_max_levels: int,
+ hierarchical: bool,
+ create_plot: bool,
+):
os.makedirs(graph_dir_path, exist_ok=True)
- xy = np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
-
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
@@ -209,9 +179,9 @@ def main(input_args=None):
nleaf = nx**nlev # leaves at the bottom = nleaf**2
mesh_levels = nlev - 1
- if args.levels:
+ if n_max_levels:
# Limit the levels in mesh graph
- mesh_levels = min(mesh_levels, args.levels)
+ mesh_levels = min(mesh_levels, n_max_levels)
print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}")
@@ -220,13 +190,13 @@ def main(input_args=None):
for lev in range(1, mesh_levels + 1):
n = int(nleaf / (nx**lev))
g = mk_2d_graph(xy, n, n)
- if args.plot:
+ if create_plot:
plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}")
plt.show()
G.append(g)
- if args.hierarchical:
+ if hierarchical:
# Relabel nodes of each level with level index first
G = [
prepend_node_index(graph, level_i)
@@ -299,7 +269,7 @@ def main(input_args=None):
up_graphs.append(pyg_up)
down_graphs.append(pyg_down)
- if args.plot:
+ if create_plot:
plot_graph(
pyg_down, title=f"Down graph, {from_level} -> {to_level}"
)
@@ -365,7 +335,7 @@ def main(input_args=None):
m2m_graphs = [pyg_m2m]
mesh_pos = [pyg_m2m.pos.to(torch.float32)]
- if args.plot:
+ if create_plot:
plot_graph(pyg_m2m, title="Mesh-to-mesh")
plt.show()
@@ -446,7 +416,7 @@ def main(input_args=None):
pyg_g2m = from_networkx(G_g2m)
- if args.plot:
+ if create_plot:
plot_graph(pyg_g2m, title="Grid-to-mesh")
plt.show()
@@ -485,7 +455,7 @@ def main(input_args=None):
)
pyg_m2g = from_networkx(G_m2g_int)
- if args.plot:
+ if create_plot:
plot_graph(pyg_m2g, title="Mesh-to-grid")
plt.show()
@@ -496,5 +466,82 @@ def main(input_args=None):
save_edges(pyg_m2g, "m2g", graph_dir_path)
+DATASTORES = dict(
+ multizarr=MultiZarrDatastore,
+ mllam=MLLAMDatastore,
+ npyfiles=NumpyFilesDatastore,
+)
+
+
+def create_graph_from_datastore(
+ datastore: BaseCartesianDatastore,
+ graph_dir_path: str,
+ n_max_levels: int = None,
+ hierarchical: bool = False,
+ create_plot: bool = False,
+):
+ xy = datastore.get_xy(category="state", stacked=False)
+ create_graph(
+ graph_dir_path=graph_dir_path,
+ xy=xy,
+ n_max_levels=n_max_levels,
+ hierarchical=hierarchical,
+ create_plot=create_plot,
+ )
+
+
+def cli(input_args=None):
+ parser = ArgumentParser(description="Graph generation arguments")
+ parser.add_argument(
+ "datastore",
+ type=str,
+ default="multizarr",
+ choices=DATASTORES.keys(),
+ help="kind of data store to use (default: multizarr)",
+ )
+ parser.add_argument(
+ "datastore-path",
+ type=str,
+ help="path to the data store",
+ )
+ parser.add_argument(
+ "--graph",
+ type=str,
+ default="multiscale",
+ help="Name to save graph as (default: multiscale)",
+ )
+ parser.add_argument(
+ "--plot",
+ type=int,
+ default=0,
+ help="If graphs should be plotted during generation "
+ "(default: 0 (false))",
+ )
+ parser.add_argument(
+ "--levels",
+ type=int,
+ help="Limit multi-scale mesh to given number of levels, "
+ "from bottom up (default: None (no limit))",
+ )
+ parser.add_argument(
+ "--hierarchical",
+ type=int,
+ default=0,
+ help="Generate hierarchical mesh graph (default: 0, no)",
+ )
+ args = parser.parse_args(input_args)
+
+ DatastoreClass = DATASTORES[args.datastore]
+ datastore = DatastoreClass(args.datastore_path)
+
+ create_graph_from_datastore(
+ datastore=datastore,
+ graph_dir_path=os.path.join("graphs", args.graph),
+ n_max_levels=args.levels,
+ hierarchical=args.hierarchical,
+ create_plot=args.plot,
+ )
+
+
if __name__ == "__main__":
- main()
+ cli()
diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py
index c0a77d00..ef20291a 100644
--- a/neural_lam/datastore/__init__.py
+++ b/neural_lam/datastore/__init__.py
@@ -1,2 +1,2 @@
-from .npyfiles.store import NpyConfig
-from .mllam import MLLAMDatastore
\ No newline at end of file
+# Local
+from .mllam import MLLAMDatastore # noqa
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 2c7470fc..81d4e0b8 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -1,31 +1,40 @@
+# Standard library
+import abc
+import dataclasses
+from typing import List, Union
+
+# Third-party
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
-from typing import List, Dict, Union
-import abc
-import dataclasses
-
class BaseDatastore(abc.ABC):
- """
- Base class for weather data used in the neural-lam package. A datastore
+ """Base class for weather data used in the neural-lam package. A datastore
defines the interface for accessing weather data by providing methods to
access the data in a processed format that can be used for training and
evaluation of neural networks.
-
- If `is_ensemble` is True, the dataset is assumed to have an `ensemble_member` dimension.
- If `is_forecast` is True, the dataset is assumed to have a `analysis_time` dimension.
+
+ # Forecast vs analysis data
+ If the datastore should represent forecast rather than analysis data, then
+ the `is_forecast` attribute should be set to True, and returned data from
+ `get_dataarray` is assumed to have `analysis_time` and `forecast_time` dimensions
+ (rather than just `time`).
+
+ # Ensemble vs deterministic data
+ If the datastore should represent ensemble data, then the `is_ensemble`
+ attribute should be set to True, and returned data from `get_dataarray` is
+ assumed to have an `ensemble_member` dimension.
"""
+
is_ensemble: bool = False
is_forecast: bool = False
@property
@abc.abstractmethod
def step_length(self) -> int:
- """
- The step length of the dataset in hours.
-
+ """The step length of the dataset in hours.
+
Returns:
int: The step length in hours.
"""
@@ -33,9 +42,8 @@ def step_length(self) -> int:
@abc.abstractmethod
def get_vars_units(self, category: str) -> List[str]:
- """
- Get the units of the variables in the given category.
-
+ """Get the units of the variables in the given category.
+
Parameters
----------
category : str
@@ -47,12 +55,11 @@ def get_vars_units(self, category: str) -> List[str]:
The units of the variables.
"""
pass
-
+
@abc.abstractmethod
def get_vars_names(self, category: str) -> List[str]:
- """
- Get the names of the variables in the given category.
-
+ """Get the names of the variables in the given category.
+
Parameters
----------
category : str
@@ -64,12 +71,11 @@ def get_vars_names(self, category: str) -> List[str]:
The names of the variables.
"""
pass
-
+
@abc.abstractmethod
def get_num_data_vars(self, category: str) -> int:
- """
- Get the number of data variables in the given category.
-
+ """Get the number of data variables in the given category.
+
Parameters
----------
category : str
@@ -81,19 +87,18 @@ def get_num_data_vars(self, category: str) -> int:
The number of data variables.
"""
pass
-
@abc.abstractmethod
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """
- Return the normalization dataarray for the given category. This should contain
- a `{category}_mean` and `{category}_std` variable for each variable in the category.
- For `category=="state"`, the dataarray should also contain a `state_diff_mean` and
- `state_diff_std` variable for the one-step differences of the state variables. The
- return dataarray should at least have dimensions of `({category}_feature)`, but can
- also include for example `grid_index` (if the normalisation is done per grid point for
- example).
-
+ """Return the normalization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one-step differences of the state variables. The return
+ dataarray should at least have dimensions of `({category}_feature)`,
+ but can also include for example `grid_index` (if the normalisation is
+ done per grid point for example).
+
Parameters
----------
category : str
@@ -102,22 +107,25 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
Returns
-------
xr.Dataset
- The normalization dataarray for the given category, with variables for the mean
- and standard deviation of the variables (and differences for state variables).
+ The normalization dataarray for the given category, with variables
+ for the mean and standard deviation of the variables (and
+ differences for state variables).
"""
pass
-
+
@abc.abstractmethod
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
- """
- Return the processed data (as a single `xr.DataArray`) for the given category and
- test/train/val-split that covers the entire timeline of the dataset.
- The returned dataarray is expected to at minimum have dimensions of `(time, grid_index, {category}_feature)` so
- that any spatial dimensions have been stacked into a single dimension and all variables
- and levels have been stacked into a single feature dimension named by the `category` of data being loaded.
- Any additional dimensions (for example `ensemble_member` or `analysis_time`) should be kept as separate
- dimensions in the dataarray, and `WeatherDataset` will handle the sampling of the data.
-
+ """Return the processed data (as a single `xr.DataArray`) for the given
+ category and test/train/val-split that covers the entire timeline of
+ the dataset. The returned dataarray is expected to at minimum have
+ dimensions of `(time, grid_index, {category}_feature)` so that any
+ spatial dimensions have been stacked into a single dimension and all
+ variables and levels have been stacked into a single feature dimension
+ named by the `category` of data being loaded. Any additional dimensions
+ (for example `ensemble_member` or `analysis_time`) should be kept as
+ separate dimensions in the dataarray, and `WeatherDataset` will handle
+ the sampling of the data.
+
Parameters
----------
category : str
@@ -131,14 +139,13 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
The xarray DataArray object with processed dataset.
"""
pass
-
+
@property
@abc.abstractmethod
def boundary_mask(self):
- """
- Return the boundary mask for the dataset, with spatial dimensions stacked.
- Where the value is 1, the grid point is a boundary point, and where the value is 0,
- the grid point is not a boundary point.
+ """Return the boundary mask for the dataset, with spatial dimensions
+ stacked. Where the value is 1, the grid point is a boundary point, and
+ where the value is 0, the grid point is not a boundary point.
Returns
-------
@@ -146,32 +153,32 @@ def boundary_mask(self):
The boundary mask for the dataset, with dimensions `('grid_index',)`.
"""
pass
-
+
@dataclasses.dataclass
class CartesianGridShape:
- """
- Dataclass to store the shape of a grid.
- """
+ """Dataclass to store the shape of a grid."""
+
x: int
y: int
-
+
class BaseCartesianDatastore(BaseDatastore):
- """
- Base class for weather data stored on a Cartesian grid. In addition
- to the methods and attributes required for weather data in general
- (see `BaseDatastore`) for Cartesian gridded source data each `grid_index`
- coordinate value is assume to have an associated `x` and `y`-value so
- that the processed data-arrays can be reshaped back into into 2D xy-gridded arrays.
+ """Base class for weather data stored on a Cartesian grid. In addition to
+ the methods and attributes required for weather data in general (see
+ `BaseDatastore`) for Cartesian gridded source data each `grid_index`
+ coordinate value is assume to have an associated `x` and `y`-value so that
+ the processed data-arrays can be reshaped back into into 2D xy-gridded
+ arrays.
In addition the following attributes and methods are required:
- `coords_projection` (property): Projection object for the coordinates.
- `grid_shape_state` (property): Shape of the grid for the state variables.
- - `get_xy_extent` (method): Return the extent of the x, y coordinates for a given category of data.
+ - `get_xy_extent` (method): Return the extent of the x, y coordinates for a
+ given category of data.
- `get_xy` (method): Return the x, y coordinates of the dataset.
"""
-
+
CARTESIAN_COORDS = ["y", "x"]
@property
@@ -187,25 +194,24 @@ def coords_projection(self) -> ccrs.Projection:
The projection object.
"""
pass
-
+
@property
@abc.abstractmethod
def grid_shape_state(self) -> CartesianGridShape:
- """
- The shape of the grid for the state variables.
-
+ """The shape of the grid for the state variables.
+
Returns
-------
CartesianGridShape:
- The shape of the grid for the state variables, which has `x` and `y` attributes.
+ The shape of the grid for the state variables, which has `x` and
+ `y` attributes.
"""
pass
-
+
@abc.abstractmethod
def get_xy(self, category: str, stacked: bool) -> np.ndarray:
- """
- Return the x, y coordinates of the dataset.
-
+ """Return the x, y coordinates of the dataset.
+
Parameters
----------
category : str
@@ -215,23 +221,24 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
Returns
-------
- np.ndarray or tuple(np.ndarray, np.ndarray)
- The x, y coordinates of the dataset with shape `(2, N_y, N_x)` if `stacked=True` or
- a tuple of two arrays with shape `((N_y, N_x), (N_y, N_x))` if `stacked=False`.
+ np.ndarray
+ The x, y coordinates of the dataset, returned differently based on the
+ value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(2, N_y, N_x)`
"""
pass
-
+
def get_xy_extent(self, category: str) -> List[float]:
- """
- Return the extent of the x, y coordinates for a given category of data.
- The extent should be returned as a list of 4 floats with `[xmin, xmax, ymin, ymax]`
- which can then be used to set the extent of a plot.
-
+ """Return the extent of the x, y coordinates for a given category of
+ data. The extent should be returned as a list of 4 floats with `[xmin,
+ xmax, ymin, ymax]` which can then be used to set the extent of a plot.
+
Parameters
----------
category : str
The category of the dataset (state/forcing/static).
-
+
Returns
-------
List[float]
@@ -239,37 +246,43 @@ def get_xy_extent(self, category: str) -> List[float]:
"""
xy = self.get_xy(category, stacked=False)
return [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
-
- def unstack_grid_coords(self, da_or_ds: Union[xr.DataArray, xr.Dataset]) -> Union[xr.DataArray, xr.Dataset]:
- """
- Stack the spatial grid coordinates into separate `x` and `y` dimensions (the names
- can be set by the `CARTESIAN_COORDS` attribute) to create a 2D grid.
-
+
+ def unstack_grid_coords(
+ self, da_or_ds: Union[xr.DataArray, xr.Dataset]
+ ) -> Union[xr.DataArray, xr.Dataset]:
+ """Stack the spatial grid coordinates into separate `x` and `y`
+ dimensions (the names can be set by the `CARTESIAN_COORDS` attribute)
+ to create a 2D grid.
+
Parameters
----------
da_or_ds : xr.DataArray or xr.Dataset
The dataarray or dataset to unstack the grid coordinates of.
-
+
Returns
-------
xr.DataArray or xr.Dataset
The dataarray or dataset with the grid coordinates unstacked.
"""
- return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack("grid_index")
-
- def stack_grid_coords(self, da_or_ds: Union[xr.DataArray, xr.Dataset]) -> Union[xr.DataArray, xr.Dataset]:
- """
- Stack the spatial grid coordinated (by default `x` and `y`, but this can be set by the
- `CARTESIAN_COORDS` attribute) into a single `grid_index` dimension.
+ return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack(
+ "grid_index"
+ )
+
+ def stack_grid_coords(
+ self, da_or_ds: Union[xr.DataArray, xr.Dataset]
+ ) -> Union[xr.DataArray, xr.Dataset]:
+ """Stack the spatial grid coordinated (by default `x` and `y`, but this
+ can be set by the `CARTESIAN_COORDS` attribute) into a single
+ `grid_index` dimension.
Parameters
----------
da_or_ds : xr.DataArray or xr.Dataset
The dataarray or dataset to stack the grid coordinates of.
-
+
Returns
-------
xr.DataArray or xr.Dataset
The dataarray or dataset with the grid coordinates stacked.
"""
- return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
\ No newline at end of file
+ return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index f822dd03..38bd8106 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -1,31 +1,33 @@
-from typing import List, Union
+# Standard library
from pathlib import Path
+from typing import List
+# Third-party
+import cartopy.crs as ccrs
+import mllam_data_prep as mdp
+import xarray as xr
from numpy import ndarray
+# Local
from .base import BaseCartesianDatastore, CartesianGridShape
-import mllam_data_prep as mdp
-import xarray as xr
-import cartopy.crs as ccrs
-
class MLLAMDatastore(BaseCartesianDatastore):
- """
- Datastore class for the MLLAM dataset.
- """
+ """Datastore class for the MLLAM dataset."""
def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
- """
- Construct a new MLLAMDatastore from the configuration file at `config_path`. A boundary mask
- is created with `n_boundary_points` boundary points. If `reuse_existing` is True, the dataset
- is loaded from a zarr file if it exists, otherwise it is created from the configuration file.
+ """Construct a new MLLAMDatastore from the configuration file at
+ `config_path`. A boundary mask is created with `n_boundary_points`
+ boundary points. If `reuse_existing` is True, the dataset is loaded
+ from a zarr file if it exists, otherwise it is created from the
+ configuration file.
Parameters
----------
config_path : str
- The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file`
- method to then call `mllam_data_prep.create_dataset` to create the dataset.
+ The path to the configuration file, this will be fed to the
+ `mllam_data_prep.Config.from_yaml_file` method to then call
+ `mllam_data_prep.create_dataset` to create the dataset.
n_boundary_points : int
The number of boundary points to use in the boundary mask.
reuse_existing : bool
@@ -33,7 +35,9 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
"""
self._config_path = Path(config_path)
self._config = mdp.Config.from_yaml_file(config_path)
- fp_ds = self._config_path.parent / self._config_path.name.replace(".yaml", ".zarr")
+ fp_ds = self._config_path.parent / self._config_path.name.replace(
+ ".yaml", ".zarr"
+ )
if reuse_existing and fp_ds.exists():
self._ds = xr.open_zarr(fp_ds, consolidated=True)
else:
@@ -41,38 +45,45 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
if reuse_existing:
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points
-
+
def step_length(self) -> int:
da_dt = self._ds["time"].diff("time")
return da_dt.dt.seconds[0] // 3600
-
+
def get_vars_units(self, category: str) -> List[str]:
return self._ds[f"{category}_unit"].values.tolist()
-
+
def get_vars_names(self, category: str) -> List[str]:
- import ipdb; ipdb.set_trace()
- return self._ds[f"{category}_longname"].values.tolist()
-
+ return self._ds[f"{category}_longname"].values.tolist()
+
def get_num_data_vars(self, category: str) -> int:
return self._ds[f"{category}_feature"].count().item()
-
+
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
da_category = self._ds[category]
-
+
if "time" not in da_category.dims:
return da_category
else:
- t_start = self._ds.splits.sel(split_name=split, split_part="start").load().item()
- t_end = self._ds.splits.sel(split_name=split, split_part="end").load().item()
+ t_start = (
+ self._ds.splits.sel(split_name=split, split_part="start")
+ .load()
+ .item()
+ )
+ t_end = (
+ self._ds.splits.sel(split_name=split, split_part="end")
+ .load()
+ .item()
+ )
return da_category.sel(time=slice(t_start, t_end))
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """
- Return the normalization dataarray for the given category. This should contain
- a `{category}_mean` and `{category}_std` variable for each variable in the category.
- For `category=="state"`, the dataarray should also contain a `state_diff_mean` and
- `state_diff_std` variable for the one-step differences of the state variables.
-
+ """Return the normalization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one-step differences of the state variables.
+
Parameters
----------
category : str
@@ -81,70 +92,81 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
Returns
-------
xr.Dataset
- The normalization dataarray for the given category, with variables for the mean
- and standard deviation of the variables (and differences for state variables).
+ The normalization dataarray for the given category, with variables
+ for the mean and standard deviation of the variables (and
+ differences for state variables).
"""
ops = ["mean", "std"]
split = "train"
stats_variables = {
- f"{category}__{split}__{op}": f"{category}_{op}"
- for op in ops
+ f"{category}__{split}__{op}": f"{category}_{op}" for op in ops
}
if category == "state":
- stats_variables.update({
- f"state__{split}__diff_{op}": f"state_diff_{op}"
- for op in ops
- })
+ stats_variables.update(
+ {f"state__{split}__diff_{op}": f"state_diff_{op}" for op in ops}
+ )
ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)
return ds_stats
-
-
+
@property
def boundary_mask(self) -> xr.DataArray:
- """
- Produce a 0/1 mask for the boundary points of the dataset, these will sit at the edges of the
- domain (in x/y extent) and will be used to mask out the boundary points from the loss function
- and to overwrite the boundary points from the prediction. For now this is created when the mask
- is requested, but in the future this could be saved to the zarr file.
+ """Produce a 0/1 mask for the boundary points of the dataset, these
+ will sit at the edges of the domain (in x/y extent) and will be used to
+ mask out the boundary points from the loss function and to overwrite
+ the boundary points from the prediction. For now this is created when
+ the mask is requested, but in the future this could be saved to the
+ zarr file.
Returns
-------
xr.DataArray
- A 0/1 mask for the boundary points of the dataset, where 1 is a boundary point and 0 is not.
+ A 0/1 mask for the boundary points of the dataset, where 1 is a
+ boundary point and 0 is not.
"""
ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
- da_state_variable = ds_unstacked["state"].isel(time=0).isel(state_feature=0)
+ da_state_variable = (
+ ds_unstacked["state"].isel(time=0).isel(state_feature=0)
+ )
da_domain_allzero = xr.zeros_like(da_state_variable)
- ds_unstacked["boundary_mask"] = da_domain_allzero.isel(x=slice(self._n_boundary_points, -self._n_boundary_points), y=slice(self._n_boundary_points, -self._n_boundary_points))
+ ds_unstacked["boundary_mask"] = da_domain_allzero.isel(
+ x=slice(self._n_boundary_points, -self._n_boundary_points),
+ y=slice(self._n_boundary_points, -self._n_boundary_points),
+ )
ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(1)
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)
-
+
@property
def coords_projection(self) -> ccrs.Projection:
- # TODO: danra doesn't contain projection information yet, but the next version wil
- # for now we hardcode the projection
+ """Return the projection of the coordinates.
+
+ Returns
+ -------
+ ccrs.Projection
+ The projection of the coordinates.
+ """
+ # TODO: danra doesn't contain projection information yet, but the next
+ # version will for now we hardcode the projection
# XXX: this is wrong
return ccrs.PlateCarree()
-
+
@property
def grid_shape_state(self):
- """
- The shape of the cartesian grid for the state variables.
+ """The shape of the cartesian grid for the state variables.
Returns
-------
CartesianGridShape
The shape of the cartesian grid for the state variables.
"""
- return CartesianGridShape(
- x=self._ds["state"].x.size, y=self._ds["state"].y.size
- )
-
+ ds_state = self.unstack_grid_coords(self._ds["state"])
+ da_x, da_y = ds_state.x, ds_state.y
+ assert da_x.ndim == da_y.ndim == 1
+ return CartesianGridShape(x=da_x.size, y=da_y.size)
+
def get_xy(self, category: str, stacked: bool) -> ndarray:
- """
- Return the x, y coordinates of the dataset.
-
+ """Return the x, y coordinates of the dataset.
+
Parameters
----------
category : str
@@ -154,14 +176,28 @@ def get_xy(self, category: str, stacked: bool) -> ndarray:
Returns
-------
- np.ndarray or tuple(np.ndarray, np.ndarray)
- The x, y coordinates of the dataset with shape `(2, N_y, N_x)` if `stacked=True` or
- a tuple of two arrays with shape `((N_y, N_x), (N_y, N_x))` if `stacked=False`.
+ np.ndarray
+ The x, y coordinates of the dataset, returned differently based on the
+ value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(2, N_y, N_x)`
"""
- da_x = self._ds[category].x
- da_y = self._ds[category].y
+ # assume variables are stored in dimensions [grid_index, ...]
+ ds_category = self.unstack_grid_coords(da_or_ds=self._ds[category])
+
+ da_xs = ds_category.x
+ da_ys = ds_category.y
+
+ assert da_xs.ndim == da_ys.ndim == 1, "x and y coordinates must be 1D"
+
+ da_x, da_y = xr.broadcast(da_xs, da_ys)
+ da_xy = xr.concat([da_x, da_y], dim="grid_coord")
+
if stacked:
- x, y = xr.broadcast(da_x, da_y)
- return xr.concat([x, y], dim="xy").values
+ da_xy = da_xy.stack(grid_index=("y", "x")).transpose(
+ "grid_coord", "grid_index"
+ )
else:
- return da_x.values, da_y.values
\ No newline at end of file
+ da_xy = da_xy.transpose("grid_coord", "y", "x")
+
+ return da_xy.values
diff --git a/neural_lam/datastore/multizarr/__init__.py b/neural_lam/datastore/multizarr/__init__.py
index 491d4a18..c1958905 100644
--- a/neural_lam/datastore/multizarr/__init__.py
+++ b/neural_lam/datastore/multizarr/__init__.py
@@ -1 +1,2 @@
-from .store import MultiZarrDatastore
\ No newline at end of file
+# Local
+from .store import MultiZarrDatastore # noqa
diff --git a/neural_lam/datastore/multizarr/config.py b/neural_lam/datastore/multizarr/config.py
index 0d93ab70..3cbd9787 100644
--- a/neural_lam/datastore/multizarr/config.py
+++ b/neural_lam/datastore/multizarr/config.py
@@ -40,4 +40,4 @@ def __getitem__(self, key):
return value
def __contains__(self, key):
- return key in self.values
\ No newline at end of file
+ return key in self.values
diff --git a/neural_lam/datastore/multizarr/create_auxiliary_forcings.py b/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
index 9ce15a2a..eab6cd7b 100644
--- a/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
+++ b/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
@@ -18,15 +18,14 @@ def get_seconds_in_year(year):
def calculate_datetime_forcing(da_time: xr.DataArray):
- """
- Compute the datetime forcing for a given set of timesteps, assuming
- that timesteps is a DataArray with a type of `np.datetime64`.
-
+ """Compute the datetime forcing for a given set of timesteps, assuming that
+ timesteps is a DataArray with a type of `np.datetime64`.
+
Parameters
----------
timesteps : xr.DataArray
The timesteps for which to compute the datetime forcing.
-
+
Returns
-------
xr.Dataset
@@ -72,7 +71,11 @@ def calculate_datetime_forcing(da_time: xr.DataArray):
def main():
"""Main function for creating the datetime forcing and boundary mask."""
parser = argparse.ArgumentParser()
- parser.add_argument("--data-config", type=str, default="tests/datastore_configs/multizarr.danra.yaml")
+ parser.add_argument(
+ "--data-config",
+ type=str,
+ default="tests/datastore_configs/multizarr.danra.yaml",
+ )
parser.add_argument(
"--zarr_path",
type=str,
@@ -81,7 +84,7 @@ def main():
"(default: same directory as the data-config)",
)
args = parser.parse_args()
-
+
zarr_path = args.zarr_path
if zarr_path is None:
zarr_path = Path(args.data_config).parent / "datetime_forcings.zarr"
@@ -89,23 +92,27 @@ def main():
datastore = MultiZarrDatastore(config_path=args.data_config)
da_state = datastore.get_dataarray(category="state", split="train")
- da_datetime_forcing = calculate_datetime_forcing(da_time=da_state.time).expand_dims({"grid_index": da_state.grid_index})
+ da_datetime_forcing = calculate_datetime_forcing(
+ da_time=da_state.time
+ ).expand_dims({"grid_index": da_state.grid_index})
chunking = {"time": 1}
-
+
if "x" in da_state.coords and "y" in da_state.coords:
# copy the x and y coordinates to the datetime forcing
for aux_coord in ["x", "y"]:
da_datetime_forcing.coords[aux_coord] = da_state[aux_coord]
- da_datetime_forcing = da_datetime_forcing.set_index(grid_index=("y", "x")).unstack("grid_index")
+ da_datetime_forcing = da_datetime_forcing.set_index(
+ grid_index=("y", "x")
+ ).unstack("grid_index")
chunking["x"] = -1
chunking["y"] = -1
else:
chunking["grid_index"] = -1
da_datetime_forcing = da_datetime_forcing.chunk(chunking)
-
+
da_datetime_forcing.to_zarr(zarr_path, mode="w")
print(da_datetime_forcing)
print(f"Datetime forcing saved to {zarr_path}")
diff --git a/neural_lam/create_grid_features.py b/neural_lam/datastore/multizarr/create_grid_features.py
similarity index 100%
rename from neural_lam/create_grid_features.py
rename to neural_lam/datastore/multizarr/create_grid_features.py
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index a258fb6d..abccf333 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -7,7 +7,6 @@
# First-party
from neural_lam.datastore.multizarr import MultiZarrDatastore
-
DEFAULT_PATH = "tests/datastore_configs/multizarr.danra.yaml"
@@ -32,7 +31,7 @@ def main():
help="Directory where data is stored",
)
args = parser.parse_args()
-
+
datastore = MultiZarrDatastore(config_path=args.data_config)
da_state = datastore.get_dataarray(category="state", split="train")
@@ -43,29 +42,34 @@ def main():
if da_forcing is not None:
da_forcing_mean, da_forcing_std = compute_stats(da_forcing)
- combined_stats = datastore._config["utilities"]["normalization"]["combined_stats"]
+ combined_stats = datastore._config["utilities"]["normalization"][
+ "combined_stats"
+ ]
if combined_stats is not None:
for group in combined_stats:
vars_to_combine = group["vars"]
- import ipdb; ipdb.set_trace()
means = da_forcing_mean.sel(variable=vars_to_combine)
stds = da_forcing_std.sel(variable=vars_to_combine)
combined_mean = means.mean(dim="variable")
combined_std = (stds**2).mean(dim="variable") ** 0.5
- da_forcing_mean.loc[dict(variable=vars_to_combine)] = combined_mean
- da_forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
+ da_forcing_mean.loc[
+ dict(variable=vars_to_combine)
+ ] = combined_mean
+ da_forcing_std.loc[
+ dict(variable=vars_to_combine)
+ ] = combined_std
window = datastore._config["forcing"]["window"]
- da_forcing_mean = xr.concat([da_forcing_mean] * window, dim="window").stack(
- forcing_variable=("variable", "window")
- )
- da_forcing_std = xr.concat([da_forcing_std] * window, dim="window").stack(
- forcing_variable=("variable", "window")
- )
+ da_forcing_mean = xr.concat(
+ [da_forcing_mean] * window, dim="window"
+ ).stack(forcing_variable=("variable", "window"))
+ da_forcing_std = xr.concat(
+ [da_forcing_std] * window, dim="window"
+ ).stack(forcing_variable=("variable", "window"))
vars = da_forcing["variable"].values.tolist()
window = datastore._config["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 1abd11af..3b7e1fe9 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -1,30 +1,19 @@
+# Standard library
+import functools
+import os
+
+# Third-party
import cartopy.crs as ccrs
import numpy as np
import pandas as pd
import xarray as xr
import yaml
-import functools
-import os
-
-from .config import Config
-from ..base import BaseDatastore
-
+# Local
+from ..base import BaseCartesianDatastore, CartesianGridShape
-def convert_stats_to_torch(stats):
- """Convert the normalization statistics to torch tensors.
- Args:
- stats (xr.Dataset): The normalization statistics.
-
- Returns:
- dict(tensor): The normalization statistics as torch tensors."""
- return {
- var: torch.tensor(stats[var].values, dtype=torch.float32)
- for var in stats.data_vars
- }
-
-class MultiZarrDatastore(BaseDatastore):
+class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable"}
def __init__(self, config_path):
@@ -38,7 +27,8 @@ def open_zarrs(self, category):
category (str): The category of the dataset (state/forcing/static).
Returns:
- xr.Dataset: The xarray Dataset object."""
+ xr.Dataset: The xarray Dataset object.
+ """
zarr_configs = self._config[category]["zarrs"]
datasets = []
@@ -60,7 +50,8 @@ def coords_projection(self):
The projection object is used to plot the coordinates on a map.
Returns:
- cartopy.crs.Projection: The projection object."""
+ cartopy.crs.Projection: The projection object.
+ """
proj_config = self._config["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
@@ -72,7 +63,8 @@ def step_length(self):
"""Return the step length of the dataset in hours.
Returns:
- int: The step length in hours."""
+ int: The step length in hours.
+ """
dataset = self.open_zarrs("state")
time = dataset.time.isel(time=slice(0, 2)).values
step_length_ns = time[1] - time[0]
@@ -87,7 +79,8 @@ def get_vars_names(self, category):
category (str): The category of the dataset (state/forcing/static).
Returns:
- list: The names of the variables in the dataset."""
+ list: The names of the variables in the dataset.
+ """
surface_vars_names = self._config[category].get("surface_vars") or []
atmosphere_vars_names = [
f"{var}_{level}"
@@ -104,7 +97,8 @@ def get_vars_units(self, category):
category (str): The category of the dataset (state/forcing/static).
Returns:
- list: The units of the variables in the dataset."""
+ list: The units of the variables in the dataset.
+ """
surface_vars_units = self._config[category].get("surface_units") or []
atmosphere_vars_units = [
unit
@@ -121,7 +115,8 @@ def get_num_data_vars(self, category):
category (str): The category of the dataset (state/forcing/static).
Returns:
- int: The number of data variables in the dataset."""
+ int: The number of data variables in the dataset.
+ """
surface_vars = self._config[category].get("surface_vars", [])
atmosphere_vars = self._config[category].get("atmosphere_vars", [])
levels = self._config[category].get("levels", [])
@@ -143,7 +138,8 @@ def _stack_grid(self, ds):
ds (xr.Dataset): The xarray Dataset object.
Returns:
- xr.Dataset: The xarray Dataset object with stacked grid dimensions."""
+ xr.Dataset: The xarray Dataset object with stacked grid dimensions.
+ """
if "grid_index" in ds.dims:
raise ValueError("Grid dimensions already stacked.")
else:
@@ -162,7 +158,8 @@ def _convert_dataset_to_dataarray(self, dataset):
dataset (xr.Dataset): The xarray Dataset object.
Returns:
- xr.DataArray: The xarray DataArray object."""
+ xr.DataArray: The xarray DataArray object.
+ """
if isinstance(dataset, xr.Dataset):
dataset = dataset.to_array()
return dataset
@@ -176,7 +173,8 @@ def _filter_dimensions(self, dataset, transpose_array=True):
Returns:
xr.Dataset: The xarray Dataset object with filtered dimensions.
- OR xr.DataArray: The xarray DataArray object with filtered dimensions."""
+ OR xr.DataArray: The xarray DataArray object with filtered dimensions.
+ """
dims_to_keep = self.DIMS_TO_KEEP
dataset_dims = set(list(dataset.dims) + ["variable"])
min_req_dims = dims_to_keep.copy()
@@ -250,7 +248,8 @@ def _reshape_grid_to_2d(self, dataset, grid_shape=None):
grid_shape (dict): The shape of the grid.
Returns:
- xr.Dataset: The xarray Dataset object with reshaped grid dimensions."""
+ xr.Dataset: The xarray Dataset object with reshaped grid dimensions.
+ """
if grid_shape is None:
grid_shape = dict(self.grid_shape_state.values.items())
x_dim, y_dim = (grid_shape["x"], grid_shape["y"])
@@ -274,33 +273,52 @@ def _reshape_grid_to_2d(self, dataset, grid_shape=None):
def get_xy(self, category, stacked=True):
"""Return the x, y coordinates of the dataset.
- Args:
- category (str): The category of the dataset (state/forcing/static).
- stacked (bool): Whether to stack the x, y coordinates.
-
- Returns:
- np.ndarray: The x, y coordinates of the dataset (if stacked) (2, N_y, N_x)
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates.
- OR tuple(np.ndarray, np.ndarray): The x, y coordinates of the dataset
- (if not stacked) ((N_y, N_x), (N_y, N_x))"""
+ Returns
+ -------
+ np.ndarray
+ The x, y coordinates of the dataset, returned differently based on the
+ value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(2, N_y, N_x)`
+ """
dataset = self.open_zarrs(category)
- x, y = dataset.x.values, dataset.y.values
- if x.ndim == 1:
- x, y = np.meshgrid(x, y)
+ xs, ys = dataset.x.values, dataset.y.values
+
+ assert (
+ xs.ndim == ys.ndim
+ ), "x and y coordinates must have the same dimensions."
+
+ if xs.ndim == 1:
+ x, y = np.meshgrid(xs, ys)
+ elif x.ndim == 2:
+ x, y = xs, ys
+ else:
+ raise ValueError("Invalid dimensions for x, y coordinates.")
+
+ xy = np.stack((x, y), axis=0) # (2, N_y, N_x)
+
if stacked:
- xy = np.stack((x, y), axis=0) # (2, N_y, N_x)
- return xy
- return x, y
+ xy = xy.reshape(2, -1) # (2, n_grid_points)
+
+ return xy
def get_xy_extent(self, category):
- """Return the extent of the x, y coordinates. This should be a list
- of 4 floats with `[xmin, xmax, ymin, ymax]`
+ """Return the extent of the x, y coordinates. This should be a list of
+ 4 floats with `[xmin, xmax, ymin, ymax]`
Args:
category (str): The category of the dataset (state/forcing/static).
Returns:
- list(float): The extent of the x, y coordinates."""
+ list(float): The extent of the x, y coordinates.
+ """
x, y = self.get_xy(category, stacked=False)
if self.projection.inverted:
extent = [x.max(), x.min(), y.max(), y.min()]
@@ -310,7 +328,7 @@ def get_xy_extent(self, category):
return extent
@functools.lru_cache()
- def get_normalization_stats(self, category):
+ def get_normalization_dataarray(self, category):
"""Load the normalization statistics for the dataset.
Args:
@@ -335,7 +353,8 @@ def _load_and_merge_stats(self):
"""Load and merge the normalization statistics for the dataset.
Returns:
- xr.Dataset: The merged normalization statistics for the dataset."""
+ xr.Dataset: The merged normalization statistics for the dataset.
+ """
combined_stats = None
for i, zarr_config in enumerate(
self._config["utilities"]["normalization"]["zarrs"]
@@ -360,7 +379,8 @@ def _rename_data_vars(self, combined_stats):
Returns:
xr.Dataset: The combined normalization statistics with renamed data
- variables."""
+ variables.
+ """
vars_mapping = {}
for zarr_config in self._config["utilities"]["normalization"]["zarrs"]:
vars_mapping.update(zarr_config["stats_vars"])
@@ -381,9 +401,12 @@ def _select_stats_by_category(self, combined_stats, category):
category (str): The category of the dataset (state/forcing/static).
Returns:
- xr.Dataset: The normalization statistics for the dataset."""
+ xr.Dataset: The normalization statistics for the dataset.
+ """
if category == "state":
- stats = combined_stats.loc[dict(variable=self.get_vars_names(category=category))]
+ stats = combined_stats.loc[
+ dict(variable=self.get_vars_names(category=category))
+ ]
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
return stats
elif category == "forcing":
@@ -432,14 +455,16 @@ def _extract_vars(self, category, ds=None):
ds = self.open_zarrs(category)
surface_vars = self._config[category].get("surface_vars")
atmoshere_vars = self._config[category].get("atmosphere_vars")
-
+
ds_surface = None
if surface_vars is not None:
ds_surface = ds[surface_vars]
ds_atmosphere = None
if atmoshere_vars is not None:
- ds_atmosphere = self._extract_atmosphere_vars(category=category, ds=ds)
+ ds_atmosphere = self._extract_atmosphere_vars(
+ category=category, ds=ds
+ )
if ds_surface and ds_atmosphere:
return xr.merge([ds_surface, ds_atmosphere])
@@ -458,9 +483,13 @@ def _extract_atmosphere_vars(self, category, ds):
ds (xr.Dataset): The xarray Dataset object.
Returns:
- xr.Dataset: The xarray Dataset object with atmosphere variables."""
+ xr.Dataset: The xarray Dataset object with atmosphere variables.
+ """
- if "level" not in list(ds.dims) and self._config[category]["atmosphere_vars"]:
+ if (
+ "level" not in list(ds.dims)
+ and self._config[category]["atmosphere_vars"]
+ ):
ds = self._rename_dataset_dims_and_vars(
ds.attrs["category"], dataset=ds
)
@@ -488,7 +517,8 @@ def _rename_dataset_dims_and_vars(self, category, dataset=None):
xr.Dataset: The xarray Dataset object with renamed dimensions and
variables.
OR xr.DataArray: The xarray DataArray object with renamed
- dimensions and variables."""
+ dimensions and variables.
+ """
convert = False
if dataset is None:
dataset = self.open_zarrs(category)
@@ -522,7 +552,8 @@ def _apply_time_split(self, dataset, split="train"):
split (str): The time split to filter the dataset.
Returns:["window"]
- xr.Dataset: The xarray Dataset object filtered by the time split."""
+ xr.Dataset: The xarray Dataset object filtered by the time split.
+ """
start, end = (
self._config["splits"][split]["start"],
self._config["splits"][split]["end"],
@@ -539,7 +570,8 @@ def apply_window(self, category, dataset=None):
dataset (xr.Dataset): The xarray Dataset object.
Returns:
- xr.Dataset: The xarray Dataset object with the window applied."""
+ xr.Dataset: The xarray Dataset object with the window applied.
+ """
if dataset is None:
dataset = self.open_zarrs(category)
if isinstance(dataset, xr.Dataset):
@@ -559,18 +591,33 @@ def apply_window(self, category, dataset=None):
return dataset
@property
- def boundary_mask(self):
+ def grid_shape_state(self):
+ """Return the shape of the state grid.
+
+ Returns:
+ CartesianGridShape: The shape of the state grid.
"""
- Load the boundary mask for the dataset, with spatial dimensions stacked.
+ return CartesianGridShape(
+ x=self._config["grid_shape_state"]["x"],
+ y=self._config["grid_shape_state"]["y"],
+ )
+
+ @property
+ def boundary_mask(self):
+ """Load the boundary mask for the dataset, with spatial dimensions
+ stacked.
Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions `('grid_index',)`.
"""
- ds_boundary_mask = xr.open_zarr(self._config["boundary"]["mask"]["path"])
- return ds_boundary_mask.mask.stack(grid_index=("y", "x")).reset_index("grid_index")
-
+ ds_boundary_mask = xr.open_zarr(
+ self._config["boundary"]["mask"]["path"]
+ )
+ return ds_boundary_mask.mask.stack(grid_index=("y", "x")).reset_index(
+ "grid_index"
+ )
def get_dataarray(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category.
@@ -581,7 +628,8 @@ def get_dataarray(self, category, split="train", apply_windowing=True):
apply_windowing (bool): Whether to apply windowing to the forcing dataset.
Returns:
- xr.DataArray: The xarray DataArray object with processed dataset."""
+ xr.DataArray: The xarray DataArray object with processed dataset.
+ """
dataset = self.open_zarrs(category)
dataset = self._extract_vars(category, dataset)
if category != "static":
diff --git a/neural_lam/datastore/npyfiles/__init__.py b/neural_lam/datastore/npyfiles/__init__.py
index 57b47049..573b7070 100644
--- a/neural_lam/datastore/npyfiles/__init__.py
+++ b/neural_lam/datastore/npyfiles/__init__.py
@@ -1 +1,2 @@
-from .store import NumpyFilesDatastore
\ No newline at end of file
+# Local
+from .store import NumpyFilesDatastore # noqa
diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py
index 842c4b83..f3fe25ca 100644
--- a/neural_lam/datastore/npyfiles/config.py
+++ b/neural_lam/datastore/npyfiles/config.py
@@ -8,11 +8,10 @@
class NpyConfig:
- """
- Class for loading configuration files.
+ """Class for loading configuration files.
- This class loads a configuration file and provides a way to access its
- values as attributes.
+ This class loads a configuration file and provides a way to access
+ its values as attributes.
"""
def __init__(self, values):
@@ -29,7 +28,7 @@ def from_file(cls, filepath):
def __getattr__(self, name):
child, *children = name.split(".")
-
+
value = self.values[child]
if len(children) > 0:
return self.__class__(values=value).get(".".join(children))
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index f60cc83e..beb860c1 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -1,55 +1,52 @@
# Standard library
-import datetime as dt
-import glob
-import os
import re
from pathlib import Path
from typing import List
# Third-party
+import dask
+import dask.array
import dask.delayed
import numpy as np
-import torch
-from xarray.core.dataarray import DataArray
import parse
-import dask
-import dask.array
+import torch
import xarray as xr
+from xarray.core.dataarray import DataArray
-# First-party
-
-from ..base import BaseCartesianDatastore
+# Local
+from ..base import BaseCartesianDatastore, CartesianGridShape
from .config import NpyConfig
STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy"
-TOA_SW_DOWN_FLUX_FILENAME_FORMAT = "nwp_toa_downwelling_shortwave_flux_{analysis_time:%Y%m%d%H}.npy"
+TOA_SW_DOWN_FLUX_FILENAME_FORMAT = (
+ "nwp_toa_downwelling_shortwave_flux_{analysis_time:%Y%m%d%H}.npy"
+)
COLUMN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"
-
class NumpyFilesDatastore(BaseCartesianDatastore):
__doc__ = f"""
Represents a dataset stored as numpy files on disk. The dataset is assumed
to be stored in a directory structure where each sample is stored in a
separate file. The file-name format is assumed to be '{STATE_FILENAME_FORMAT}'
-
+
The MEPS dataset is organised into three splits: train, val, and test. Each
split has a set of files which are:
- `{STATE_FILENAME_FORMAT}`:
The state variables for a forecast started at `analysis_time` with
- member id `member_id`. The dimensions of the array are
+ member id `member_id`. The dimensions of the array are
`[forecast_timestep, y, x, feature]`.
-
+
- `{TOA_SW_DOWN_FLUX_FILENAME_FORMAT}`:
The top-of-atmosphere downwelling shortwave flux at `time`. The
dimensions of the array are `[forecast_timestep, y, x]`.
-
+
- `{COLUMN_WATER_FILENAME_FORMAT}`:
The column water at `time`. The dimensions of the array are
`[y, x]`.
-
+
Folder structure:
meps_example_reduced
@@ -104,7 +101,7 @@ class NumpyFilesDatastore(BaseCartesianDatastore):
├── parameter_std.pt
├── parameter_weights.npy
└── surface_geopotential.npy
-
+
For the MEPS dataset:
N_t' = 65
N_t = 65//subsample_step (= 21 for 3h steps)
@@ -113,7 +110,7 @@ class NumpyFilesDatastore(BaseCartesianDatastore):
N_grid = 268x238 = 63784
d_features = 17 (d_features' = 18)
d_forcing = 5
-
+
For the MEPS reduced dataset:
N_t' = 65
N_t = 65//subsample_step (= 21 for 3h steps)
@@ -137,13 +134,12 @@ def __init__(
self.root_path = Path(root_path)
self.config = NpyConfig.from_file(self.root_path / "data_config.yaml")
-
+
def get_dataarray(self, category: str, split: str) -> DataArray:
- """
- Get the data array for the given category and split of data. If the category
- is 'state', the data array will be a concatenation of the data arrays for all
- ensemble members. The data will be loaded as a dask array, so that the data
- isn't actually loaded until it's needed.
+ """Get the data array for the given category and split of data. If the
+ category is 'state', the data array will be a concatenation of the data
+ arrays for all ensemble members. The data will be loaded as a dask
+ array, so that the data isn't actually loaded until it's needed.
Parameters
----------
@@ -151,98 +147,134 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
The category of the data to load. One of 'state', 'forcing', or 'static'.
split : str
The dataset split to load the data for. One of 'train', 'val', or 'test'.
-
+
Returns
-------
xr.DataArray
- The data array for the given category and split, with dimensions per category:
- state: `[time, analysis_time, grid_index, feature, ensemble_member]`
- forcing: `[time, analysis_time, grid_index, feature]`
+ The data array for the given category and split, with dimensions
+ per category:
+ state: `[elapsed_forecast_time, analysis_time, grid_index, feature,
+ ensemble_member]`
+ forcing: `[elapsed_forecast_time, analysis_time, grid_index, feature]`
static: `[grid_index, feature]`
"""
if category == "state":
das = []
# for the state category, we need to load all ensemble members
for member in range(self._num_ensemble_members):
- da_member = self._get_single_timeseries_dataarray(features=self.get_vars_names(category="state"), split=split, member=member)
+ da_member = self._get_single_timeseries_dataarray(
+ features=self.get_vars_names(category="state"),
+ split=split,
+ member=member,
+ )
das.append(da_member)
da = xr.concat(das, dim="ensemble_member")
elif category == "forcing":
- # the forcing features are in separate files, so we need to load them separately
+ # the forcing features are in separate files, so we need to load
+ # them separately
features = ["toa_downwelling_shortwave_flux", "column_water"]
- das = [self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features]
+ das = [
+ self._get_single_timeseries_dataarray(
+ features=[feature], split=split
+ )
+ for feature in features
+ ]
da = xr.concat(das, dim="feature")
# add datetime forcing as a feature
- # to do this we create a forecast time variable which has the dimensions of
- # (analysis_time, elapsed_forecast_time) with values that are the actual forecast time of each
- # time step. By calling .chunk({"elapsed_forecast_time": 1}) this time variable is turned into
- # a dask array and so execution of the calculation is delayed until the feature
- # values are actually used.
- da_forecast_time = (da.analysis_time + da.elapsed_forecast_time).chunk({"elapsed_forecast_time": 1})
- da_datetime_forcing_features = self._calc_datetime_forcing_features(da_time=da_forecast_time)
+ # to do this we create a forecast time variable which has the
+ # dimensions of (analysis_time, elapsed_forecast_time) with values
+ # that are the actual forecast time of each time step. By calling
+ # .chunk({"elapsed_forecast_time": 1}) this time variable is turned
+ # into a dask array and so execution of the calculation is delayed
+ # until the feature values are actually used.
+ da_forecast_time = (
+ da.analysis_time + da.elapsed_forecast_time
+ ).chunk({"elapsed_forecast_time": 1})
+ da_datetime_forcing_features = self._calc_datetime_forcing_features(
+ da_time=da_forecast_time
+ )
da = xr.concat([da, da_datetime_forcing_features], dim="feature")
-
+
elif category == "static":
# the static features are collected in three files:
# - surface_geopotential
# - border_mask
# - x, y
das = []
- for features in [["surface_geopotential"], ["border_mask"], ["x", "y"]]:
- da = self._get_single_timeseries_dataarray(features=features, split=split)
+ for features in [
+ ["surface_geopotential"],
+ ["border_mask"],
+ ["x", "y"],
+ ]:
+ da = self._get_single_timeseries_dataarray(
+ features=features, split=split
+ )
das.append(da)
- da = xr.concat(das, dim="feature").transpose("grid_index", "feature")
+ da = xr.concat(das, dim="feature").transpose(
+ "grid_index", "feature"
+ )
else:
raise NotImplementedError(category)
-
+
da = da.rename(dict(feature=f"{category}_feature"))
-
+
# check that we have the right features
actual_features = da[f"{category}_feature"].values.tolist()
expected_features = self.get_vars_names(category=category)
if actual_features != expected_features:
- raise ValueError(f"Expected features {expected_features}, got {actual_features}")
-
+ raise ValueError(
+ f"Expected features {expected_features}, got {actual_features}"
+ )
+
return da
-
- def _get_single_timeseries_dataarray(self, features: List[str], split: str, member: int = None) -> DataArray:
- """
- Get the data array spanning the complete time series for a given set of features and split
- of data. For state features the `member` argument should be specified to select
- the ensemble member to load. The data will be loaded using dask.delayed, so that the data
- isn't actually loaded until it's needed.
+
+ def _get_single_timeseries_dataarray(
+ self, features: List[str], split: str, member: int = None
+ ) -> DataArray:
+ """Get the data array spanning the complete time series for a given set
+ of features and split of data. For state features the `member` argument
+ should be specified to select the ensemble member to load. The data
+ will be loaded using dask.delayed, so that the data isn't actually
+ loaded until it's needed.
Parameters
----------
features : List[str]
- The list of features to load the data for. For the 'state' category, this should be
- the result of `self.get_vars_names(category="state")`, for the 'forcing' category this
- should be the list of forcing features to load, and for the 'static' category this should
- be the list of static features to load.
+ The list of features to load the data for. For the 'state'
+ category, this should be the result of
+ `self.get_vars_names(category="state")`, for the 'forcing' category
+ this should be the list of forcing features to load, and for the
+ 'static' category this should be the list of static features to
+ load.
split : str
The dataset split to load the data for. One of 'train', 'val', or 'test'.
member : int, optional
The ensemble member to load. Only applicable for the 'state' category.
-
+
Returns
-------
xr.DataArray
The data array for the given category and split, with dimensions
- `[elapsed_forecast_time, analysis_time, grid_index, feature]` for all categories of data
+ `[elapsed_forecast_time, analysis_time, grid_index, feature]` for
+ all categories of data
"""
assert split in ("train", "val", "test"), "Unknown dataset split"
-
- if member is not None and features != self.get_vars_names(category="state"):
- raise ValueError("Member can only be specified for the 'state' category")
-
+
+ if member is not None and features != self.get_vars_names(
+ category="state"
+ ):
+ raise ValueError(
+ "Member can only be specified for the 'state' category"
+ )
+
# XXX: we here assume that the grid shape is the same for all categories
grid_shape = self.grid_shape_state
fp_samples = self.root_path / "samples" / split
-
+
file_params = {}
add_feature_dim = False
features_vary_with_analysis_time = True
@@ -264,7 +296,8 @@ def _get_single_timeseries_dataarray(self, features: List[str], split: str, memb
file_dims = ["y", "x", "feature"]
add_feature_dim = True
features_vary_with_analysis_time = False
- # XXX: surface_geopotential is the same for all splits, and so saved in static/
+ # XXX: surface_geopotential is the same for all splits, and so
+ # saved in static/
fp_samples = self.root_path / "static"
elif features == ["border_mask"]:
filename_format = "border_mask.npy"
@@ -280,18 +313,24 @@ def _get_single_timeseries_dataarray(self, features: List[str], split: str, memb
# XXX: x, y are the same for all splits, and so saved in static/
fp_samples = self.root_path / "static"
else:
- raise NotImplementedError(f"Reading of variables set `{features}` not supported")
-
+ raise NotImplementedError(
+ f"Reading of variables set `{features}` not supported"
+ )
+
if features_vary_with_analysis_time:
dims = ["analysis_time"] + file_dims
else:
dims = file_dims
-
+
coords = {}
arr_shape = []
for d in dims:
if d == "elapsed_forecast_time":
- coord_values = self.step_length * np.arange(self._num_timesteps) * np.timedelta64(1, "h")
+ coord_values = (
+ self.step_length
+ * np.arange(self._num_timesteps)
+ * np.timedelta64(1, "h")
+ )
elif d == "analysis_time":
coord_values = self._get_analysis_times(split=split)
elif d == "y":
@@ -302,24 +341,28 @@ def _get_single_timeseries_dataarray(self, features: List[str], split: str, memb
coord_values = features
else:
raise NotImplementedError(f"Dimension {d} not supported")
-
+
print(f"{d}: {len(coord_values)}")
-
+
coords[d] = coord_values
if d != "analysis_time":
- # analysis_time varies across the different files, but not within a single file
+ # analysis_time varies across the different files, but not
+ # within a single file
arr_shape.append(len(coord_values))
-
+
print(f"{features}: {dims=} {file_dims=} {arr_shape=}")
-
+
if features_vary_with_analysis_time:
filepaths = [
- fp_samples / filename_format.format(analysis_time=analysis_time, **file_params)
+ fp_samples
+ / filename_format.format(
+ analysis_time=analysis_time, **file_params
+ )
for analysis_time in coords["analysis_time"]
]
else:
filepaths = [fp_samples / filename_format.format(**file_params)]
-
+
# use dask.delayed to load the numpy files, so that loading isn't
# done until the data is actually needed
@dask.delayed
@@ -331,15 +374,16 @@ def _load_np(fp):
arrays = [
dask.array.from_delayed(
- _load_np(fp), shape=arr_shape, dtype=np.float32
- ) for fp in filepaths
+ _load_np(fp), shape=arr_shape, dtype=np.float32
+ )
+ for fp in filepaths
]
-
+
if features_vary_with_analysis_time:
arr_all = dask.array.stack(arrays, axis=0)
else:
arr_all = arrays[0]
-
+
# if features == ["column_water"]:
# # for column water, we need to repeat the array for each forecast time
# # first insert a new axis for the forecast time
@@ -347,15 +391,14 @@ def _load_np(fp):
# # and then repeat
# arr_all = dask.array.repeat(arr_all, self._num_timesteps, axis=1)
da = xr.DataArray(arr_all, dims=dims, coords=coords)
-
+
# stack the [x, y] dimensions into a `grid_index` dimension
da = self.stack_grid_coords(da)
-
+
return da
-
+
def _get_analysis_times(self, split):
- """
- Get the analysis times for the given split by parsing the filenames
+ """Get the analysis times for the given split by parsing the filenames
of all the files found for the given split.
Parameters
@@ -368,18 +411,18 @@ def _get_analysis_times(self, split):
List[dt.datetime]
The analysis times for the given split.
"""
- pattern = re.sub(r'{analysis_time:[^}]*}', '*', STATE_FILENAME_FORMAT)
- pattern = re.sub(r'{member_id:[^}]*}', '*', pattern)
-
+ pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT)
+ pattern = re.sub(r"{member_id:[^}]*}", "*", pattern)
+
sample_dir = self.root_path / "samples" / split
sample_files = sample_dir.glob(pattern)
times = []
for fp in sample_files:
name_parts = parse.parse(STATE_FILENAME_FORMAT, fp.name)
times.append(name_parts["analysis_time"])
-
+
return times
-
+
def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
da_hour_angle = da_time.dt.hour / 12 * np.pi
da_year_angle = da_time.dt.dayofyear / 365 * 2 * np.pi
@@ -394,8 +437,13 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
dim="feature",
)
da_datetime_forcing = (da_datetime_forcing + 1) / 2 # Rescale to [0,1]
- da_datetime_forcing["feature"] = ["sin_hour", "cos_hour", "sin_year", "cos_year"]
-
+ da_datetime_forcing["feature"] = [
+ "sin_hour",
+ "cos_hour",
+ "sin_year",
+ "cos_year",
+ ]
+
return da_datetime_forcing
def get_vars_units(self, category: str) -> torch.List[str]:
@@ -403,65 +451,103 @@ def get_vars_units(self, category: str) -> torch.List[str]:
return self.config["dataset"]["var_units"]
else:
raise NotImplementedError(f"Category {category} not supported")
-
+
def get_vars_names(self, category: str) -> torch.List[str]:
if category == "state":
return self.config["dataset"]["var_names"]
elif category == "forcing":
- # XXX: this really shouldn't be hard-coded here, this should be in the config
- return ["toa_downwelling_shortwave_flux", "column_water", "sin_hour", "cos_hour", "sin_year", "cos_year"]
+ # XXX: this really shouldn't be hard-coded here, this should be in
+ # the config
+ return [
+ "toa_downwelling_shortwave_flux",
+ "column_water",
+ "sin_hour",
+ "cos_hour",
+ "sin_year",
+ "cos_year",
+ ]
elif category == "static":
return ["surface_geopotential", "border_mask", "x", "y"]
else:
raise NotImplementedError(f"Category {category} not supported")
-
+
@property
def get_num_data_vars(self) -> int:
return len(self.get_vars_names(category="state"))
-
+
def get_xy(self, category: str, stacked: bool) -> np.ndarray:
- arr = np.load(self.root_path / "static" / "nwp_xy.npy")
-
+ """Return the x, y coordinates of the dataset.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates.
+
+ Returns
+ -------
+ np.ndarray
+ The x, y coordinates of the dataset, returned differently based on the
+ value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(2, N_y, N_x)`
+ """
+
+ # the array on disk has shape [2, N_x, N_y], but we want to return it
+ # as [2, N_y, N_x] so we swap the axes
+ arr = np.load(self.root_path / "static" / "nwp_xy.npy").swapaxes(1, 2)
+
assert arr.shape[0] == 2, "Expected 2D array"
- assert arr.shape[1:] == tuple(self.grid_shape_state), "Unexpected shape"
-
+ grid_shape = self.grid_shape_state
+ assert arr.shape[1:] == (grid_shape.y, grid_shape.x), "Unexpected shape"
+
if stacked:
- return arr
+ return arr.reshape(2, -1)
else:
- return arr[0], arr[1]
-
+ return arr
+
@property
def step_length(self):
return self._step_length
-
+
@property
def coords_projection(self):
return self.config.coords_projection
-
+
@property
def grid_shape_state(self):
- return self.config.grid_shape_state
-
+ """The shape of the cartesian grid for the state variables.
+
+ Returns
+ -------
+ CartesianGridShape
+ The shape of the cartesian grid for the state variables.
+ """
+ nx, ny = self.config.grid_shape_state
+ return CartesianGridShape(x=nx, y=ny)
+
@property
def boundary_mask(self):
xs, ys = self.get_xy(category="state", stacked=False)
- assert np.all(xs[0,:] == xs[-1,:])
- assert np.all(ys[:,0] == ys[:,-1])
- x = xs[0,:]
- y = ys[:,0]
+ assert np.all(xs[0, :] == xs[-1, :])
+ assert np.all(ys[:, 0] == ys[:, -1])
+ x = xs[0, :]
+ y = ys[:, 0]
values = np.load(self.root_path / "static" / "border_mask.npy")
- da_mask = xr.DataArray(values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask")
+ da_mask = xr.DataArray(
+ values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
+ )
da_mask_stacked_xy = self.stack_grid_coords(da_mask)
return da_mask_stacked_xy
-
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """
- Return the normalization dataarray for the given category. This should contain
- a `{category}_mean` and `{category}_std` variable for each variable in the category.
- For `category=="state"`, the dataarray should also contain a `state_diff_mean` and
- `state_diff_std` variable for the one-step differences of the state variables.
-
+ """Return the normalization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one-step differences of the state variables.
+
Parameters
----------
category : str
@@ -470,12 +556,14 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
Returns
-------
xr.Dataset
- The normalization dataarray for the given category, with variables for the mean
- and standard deviation of the variables (and differences for state variables).
+ The normalization dataarray for the given category, with variables
+ for the mean and standard deviation of the variables (and
+ differences for state variables).
"""
+
def load_pickled_tensor(fn):
return torch.load(self.root_path / "static" / fn).numpy()
-
+
mean_diff_values = None
std_diff_values = None
if category == "state":
@@ -493,20 +581,20 @@ def load_pickled_tensor(fn):
else:
raise NotImplementedError(f"Category {category} not supported")
-
+
feature_dim_name = f"{category}_feature"
variables = {
- f"{category}_mean": (feature_dim_name, mean_values),
- f"{category}_std": (feature_dim_name, std_values),
+ f"{category}_mean": (feature_dim_name, mean_values),
+ f"{category}_std": (feature_dim_name, std_values),
}
-
+
if mean_diff_values is not None and std_diff_values is not None:
variables["state_diff_mean"] = (feature_dim_name, mean_diff_values)
variables["state_diff_std"] = (feature_dim_name, std_diff_values)
-
+
ds_norm = xr.Dataset(
variables,
- coords={ feature_dim_name: self.get_vars_names(category=category) }
+ coords={feature_dim_name: self.get_vars_names(category=category)},
)
-
- return ds_norm
\ No newline at end of file
+
+ return ds_norm
diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py
index 2f45b03f..4ed3e3eb 100644
--- a/neural_lam/interaction_net.py
+++ b/neural_lam/interaction_net.py
@@ -8,9 +8,9 @@
class InteractionNet(pyg.nn.MessagePassing):
- """
- Implementation of a generic Interaction Network,
- from Battaglia et al. (2016)
+ """Implementation of a generic Interaction Network, from Battaglia et al.
+
+ (2016)
"""
# pylint: disable=arguments-differ
@@ -27,8 +27,7 @@ def __init__(
aggr_chunk_sizes=None,
aggr="sum",
):
- """
- Create a new InteractionNet
+ """Create a new InteractionNet.
edge_index: (2,M), Edges in pyg format
input_dim: Dimensionality of input representations,
@@ -84,8 +83,7 @@ def __init__(
self.update_edges = update_edges
def forward(self, send_rep, rec_rep, edge_rep):
- """
- Apply interaction network to update the representations of receiver
+ """Apply interaction network to update the representations of receiver
nodes, and optionally the edge representations.
send_rep: (N_send, d_h), vector representations of sender nodes
@@ -115,9 +113,7 @@ def forward(self, send_rep, rec_rep, edge_rep):
return rec_rep
def message(self, x_j, x_i, edge_attr):
- """
- Compute messages from node j to node i.
- """
+ """Compute messages from node j to node i."""
return self.edge_mlp(torch.cat((edge_attr, x_j, x_i), dim=-1))
# pylint: disable-next=signature-differs
@@ -132,10 +128,10 @@ def aggregate(self, inputs, index, ptr, dim_size):
class SplitMLPs(nn.Module):
- """
- Module that feeds chunks of input through different MLPs.
- Split up input along dim -2 using given chunk sizes and feeds
- each chunk through separate MLPs.
+ """Module that feeds chunks of input through different MLPs.
+
+ Split up input along dim -2 using given chunk sizes and feeds each
+ chunk through separate MLPs.
"""
def __init__(self, mlps, chunk_sizes):
@@ -148,8 +144,7 @@ def __init__(self, mlps, chunk_sizes):
self.chunk_sizes = chunk_sizes
def forward(self, x):
- """
- Chunk up input and feed through MLPs
+ """Chunk up input and feed through MLPs.
x: (..., N, d), where N = sum(chunk_sizes)
diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py
index 7db2cca6..1ed4fb08 100644
--- a/neural_lam/metrics.py
+++ b/neural_lam/metrics.py
@@ -3,8 +3,7 @@
def get_metric(metric_name):
- """
- Get a defined metric with given name
+ """Get a defined metric with given name.
metric_name: str, name of the metric
@@ -19,8 +18,7 @@ def get_metric(metric_name):
def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
- """
- Masks and (optionally) reduces entry-wise metric values
+ """Masks and (optionally) reduces entry-wise metric values.
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -54,8 +52,7 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """
- Weighted Mean Squared Error
+ """Weighted Mean Squared Error.
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -85,8 +82,7 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """
- (Unweighted) Mean Squared Error
+ """(Unweighted) Mean Squared Error.
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -109,8 +105,7 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """
- Weighted Mean Absolute Error
+ """Weighted Mean Absolute Error.
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -140,8 +135,7 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """
- (Unweighted) Mean Absolute Error
+ """(Unweighted) Mean Absolute Error.
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -164,8 +158,7 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """
- Negative Log Likelihood loss, for isotropic Gaussian likelihood
+ """Negative Log Likelihood loss, for isotropic Gaussian likelihood.
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -193,9 +186,8 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def crps_gauss(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True
):
- """
- (Negative) Continuous Ranked Probability Score (CRPS)
- Closed-form expression based on Gaussian predictive distribution
+ """(Negative) Continuous Ranked Probability Score (CRPS) Closed-form
+ expression based on Gaussian predictive distribution.
(...,) is any number of batch dimensions, potentially different
but broadcastable
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index a2ebcc1b..14827f25 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -8,9 +8,7 @@
class BaseHiGraphModel(BaseGraphModel):
- """
- Base class for hierarchical graph models.
- """
+ """Base class for hierarchical graph models."""
def __init__(self, args):
super().__init__(args)
@@ -98,10 +96,8 @@ def __init__(self, args):
)
def get_num_mesh(self):
- """
- Compute number of mesh nodes from loaded features,
- and number of mesh nodes that should be ignored in encoding/decoding
- """
+ """Compute number of mesh nodes from loaded features, and number of
+ mesh nodes that should be ignored in encoding/decoding."""
num_mesh_nodes = sum(
node_feat.shape[0] for node_feat in self.mesh_static_features
)
@@ -111,18 +107,14 @@ def get_num_mesh(self):
return num_mesh_nodes, num_mesh_nodes_ignore
def embedd_mesh_nodes(self):
- """
- Embed static mesh features
- This embeds only bottom level, rest is done at beginning of
- processing step
- Returns tensor of shape (num_mesh_nodes[0], d_h)
- """
+ """Embed static mesh features This embeds only bottom level, rest is
+ done at beginning of processing step Returns tensor of shape
+ (num_mesh_nodes[0], d_h)"""
return self.mesh_embedders[0](self.mesh_static_features[0])
def process_step(self, mesh_rep):
- """
- Process step of embedd-process-decode framework
- Processes the representation on the mesh, possible in multiple steps
+ """Process step of embedd-process-decode framework Processes the
+ representation on the mesh, possible in multiple steps.
mesh_rep: has shape (B, num_mesh_nodes, d_h)
Returns mesh_rep: (B, num_mesh_nodes, d_h)
@@ -217,9 +209,8 @@ def process_step(self, mesh_rep):
def hi_processor_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
):
- """
- Internal processor step of hierarchical graph models.
- Between mesh init and read out.
+ """Internal processor step of hierarchical graph models. Between mesh
+ init and read out.
Each input is list with representations, each with shape
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index ee23bed6..af4f001c 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -12,10 +12,10 @@
# Local
from . import utils
-from .datasets import WeatherDataModule
from .models.graph_lam import GraphLAM
from .models.hi_lam import HiLAM
from .models.hi_lam_parallel import HiLAMParallel
+from .weather_dataset import WeatherDataModule
MODELS = {
"graph_lam": GraphLAM,
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 682aa2e3..ba146355 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -9,8 +9,7 @@
class BufferList(nn.Module):
- """
- A list of torch buffer tensors that sit together as a Module with no
+ """A list of torch buffer tensors that sit together as a Module with no
parameters and only buffers.
This should be replaced by a native torch BufferList once implemented.
@@ -34,9 +33,7 @@ def __iter__(self):
def load_graph(graph_name, device="cpu"):
- """
- Load all tensors representing the graph
- """
+ """Load all tensors representing the graph."""
# Define helper lambda function
graph_dir_path = os.path.join("graphs", graph_name)
@@ -173,10 +170,8 @@ def make_mlp(blueprint, layer_norm=True):
def fractional_plot_bundle(fraction):
- """
- Get the tueplots bundle, but with figure width as a fraction of
- the page width.
- """
+ """Get the tueplots bundle, but with figure width as a fraction of the page
+ width."""
# If latex is not available, some visualizations might not render
# correctly, but will at least not raise an error. Alternatively, use
# unicode raised numbers.
@@ -192,9 +187,7 @@ def fractional_plot_bundle(fraction):
def init_wandb_metrics(wandb_logger, val_steps):
- """
- Set up wandb metrics to track
- """
+ """Set up wandb metrics to track."""
experiment = wandb_logger.experiment
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
diff --git a/plot_graph.py b/plot_graph.py
index db4dc536..b7b710bf 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -16,9 +16,7 @@
def main():
- """
- Plot graph structure in 3D using plotly
- """
+ """Plot graph structure in 3D using plotly."""
parser = ArgumentParser(description="Plot graph")
parser.add_argument(
"--data_config",
diff --git a/tests/datastore_configs/npy/.gitignore b/tests/datastore_configs/npy/.gitignore
new file mode 100644
index 00000000..718ecfd8
--- /dev/null
+++ b/tests/datastore_configs/npy/.gitignore
@@ -0,0 +1,2 @@
+samples/
+static/
diff --git a/tests/datastore_configs/npy/data_config.yaml b/tests/datastore_configs/npy/data_config.yaml
new file mode 100644
index 00000000..12386bc8
--- /dev/null
+++ b/tests/datastore_configs/npy/data_config.yaml
@@ -0,0 +1,40 @@
+dataset:
+ name: meps_example_reduced
+ var_names:
+ - pres_0g
+ - pres_0s
+ - nlwrs_0
+ - nswrs_0
+ - r_2
+ - r_65
+ - t_2
+ - t_65
+ var_units:
+ - Pa
+ - Pa
+ - "W/m**2"
+ - "W/m**2"
+ - ""
+ - ""
+ - K
+ - K
+ var_longnames:
+ - pres_heightAboveGround_0_instant
+ - pres_heightAboveSea_0_instant
+ - nlwrs_heightAboveGround_0_accum
+ - nswrs_heightAboveGround_0_accum
+ - r_heightAboveGround_2_instant
+ - r_hybrid_65_instant
+ - t_heightAboveGround_2_instant
+ - t_hybrid_65_instant
+ # increased num_forcing_features from 16 to 18 so that it reflects
+ # ["toa_downwelling_shortwave_flux", "column_water", "sin_hour", "cos_hour", "sin_year", "cos_year"] x forcing_window_size
+ # i.e. 6 x 3 = 18 forcing features
+ num_forcing_features: 18
+grid_shape_state: [134, 119]
+projection:
+ class: LambertConformal
+ kwargs:
+ central_longitude: 15.0
+ central_latitude: 63.3
+ standard_parallels: [63.3, 63.3]
diff --git a/tests/test_cli.py b/tests/test_cli.py
index e90daa04..cae8c9b9 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,18 +1,14 @@
# First-party
import neural_lam
-import neural_lam.create_grid_features
-import neural_lam.create_mesh
-import neural_lam.create_parameter_weights
+import neural_lam.create_graph
+import neural_lam.datastore.multizarr.create_grid_features
import neural_lam.train_model
def test_import():
- """
- This test just ensures that each cli entry-point can be imported for now,
- eventually we should test their execution too
- """
+ """This test just ensures that each cli entry-point can be imported for
+ now, eventually we should test their execution too."""
assert neural_lam is not None
- assert neural_lam.create_mesh is not None
- assert neural_lam.create_grid_features is not None
- assert neural_lam.create_parameter_weights is not None
+ assert neural_lam.create_graph is not None
+ assert neural_lam.datastore.multizarr.create_grid_features is not None
assert neural_lam.train_model is not None
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
new file mode 100644
index 00000000..e12b4caa
--- /dev/null
+++ b/tests/test_datastores.py
@@ -0,0 +1,42 @@
+# Third-party
+import pytest
+
+# First-party
+from neural_lam.datastore.mllam import MLLAMDatastore
+from neural_lam.datastore.multizarr import MultiZarrDatastore
+from neural_lam.datastore.npyfiles import NumpyFilesDatastore
+
+DATASTORES = dict(
+ multizarr=MultiZarrDatastore,
+ mllam=MLLAMDatastore,
+ npyfiles=NumpyFilesDatastore,
+)
+
+
+EXAMPLES = dict(
+ multizarr=dict(
+ config_path="tests/datastore_configs/multizarr/data_config.yaml"
+ ),
+ mllam=dict(config_path="tests/datastore_configs/mllam/example.danra.yaml"),
+ npyfiles=dict(root_path="tests/datastore_configs/npy"),
+)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_datastore(datastore_name):
+ DatastoreClass = DATASTORES[datastore_name]
+ datastore = DatastoreClass(**EXAMPLES[datastore_name])
+
+ # check the shapes of the xy grid
+ grid_shape = datastore.grid_shape_state
+ nx, ny = grid_shape.x, grid_shape.y
+ for stacked in [True, False]:
+ xy = datastore.get_xy("static", stacked=stacked)
+ """
+ - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(2, N_y, N_x)`
+ """
+ if stacked:
+ assert xy.shape == (2, nx * ny)
+ else:
+ assert xy.shape == (2, ny, nx)
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
index 9aa335e7..b76faa31 100644
--- a/tests/test_mllam_dataset.py
+++ b/tests/test_mllam_dataset.py
@@ -2,6 +2,7 @@
import torch
# First-party
+from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.datastore import MLLAMDatastore
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataModule, WeatherDataset
@@ -36,6 +37,11 @@ def test_mllam():
args = ModelArgs()
+ create_graph_from_datastore(
+ datastore=datastore,
+ graph_dir_path="tests/datastore_configs/mllam/graph",
+ )
+
model = GraphLAM( # noqa
args=args,
forcing_window_size=dataset.forcing_window_size,
diff --git a/tests/test_multizarr_dataset.py b/tests/test_multizarr_dataset.py
index 05d2e969..7b51d9df 100644
--- a/tests/test_multizarr_dataset.py
+++ b/tests/test_multizarr_dataset.py
@@ -1,9 +1,12 @@
# Standard library
import os
+from pathlib import Path
# First-party
-from create_mesh import main as create_mesh
+from neural_lam.create_graph import create_graph as create_graph
+from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.datastore.multizarr import MultiZarrDatastore
+
# from neural_lam.datasets.config import Config
from neural_lam.weather_dataset import WeatherDataset
@@ -11,11 +14,13 @@
# and to avoid having to deal with authentication
os.environ["WANDB_DISABLED"] = "true"
+DATASTORE_PATH = Path("tests/datastore_configs/multizarr")
+
def test_load_analysis_dataset():
# TODO: Access rights should be fixed for pooch to work
datastore = MultiZarrDatastore(
- config_path="tests/datastore_configs/multizarr/data_config.yaml"
+ config_path=DATASTORE_PATH / "data_config.yaml"
)
var_state_names = datastore.get_vars_names(category="state")
@@ -29,17 +34,15 @@ def test_load_analysis_dataset():
num_forcing_vars = datastore.get_num_data_vars(category="forcing")
assert len(var_forcing_names) == len(var_forcing_units) == num_forcing_vars
-
- stats = datastore.get_normalization_stats(category="state")
-
- import ipdb
- ipdb.set_trace()
+ stats = datastore.get_normalization_stats(category="state") # noqa
# Assert dataset can be loaded
ds = datastore.get_dataarray(category="state")
grid = ds.sizes["y"] * ds.sizes["x"]
- dataset = WeatherDataset(datastore=datastore, split="train", ar_steps=3, standardize=True)
+ dataset = WeatherDataset(
+ datastore=datastore, split="train", ar_steps=3, standardize=True
+ )
batch = dataset[0]
# return init_states, target_states, forcing, batch_times
# init_states: (2, N_grid, d_features)
@@ -48,29 +51,37 @@ def test_load_analysis_dataset():
# batch_times: (ar_steps-2,)
assert list(batch[0].shape) == [2, grid, num_state_vars]
assert list(batch[1].shape) == [dataset.ar_steps - 2, grid, num_state_vars]
- assert list(batch[2].shape) == [
- dataset.ar_steps - 2,
- grid,
- num_forcing_vars * config.forcing.window,
- ]
+ # assert list(batch[2].shape) == [
+ # dataset.ar_steps - 2,
+ # grid,
+ # num_forcing_vars * config.forcing.window,
+ # ]
assert isinstance(batch[3], list)
# Assert provided grid-shapes
- assert config.get_xy("static")[0].shape == (
- config.grid_shape_state.y,
- config.grid_shape_state.x,
- )
- assert config.get_xy("static")[0].shape == (ds.sizes["y"], ds.sizes["x"])
+ # assert config.get_xy("static")[0].shape == (
+ # config.grid_shape_state.y,
+ # config.grid_shape_state.x,
+ # )
+ # assert config.get_xy("static")[0].shape == (ds.sizes["y"], ds.sizes["x"])
def test_create_graph_analysis_dataset():
+ datastore = MultiZarrDatastore(
+ config_path=DATASTORE_PATH / "data_config.yaml"
+ )
+ create_graph_from_datastore(
+ datastore=datastore, graph_dir_path=DATASTORE_PATH / "graph"
+ )
+
+ # test cli
args = [
"--graph=hierarchical",
"--hierarchical=1",
"--data_config=tests/data_config.yaml",
"--levels=2",
]
- create_mesh(args)
+ create_graph(args)
# def test_train_model_analysis_dataset():
diff --git a/tests/test_npy_forecast_dataset.py b/tests/test_npy_forecast_dataset.py
index 67c128ed..ed13e286 100644
--- a/tests/test_npy_forecast_dataset.py
+++ b/tests/test_npy_forecast_dataset.py
@@ -6,11 +6,10 @@
import pytest
# First-party
-from create_mesh import main as create_mesh
-from neural_lam.weather_dataset import WeatherDataset
+from neural_lam.create_graph import create_graph as create_graph
from neural_lam.datastore.npyfiles import NumpyFilesDatastore
-from neural_lam.datastore.multizarr import MultiZarrDatastore
-from train_model import main as train_model
+from neural_lam.train_model import main as train_model
+from neural_lam.weather_dataset import WeatherDataset
# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
@@ -36,18 +35,20 @@ def ewc_testdata_path():
path="data",
fname="meps_example_reduced.zip",
)
-
+
return "data/meps_example_reduced"
def test_load_reduced_meps_dataset(ewc_testdata_path):
- datastore = NumpyFilesDatastore(
- root_path=ewc_testdata_path
- )
+ datastore = NumpyFilesDatastore(root_path=ewc_testdata_path)
datastore.get_xy(category="state", stacked=True)
- datastore.get_dataarray(category="forcing", split="train").unstack("grid_index")
- datastore.get_dataarray(category="state", split="train").unstack("grid_index")
+ datastore.get_dataarray(category="forcing", split="train").unstack(
+ "grid_index"
+ )
+ datastore.get_dataarray(category="state", split="train").unstack(
+ "grid_index"
+ )
dataset = WeatherDataset(datastore=datastore)
@@ -64,7 +65,9 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
# Hardcoded in model
n_input_steps = 2
- n_forcing_features = datastore.config.values["dataset"]["num_forcing_features"]
+ n_forcing_features = datastore.config.values["dataset"][
+ "num_forcing_features"
+ ]
n_state_features = len(var_names)
n_prediction_timesteps = dataset.ar_steps
@@ -79,7 +82,7 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
init_states = item.init_states
target_states = item.target_states
forcing = item.forcing
-
+
# check that the shapes of the tensors are correct
assert init_states.shape == (n_input_steps, n_grid, n_state_features)
assert target_states.shape == (
@@ -92,12 +95,14 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
n_grid,
n_forcing_features,
)
-
+
ds_state_norm = datastore.get_normalization_dataarray(category="state")
-
+
static_data = {
"border_mask": datastore.boundary_mask.values,
- "grid_static_features": datastore.get_dataarray(category="static", split="train").values,
+ "grid_static_features": datastore.get_dataarray(
+ category="static", split="train"
+ ).values,
"data_mean": ds_state_norm.state_mean.values,
"data_std": ds_state_norm.state_std.values,
"step_diff_mean": ds_state_norm.state_diff_mean.values,
@@ -115,7 +120,7 @@ def test_load_reduced_meps_dataset(ewc_testdata_path):
}
# check the sizes of the props
- assert static_data["border_mask"].shape == (n_grid, )
+ assert static_data["border_mask"].shape == (n_grid,)
assert static_data["grid_static_features"].shape == (
n_grid,
n_grid_static_features,
@@ -136,7 +141,7 @@ def test_create_graph_reduced_meps_dataset():
"--data_config=data/meps_example_reduced/data_config.yaml",
"--levels=2",
]
- create_mesh(args)
+ create_graph(args)
def test_train_model_reduced_meps_dataset():
From 1f54b0e21c026827040449048ac6b60ea7b448ab Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 17 Jul 2024 23:52:16 +0200
Subject: [PATCH 127/273] get_vars_names and units
---
neural_lam/datastore/mllam.py | 4 +-
neural_lam/datastore/npyfiles/store.py | 16 ++++-
pyproject.toml | 1 +
.../mllam/example.danra.yaml | 31 ++++----
tests/test_datastores.py | 71 ++++++++++++++++++-
5 files changed, 100 insertions(+), 23 deletions(-)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 38bd8106..531709fe 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -51,10 +51,10 @@ def step_length(self) -> int:
return da_dt.dt.seconds[0] // 3600
def get_vars_units(self, category: str) -> List[str]:
- return self._ds[f"{category}_unit"].values.tolist()
+ return self._ds[f"{category}_feature_units"].values.tolist()
def get_vars_names(self, category: str) -> List[str]:
- return self._ds[f"{category}_longname"].values.tolist()
+ return self._ds[f"{category}_feature"].values.tolist()
def get_num_data_vars(self, category: str) -> int:
return self._ds[f"{category}_feature"].count().item()
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index beb860c1..8ca4dd49 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -449,6 +449,17 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
def get_vars_units(self, category: str) -> torch.List[str]:
if category == "state":
return self.config["dataset"]["var_units"]
+ elif category == "forcing":
+ return [
+ "W/m^2",
+ "kg/m^2",
+ "1",
+ "1",
+ "1",
+ "1",
+ ]
+ elif category == "static":
+ return ["m^2/s^2", "1", "m", "m"]
else:
raise NotImplementedError(f"Category {category} not supported")
@@ -471,9 +482,8 @@ def get_vars_names(self, category: str) -> torch.List[str]:
else:
raise NotImplementedError(f"Category {category} not supported")
- @property
- def get_num_data_vars(self) -> int:
- return len(self.get_vars_names(category="state"))
+ def get_num_data_vars(self, category: str) -> int:
+ return len(self.get_vars_names(category=category))
def get_xy(self, category: str, stacked: bool) -> np.ndarray:
"""Return the x, y coordinates of the dataset.
diff --git a/pyproject.toml b/pyproject.toml
index f86cf653..6d2ddf71 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,6 +24,7 @@ dependencies = [
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
+ "mllam-data-prep @ git+https://github.com/mllam/mllam-data-prep",
]
requires-python = ">=3.9"
diff --git a/tests/datastore_configs/mllam/example.danra.yaml b/tests/datastore_configs/mllam/example.danra.yaml
index 3be8debb..c04b069b 100644
--- a/tests/datastore_configs/mllam/example.danra.yaml
+++ b/tests/datastore_configs/mllam/example.danra.yaml
@@ -13,20 +13,21 @@ output:
step: PT3H
chunking:
time: 1
- splitting_dim: time
- splits:
- train:
- start: 1990-09-03T00:00
- end: 1990-09-06T00:00
- compute_statistics:
- ops: [mean, std]
- dims: [grid_index, time]
- validation:
- start: 1990-09-06T00:00
- end: 1990-09-07T00:00
- test:
- start: 1990-09-07T00:00
- end: 1990-09-09T00:00
+ splitting:
+ dim: time
+ splits:
+ train:
+ start: 1990-09-03T00:00
+ end: 1990-09-06T00:00
+ compute_statistics:
+ ops: [mean, std]
+ dims: [grid_index, time]
+ validation:
+ start: 1990-09-06T00:00
+ end: 1990-09-07T00:00
+ test:
+ start: 1990-09-07T00:00
+ end: 1990-09-09T00:00
inputs:
danra_height_levels:
@@ -59,7 +60,7 @@ inputs:
dims: [time, x, y]
variables:
# shouldn't really be using sea-surface pressure as "forcing", but don't
- # have radiation varibles in danra yet
+ # have radiation variables in danra yet
- pres_seasurface
dim_mapping:
time:
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index e12b4caa..faafe7c8 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -1,4 +1,30 @@
+"""List of methods and attributes that should be implemented in a subclass of
+`BaseCartesianDatastore` (these are all decorated with `@abc.abstractmethod`):
+
+- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
+- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
+- [x] `coords_projection` (property): Projection object for the coordinates.
+- [ ] `get_vars_units` (method): Get the units of the variables in the given category.
+- [ ] `get_vars_names` (method): Get the names of the variables in the given category.
+- [ ] `get_num_data_vars` (method): Get the number of data variables in the
+ given category.
+- [ ] `get_normalization_dataarray` (method): Return the normalization
+ dataarray for the given category.
+- [ ] `get_dataarray` (method): Return the processed data (as a single
+ `xr.DataArray`) for the given category and test/train/val-split.
+- [ ] `boundary_mask` (property): Return the boundary mask for the dataset,
+ with spatial dimensions stacked.
+
+In addition BaseCartesianDatastore must have the following methods and attributes:
+- [ ] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
+ given category of data.
+- [ ] `get_xy` (method): Return the x, y coordinates of the dataset.
+- [ ] `coords_projection` (property): Projection object for the coordinates.
+- [ ] `grid_shape_state` (property): Shape of the grid for the state variables.
+"""
+
# Third-party
+import cartopy.crs as ccrs
import pytest
# First-party
@@ -22,10 +48,17 @@
)
-@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
-def test_datastore(datastore_name):
+def _init_datastore(datastore_name):
DatastoreClass = DATASTORES[datastore_name]
- datastore = DatastoreClass(**EXAMPLES[datastore_name])
+ return DatastoreClass(**EXAMPLES[datastore_name])
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_datastore_grid_xy(datastore_name):
+ """Use the `datastore.get_xy` method to get the x, y coordinates of the
+ dataset and check that the shape is correct against the
+ `datastore.grid_shape_state` property."""
+ datastore = _init_datastore(datastore_name)
# check the shapes of the xy grid
grid_shape = datastore.grid_shape_state
@@ -40,3 +73,35 @@ def test_datastore(datastore_name):
assert xy.shape == (2, nx * ny)
else:
assert xy.shape == (2, ny, nx)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_projection(datastore_name):
+ """Check that the `datastore.coords_projection` property is implemented."""
+ datastore = _init_datastore(datastore_name)
+
+ assert isinstance(datastore.coords_projection, ccrs.Projection)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_vars(datastore_name):
+ """Check that results of.
+
+ - `datastore.get_vars_units`
+ - `datastore.get_vars_names`
+ - `datastore.get_num_data_vars`
+
+ are consistent (as in the number of variables are the same) and that the
+ return types of each are correct.
+ """
+ datastore = _init_datastore(datastore_name)
+
+ for category in ["state", "forcing", "static"]:
+ units = datastore.get_vars_units(category)
+ names = datastore.get_vars_names(category)
+ num_vars = datastore.get_num_data_vars(category)
+
+ assert len(units) == len(names) == num_vars
+ assert isinstance(units, list)
+ assert isinstance(names, list)
+ assert isinstance(num_vars, int)
From 9b8816096532e1501217afbcbb9704df17c34f4b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 17 Jul 2024 23:53:15 +0200
Subject: [PATCH 128/273] get_vars_names and units 2
---
tests/test_datastores.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index faafe7c8..2685721d 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -4,9 +4,9 @@
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
-- [ ] `get_vars_units` (method): Get the units of the variables in the given category.
-- [ ] `get_vars_names` (method): Get the names of the variables in the given category.
-- [ ] `get_num_data_vars` (method): Get the number of data variables in the
+- [x] `get_vars_units` (method): Get the units of the variables in the given category.
+- [x] `get_vars_names` (method): Get the names of the variables in the given category.
+- [x] `get_num_data_vars` (method): Get the number of data variables in the
given category.
- [ ] `get_normalization_dataarray` (method): Return the normalization
dataarray for the given category.
From a9fdad544ec2cabb2c65cc183dfa6cfb032df7b8 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 23 Jul 2024 17:48:43 +0200
Subject: [PATCH 129/273] test for stats
---
neural_lam/datastore/multizarr/config.py | 6 +-
.../multizarr/create_auxiliary_forcings.py | 9 ++-
.../multizarr/create_grid_features.py | 61 -------------------
.../multizarr/create_normalization_stats.py | 21 ++++---
neural_lam/datastore/multizarr/store.py | 18 +++++-
tests/datastore_configs/mllam/.gitignore | 2 +
.../mllam/example.danra.yaml | 2 +-
.../multizarr/data_config.yaml | 6 +-
tests/test_cli.py | 2 -
tests/test_datastores.py | 28 +++++++++
10 files changed, 69 insertions(+), 86 deletions(-)
delete mode 100644 neural_lam/datastore/multizarr/create_grid_features.py
create mode 100644 tests/datastore_configs/mllam/.gitignore
diff --git a/neural_lam/datastore/multizarr/config.py b/neural_lam/datastore/multizarr/config.py
index 3cbd9787..1f0a1def 100644
--- a/neural_lam/datastore/multizarr/config.py
+++ b/neural_lam/datastore/multizarr/config.py
@@ -25,10 +25,10 @@ def __getattr__(self, name):
keys = name.split(".")
value = self.values
for key in keys:
- if key in value:
+ try:
value = value[key]
- else:
- return None
+ except KeyError:
+ raise AttributeError(f"Key '{key}' not found in {value}")
if isinstance(value, dict):
return Config(values=value)
return value
diff --git a/neural_lam/datastore/multizarr/create_auxiliary_forcings.py b/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
index eab6cd7b..c4839be3 100644
--- a/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
+++ b/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
@@ -70,11 +70,14 @@ def calculate_datetime_forcing(da_time: xr.DataArray):
def main():
"""Main function for creating the datetime forcing and boundary mask."""
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(
+ description="Create the datetime forcing for neural LAM.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
parser.add_argument(
- "--data-config",
+ "data_config",
type=str,
- default="tests/datastore_configs/multizarr.danra.yaml",
+ help="Path to data config file",
)
parser.add_argument(
"--zarr_path",
diff --git a/neural_lam/datastore/multizarr/create_grid_features.py b/neural_lam/datastore/multizarr/create_grid_features.py
deleted file mode 100644
index 69fea730..00000000
--- a/neural_lam/datastore/multizarr/create_grid_features.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Standard library
-import os
-from argparse import ArgumentParser
-
-# Third-party
-import numpy as np
-import torch
-
-# Local
-from . import config
-
-
-def main():
- """Pre-compute all static features related to the grid nodes."""
- parser = ArgumentParser(description="Training arguments")
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
- args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
-
- static_dir_path = os.path.join("data", config_loader.dataset.name, "static")
-
- # -- Static grid node features --
- grid_xy = torch.tensor(
- np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
- ) # (2, N_y, N_x)
- grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
- pos_max = torch.max(torch.abs(grid_xy))
- grid_xy = grid_xy / pos_max # Divide by maximum coordinate
-
- geopotential = torch.tensor(
- np.load(os.path.join(static_dir_path, "surface_geopotential.npy"))
- ) # (N_y, N_x)
- geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1)
- gp_min = torch.min(geopotential)
- gp_max = torch.max(geopotential)
- # Rescale geopotential to [0,1]
- geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1)
-
- grid_border_mask = torch.tensor(
- np.load(os.path.join(static_dir_path, "border_mask.npy")),
- dtype=torch.int64,
- ) # (N_y, N_x)
- grid_border_mask = (
- grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
- ) # (N_grid, 1)
-
- # Concatenate grid features
- grid_features = torch.cat(
- (grid_xy, geopotential, grid_border_mask), dim=1
- ) # (N_grid, 4)
-
- torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt"))
-
-
-if __name__ == "__main__":
- main()
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index abccf333..2298e191 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -1,5 +1,5 @@
# Standard library
-from argparse import ArgumentParser
+import argparse
# Third-party
import xarray as xr
@@ -7,8 +7,6 @@
# First-party
from neural_lam.datastore.multizarr import MultiZarrDatastore
-DEFAULT_PATH = "tests/datastore_configs/multizarr.danra.yaml"
-
def compute_stats(da):
mean = da.mean(dim=("time", "grid_index"))
@@ -17,17 +15,19 @@ def compute_stats(da):
def main():
- parser = ArgumentParser(description="Training arguments")
+ parser = argparse.ArgumentParser(
+ description="Training arguments",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
parser.add_argument(
- "--data_config",
+ "data_config",
type=str,
- default=DEFAULT_PATH,
- help=f"Path to data config file (default: {DEFAULT_PATH})",
+ help="Path to data config file",
)
parser.add_argument(
"--zarr_path",
type=str,
- default="data/normalization.zarr",
+ default="normalization.zarr",
help="Directory where data is stored",
)
args = parser.parse_args()
@@ -49,6 +49,7 @@ def main():
if combined_stats is not None:
for group in combined_stats:
vars_to_combine = group["vars"]
+
means = da_forcing_mean.sel(variable=vars_to_combine)
stds = da_forcing_std.sel(variable=vars_to_combine)
@@ -85,8 +86,8 @@ def main():
{
"state_mean": da_state_mean,
"state_std": da_state_std,
- "diff_mean": diff_mean,
- "diff_std": diff_std,
+ "state_diff_mean": diff_mean,
+ "state_diff_std": diff_std,
}
)
if da_forcing is not None:
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 3b7e1fe9..37993be5 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -17,9 +17,19 @@ class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable"}
def __init__(self, config_path):
+ self.config_path = config_path
with open(config_path, encoding="utf-8", mode="r") as file:
self._config = yaml.safe_load(file)
+ def _normalize_path(self, path):
+ # try to parse path to see if it defines a protocol, e.g. s3://
+ if "://" in path or path.startswith("/"):
+ pass
+ else:
+ # assume path is relative to config file
+ path = os.path.join(os.path.dirname(self.config_path), path)
+ return path
+
def open_zarrs(self, category):
"""Open the zarr dataset for the given category.
@@ -33,7 +43,8 @@ def open_zarrs(self, category):
datasets = []
for config in zarr_configs:
- dataset_path = config["path"]
+ dataset_path = self._normalize_path(config["path"])
+
try:
dataset = xr.open_zarr(dataset_path, consolidated=True)
except Exception as e:
@@ -359,7 +370,7 @@ def _load_and_merge_stats(self):
for i, zarr_config in enumerate(
self._config["utilities"]["normalization"]["zarrs"]
):
- stats_path = zarr_config["path"]
+ stats_path = self._normalize_path(zarr_config["path"])
if not os.path.exists(stats_path):
raise FileNotFoundError(
f"Normalization statistics not found at path: {stats_path}"
@@ -612,9 +623,10 @@ def boundary_mask(self):
xr.DataArray
The boundary mask for the dataset, with dimensions `('grid_index',)`.
"""
- ds_boundary_mask = xr.open_zarr(
+ boundary_mask_path = self._normalize_path(
self._config["boundary"]["mask"]["path"]
)
+ ds_boundary_mask = xr.open_zarr(boundary_mask_path)
return ds_boundary_mask.mask.stack(grid_index=("y", "x")).reset_index(
"grid_index"
)
diff --git a/tests/datastore_configs/mllam/.gitignore b/tests/datastore_configs/mllam/.gitignore
new file mode 100644
index 00000000..f2828f46
--- /dev/null
+++ b/tests/datastore_configs/mllam/.gitignore
@@ -0,0 +1,2 @@
+*.zarr/
+graph/
diff --git a/tests/datastore_configs/mllam/example.danra.yaml b/tests/datastore_configs/mllam/example.danra.yaml
index c04b069b..1ba08865 100644
--- a/tests/datastore_configs/mllam/example.danra.yaml
+++ b/tests/datastore_configs/mllam/example.danra.yaml
@@ -20,7 +20,7 @@ output:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
- ops: [mean, std]
+ ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
validation:
start: 1990-09-06T00:00
diff --git a/tests/datastore_configs/multizarr/data_config.yaml b/tests/datastore_configs/multizarr/data_config.yaml
index d46afa53..0b857761 100644
--- a/tests/datastore_configs/multizarr/data_config.yaml
+++ b/tests/datastore_configs/multizarr/data_config.yaml
@@ -51,7 +51,7 @@ forcing:
lat_lon_names:
lon: lon
lat: lat
- - path: "tests/config_examples/multizarr/datetime_forcings.zarr"
+ - path: "datetime_forcings.zarr"
dims:
time: time
level: null
@@ -111,7 +111,7 @@ boundary:
lon: longitude
lat: latitude
mask:
- path: "data/boundary_mask.zarr"
+ path: "boundary_mask.zarr"
dims:
x: x
y: y
@@ -126,7 +126,7 @@ boundary:
utilities:
normalization:
zarrs:
- - path: "tests/datastore_configs/multizarr/normalization.zarr"
+ - path: "normalization.zarr"
stats_vars:
state_mean: state_mean
state_std: state_std
diff --git a/tests/test_cli.py b/tests/test_cli.py
index cae8c9b9..19ca1ed8 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,7 +1,6 @@
# First-party
import neural_lam
import neural_lam.create_graph
-import neural_lam.datastore.multizarr.create_grid_features
import neural_lam.train_model
@@ -10,5 +9,4 @@ def test_import():
now, eventually we should test their execution too."""
assert neural_lam is not None
assert neural_lam.create_graph is not None
- assert neural_lam.datastore.multizarr.create_grid_features is not None
assert neural_lam.train_model is not None
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 2685721d..2e59eae1 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -26,6 +26,7 @@
# Third-party
import cartopy.crs as ccrs
import pytest
+import xarray as xr
# First-party
from neural_lam.datastore.mllam import MLLAMDatastore
@@ -105,3 +106,30 @@ def test_get_vars(datastore_name):
assert isinstance(units, list)
assert isinstance(names, list)
assert isinstance(num_vars, int)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_normalization_dataarray(datastore_name):
+ """Check that the `datastore.get_normalization_dataarray` method is
+ implemented."""
+ datastore = _init_datastore(datastore_name)
+
+ for category in ["state", "forcing", "static"]:
+ ds_stats = datastore.get_normalization_dataarray(category=category)
+
+ # check that the returned object is an xarray DataArray
+ # and that it has the correct variables
+ assert isinstance(ds_stats, xr.Dataset)
+
+ if category == "state":
+ ops = ["mean", "std", "diff_mean", "diff_std"]
+ elif category == "forcing":
+ ops = ["mean", "std"]
+ elif category == "static":
+ ops = []
+ else:
+ raise NotImplementedError(category)
+
+ for op in ops:
+ var_name = f"{category}_{op}"
+ assert var_name in ds_stats.data_vars
From 555154fd3af9ad1b7c4cf8235f6b2a88fb7343e4 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 09:16:37 +0200
Subject: [PATCH 130/273] get_dataarray test
---
neural_lam/datastore/mllam.py | 6 ++-
.../mllam/example.danra.yaml | 2 +-
tests/test_datastores.py | 40 ++++++++++++++++++-
3 files changed, 43 insertions(+), 5 deletions(-)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 531709fe..18757fe9 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -66,12 +66,14 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
return da_category
else:
t_start = (
- self._ds.splits.sel(split_name=split, split_part="start")
+ self._ds.splits.sel(split_name=split)
+ .sel(split_part="start")
.load()
.item()
)
t_end = (
- self._ds.splits.sel(split_name=split, split_part="end")
+ self._ds.splits.sel(split_name=split)
+ .sel(split_part="end")
.load()
.item()
)
diff --git a/tests/datastore_configs/mllam/example.danra.yaml b/tests/datastore_configs/mllam/example.danra.yaml
index 1ba08865..5c2d02d7 100644
--- a/tests/datastore_configs/mllam/example.danra.yaml
+++ b/tests/datastore_configs/mllam/example.danra.yaml
@@ -22,7 +22,7 @@ output:
compute_statistics:
ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
- validation:
+ val:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
test:
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 2e59eae1..026f8a38 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -8,9 +8,9 @@
- [x] `get_vars_names` (method): Get the names of the variables in the given category.
- [x] `get_num_data_vars` (method): Get the number of data variables in the
given category.
-- [ ] `get_normalization_dataarray` (method): Return the normalization
+- [x] `get_normalization_dataarray` (method): Return the normalization
dataarray for the given category.
-- [ ] `get_dataarray` (method): Return the processed data (as a single
+- [x] `get_dataarray` (method): Return the processed data (as a single
`xr.DataArray`) for the given category and test/train/val-split.
- [ ] `boundary_mask` (property): Return the boundary mask for the dataset,
with spatial dimensions stacked.
@@ -29,6 +29,7 @@
import xarray as xr
# First-party
+from neural_lam.datastore.base import BaseCartesianDatastore
from neural_lam.datastore.mllam import MLLAMDatastore
from neural_lam.datastore.multizarr import MultiZarrDatastore
from neural_lam.datastore.npyfiles import NumpyFilesDatastore
@@ -133,3 +134,38 @@ def test_get_normalization_dataarray(datastore_name):
for op in ops:
var_name = f"{category}_{op}"
assert var_name in ds_stats.data_vars
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_dataarray(datastore_name):
+ """Check that the `datastore.get_dataarray` method is implemented.
+
+ And that it returns an xarray DataArray with the correct dimensions.
+ """
+
+ datastore = _init_datastore(datastore_name)
+
+ for category in ["state", "forcing", "static"]:
+ # TODO: should we expect there to be a "test" split too?
+ for split in ["train", "val"]:
+ expected_dims = ["grid_index", f"{category}_feature"]
+ if category != "static":
+ if not datastore.is_forecast:
+ expected_dims.append("time")
+ else:
+ expected_dims += [
+ "analysis_time",
+ "elapsed_forecast_duration",
+ ]
+
+ # XXX: for now we only have a single attribute to get the shape of
+ # the grid which uses the shape from the "state" category, maybe
+ # this should change?
+ grid_shape = datastore.grid_shape_state
+
+ da = datastore.get_dataarray(category=category, split=split)
+
+ assert isinstance(da, xr.DataArray)
+ assert set(da.dims) == set(expected_dims)
+ if isinstance(datastore, BaseCartesianDatastore):
+ assert da.grid_index.size == grid_shape.x * grid_shape.y
From 8b8a77e99b9eae238963c4d714e0531622e1d76b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 09:25:11 +0200
Subject: [PATCH 131/273] get_dataarray test
---
neural_lam/datastore/base.py | 25 ++++++++++++++++---------
tests/test_datastores.py | 3 +--
2 files changed, 17 insertions(+), 11 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 81d4e0b8..bd9febc5 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -116,15 +116,22 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
@abc.abstractmethod
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
"""Return the processed data (as a single `xr.DataArray`) for the given
- category and test/train/val-split that covers the entire timeline of
- the dataset. The returned dataarray is expected to at minimum have
- dimensions of `(time, grid_index, {category}_feature)` so that any
- spatial dimensions have been stacked into a single dimension and all
- variables and levels have been stacked into a single feature dimension
- named by the `category` of data being loaded. Any additional dimensions
- (for example `ensemble_member` or `analysis_time`) should be kept as
- separate dimensions in the dataarray, and `WeatherDataset` will handle
- the sampling of the data.
+ category of data and test/train/val-split that covers all the data (in
+ space and time) of a given category.
+
+ The returned dataarray is expected to at minimum have dimensions of
+ `(grid_index, {category}_feature)` so that any spatial dimensions have
+ been stacked into a single dimension and all variables and levels have
+ been stacked into a single feature dimension named by the `category` of
+ data being loaded.
+
+ For categories of data that have a time dimension (i.e. not static
+ data), the dataarray is expected additionally have `(analysis_time,
+ elapsed_forecast_duration)` dimensions if `is_forecast` is True, or
+ `(time)` if `is_forecast` is False.
+
+ If the data is ensemble data, the dataarray is expected to have an
+ additional `ensemble_member` dimension.
Parameters
----------
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 026f8a38..fc8afac7 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -146,8 +146,7 @@ def test_get_dataarray(datastore_name):
datastore = _init_datastore(datastore_name)
for category in ["state", "forcing", "static"]:
- # TODO: should we expect there to be a "test" split too?
- for split in ["train", "val"]:
+ for split in ["train", "val", "test"]:
expected_dims = ["grid_index", f"{category}_feature"]
if category != "static":
if not datastore.is_forecast:
From 41f11cdfeb3bd5a002e65207c64e336e9d70c36d Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 09:33:27 +0200
Subject: [PATCH 132/273] boundary_mask
---
neural_lam/datastore/mllam.py | 4 +++-
tests/test_datastores.py | 19 +++++++++++++++++++
2 files changed, 22 insertions(+), 1 deletion(-)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 18757fe9..362573e0 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -135,7 +135,9 @@ def boundary_mask(self) -> xr.DataArray:
x=slice(self._n_boundary_points, -self._n_boundary_points),
y=slice(self._n_boundary_points, -self._n_boundary_points),
)
- ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(1)
+ ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(
+ 1
+ ).astype(int)
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)
@property
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index fc8afac7..297b287d 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -168,3 +168,22 @@ def test_get_dataarray(datastore_name):
assert set(da.dims) == set(expected_dims)
if isinstance(datastore, BaseCartesianDatastore):
assert da.grid_index.size == grid_shape.x * grid_shape.y
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_boundary_mask(datastore_name):
+ """Check that the `datastore.boundary_mask` property is implemented and
+ that the returned object is an xarray DataArray with the correct shape."""
+ datastore = _init_datastore(datastore_name)
+ da_mask = datastore.boundary_mask
+
+ assert isinstance(da_mask, xr.DataArray)
+ assert set(da_mask.dims) == {"grid_index"}
+ assert da_mask.dtype == "int"
+ assert set(da_mask.values) == {0, 1}
+ assert da_mask.sum() > 0
+ assert da_mask.sum() < da_mask.size
+
+ if isinstance(datastore, BaseCartesianDatastore):
+ grid_shape = datastore.grid_shape_state
+ assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y
From a17de0f6dbd72d4c198861f31f81263cf6e0b42e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 09:40:21 +0200
Subject: [PATCH 133/273] get_xy
---
tests/test_datastores.py | 57 ++++++++++++++++++++++++++++++++++++++--
1 file changed, 55 insertions(+), 2 deletions(-)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 297b287d..84c27b05 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -12,11 +12,11 @@
dataarray for the given category.
- [x] `get_dataarray` (method): Return the processed data (as a single
`xr.DataArray`) for the given category and test/train/val-split.
-- [ ] `boundary_mask` (property): Return the boundary mask for the dataset,
+- [x] `boundary_mask` (property): Return the boundary mask for the dataset,
with spatial dimensions stacked.
In addition BaseCartesianDatastore must have the following methods and attributes:
-- [ ] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
+- [x] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
given category of data.
- [ ] `get_xy` (method): Return the x, y coordinates of the dataset.
- [ ] `coords_projection` (property): Projection object for the coordinates.
@@ -25,6 +25,7 @@
# Third-party
import cartopy.crs as ccrs
+import numpy as np
import pytest
import xarray as xr
@@ -187,3 +188,55 @@ def test_boundary_mask(datastore_name):
if isinstance(datastore, BaseCartesianDatastore):
grid_shape = datastore.grid_shape_state
assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_xy_extent(datastore_name):
+ """Check that the `datastore.get_xy_extent` method is implemented and that
+ the returned object is a tuple of the correct length."""
+ datastore = _init_datastore(datastore_name)
+
+ if not isinstance(datastore, BaseCartesianDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ extents = {}
+ # get the extents for each category, and finally check they are all the same
+ for category in ["state", "forcing", "static"]:
+ extent = datastore.get_xy_extent(category)
+ assert isinstance(extent, list)
+ assert len(extent) == 4
+ assert all(isinstance(e, (int, float)) for e in extent)
+ extents[category] = extent
+
+ # check that the extents are the same for all categories
+ for category in ["forcing", "static"]:
+ assert extents["state"] == extents[category]
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_xy(datastore_name):
+ """Check that the `datastore.get_xy` method is implemented."""
+ datastore = _init_datastore(datastore_name)
+
+ if not isinstance(datastore, BaseCartesianDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ for category in ["state", "forcing", "static"]:
+ xy_stacked = datastore.get_xy(category=category, stacked=True)
+ xy_unstacked = datastore.get_xy(category=category, stacked=False)
+
+ assert isinstance(xy_stacked, np.ndarray)
+ assert isinstance(xy_unstacked, np.ndarray)
+
+ nx, ny = datastore.grid_shape_state.x, datastore.grid_shape_state.y
+
+ # for stacked=True, the shape should be (2, n_grid_points)
+ assert xy_stacked.ndim == 2
+ assert xy_stacked.shape[0] == 2
+ assert xy_stacked.shape[1] == nx * ny
+
+ # for stacked=False, the shape should be (2, ny, nx)
+ assert xy_unstacked.ndim == 3
+ assert xy_unstacked.shape[0] == 2
+ assert xy_unstacked.shape[1] == ny
+ assert xy_unstacked.shape[2] == nx
From 0a38a7d453d0a2cdb73f38d19b7b6af8adf32b34 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 10:45:20 +0200
Subject: [PATCH 134/273] remove TrainingSample dataclass
---
neural_lam/create_graph.py | 8 +--
neural_lam/datastore/base.py | 14 +++++
neural_lam/datastore/mllam.py | 4 ++
neural_lam/models/base_graph_model.py | 10 +++-
neural_lam/utils.py | 4 +-
neural_lam/weather_dataset.py | 78 +--------------------------
tests/test_datastores.py | 51 +++++++++++-------
tests/test_mllam_dataset.py | 2 +-
tests/test_multizarr_dataset.py | 2 +-
9 files changed, 68 insertions(+), 105 deletions(-)
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index de73a9c8..91bbe8ed 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -475,14 +475,14 @@ def create_graph(
def create_graph_from_datastore(
datastore: BaseCartesianDatastore,
- graph_dir_path: str,
+ output_root_path: str,
n_max_levels: int = None,
hierarchical: bool = False,
create_plot: bool = False,
):
xy = datastore.get_xy(category="state", stacked=False)
create_graph(
- graph_dir_path=graph_dir_path,
+ graph_dir_path=output_root_path,
xy=xy,
n_max_levels=n_max_levels,
hierarchical=hierarchical,
@@ -505,7 +505,7 @@ def cli(input_args=None):
help="path to the data store",
)
parser.add_argument(
- "--graph",
+ "--name",
type=str,
default="multiscale",
help="Name to save graph as (default: multiscale)",
@@ -536,7 +536,7 @@ def cli(input_args=None):
create_graph_from_datastore(
datastore=datastore,
- graph_dir_path=os.path.join("graphs", args.graph),
+ output_root_path=os.path.join(datastore.root_path, "graphs", args.name),
n_max_levels=args.levels,
hierarchical=args.hierarchical,
create_plot=args.plot,
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index bd9febc5..2a472cbf 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -1,6 +1,7 @@
# Standard library
import abc
import dataclasses
+from pathlib import Path
from typing import List, Union
# Third-party
@@ -30,6 +31,19 @@ class BaseDatastore(abc.ABC):
is_ensemble: bool = False
is_forecast: bool = False
+ @property
+ @abc.abstractmethod
+ def root_path(self) -> Path:
+ """The root path to the datastore. It is relative to this that any
+ derived files (for example the graph components) are stored.
+
+ Returns
+ -------
+ pathlib.Path
+ The root path to the datastore.
+ """
+ pass
+
@property
@abc.abstractmethod
def step_length(self) -> int:
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 362573e0..70f913ae 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -46,6 +46,10 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points
+ @property
+ def root_path(self) -> Path:
+ return Path(self._config_path.parent)
+
def step_length(self) -> int:
da_dt = self._ds["time"].diff("time")
return da_dt.dt.seconds[0] // 3600
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index d20a2d24..4175e2d1 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -19,7 +19,10 @@ def __init__(self, args, datastore, forcing_window_size):
# Load graph with static features
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
# num_mesh_nodes indices,
- self.hierarchical, graph_ldict = utils.load_graph(args.graph)
+ graph_dir_path = datastore.root_path / "graph" / args.graph
+ self.hierarchical, graph_ldict = utils.load_graph(
+ graph_dir_path=graph_dir_path
+ )
for name, attr_value in graph_ldict.items():
# Make BufferLists module members and register tensors as buffers
if isinstance(attr_value, torch.Tensor):
@@ -102,6 +105,11 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
"""
batch_size = prev_state.shape[0]
+ print(f"prev_state.shape: {prev_state.shape}")
+ print(f"prev_prev_state.shape: {prev_prev_state.shape}")
+ print(f"forcing.shape: {forcing.shape}")
+ print(f"grid_static_features.shape: {self.grid_static_features.shape}")
+
# Create full grid node features of shape (B, num_grid_nodes, grid_dim)
grid_features = torch.cat(
(
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index ba146355..a97dcc8f 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -32,10 +32,8 @@ def __iter__(self):
return (self[i] for i in range(len(self)))
-def load_graph(graph_name, device="cpu"):
+def load_graph(graph_dir_path, device="cpu"):
"""Load all tensors representing the graph."""
- # Define helper lambda function
- graph_dir_path = os.path.join("graphs", graph_name)
def loads_file(fn):
return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 84d4a2bc..4e38dbd5 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,5 +1,4 @@
# Standard library
-import dataclasses
import warnings
# Third-party
@@ -12,74 +11,6 @@
from neural_lam.datastore.base import BaseDatastore
-@dataclasses.dataclass
-class TrainingSample:
- """A dataclass to hold a single training sample of `ar_steps`
- autoregressive steps, which consists of the initial states, target states,
- forcing and batch times. The initial and target states should have
- `d_features` features, and the forcing should have `d_windowed_forcing`
- features.
-
- Parameters
- ----------
- init_states : torch.Tensor
- The initial states of the training sample,
- shape (2, N_grid, d_features).
- target_states : torch.Tensor
- The target states of the training sample,
- shape (ar_steps, N_grid, d_features).
- forcing : torch.Tensor
- The forcing of the training sample,
- shape (ar_steps, N_grid, d_windowed_forcing).
- batch_times : np.ndarray
- The times of the batch, shape (ar_steps,).
- """
-
- init_states: torch.Tensor
- target_states: torch.Tensor
- forcing: torch.Tensor
- batch_times: np.ndarray
-
- def __post_init__(self):
- """Validate the shapes of the tensors match between the different
- components of the training sample.
-
- init_states: (2, N_grid, d_features)
- target_states: (ar_steps, N_grid, d_features)
- forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,)
- """
- assert self.init_states.shape[0] == 2
- _, N_grid, d_features = self.init_states.shape
- N_pred_steps = self.target_states.shape[0]
-
- # check number of grid points
- if not (
- self.target_states.shape[1] == self.target_states.shape[1] == N_grid
- ):
- raise Exception(
- "Number of grid points do not match, got "
- f"{self.target_states.shape[1]=} and "
- f"{self.target_states.shape[2]=}, expected {N_grid=}"
- )
-
- # check number of features for init and target states
- assert self.target_states.shape[2] == d_features
-
- # check that target, forcing and batch times have the same number of
- # prediction steps
- if not (
- self.target_states.shape[0]
- == self.forcing.shape[0]
- == self.batch_times.shape[0]
- == N_pred_steps
- ):
- raise Exception(
- "Number of prediction steps do not match, got "
- f"{self.target_states.shape[0]=}, {self.forcing.shape[0]=} and "
- f"{self.batch_times.shape[0]=}, expected {N_pred_steps=}"
- )
-
-
class WeatherDataset(torch.utils.data.Dataset):
"""Dataset class for weather data.
@@ -268,7 +199,7 @@ def __getitem__(self, idx):
da_init_states = da_state.isel(time=slice(None, 2))
da_target_states = da_state.isel(time=slice(2, None))
- batch_times = da_forcing_windowed.time
+ batch_times = da_forcing_windowed.time.values.astype(float)
if self.standardize:
da_init_states = (
@@ -300,12 +231,7 @@ def __getitem__(self, idx):
# forcing: (ar_steps, N_grid, d_windowed_forcing)
# batch_times: (ar_steps,)
- return TrainingSample(
- init_states=init_states,
- target_states=target_states,
- forcing=forcing,
- batch_times=batch_times,
- )
+ return init_states, target_states, forcing, batch_times
class WeatherDataModule(pl.LightningDataModule):
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 84c27b05..44c75a48 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -18,11 +18,14 @@
In addition BaseCartesianDatastore must have the following methods and attributes:
- [x] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
given category of data.
-- [ ] `get_xy` (method): Return the x, y coordinates of the dataset.
-- [ ] `coords_projection` (property): Projection object for the coordinates.
-- [ ] `grid_shape_state` (property): Shape of the grid for the state variables.
+- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
+- [x] `coords_projection` (property): Projection object for the coordinates.
+- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
"""
+# Standard library
+from pathlib import Path
+
# Third-party
import cartopy.crs as ccrs
import numpy as np
@@ -51,17 +54,24 @@
)
-def _init_datastore(datastore_name):
+def init_datastore(datastore_name):
DatastoreClass = DATASTORES[datastore_name]
return DatastoreClass(**EXAMPLES[datastore_name])
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_root_path(datastore_name):
+ """Check that the `datastore.root_path` property is implemented."""
+ datastore = init_datastore(datastore_name)
+ assert isinstance(datastore.root_path, Path)
+
+
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_datastore_grid_xy(datastore_name):
"""Use the `datastore.get_xy` method to get the x, y coordinates of the
dataset and check that the shape is correct against the
`datastore.grid_shape_state` property."""
- datastore = _init_datastore(datastore_name)
+ datastore = init_datastore(datastore_name)
# check the shapes of the xy grid
grid_shape = datastore.grid_shape_state
@@ -78,14 +88,6 @@ def test_datastore_grid_xy(datastore_name):
assert xy.shape == (2, ny, nx)
-@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
-def test_projection(datastore_name):
- """Check that the `datastore.coords_projection` property is implemented."""
- datastore = _init_datastore(datastore_name)
-
- assert isinstance(datastore.coords_projection, ccrs.Projection)
-
-
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_vars(datastore_name):
"""Check that results of.
@@ -97,7 +99,7 @@ def test_get_vars(datastore_name):
are consistent (as in the number of variables are the same) and that the
return types of each are correct.
"""
- datastore = _init_datastore(datastore_name)
+ datastore = init_datastore(datastore_name)
for category in ["state", "forcing", "static"]:
units = datastore.get_vars_units(category)
@@ -114,7 +116,7 @@ def test_get_vars(datastore_name):
def test_get_normalization_dataarray(datastore_name):
"""Check that the `datastore.get_normalization_dataarray` method is
implemented."""
- datastore = _init_datastore(datastore_name)
+ datastore = init_datastore(datastore_name)
for category in ["state", "forcing", "static"]:
ds_stats = datastore.get_normalization_dataarray(category=category)
@@ -144,7 +146,7 @@ def test_get_dataarray(datastore_name):
And that it returns an xarray DataArray with the correct dimensions.
"""
- datastore = _init_datastore(datastore_name)
+ datastore = init_datastore(datastore_name)
for category in ["state", "forcing", "static"]:
for split in ["train", "val", "test"]:
@@ -175,7 +177,7 @@ def test_get_dataarray(datastore_name):
def test_boundary_mask(datastore_name):
"""Check that the `datastore.boundary_mask` property is implemented and
that the returned object is an xarray DataArray with the correct shape."""
- datastore = _init_datastore(datastore_name)
+ datastore = init_datastore(datastore_name)
da_mask = datastore.boundary_mask
assert isinstance(da_mask, xr.DataArray)
@@ -194,7 +196,7 @@ def test_boundary_mask(datastore_name):
def test_get_xy_extent(datastore_name):
"""Check that the `datastore.get_xy_extent` method is implemented and that
the returned object is a tuple of the correct length."""
- datastore = _init_datastore(datastore_name)
+ datastore = init_datastore(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
@@ -216,7 +218,7 @@ def test_get_xy_extent(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_xy(datastore_name):
"""Check that the `datastore.get_xy` method is implemented."""
- datastore = _init_datastore(datastore_name)
+ datastore = init_datastore(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
@@ -240,3 +242,14 @@ def test_get_xy(datastore_name):
assert xy_unstacked.shape[0] == 2
assert xy_unstacked.shape[1] == ny
assert xy_unstacked.shape[2] == nx
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_projection(datastore_name):
+ """Check that the `datastore.coords_projection` property is implemented."""
+ datastore = init_datastore(datastore_name)
+
+ if not isinstance(datastore, BaseCartesianDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ assert isinstance(datastore.coords_projection, ccrs.Projection)
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
index b76faa31..565aebaf 100644
--- a/tests/test_mllam_dataset.py
+++ b/tests/test_mllam_dataset.py
@@ -39,7 +39,7 @@ def test_mllam():
create_graph_from_datastore(
datastore=datastore,
- graph_dir_path="tests/datastore_configs/mllam/graph",
+ output_root_path="tests/datastore_configs/mllam/graph",
)
model = GraphLAM( # noqa
diff --git a/tests/test_multizarr_dataset.py b/tests/test_multizarr_dataset.py
index 7b51d9df..4a780fcb 100644
--- a/tests/test_multizarr_dataset.py
+++ b/tests/test_multizarr_dataset.py
@@ -71,7 +71,7 @@ def test_create_graph_analysis_dataset():
config_path=DATASTORE_PATH / "data_config.yaml"
)
create_graph_from_datastore(
- datastore=datastore, graph_dir_path=DATASTORE_PATH / "graph"
+ datastore=datastore, output_root_path=DATASTORE_PATH / "graph"
)
# test cli
From f65f6b52e22249bb0909b1021789f1a1a5a950a8 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 10:48:01 +0200
Subject: [PATCH 135/273] test for WeatherDataset.__getitem__
---
tests/test_datasets.py | 59 ++++++++++++++++++++++++++++++++++++++++++
1 file changed, 59 insertions(+)
create mode 100644 tests/test_datasets.py
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
new file mode 100644
index 00000000..40ca7398
--- /dev/null
+++ b/tests/test_datasets.py
@@ -0,0 +1,59 @@
+# Third-party
+import pytest
+from test_datastores import DATASTORES, init_datastore
+
+# First-party
+from neural_lam.weather_dataset import WeatherDataset
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_dataset_item(datastore_name):
+ """Check that the `datastore.get_dataarray` method is implemented.
+
+ Validate the shapes of the tensors match between the different
+ components of the training sample.
+
+ init_states: (2, N_grid, d_features)
+ target_states: (ar_steps, N_grid, d_features)
+ forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,)
+ """
+ datastore = init_datastore(datastore_name)
+ N_gridpoints = datastore.grid_shape_state.x * datastore.grid_shape_state.y
+
+ N_pred_steps = 4
+ forcing_window_size = 3
+ dataset = WeatherDataset(
+ datastore=datastore,
+ batch_size=1,
+ split="train",
+ ar_steps=N_pred_steps,
+ forcing_window_size=forcing_window_size,
+ )
+
+ item = dataset[0]
+
+ # unpack the item, this is the current return signature for
+ # WeatherDataset.__getitem__
+ init_states, target_states, forcing, batch_times = item
+
+ # initial states
+ assert init_states.shape[0] == 2 # two time steps go into the input
+ assert init_states.shape[1] == N_gridpoints
+ assert init_states.shape[2] == datastore.get_num_data_vars("state")
+
+ # output states
+ assert target_states.shape[0] == N_pred_steps
+ assert target_states.shape[1] == N_gridpoints
+ assert target_states.shape[2] == datastore.get_num_data_vars("state")
+
+ # forcing
+ assert forcing.shape[0] == N_pred_steps # number of prediction steps
+ assert forcing.shape[1] == N_gridpoints # number of grid points
+ # number of features x window size
+ assert (
+ forcing.shape[2]
+ == datastore.get_num_data_vars("forcing") * forcing_window_size
+ )
+
+ # batch times
+ assert batch_times.shape[0] == N_pred_steps
From a35100e830252bc7b18caec0b61225b0c7ee6e8f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 11:46:30 +0200
Subject: [PATCH 136/273] test for graph creation
---
neural_lam/create_graph.py | 62 ++++++++++++++++++++++++++
neural_lam/utils.py | 43 +++++++++++++++++-
tests/test_graph_creation.py | 86 ++++++++++++++++++++++++++++++++++++
3 files changed, 190 insertions(+), 1 deletion(-)
create mode 100644 tests/test_graph_creation.py
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index 91bbe8ed..872c5aff 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -164,6 +164,68 @@ def create_graph(
hierarchical: bool,
create_plot: bool,
):
+ """Create graph components from `xy` grid coordinates and store in
+ `graph_dir_path`.
+
+ Creates the following files for all graphs:
+ - g2m_edge_index.pt [2, N_g2m_edges]
+ - g2m_features.pt [N_g2m_edges, d_features]
+ - m2g_edge_index.pt [2, N_m2m_edges]
+ - m2g_features.pt [N_m2m_edges, d_features]
+ - m2m_edge_index.pt list of [2, N_m2m_edges_level], length==n_levels
+ - m2m_features.pt list of [N_m2m_edges_level, d_features], length==n_levels
+ - mesh_features.pt list of [N_mesh_nodes_level, d_mesh_static], length==n_levels
+
+ where
+ d_features:
+ number of features per edge (currently d_features==3, for
+ edge-length, x and y)
+ N_g2m_edges:
+ number of edges in the graph from grid-to-mesh
+ N_m2g_edges:
+ number of edges in the graph from mesh-to-grid
+ N_m2m_edges_level:
+ number of edges in the graph from mesh-to-mesh at a given level
+ (list index corresponds to the level)
+ d_mesh_static:
+ number of static features per mesh node (currently
+ d_mesh_static==2, for x and y)
+ N_mesh_nodes_level:
+ number of nodes in the mesh at a given level
+
+ And in addition for hierarchical graphs:
+ - mesh_up_edge_index.pt
+ list of [2, N_mesh_updown_edges_level], length==n_levels-1
+ - mesh_up_features.pt
+ list of [N_mesh_updown_edges_level, d_features], length==n_levels-1
+ - mesh_down_edge_index.pt
+ list of [2, N_mesh_updown_edges_level], length==n_levels-1
+ - mesh_down_features.pt
+ list of [N_mesh_updown_edges_level, d_features], length==n_levels-1
+
+ where N_mesh_updown_edges_level is the number of edges in the graph from
+ mesh-to-mesh between two consecutive levels (list index corresponds index
+ of lower level)
+
+
+ Parameters
+ ----------
+ graph_dir_path : str
+ Path to store the graph components.
+ xy : np.ndarray
+ Grid coordinates, expected to be of shape (2, Ny, Nx).
+ n_max_levels : int
+ Limit multi-scale mesh to given number of levels, from bottom up
+ (default: None (no limit)).
+ hierarchical : bool
+ Generate hierarchical mesh graph (default: False).
+ create_plot : bool
+ If graphs should be plotted during generation (default: False).
+
+ Returns
+ -------
+ None
+ """
os.makedirs(graph_dir_path, exist_ok=True)
grid_xy = torch.tensor(xy)
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index a97dcc8f..79de3193 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -33,7 +33,48 @@ def __iter__(self):
def load_graph(graph_dir_path, device="cpu"):
- """Load all tensors representing the graph."""
+ """Load all tensors representing the graph from `graph_dir_path`.
+
+ Needs the following files for all graphs:
+ - m2m_edge_index.pt
+ - g2m_edge_index.pt
+ - m2g_edge_index.pt
+ - m2m_features.pt
+ - g2m_features.pt
+ - m2g_features.pt
+ - mesh_features.pt
+
+ And in addition for hierarchical graphs:
+ - mesh_up_edge_index.pt
+ - mesh_down_edge_index.pt
+ - mesh_up_features.pt
+ - mesh_down_features.pt
+
+ Parameters
+ ----------
+ graph_dir_path : str
+ Path to directory containing the graph files.
+ device : str
+ Device to load tensors to.
+
+ Returns
+ -------
+ hierarchical : bool
+ Whether the graph is hierarchical.
+ graph : dict
+ Dictionary containing the graph tensors, with keys as follows:
+ - g2m_edge_index
+ - m2g_edge_index
+ - m2m_edge_index
+ - mesh_up_edge_index
+ - mesh_down_edge_index
+ - g2m_features
+ - m2g_features
+ - m2m_features
+ - mesh_up_features
+ - mesh_down_features
+ - mesh_static_features
+ """
def loads_file(fn):
return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py
new file mode 100644
index 00000000..01c69426
--- /dev/null
+++ b/tests/test_graph_creation.py
@@ -0,0 +1,86 @@
+# Standard library
+import tempfile
+from pathlib import Path
+
+# Third-party
+import pytest
+import torch
+from test_datastores import DATASTORES, init_datastore
+
+# First-party
+from neural_lam.create_graph import create_graph_from_datastore
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_graph_creation(datastore_name):
+ """Check that the `create_graph_from_datastore` function is implemented.
+
+ And that the graph is created in the correct location.
+ """
+ datastore = init_datastore(datastore_name)
+ graph_name = "multiscale"
+ hierarchical = False
+
+ required_graph_files = [
+ "m2m_edge_index.pt",
+ "g2m_edge_index.pt",
+ "m2g_edge_index.pt",
+ "m2m_features.pt",
+ "g2m_features.pt",
+ "m2g_features.pt",
+ "mesh_features.pt",
+ ]
+ if hierarchical:
+ required_graph_files.extend(
+ [
+ "mesh_up_edge_index.pt",
+ "mesh_down_edge_index.pt",
+ "mesh_up_features.pt",
+ "mesh_down_features.pt",
+ ]
+ )
+
+ # TODO: check that the number of edges is consistent over the files, for
+ # now we just check the number of features
+ d_features = 3
+ d_mesh_static = 2
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ graph_dir_path = Path(tmpdir) / "graph" / graph_name
+
+ create_graph_from_datastore(
+ datastore=datastore, output_root_path=str(graph_dir_path)
+ )
+
+ assert graph_dir_path.exists()
+
+ # check that all the required files are present
+ for file_name in required_graph_files:
+ assert (graph_dir_path / file_name).exists()
+
+ # try to load each and ensure they have the right shape
+ for file_name in required_graph_files:
+ file_id = Path(file_name).stem # remove the extension
+ result = torch.load(graph_dir_path / file_name)
+
+ if file_id.startswith("g2m") or file_id.startswith("m2g"):
+ assert isinstance(result, torch.Tensor)
+
+ if file_id.endswith("_index"):
+ assert (
+ result.shape[0] == 2
+ ) # adjacency matrix uses two rows
+ elif file_id.endswith("_features"):
+ assert result.shape[1] == d_features
+
+ elif file_id.startswith("m2m") or file_id.startswith("mesh"):
+ assert isinstance(result, list)
+ for r in result:
+ assert isinstance(r, torch.Tensor)
+
+ if file_id == "mesh_features":
+ assert r.shape[1] == d_mesh_static
+ elif file_id.endswith("_index"):
+ assert r.shape[0] == 2 # adjacency matrix uses two rows
+ elif file_id.endswith("_features"):
+ assert r.shape[1] == d_features
From cfb061887007d0ee6f9e2801dffa63717e404ed4 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 11:57:00 +0200
Subject: [PATCH 137/273] more graph creation tests
---
neural_lam/create_graph.py | 2 +-
tests/test_graph_creation.py | 31 +++++++++++++++++++++++++++----
2 files changed, 28 insertions(+), 5 deletions(-)
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index 872c5aff..c13d0f93 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -245,7 +245,7 @@ def create_graph(
# Limit the levels in mesh graph
mesh_levels = min(mesh_levels, n_max_levels)
- print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}")
+ # print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}")
# multi resolution tree levels
G = []
diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py
index 01c69426..6384c46f 100644
--- a/tests/test_graph_creation.py
+++ b/tests/test_graph_creation.py
@@ -11,15 +11,25 @@
from neural_lam.create_graph import create_graph_from_datastore
+@pytest.mark.parametrize("graph_name", ["1level", "multiscale", "hierarchical"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
-def test_graph_creation(datastore_name):
+def test_graph_creation(datastore_name, graph_name):
"""Check that the `create_graph_from_datastore` function is implemented.
And that the graph is created in the correct location.
"""
datastore = init_datastore(datastore_name)
- graph_name = "multiscale"
- hierarchical = False
+ if graph_name == "hierarchical":
+ hierarchical = True
+ n_max_levels = 3
+ elif graph_name == "multiscale":
+ hierarchical = False
+ n_max_levels = 3
+ elif graph_name == "1level":
+ hierarchical = False
+ n_max_levels = 1
+ else:
+ raise ValueError(f"Unknown graph_name: {graph_name}")
required_graph_files = [
"m2m_edge_index.pt",
@@ -49,7 +59,10 @@ def test_graph_creation(datastore_name):
graph_dir_path = Path(tmpdir) / "graph" / graph_name
create_graph_from_datastore(
- datastore=datastore, output_root_path=str(graph_dir_path)
+ datastore=datastore,
+ output_root_path=str(graph_dir_path),
+ hierarchical=hierarchical,
+ n_max_levels=n_max_levels,
)
assert graph_dir_path.exists()
@@ -75,6 +88,16 @@ def test_graph_creation(datastore_name):
elif file_id.startswith("m2m") or file_id.startswith("mesh"):
assert isinstance(result, list)
+ if not hierarchical:
+ assert len(result) == 1
+ else:
+ if file_id.startswith("mesh_up") or file_id.startswith(
+ "mesh_down"
+ ):
+ assert len(result) == n_max_levels - 1
+ else:
+ assert len(result) == n_max_levels
+
for r in result:
assert isinstance(r, torch.Tensor)
From 86987198302ef3e92cccf807844126b42ae286b6 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 12:06:38 +0200
Subject: [PATCH 138/273] check for consistency of num features across splits
---
tests/test_datastores.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 44c75a48..e791ac86 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -149,6 +149,7 @@ def test_get_dataarray(datastore_name):
datastore = init_datastore(datastore_name)
for category in ["state", "forcing", "static"]:
+ n_features = {}
for split in ["train", "val", "test"]:
expected_dims = ["grid_index", f"{category}_feature"]
if category != "static":
@@ -172,6 +173,11 @@ def test_get_dataarray(datastore_name):
if isinstance(datastore, BaseCartesianDatastore):
assert da.grid_index.size == grid_shape.x * grid_shape.y
+ n_features[split] = da[category + "_feature"].size
+
+ # check that the number of features is the same for all splits
+ assert n_features["train"] == n_features["val"] == n_features["test"]
+
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_boundary_mask(datastore_name):
From 3381404853c8158877660b4fac907c2f1bdce688 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 24 Jul 2024 12:54:30 +0200
Subject: [PATCH 139/273] test for single batch from mllam through model
---
neural_lam/models/ar_model.py | 13 ++++--
neural_lam/models/base_graph_model.py | 5 --
neural_lam/weather_dataset.py | 2 -
tests/test_datasets.py | 66 +++++++++++++++++++++++++--
4 files changed, 73 insertions(+), 13 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 81e26d22..59ca1fdc 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -38,10 +38,14 @@ def __init__(
da_state_stats = datastore.get_normalization_dataarray(category="state")
da_boundary_mask = datastore.boundary_mask
- # Load static features for grid/data
+ # Load static features for grid/data, NB: self.predict_step assumes dimension
+ # order to be (grid_index, static_feature)
+ arr_static = da_static_features.transpose(
+ "grid_index", "static_feature"
+ ).values
self.register_buffer(
"grid_static_features",
- torch.tensor(da_static_features.values, dtype=torch.float32),
+ torch.tensor(arr_static, dtype=torch.float32),
persistent=False,
)
@@ -98,7 +102,10 @@ def __init__(
boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
- )
+ ).unsqueeze(
+ 1
+ ) # add feature dim
+
self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 4175e2d1..a76fc518 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -105,11 +105,6 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
"""
batch_size = prev_state.shape[0]
- print(f"prev_state.shape: {prev_state.shape}")
- print(f"prev_prev_state.shape: {prev_prev_state.shape}")
- print(f"forcing.shape: {forcing.shape}")
- print(f"grid_static_features.shape: {self.grid_static_features.shape}")
-
# Create full grid node features of shape (B, num_grid_nodes, grid_dim)
grid_features = torch.cat(
(
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 4e38dbd5..de5067b3 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -23,13 +23,11 @@ def __init__(
split="train",
ar_steps=3,
forcing_window_size=3,
- batch_size=4,
standardize=True,
):
super().__init__()
self.split = split
- self.batch_size = batch_size
self.ar_steps = ar_steps
self.datastore = datastore
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 40ca7398..f6802f5b 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -1,8 +1,15 @@
+# Standard library
+from pathlib import Path
+
# Third-party
import pytest
+import torch
from test_datastores import DATASTORES, init_datastore
+from torch.utils.data import DataLoader
# First-party
+from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataset
@@ -47,9 +54,8 @@ def test_dataset_item(datastore_name):
assert target_states.shape[2] == datastore.get_num_data_vars("state")
# forcing
- assert forcing.shape[0] == N_pred_steps # number of prediction steps
- assert forcing.shape[1] == N_gridpoints # number of grid points
- # number of features x window size
+ assert forcing.shape[0] == N_pred_steps
+ assert forcing.shape[1] == N_gridpoints
assert (
forcing.shape[2]
== datastore.get_num_data_vars("forcing") * forcing_window_size
@@ -57,3 +63,57 @@ def test_dataset_item(datastore_name):
# batch times
assert batch_times.shape[0] == N_pred_steps
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_single_batch(datastore_name, split="train"):
+ """Check that the `datastore.get_dataarray` method is implemented.
+
+ And that it returns an xarray DataArray with the correct dimensions.
+ """
+ datastore = init_datastore(datastore_name)
+
+ device_name = ( # noqa
+ torch.device("cuda") if torch.cuda.is_available() else "cpu"
+ )
+
+ graph_name = "1level"
+
+ class ModelArgs:
+ output_std = False
+ loss = "mse"
+ restore_opt = False
+ n_example_pred = 1
+ # XXX: this should be superfluous when we have already defined the
+ # model object no?
+ graph = graph_name
+ hidden_dim = 64
+ hidden_layers = 1
+ processor_layers = 4
+ mesh_aggr = "sum"
+
+ args = ModelArgs()
+
+ graph_dir_path = Path(datastore.root_path) / "graph" / graph_name
+
+ if not graph_dir_path.exists():
+ create_graph_from_datastore(
+ datastore=datastore,
+ output_root_path=str(graph_dir_path),
+ n_max_levels=1,
+ )
+
+ dataset = WeatherDataset(datastore=datastore, split=split)
+
+ model = GraphLAM( # noqa
+ args=args,
+ forcing_window_size=dataset.forcing_window_size,
+ datastore=datastore,
+ )
+
+ model_device = model.to(device_name)
+ data_loader = DataLoader(dataset, batch_size=2)
+ batch = next(iter(data_loader))
+ model_device.common_step(batch)
+
+ assert False
From 2a6796c2900017a4ee4f69c372669a1719d58543 Mon Sep 17 00:00:00 2001
From: joeloskarsson
Date: Wed, 24 Jul 2024 14:51:41 +0200
Subject: [PATCH 140/273] Add init files to expose classes in editable package
---
neural_lam/__init__.py | 8 ++++++++
neural_lam/models/__init__.py | 5 +++++
neural_lam/train_model.py | 9 ++++-----
pyproject.toml | 7 +++++++
4 files changed, 24 insertions(+), 5 deletions(-)
create mode 100644 neural_lam/__init__.py
create mode 100644 neural_lam/models/__init__.py
diff --git a/neural_lam/__init__.py b/neural_lam/__init__.py
new file mode 100644
index 00000000..9aff809f
--- /dev/null
+++ b/neural_lam/__init__.py
@@ -0,0 +1,8 @@
+import neural_lam.config
+import neural_lam.interaction_net
+import neural_lam.metrics
+import neural_lam.models
+import neural_lam.utils
+import neural_lam.vis
+from .weather_dataset import WeatherDataset
+
diff --git a/neural_lam/models/__init__.py b/neural_lam/models/__init__.py
new file mode 100644
index 00000000..f7c1b94f
--- /dev/null
+++ b/neural_lam/models/__init__.py
@@ -0,0 +1,5 @@
+from .graph_lam import GraphLAM
+from .hi_lam import HiLAM
+from .hi_lam_parallel import HiLAMParallel
+from .base_graph_model import BaseGraphModel
+from .base_hi_graph_model import BaseHiGraphModel
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index dd1ad313..6c81213d 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -10,11 +10,10 @@
from lightning_fabric.utilities import seed
# Local
-from . import config, utils
-from .models.graph_lam import GraphLAM
-from .models.hi_lam import HiLAM
-from .models.hi_lam_parallel import HiLAMParallel
-from .weather_dataset import WeatherDataset
+from . import config, utils, WeatherDataset
+from .models import GraphLAM
+from .models import HiLAM
+from .models import HiLAMParallel
MODELS = {
"graph_lam": GraphLAM,
diff --git a/pyproject.toml b/pyproject.toml
index b513a258..c482abc9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,10 @@
+[project]
+name = "neural-lam"
+version = "0.1.0"
+
+[tool.setuptools]
+py-modules = ["neural_lam"]
+
[tool.black]
line-length = 80
From 8f4e0e05015e8186d03c04295ae7cf4a155497f1 Mon Sep 17 00:00:00 2001
From: joeloskarsson
Date: Wed, 24 Jul 2024 15:27:05 +0200
Subject: [PATCH 141/273] Linting
---
neural_lam/__init__.py | 4 +++-
neural_lam/create_parameter_weights.py | 3 +--
neural_lam/models/__init__.py | 5 +++--
neural_lam/train_model.py | 6 ++----
4 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/neural_lam/__init__.py b/neural_lam/__init__.py
index 9aff809f..dd565a26 100644
--- a/neural_lam/__init__.py
+++ b/neural_lam/__init__.py
@@ -1,8 +1,10 @@
+# First-party
import neural_lam.config
import neural_lam.interaction_net
import neural_lam.metrics
import neural_lam.models
import neural_lam.utils
import neural_lam.vis
-from .weather_dataset import WeatherDataset
+# Local
+from .weather_dataset import WeatherDataset
diff --git a/neural_lam/create_parameter_weights.py b/neural_lam/create_parameter_weights.py
index a33b56b2..74058d38 100644
--- a/neural_lam/create_parameter_weights.py
+++ b/neural_lam/create_parameter_weights.py
@@ -11,8 +11,7 @@
from tqdm import tqdm
# Local
-from . import config
-from .weather_dataset import WeatherDataset
+from . import WeatherDataset, config
class PaddedWeatherDataset(torch.utils.data.Dataset):
diff --git a/neural_lam/models/__init__.py b/neural_lam/models/__init__.py
index f7c1b94f..f65387ab 100644
--- a/neural_lam/models/__init__.py
+++ b/neural_lam/models/__init__.py
@@ -1,5 +1,6 @@
+# Local
+from .base_graph_model import BaseGraphModel
+from .base_hi_graph_model import BaseHiGraphModel
from .graph_lam import GraphLAM
from .hi_lam import HiLAM
from .hi_lam_parallel import HiLAMParallel
-from .base_graph_model import BaseGraphModel
-from .base_hi_graph_model import BaseHiGraphModel
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 6c81213d..39f7aecd 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -10,10 +10,8 @@
from lightning_fabric.utilities import seed
# Local
-from . import config, utils, WeatherDataset
-from .models import GraphLAM
-from .models import HiLAM
-from .models import HiLAMParallel
+from . import WeatherDataset, config, utils
+from .models import GraphLAM, HiLAM, HiLAMParallel
MODELS = {
"graph_lam": GraphLAM,
From e657abbe3edba29484a16233a16b6eefc805526f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 08:06:35 +0000
Subject: [PATCH 142/273] working training_step with datastores!
---
pyproject.toml | 3 +++
tests/test_datasets.py | 8 ++++----
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 6d2ddf71..739913ff 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,6 +25,7 @@ dependencies = [
"torch>=2.3.0",
"torch-geometric==2.3.1",
"mllam-data-prep @ git+https://github.com/mllam/mllam-data-prep",
+ "parse>=1.20.2",
]
requires-python = ">=3.9"
@@ -33,6 +34,8 @@ dev = [
"pre-commit>=2.15.0",
"pytest>=8.2.1",
"pooch>=1.8.1",
+ "ipdb>=0.13.13",
+ "gpustat>=1.1.1",
]
[tool.black]
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index f6802f5b..263820a6 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -87,7 +87,7 @@ class ModelArgs:
# XXX: this should be superfluous when we have already defined the
# model object no?
graph = graph_name
- hidden_dim = 64
+ hidden_dim = 8
hidden_layers = 1
processor_layers = 4
mesh_aggr = "sum"
@@ -114,6 +114,6 @@ class ModelArgs:
model_device = model.to(device_name)
data_loader = DataLoader(dataset, batch_size=2)
batch = next(iter(data_loader))
- model_device.common_step(batch)
-
- assert False
+ batch_device = [part.to(device_name) for part in batch]
+ model_device.common_step(batch_device)
+ model_device.training_step(batch_device)
From effc99b981d4b94c155af32744ceaa22bebedf9c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 08:15:55 +0000
Subject: [PATCH 143/273] remove superfluous tests
---
tests/test_datasets.py | 1 -
tests/test_mllam_dataset.py | 49 ---------------
tests/test_multizarr_dataset.py | 102 --------------------------------
3 files changed, 152 deletions(-)
delete mode 100644 tests/test_mllam_dataset.py
delete mode 100644 tests/test_multizarr_dataset.py
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 263820a6..ecdae7b2 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -31,7 +31,6 @@ def test_dataset_item(datastore_name):
forcing_window_size = 3
dataset = WeatherDataset(
datastore=datastore,
- batch_size=1,
split="train",
ar_steps=N_pred_steps,
forcing_window_size=forcing_window_size,
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
deleted file mode 100644
index 565aebaf..00000000
--- a/tests/test_mllam_dataset.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Third-party
-import torch
-
-# First-party
-from neural_lam.create_graph import create_graph_from_datastore
-from neural_lam.datastore import MLLAMDatastore
-from neural_lam.models.graph_lam import GraphLAM
-from neural_lam.weather_dataset import WeatherDataModule, WeatherDataset
-
-
-class ModelArgs:
- output_std = True
- loss = "mse"
- restore_opt = False
- n_example_pred = 1
- # XXX: this should be superfluous when we have already defined the model object
- graph = "multiscale"
-
-
-def test_mllam():
- config_path = "tests/datastore_configs/mllam/example.danra.yaml"
- datastore = MLLAMDatastore(config_path=config_path)
- dataset = WeatherDataset(datastore=datastore)
-
- item = dataset[0] # noqa
-
- data_module = WeatherDataModule( # noqa
- ar_steps_train=3,
- ar_steps_eval=3,
- standardize=True,
- batch_size=2,
- )
-
- device_name = ( # noqa
- torch.device("cuda") if torch.cuda.is_available() else "cpu"
- )
-
- args = ModelArgs()
-
- create_graph_from_datastore(
- datastore=datastore,
- output_root_path="tests/datastore_configs/mllam/graph",
- )
-
- model = GraphLAM( # noqa
- args=args,
- forcing_window_size=dataset.forcing_window_size,
- datastore=datastore,
- )
diff --git a/tests/test_multizarr_dataset.py b/tests/test_multizarr_dataset.py
deleted file mode 100644
index 4a780fcb..00000000
--- a/tests/test_multizarr_dataset.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Standard library
-import os
-from pathlib import Path
-
-# First-party
-from neural_lam.create_graph import create_graph as create_graph
-from neural_lam.create_graph import create_graph_from_datastore
-from neural_lam.datastore.multizarr import MultiZarrDatastore
-
-# from neural_lam.datasets.config import Config
-from neural_lam.weather_dataset import WeatherDataset
-
-# Disable weights and biases to avoid unnecessary logging
-# and to avoid having to deal with authentication
-os.environ["WANDB_DISABLED"] = "true"
-
-DATASTORE_PATH = Path("tests/datastore_configs/multizarr")
-
-
-def test_load_analysis_dataset():
- # TODO: Access rights should be fixed for pooch to work
- datastore = MultiZarrDatastore(
- config_path=DATASTORE_PATH / "data_config.yaml"
- )
-
- var_state_names = datastore.get_vars_names(category="state")
- var_state_units = datastore.get_vars_units(category="state")
- num_state_vars = datastore.get_num_data_vars(category="state")
-
- assert len(var_state_names) == len(var_state_units) == num_state_vars
-
- var_forcing_names = datastore.get_vars_names(category="forcing")
- var_forcing_units = datastore.get_vars_units(category="forcing")
- num_forcing_vars = datastore.get_num_data_vars(category="forcing")
-
- assert len(var_forcing_names) == len(var_forcing_units) == num_forcing_vars
-
- stats = datastore.get_normalization_stats(category="state") # noqa
-
- # Assert dataset can be loaded
- ds = datastore.get_dataarray(category="state")
- grid = ds.sizes["y"] * ds.sizes["x"]
- dataset = WeatherDataset(
- datastore=datastore, split="train", ar_steps=3, standardize=True
- )
- batch = dataset[0]
- # return init_states, target_states, forcing, batch_times
- # init_states: (2, N_grid, d_features)
- # target_states: (ar_steps-2, N_grid, d_features)
- # forcing: (ar_steps-2, N_grid, d_windowed_forcing)
- # batch_times: (ar_steps-2,)
- assert list(batch[0].shape) == [2, grid, num_state_vars]
- assert list(batch[1].shape) == [dataset.ar_steps - 2, grid, num_state_vars]
- # assert list(batch[2].shape) == [
- # dataset.ar_steps - 2,
- # grid,
- # num_forcing_vars * config.forcing.window,
- # ]
- assert isinstance(batch[3], list)
-
- # Assert provided grid-shapes
- # assert config.get_xy("static")[0].shape == (
- # config.grid_shape_state.y,
- # config.grid_shape_state.x,
- # )
- # assert config.get_xy("static")[0].shape == (ds.sizes["y"], ds.sizes["x"])
-
-
-def test_create_graph_analysis_dataset():
- datastore = MultiZarrDatastore(
- config_path=DATASTORE_PATH / "data_config.yaml"
- )
- create_graph_from_datastore(
- datastore=datastore, output_root_path=DATASTORE_PATH / "graph"
- )
-
- # test cli
- args = [
- "--graph=hierarchical",
- "--hierarchical=1",
- "--data_config=tests/data_config.yaml",
- "--levels=2",
- ]
- create_graph(args)
-
-
-# def test_train_model_analysis_dataset():
-# args = [
-# "--model=hi_lam",
-# "--data_config=tests/data_config.yaml",
-# "--num_workers=4",
-# "--epochs=1",
-# "--graph=hierarchical",
-# "--hidden_dim=16",
-# "--hidden_layers=1",
-# "--processor_layers=1",
-# "--ar_steps_eval=1",
-# "--eval=val",
-# "--n_example_pred=0",
-# "--val_steps_to_log=1",
-# ]
-# train_model(args)
From a047026ce63ee08f75f247081b2029313b69a5f3 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 09:15:23 +0000
Subject: [PATCH 144/273] fix for dataset length
---
neural_lam/weather_dataset.py | 31 +++++++++++++++++++++++++------
tests/test_datasets.py | 8 +++++++-
2 files changed, 32 insertions(+), 7 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index de5067b3..4355b1ea 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -76,8 +76,11 @@ def __len__(self):
)
return self.da_state.analysis_time.size
else:
- # Skip first and last time step
- return len(self.da_state.time) - self.ar_steps
+ # sample_len = 2 + ar_steps <-- 2 initial states + ar_steps target states
+ # n_samples = len(self.da_state.time) - sample_len + 1
+ # = len(self.da_state.time) - 2 - ar_steps + 1
+ # = len(self.da_state.time) - ar_steps - 1
+ return len(self.da_state.time) - self.ar_steps - 1
def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
"""Produce a time slice of the given dataarray `da` (state or forcing)
@@ -181,7 +184,7 @@ def __getitem__(self, idx):
da=da_forcing,
idx=idx,
n_steps=self.ar_steps,
- n_timesteps_offset=2 + n,
+ n_timesteps_offset=n,
)
if n > 0:
da_ = da_.drop_vars("time")
@@ -231,19 +234,32 @@ def __getitem__(self, idx):
return init_states, target_states, forcing, batch_times
+ def __iter__(self):
+ """Convenience method to iterate over the dataset.
+
+ This isn't used by pytorch DataLoader which itself implements an
+ iterator that uses Dataset.__getitem__ and Dataset.__len__.
+ """
+ for i in range(len(self)):
+ yield self[i]
+
class WeatherDataModule(pl.LightningDataModule):
"""DataModule for weather data."""
def __init__(
self,
+ datastore: BaseDatastore,
ar_steps_train=3,
ar_steps_eval=25,
standardize=True,
+ forcing_window_size=3,
batch_size=4,
num_workers=16,
):
super().__init__()
+ self._datastore = datastore
+ self.forcing_window_size = forcing_window_size
self.ar_steps_train = ar_steps_train
self.ar_steps_eval = ar_steps_eval
self.standardize = standardize
@@ -256,24 +272,27 @@ def __init__(
def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_dataset = WeatherDataset(
+ datastore=self._datastore,
split="train",
ar_steps=self.ar_steps_train,
standardize=self.standardize,
- batch_size=self.batch_size,
+ forcing_window_size=self.forcing_window_size,
)
self.val_dataset = WeatherDataset(
+ datastore=self._datastore,
split="val",
ar_steps=self.ar_steps_eval,
standardize=self.standardize,
- batch_size=self.batch_size,
+ forcing_window_size=self.forcing_window_size,
)
if stage == "test" or stage is None:
self.test_dataset = WeatherDataset(
+ datastore=self._datastore,
split="test",
ar_steps=self.ar_steps_eval,
standardize=self.standardize,
- batch_size=self.batch_size,
+ forcing_window_size=self.forcing_window_size,
)
def train_dataloader(self):
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index ecdae7b2..72518887 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -27,7 +27,7 @@ def test_dataset_item(datastore_name):
datastore = init_datastore(datastore_name)
N_gridpoints = datastore.grid_shape_state.x * datastore.grid_shape_state.y
- N_pred_steps = 4
+ N_pred_steps = 1
forcing_window_size = 3
dataset = WeatherDataset(
datastore=datastore,
@@ -63,6 +63,12 @@ def test_dataset_item(datastore_name):
# batch times
assert batch_times.shape[0] == N_pred_steps
+ # try to run through the whole dataset to ensure slicing and stacking
+ # operations are working as expected and are consistent with the dataset
+ # length
+ for item in iter(dataset):
+ pass
+
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_single_batch(datastore_name, split="train"):
From d2c62ed174af6281b4db3e6dac0c99ef5f52cb88 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 09:30:41 +0000
Subject: [PATCH 145/273] step length should be int
---
neural_lam/datastore/mllam.py | 2 +-
tests/test_datastores.py | 11 +++++++++++
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 70f913ae..d97b9b4a 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -52,7 +52,7 @@ def root_path(self) -> Path:
def step_length(self) -> int:
da_dt = self._ds["time"].diff("time")
- return da_dt.dt.seconds[0] // 3600
+ return (da_dt.dt.seconds[0] // 3600).item()
def get_vars_units(self, category: str) -> List[str]:
return self._ds[f"{category}_feature_units"].values.tolist()
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index e791ac86..15f3e281 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -1,6 +1,8 @@
"""List of methods and attributes that should be implemented in a subclass of
`BaseCartesianDatastore` (these are all decorated with `@abc.abstractmethod`):
+- [x] `root_path` (property): Root path of the datastore.
+- [ ] `step_length` (property): Length of the time step in hours.
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
@@ -66,6 +68,15 @@ def test_root_path(datastore_name):
assert isinstance(datastore.root_path, Path)
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_step_length(datastore_name):
+ """Check that the `datastore.step_length` property is implemented."""
+ datastore = init_datastore(datastore_name)
+ step_length = datastore.step_length()
+ assert isinstance(step_length, int)
+ assert step_length > 0
+
+
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_datastore_grid_xy(datastore_name):
"""Use the `datastore.get_xy` method to get the x, y coordinates of the
From 58f5d99c309a8bd2ccc6a5d26e840f266cbe29e6 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 09:32:30 +0000
Subject: [PATCH 146/273] step length should be int
---
neural_lam/datastore/mllam.py | 1 +
tests/test_datastores.py | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index d97b9b4a..f91faad9 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -50,6 +50,7 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
def root_path(self) -> Path:
return Path(self._config_path.parent)
+ @property
def step_length(self) -> int:
da_dt = self._ds["time"].diff("time")
return (da_dt.dt.seconds[0] // 3600).item()
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 15f3e281..861f1d54 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -72,7 +72,7 @@ def test_root_path(datastore_name):
def test_step_length(datastore_name):
"""Check that the `datastore.step_length` property is implemented."""
datastore = init_datastore(datastore_name)
- step_length = datastore.step_length()
+ step_length = datastore.step_length
assert isinstance(step_length, int)
assert step_length > 0
From 64d43a61288acbf16765d50d7d99ed2a6f299983 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 10:01:20 +0000
Subject: [PATCH 147/273] training working with mllam datastore!
---
.gitignore | 1 +
neural_lam/models/ar_model.py | 19 ++++----
neural_lam/vis.py | 30 ++++++++-----
tests/conftest.py | 6 +++
tests/test_training.py | 82 +++++++++++++++++++++++++++++++++++
5 files changed, 118 insertions(+), 20 deletions(-)
create mode 100644 tests/conftest.py
create mode 100644 tests/test_training.py
diff --git a/.gitignore b/.gitignore
index f5faeb52..8cd4e45d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -82,6 +82,7 @@ tags
# pdm (https://pdm-project.org/en/stable/)
.pdm-python
+.venv
# exclude pdm.lock file so that both cpu and gpu versions of torch will be accepted by pdm
pdm.lock
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 59ca1fdc..d18c89ab 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -28,6 +28,7 @@ def __init__(
super().__init__()
self.save_hyperparameters()
self.args = args
+ self._datastore = datastore
# XXX: should be this be somewhere else?
split = "train"
num_state_vars = datastore.get_num_data_vars(category="state")
@@ -429,18 +430,18 @@ def plot_examples(self, batch, n_examples, prediction=None):
# Create one figure per variable at this time step
var_figs = [
vis.plot_prediction(
- pred_t[:, var_i],
- target_t[:, var_i],
- self.interior_mask[:, 0],
- self.datastore,
+ pred=pred_t[:, var_i],
+ target=target_t[:, var_i],
+ obs_mask=self.interior_mask[:, 0],
+ datastore=self.datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self.step_length * t_i} h)",
vrange=var_vrange,
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
- self.data_config.vars_names("state"),
- self.data_config.vars_units("state"),
+ self._datastore.get_vars_names("state"),
+ self._datastore.get_vars_units("state"),
var_vranges,
)
)
@@ -451,7 +452,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
- self.data_config.vars_names("state"), var_figs
+ self._datastore.get_vars_names("state"), var_figs
)
}
)
@@ -485,7 +486,9 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
log_dict = {}
metric_fig = vis.plot_error_map(
- metric_tensor, self.data_config, step_length=self.step_length
+ errors=metric_tensor,
+ datastore=self._datastore,
+ step_length=self.step_length,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 1edf71e9..98e066c4 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -5,10 +5,13 @@
# Local
from . import utils
+from .datastore.base import BaseCartesianDatastore
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_error_map(errors, data_config, title=None, step_length=1):
+def plot_error_map(
+ errors, datastore: BaseCartesianDatastore, title=None, step_length=1
+):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
@@ -48,11 +51,10 @@ def plot_error_map(errors, data_config, title=None, step_length=1):
ax.set_xlabel("Lead time (h)", size=label_size)
ax.set_yticks(np.arange(d_f))
+ var_names = datastore.get_vars_names(category="state")
+ var_units = datastore.get_vars_units(category="state")
y_ticklabels = [
- f"{name} ({unit})"
- for name, unit in zip(
- data_config.vars_names("state"), data_config.vars_units("state")
- )
+ f"{name} ({unit})" for name, unit in zip(var_names, var_units)
]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)
@@ -64,7 +66,12 @@ def plot_error_map(errors, data_config, title=None, step_length=1):
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
- pred, target, obs_mask, data_config, title=None, vrange=None
+ pred,
+ target,
+ obs_mask,
+ datastore: BaseCartesianDatastore,
+ title=None,
+ vrange=None,
):
"""Plot example prediction and grond truth.
@@ -77,12 +84,11 @@ def plot_prediction(
else:
vmin, vmax = vrange
- extent = data_config.get_xy_extent("state")
+ extent = datastore.get_xy_extent("state")
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(
- list(data_config.grid_shape_state.values.values())
- )
+ da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
+ mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
@@ -91,14 +97,14 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
- subplot_kw={"projection": data_config.coords_projection},
+ subplot_kw={"projection": datastore.coords_projection},
)
# Plot pred and target
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
data_grid = (
- data.reshape(list(data_config.grid_shape_state.values.values()))
+ data.reshape(list(datastore.grid_shape_state.values.values()))
.cpu()
.numpy()
)
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..0ec7f4b0
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,6 @@
+# Standard library
+import os
+
+# Disable weights and biases to avoid unnecessary logging
+# and to avoid having to deal with authentication
+os.environ["WANDB_DISABLED"] = "true"
diff --git a/tests/test_training.py b/tests/test_training.py
new file mode 100644
index 00000000..3767fbc0
--- /dev/null
+++ b/tests/test_training.py
@@ -0,0 +1,82 @@
+# Standard library
+from pathlib import Path
+
+# Third-party
+import pytest
+import pytorch_lightning as pl
+import torch
+import wandb
+from test_datastores import DATASTORES, init_datastore
+
+# First-party
+from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.models.graph_lam import GraphLAM
+from neural_lam.weather_dataset import WeatherDataModule
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_training(datastore_name):
+ datastore = init_datastore(datastore_name)
+
+ if torch.cuda.is_available():
+ device_name = "cuda"
+ torch.set_float32_matmul_precision(
+ "high"
+ ) # Allows using Tensor Cores on A100s
+ else:
+ device_name = "cpu"
+
+ trainer = pl.Trainer(
+ max_epochs=3,
+ deterministic=True,
+ strategy="ddp",
+ accelerator=device_name,
+ log_every_n_steps=1,
+ )
+
+ graph_name = "1level"
+
+ graph_dir_path = Path(datastore.root_path) / "graph" / graph_name
+
+ if not graph_dir_path.exists():
+ create_graph_from_datastore(
+ datastore=datastore,
+ output_root_path=str(graph_dir_path),
+ n_max_levels=1,
+ )
+
+ data_module = WeatherDataModule(
+ datastore=datastore,
+ ar_steps_train=3,
+ ar_steps_eval=5,
+ standardize=True,
+ batch_size=2,
+ num_workers=1,
+ forcing_window_size=3,
+ )
+
+ class ModelArgs:
+ output_std = False
+ loss = "mse"
+ restore_opt = False
+ n_example_pred = 1
+ # XXX: this should be superfluous when we have already defined the
+ # model object no?
+ graph = graph_name
+ hidden_dim = 8
+ hidden_layers = 1
+ processor_layers = 4
+ mesh_aggr = "sum"
+ lr = 1.0e-3
+ val_steps_to_log = [1]
+ metrics_watch = []
+
+ model_args = ModelArgs()
+
+ model = GraphLAM( # noqa
+ args=model_args,
+ forcing_window_size=data_module.forcing_window_size,
+ datastore=datastore,
+ )
+ wandb.init()
+ trainer.fit(model=model, datamodule=data_module)
From 07444f8fb158f952855b28731b2923c1d391b57f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 10:38:22 +0000
Subject: [PATCH 148/273] adapt neural_lam.train_model for datastores
---
neural_lam/train_model.py | 46 ++++++++++++++++++++++++++++++++++-----
1 file changed, 40 insertions(+), 6 deletions(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index af4f001c..7c489058 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -12,6 +12,9 @@
# Local
from . import utils
+from .datastore.mllam import MLLAMDatastore
+from .datastore.multizarr import MultiZarrDatastore
+from .datastore.npyfiles import NumpyFilesDatastore
from .models.graph_lam import GraphLAM
from .models.hi_lam import HiLAM
from .models.hi_lam_parallel import HiLAMParallel
@@ -24,16 +27,35 @@
}
+def _init_datastore(datastore_kind, data_config):
+ if datastore_kind == "multizarr":
+ datastore = MultiZarrDatastore(data_config)
+ elif datastore_kind == "npyfiles":
+ datastore = NumpyFilesDatastore(data_config)
+ elif datastore_kind == "mllam":
+ datastore = MLLAMDatastore(data_config)
+ else:
+ raise ValueError(f"Unknown datastore kind: {datastore_kind}")
+ return datastore
+
+
def main(input_args=None):
"""Main function for training and evaluating models."""
parser = ArgumentParser(
description="Train or evaluate NeurWP models for LAM"
)
parser.add_argument(
- "--data_config",
+ "--datastore-kind",
+ type=str,
+ choices=["multizarr", "npyfiles", "mllam"],
+ default="multizarr",
+ help="Kind of datastore to use (default: multizarr)",
+ )
+ parser.add_argument(
+ "--datastore-config",
type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
+ default="tests/datastore_configs/multizarr/data_config.yaml",
+ help="Path to data config file",
)
parser.add_argument(
"--model",
@@ -201,6 +223,12 @@ def main(input_args=None):
help="""JSON string with variable-IDs and lead times to log watched
metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""",
)
+ parser.add_argument(
+ "--forcing-window-size",
+ type=int,
+ default=3,
+ help="Number of time steps to use as input for forcing data",
+ )
args = parser.parse_args(input_args)
args.var_leads_metrics_watch = {
int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items()
@@ -219,11 +247,15 @@ def main(input_args=None):
# Set seed
seed.seed_everything(args.seed)
+ # Create datastore
+ datastore = _init_datastore(args.datastore_kind, args.datastore_config)
# Create datamodule
data_module = WeatherDataModule(
+ datastore=datastore,
ar_steps_train=args.ar_steps_train,
ar_steps_eval=args.ar_steps_eval,
standardize=True,
+ forcing_window_size=args.forcing_window_size,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
@@ -238,8 +270,10 @@ def main(input_args=None):
device_name = "cpu"
# Load model parameters Use new args for model
- model_class = MODELS[args.model]
- model = model_class(args)
+ ModelClass = MODELS[args.model]
+ model = ModelClass(
+ args, datastore=datastore, forcing_window_size=args.forcing_window_size
+ )
if args.eval:
prefix = f"eval-{args.eval}-"
@@ -276,7 +310,7 @@ def main(input_args=None):
utils.init_wandb_metrics(
logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
- wandb.save(args.data_config)
+ wandb.save(args.datastore_config)
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
From d1b6fc17bcdebac22bec3e482603fff8f2090c2c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 25 Jul 2024 14:36:22 +0000
Subject: [PATCH 149/273] fixes for npy
---
neural_lam/create_graph.py | 4 +-
neural_lam/datastore/npyfiles/__init__.py | 2 +-
neural_lam/datastore/npyfiles/store.py | 31 +++++++++---
neural_lam/models/ar_model.py | 2 +-
neural_lam/train_model.py | 4 +-
neural_lam/weather_dataset.py | 6 +--
pyproject.toml | 1 +
tests/conftest.py | 51 ++++++++++++++++++++
tests/datastore_configs/.gitignore | 1 +
tests/datastore_configs/npy/.gitignore | 2 -
tests/datastore_configs/npy/data_config.yaml | 40 ---------------
tests/test_datastores.py | 41 +++++++---------
tests/test_npy_forecast_dataset.py | 4 +-
13 files changed, 103 insertions(+), 86 deletions(-)
create mode 100644 tests/datastore_configs/.gitignore
delete mode 100644 tests/datastore_configs/npy/.gitignore
delete mode 100644 tests/datastore_configs/npy/data_config.yaml
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index c13d0f93..c281887d 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -17,7 +17,7 @@
from .datastore.base import BaseCartesianDatastore
from .datastore.mllam import MLLAMDatastore
from .datastore.multizarr import MultiZarrDatastore
-from .datastore.npyfiles import NumpyFilesDatastore
+from .datastore.npyfiles import NpyFilesDatastore
def plot_graph(graph, title=None):
@@ -531,7 +531,7 @@ def create_graph(
DATASTORES = dict(
multizarr=MultiZarrDatastore,
mllam=MLLAMDatastore,
- npyfiles=NumpyFilesDatastore,
+ npyfiles=NpyFilesDatastore,
)
diff --git a/neural_lam/datastore/npyfiles/__init__.py b/neural_lam/datastore/npyfiles/__init__.py
index 573b7070..3bf6fadb 100644
--- a/neural_lam/datastore/npyfiles/__init__.py
+++ b/neural_lam/datastore/npyfiles/__init__.py
@@ -1,2 +1,2 @@
# Local
-from .store import NumpyFilesDatastore # noqa
+from .store import NpyFilesDatastore # noqa
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 8ca4dd49..295ef882 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -24,7 +24,7 @@
COLUMN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"
-class NumpyFilesDatastore(BaseCartesianDatastore):
+class NpyFilesDatastore(BaseCartesianDatastore):
__doc__ = f"""
Represents a dataset stored as numpy files on disk. The dataset is assumed
to be stored in a directory structure where each sample is stored in a
@@ -132,9 +132,13 @@ def __init__(
self._step_length = 3 # 3 hours
self._num_ensemble_members = 2
- self.root_path = Path(root_path)
+ self._root_path = Path(root_path)
self.config = NpyConfig.from_file(self.root_path / "data_config.yaml")
+ @property
+ def root_path(self):
+ return self._root_path
+
def get_dataarray(self, category: str, split: str) -> DataArray:
"""Get the data array for the given category and split of data. If the
category is 'state', the data array will be a concatenation of the data
@@ -334,9 +338,9 @@ def _get_single_timeseries_dataarray(
elif d == "analysis_time":
coord_values = self._get_analysis_times(split=split)
elif d == "y":
- coord_values = np.arange(grid_shape[0])
+ coord_values = np.arange(grid_shape.y)
elif d == "x":
- coord_values = np.arange(grid_shape[1])
+ coord_values = np.arange(grid_shape.x)
elif d == "feature":
coord_values = features
else:
@@ -421,6 +425,11 @@ def _get_analysis_times(self, split):
name_parts = parse.parse(STATE_FILENAME_FORMAT, fp.name)
times.append(name_parts["analysis_time"])
+ if len(times) == 0:
+ raise ValueError(
+ f"No files found in {sample_dir} with pattern {pattern}"
+ )
+
return times
def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
@@ -540,10 +549,10 @@ def grid_shape_state(self):
@property
def boundary_mask(self):
xs, ys = self.get_xy(category="state", stacked=False)
- assert np.all(xs[0, :] == xs[-1, :])
- assert np.all(ys[:, 0] == ys[:, -1])
- x = xs[0, :]
- y = ys[:, 0]
+ assert np.all(xs[:, 0] == xs[:, -1])
+ assert np.all(ys[0, :] == ys[-1, :])
+ x = xs[:, 0]
+ y = ys[0, :]
values = np.load(self.root_path / "static" / "border_mask.npy")
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
@@ -589,6 +598,12 @@ def load_pickled_tensor(fn):
mean_values = np.array([flux_mean, 0.34033957, 0.0, 0.0, 0.0, 0.0])
std_values = np.array([flux_std, 0.4661307, 1.0, 1.0, 1.0, 1.0])
+ elif category == "static":
+ ds_static = self.get_dataarray(category="static", split="train")
+ ds_static_mean = ds_static.mean(dim=["grid_index"])
+ ds_static_std = ds_static.std(dim=["grid_index"])
+ mean_values = ds_static_mean["static_feature"].values
+ std_values = ds_static_std["static_feature"].values
else:
raise NotImplementedError(f"Category {category} not supported")
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index d18c89ab..fec31e5b 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -26,7 +26,7 @@ def __init__(
self, args, datastore: BaseDatastore, forcing_window_size: int
):
super().__init__()
- self.save_hyperparameters()
+ self.save_hyperparameters(ignore=["datastore"])
self.args = args
self._datastore = datastore
# XXX: should be this be somewhere else?
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 7c489058..3ea86716 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -14,7 +14,7 @@
from . import utils
from .datastore.mllam import MLLAMDatastore
from .datastore.multizarr import MultiZarrDatastore
-from .datastore.npyfiles import NumpyFilesDatastore
+from .datastore.npyfiles import NpyFilesDatastore
from .models.graph_lam import GraphLAM
from .models.hi_lam import HiLAM
from .models.hi_lam_parallel import HiLAMParallel
@@ -31,7 +31,7 @@ def _init_datastore(datastore_kind, data_config):
if datastore_kind == "multizarr":
datastore = MultiZarrDatastore(data_config)
elif datastore_kind == "npyfiles":
- datastore = NumpyFilesDatastore(data_config)
+ datastore = NpyFilesDatastore(data_config)
elif datastore_kind == "mllam":
datastore = MLLAMDatastore(data_config)
else:
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 4355b1ea..05607f8f 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -70,10 +70,8 @@ def __len__(self):
f"({self.da_state.ensemble_member.size})",
UserWarning,
)
- return (
- self.da_state.analysis_time.size
- * self.da_state.ensemble_member.size
- )
+ # XXX: we should maybe check that the 2+ar_steps actually fits
+ # in the elapsed_forecast_time dimension, should that be checked here?
return self.da_state.analysis_time.size
else:
# sample_len = 2 + ar_steps <-- 2 initial states + ar_steps target states
diff --git a/pyproject.toml b/pyproject.toml
index 739913ff..2681831d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,6 +36,7 @@ dev = [
"pooch>=1.8.1",
"ipdb>=0.13.13",
"gpustat>=1.1.1",
+ "zarrdump>=0.4.1",
]
[tool.black]
diff --git a/tests/conftest.py b/tests/conftest.py
index 0ec7f4b0..1c1cdd3e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,57 @@
# Standard library
import os
+from pathlib import Path
+
+# Third-party
+import pooch
+
+# First-party
+from neural_lam.datastore.mllam import MLLAMDatastore
+from neural_lam.datastore.multizarr import MultiZarrDatastore
+from neural_lam.datastore.npyfiles import NpyFilesDatastore
# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
os.environ["WANDB_DISABLED"] = "true"
+
+DATASTORES = dict(
+ multizarr=MultiZarrDatastore,
+ mllam=MLLAMDatastore,
+ npyfiles=NpyFilesDatastore,
+)
+
+# Initializing variables for the s3 client
+S3_BUCKET_NAME = "mllam-testdata"
+S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
+S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip"
+S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH])
+TEST_DATA_KNOWN_HASH = (
+ "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7"
+)
+
+
+def download_meps_example_reduced_dataset():
+ # Download and unzip test data into data/meps_example_reduced
+ root_path = Path("tests/datastores_examples/npy")
+ pooch.retrieve(
+ url=S3_FULL_PATH,
+ known_hash=TEST_DATA_KNOWN_HASH,
+ processor=pooch.Unzip(extract_dir=""),
+ path=root_path,
+ fname="meps_example_reduced.zip",
+ )
+ return root_path / "meps_example_reduced"
+
+
+DATASTORES_EXAMPLES = dict(
+ multizarr=dict(
+ config_path="tests/datastore_configs/multizarr/data_config.yaml"
+ ),
+ mllam=dict(config_path="tests/datastore_configs/mllam/example.danra.yaml"),
+ npyfiles=dict(root_path=download_meps_example_reduced_dataset()),
+)
+
+
+def init_datastore(datastore_name):
+ DatastoreClass = DATASTORES[datastore_name]
+ return DatastoreClass(**DATASTORES_EXAMPLES[datastore_name])
diff --git a/tests/datastore_configs/.gitignore b/tests/datastore_configs/.gitignore
new file mode 100644
index 00000000..2d0a57fd
--- /dev/null
+++ b/tests/datastore_configs/.gitignore
@@ -0,0 +1 @@
+npy/
diff --git a/tests/datastore_configs/npy/.gitignore b/tests/datastore_configs/npy/.gitignore
deleted file mode 100644
index 718ecfd8..00000000
--- a/tests/datastore_configs/npy/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-samples/
-static/
diff --git a/tests/datastore_configs/npy/data_config.yaml b/tests/datastore_configs/npy/data_config.yaml
deleted file mode 100644
index 12386bc8..00000000
--- a/tests/datastore_configs/npy/data_config.yaml
+++ /dev/null
@@ -1,40 +0,0 @@
-dataset:
- name: meps_example_reduced
- var_names:
- - pres_0g
- - pres_0s
- - nlwrs_0
- - nswrs_0
- - r_2
- - r_65
- - t_2
- - t_65
- var_units:
- - Pa
- - Pa
- - "W/m**2"
- - "W/m**2"
- - ""
- - ""
- - K
- - K
- var_longnames:
- - pres_heightAboveGround_0_instant
- - pres_heightAboveSea_0_instant
- - nlwrs_heightAboveGround_0_accum
- - nswrs_heightAboveGround_0_accum
- - r_heightAboveGround_2_instant
- - r_hybrid_65_instant
- - t_heightAboveGround_2_instant
- - t_hybrid_65_instant
- # increased num_forcing_features from 16 to 18 so that it reflects
- # ["toa_downwelling_shortwave_flux", "column_water", "sin_hour", "cos_hour", "sin_year", "cos_year"] x forcing_window_size
- # i.e. 6 x 3 = 18 forcing features
- num_forcing_features: 18
-grid_shape_state: [134, 119]
-projection:
- class: LambertConformal
- kwargs:
- central_longitude: 15.0
- central_latitude: 63.3
- standard_parallels: [63.3, 63.3]
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 861f1d54..abb41e92 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -2,7 +2,7 @@
`BaseCartesianDatastore` (these are all decorated with `@abc.abstractmethod`):
- [x] `root_path` (property): Root path of the datastore.
-- [ ] `step_length` (property): Length of the time step in hours.
+- [x] `step_length` (property): Length of the time step in hours.
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
@@ -33,32 +33,10 @@
import numpy as np
import pytest
import xarray as xr
+from conftest import DATASTORES, init_datastore
# First-party
from neural_lam.datastore.base import BaseCartesianDatastore
-from neural_lam.datastore.mllam import MLLAMDatastore
-from neural_lam.datastore.multizarr import MultiZarrDatastore
-from neural_lam.datastore.npyfiles import NumpyFilesDatastore
-
-DATASTORES = dict(
- multizarr=MultiZarrDatastore,
- mllam=MLLAMDatastore,
- npyfiles=NumpyFilesDatastore,
-)
-
-
-EXAMPLES = dict(
- multizarr=dict(
- config_path="tests/datastore_configs/multizarr/data_config.yaml"
- ),
- mllam=dict(config_path="tests/datastore_configs/mllam/example.danra.yaml"),
- npyfiles=dict(root_path="tests/datastore_configs/npy"),
-)
-
-
-def init_datastore(datastore_name):
- DatastoreClass = DATASTORES[datastore_name]
- return DatastoreClass(**EXAMPLES[datastore_name])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
@@ -270,3 +248,18 @@ def test_get_projection(datastore_name):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
assert isinstance(datastore.coords_projection, ccrs.Projection)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def get_grid_shape_state(datastore_name):
+ """Check that the `datastore.grid_shape_state` property is implemented."""
+ datastore = init_datastore(datastore_name)
+
+ if not isinstance(datastore, BaseCartesianDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ grid_shape = datastore.grid_shape_state
+ assert isinstance(grid_shape, tuple)
+ assert len(grid_shape) == 2
+ assert all(isinstance(e, int) for e in grid_shape)
+ assert all(e > 0 for e in grid_shape)
diff --git a/tests/test_npy_forecast_dataset.py b/tests/test_npy_forecast_dataset.py
index ed13e286..571565dd 100644
--- a/tests/test_npy_forecast_dataset.py
+++ b/tests/test_npy_forecast_dataset.py
@@ -7,7 +7,7 @@
# First-party
from neural_lam.create_graph import create_graph as create_graph
-from neural_lam.datastore.npyfiles import NumpyFilesDatastore
+from neural_lam.datastore.npyfiles import NpyFilesDatastore
from neural_lam.train_model import main as train_model
from neural_lam.weather_dataset import WeatherDataset
@@ -40,7 +40,7 @@ def ewc_testdata_path():
def test_load_reduced_meps_dataset(ewc_testdata_path):
- datastore = NumpyFilesDatastore(root_path=ewc_testdata_path)
+ datastore = NpyFilesDatastore(root_path=ewc_testdata_path)
datastore.get_xy(category="state", stacked=True)
datastore.get_dataarray(category="forcing", split="train").unstack(
From 6fe19ac71e1023b036a7c41d400e07a0d40fdd46 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 26 Jul 2024 10:36:14 +0200
Subject: [PATCH 150/273] npyfiles datastore complete
---
neural_lam/create_graph.py | 8 +-
neural_lam/datastore/base.py | 3 +-
neural_lam/datastore/mllam.py | 8 +-
neural_lam/datastore/multizarr/store.py | 10 +-
neural_lam/datastore/npyfiles/config.py | 101 ++++++++-------
neural_lam/datastore/npyfiles/store.py | 64 ++++++----
neural_lam/train_model.py | 23 ++--
neural_lam/weather_dataset.py | 12 +-
pyproject.toml | 1 +
tests/conftest.py | 25 +++-
tests/test_datastores.py | 4 +
tests/test_npy_forecast_dataset.py | 161 ------------------------
tests/test_training.py | 17 +++
13 files changed, 172 insertions(+), 265 deletions(-)
delete mode 100644 tests/test_npy_forecast_dataset.py
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index c281887d..6b062e3d 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -228,6 +228,8 @@ def create_graph(
"""
os.makedirs(graph_dir_path, exist_ok=True)
+ print(f"Writing graph components to {graph_dir_path}")
+
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
@@ -562,7 +564,7 @@ def cli(input_args=None):
help="kind of data store to use (default: multizarr)",
)
parser.add_argument(
- "datastore-path",
+ "datastore_path",
type=str,
help="path to the data store",
)
@@ -594,11 +596,11 @@ def cli(input_args=None):
args = parser.parse_args(input_args)
DatastoreClass = DATASTORES[args.datastore]
- datastore = DatastoreClass(args.datastore_path)
+ datastore = DatastoreClass(root_path=args.datastore_path)
create_graph_from_datastore(
datastore=datastore,
- output_root_path=os.path.join(datastore.root_path, "graphs", args.name),
+ output_root_path=os.path.join(datastore.root_path, "graph", args.name),
n_max_levels=args.levels,
hierarchical=args.hierarchical,
create_plot=args.plot,
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 2a472cbf..c2c2d798 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -266,7 +266,8 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.
"""
xy = self.get_xy(category, stacked=False)
- return [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
+ extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
+ return [float(v) for v in extent]
def unstack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index f91faad9..a83cf31c 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -15,7 +15,7 @@
class MLLAMDatastore(BaseCartesianDatastore):
"""Datastore class for the MLLAM dataset."""
- def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
+ def __init__(self, root_path, n_boundary_points=30, reuse_existing=True):
"""Construct a new MLLAMDatastore from the configuration file at
`config_path`. A boundary mask is created with `n_boundary_points`
boundary points. If `reuse_existing` is True, the dataset is loaded
@@ -33,7 +33,9 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
reuse_existing : bool
Whether to reuse an existing dataset zarr file if it exists.
"""
- self._config_path = Path(config_path)
+ config_filename = "data_config.yaml"
+ self._root_path = Path(root_path)
+ config_path = self._root_path / config_filename
self._config = mdp.Config.from_yaml_file(config_path)
fp_ds = self._config_path.parent / self._config_path.name.replace(
".yaml", ".zarr"
@@ -48,7 +50,7 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
@property
def root_path(self) -> Path:
- return Path(self._config_path.parent)
+ return self._root_path
@property
def step_length(self) -> int:
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 37993be5..1f874d6e 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -1,6 +1,7 @@
# Standard library
import functools
import os
+from pathlib import Path
# Third-party
import cartopy.crs as ccrs
@@ -16,11 +17,16 @@
class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable"}
- def __init__(self, config_path):
- self.config_path = config_path
+ def __init__(self, root_path):
+ self._root_path = Path(root_path)
+ config_path = self._root_path / "data_config.yaml"
with open(config_path, encoding="utf-8", mode="r") as file:
self._config = yaml.safe_load(file)
+ @property
+ def root_path(self):
+ return self._root_path
+
def _normalize_path(self, path):
# try to parse path to see if it defines a protocol, e.g. s3://
if "://" in path or path.startswith("/"):
diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py
index f3fe25ca..545b4b8b 100644
--- a/neural_lam/datastore/npyfiles/config.py
+++ b/neural_lam/datastore/npyfiles/config.py
@@ -1,10 +1,61 @@
# Standard library
-import functools
-from pathlib import Path
+from dataclasses import dataclass
+from typing import Any, Dict, List
# Third-party
-import cartopy.crs as ccrs
-import yaml
+import dataclass_wizard
+
+
+@dataclass
+class Projection:
+ """Represents the projection information for a dataset, including the type
+ of projection and its parameters. Capable of creating a cartopy.crs
+ projection object.
+
+ Attributes:
+ class_name: The class name of the projection, this should be a valid
+ cartopy.crs class.
+ kwargs: A dictionary of keyword arguments specific to the projection type.
+ """
+
+ class_name: str # = field(metadata={'data_key': 'class'})
+ kwargs: Dict[str, Any]
+
+
+@dataclass
+class Dataset:
+ """Contains information about the dataset, including variable names, units,
+ and descriptions.
+
+ Attributes:
+ name: The name of the dataset.
+ var_names: A list of variable names in the dataset.
+ var_units: A list of units for each variable.
+ var_longnames: A list of long, descriptive names for each variable.
+ num_forcing_features: The number of forcing features in the dataset.
+ """
+
+ name: str
+ var_names: List[str]
+ var_units: List[str]
+ var_longnames: List[str]
+ num_forcing_features: int
+
+
+@dataclass
+class NpyDatastoreConfig(dataclass_wizard.YAMLWizard):
+ """Configuration for loading and processing a dataset, including dataset
+ details, grid shape, and projection information.
+
+ Attributes:
+ dataset: An instance of Dataset containing details about the dataset.
+ grid_shape_state: A list representing the shape of the grid state.
+ projection: An instance of Projection containing projection details.
+ """
+
+ dataset: Dataset
+ grid_shape_state: List[int]
+ projection: Projection
class NpyConfig:
@@ -14,48 +65,6 @@ class NpyConfig:
its values as attributes.
"""
- def __init__(self, values):
- self.values = values
-
- @classmethod
- def from_file(cls, filepath):
- """Load a configuration file."""
- if str(filepath).endswith(".yaml"):
- with open(filepath, encoding="utf-8", mode="r") as file:
- return cls(values=yaml.safe_load(file))
- else:
- raise NotImplementedError(Path(filepath).suffix)
-
- def __getattr__(self, name):
- child, *children = name.split(".")
-
- value = self.values[child]
- if len(children) > 0:
- return self.__class__(values=value).get(".".join(children))
- else:
- if isinstance(value, dict):
- return self.__class__(values=value)
- else:
- return value
-
- def __getitem__(self, key):
- value = self.values[key]
- if isinstance(value, dict):
- return self.__class__(values=value)
- return value
-
- def __contains__(self, key):
- return key in self.values
-
def num_data_vars(self):
"""Return the number of data variables for a given key."""
return len(self.dataset.var_names)
-
- @functools.cached_property
- def coords_projection(self):
- """Return the projection."""
- proj_config = self.values["projection"]
- proj_class_name = proj_config["class"]
- proj_class = getattr(ccrs, proj_class_name)
- proj_params = proj_config.get("kwargs", {})
- return proj_class(**proj_params)
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 295ef882..02365a46 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -1,9 +1,11 @@
# Standard library
+import functools
import re
from pathlib import Path
from typing import List
# Third-party
+import cartopy.crs as ccrs
import dask
import dask.array
import dask.delayed
@@ -15,7 +17,7 @@
# Local
from ..base import BaseCartesianDatastore, CartesianGridShape
-from .config import NpyConfig
+from .config import NpyDatastoreConfig
STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy"
TOA_SW_DOWN_FLUX_FILENAME_FORMAT = (
@@ -24,6 +26,13 @@
COLUMN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"
+def _load_np(fp, add_feature_dim):
+ arr = np.load(fp)
+ if add_feature_dim:
+ arr = arr[..., np.newaxis]
+ return arr
+
+
class NpyFilesDatastore(BaseCartesianDatastore):
__doc__ = f"""
Represents a dataset stored as numpy files on disk. The dataset is assumed
@@ -133,7 +142,9 @@ def __init__(
self._num_ensemble_members = 2
self._root_path = Path(root_path)
- self.config = NpyConfig.from_file(self.root_path / "data_config.yaml")
+ self.config = NpyDatastoreConfig.from_yaml_file(
+ self.root_path / "data_config.yaml"
+ )
@property
def root_path(self):
@@ -157,9 +168,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
xr.DataArray
The data array for the given category and split, with dimensions
per category:
- state: `[elapsed_forecast_time, analysis_time, grid_index, feature,
+ state: `[elapsed_forecast_duration, analysis_time, grid_index, feature,
ensemble_member]`
- forcing: `[elapsed_forecast_time, analysis_time, grid_index, feature]`
+ forcing: `[elapsed_forecast_duration, analysis_time, grid_index, feature]`
static: `[grid_index, feature]`
"""
if category == "state":
@@ -188,14 +199,14 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# add datetime forcing as a feature
# to do this we create a forecast time variable which has the
- # dimensions of (analysis_time, elapsed_forecast_time) with values
+ # dimensions of (analysis_time, elapsed_forecast_duration) with values
# that are the actual forecast time of each time step. By calling
- # .chunk({"elapsed_forecast_time": 1}) this time variable is turned
+ # .chunk({"elapsed_forecast_duration": 1}) this time variable is turned
# into a dask array and so execution of the calculation is delayed
# until the feature values are actually used.
da_forecast_time = (
- da.analysis_time + da.elapsed_forecast_time
- ).chunk({"elapsed_forecast_time": 1})
+ da.analysis_time + da.elapsed_forecast_duration
+ ).chunk({"elapsed_forecast_duration": 1})
da_datetime_forcing_features = self._calc_datetime_forcing_features(
da_time=da_forecast_time
)
@@ -262,7 +273,7 @@ def _get_single_timeseries_dataarray(
-------
xr.DataArray
The data array for the given category and split, with dimensions
- `[elapsed_forecast_time, analysis_time, grid_index, feature]` for
+ `[elapsed_forecast_duration, analysis_time, grid_index, feature]` for
all categories of data
"""
assert split in ("train", "val", "test"), "Unknown dataset split"
@@ -284,12 +295,12 @@ def _get_single_timeseries_dataarray(
features_vary_with_analysis_time = True
if features == self.get_vars_names(category="state"):
filename_format = STATE_FILENAME_FORMAT
- file_dims = ["elapsed_forecast_time", "y", "x", "feature"]
+ file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
# only select one member for now
file_params["member_id"] = member
elif features == ["toa_downwelling_shortwave_flux"]:
filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
- file_dims = ["elapsed_forecast_time", "y", "x", "feature"]
+ file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
add_feature_dim = True
elif features == ["column_water"]:
filename_format = COLUMN_WATER_FILENAME_FORMAT
@@ -329,7 +340,7 @@ def _get_single_timeseries_dataarray(
coords = {}
arr_shape = []
for d in dims:
- if d == "elapsed_forecast_time":
+ if d == "elapsed_forecast_duration":
coord_values = (
self.step_length
* np.arange(self._num_timesteps)
@@ -346,16 +357,12 @@ def _get_single_timeseries_dataarray(
else:
raise NotImplementedError(f"Dimension {d} not supported")
- print(f"{d}: {len(coord_values)}")
-
coords[d] = coord_values
if d != "analysis_time":
# analysis_time varies across the different files, but not
# within a single file
arr_shape.append(len(coord_values))
- print(f"{features}: {dims=} {file_dims=} {arr_shape=}")
-
if features_vary_with_analysis_time:
filepaths = [
fp_samples
@@ -369,16 +376,11 @@ def _get_single_timeseries_dataarray(
# use dask.delayed to load the numpy files, so that loading isn't
# done until the data is actually needed
- @dask.delayed
- def _load_np(fp):
- arr = np.load(fp)
- if add_feature_dim:
- arr = arr[..., np.newaxis]
- return arr
-
arrays = [
dask.array.from_delayed(
- _load_np(fp), shape=arr_shape, dtype=np.float32
+ dask.delayed(_load_np)(fp=fp, add_feature_dim=add_feature_dim),
+ shape=arr_shape,
+ dtype=np.float32,
)
for fp in filepaths
]
@@ -457,7 +459,7 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
def get_vars_units(self, category: str) -> torch.List[str]:
if category == "state":
- return self.config["dataset"]["var_units"]
+ return self.config.dataset.var_units
elif category == "forcing":
return [
"W/m^2",
@@ -474,7 +476,7 @@ def get_vars_units(self, category: str) -> torch.List[str]:
def get_vars_names(self, category: str) -> torch.List[str]:
if category == "state":
- return self.config["dataset"]["var_names"]
+ return self.config.dataset.var_names
elif category == "forcing":
# XXX: this really shouldn't be hard-coded here, this should be in
# the config
@@ -557,7 +559,7 @@ def boundary_mask(self):
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
)
- da_mask_stacked_xy = self.stack_grid_coords(da_mask)
+ da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int)
return da_mask_stacked_xy
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
@@ -623,3 +625,11 @@ def load_pickled_tensor(fn):
)
return ds_norm
+
+ @functools.cached_property
+ def coords_projection(self):
+ """Return the projection."""
+ proj_class_name = self.config.projection.class_name
+ ProjectionClass = getattr(ccrs, proj_class_name)
+ proj_params = self.config.projection.kwargs
+ return ProjectionClass(**proj_params)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 3ea86716..39f0cbdf 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -7,7 +7,6 @@
# Third-party
import pytorch_lightning as pl
import torch
-import wandb
from lightning_fabric.utilities import seed
# Local
@@ -27,13 +26,13 @@
}
-def _init_datastore(datastore_kind, data_config):
+def _init_datastore(datastore_kind, path):
if datastore_kind == "multizarr":
- datastore = MultiZarrDatastore(data_config)
+ datastore = MultiZarrDatastore(root_path=path)
elif datastore_kind == "npyfiles":
- datastore = NpyFilesDatastore(data_config)
+ datastore = NpyFilesDatastore(root_path=path)
elif datastore_kind == "mllam":
- datastore = MLLAMDatastore(data_config)
+ datastore = MLLAMDatastore(root_path=path)
else:
raise ValueError(f"Unknown datastore kind: {datastore_kind}")
return datastore
@@ -52,10 +51,10 @@ def main(input_args=None):
help="Kind of datastore to use (default: multizarr)",
)
parser.add_argument(
- "--datastore-config",
+ "--datastore-path",
type=str,
- default="tests/datastore_configs/multizarr/data_config.yaml",
- help="Path to data config file",
+ default="tests/datastore_configs/multizarr",
+ help="The root path for the datastore",
)
parser.add_argument(
"--model",
@@ -248,7 +247,9 @@ def main(input_args=None):
# Set seed
seed.seed_everything(args.seed)
# Create datastore
- datastore = _init_datastore(args.datastore_kind, args.datastore_config)
+ datastore = _init_datastore(
+ datastore_kind=args.datastore_kind, path=args.datastore_path
+ )
# Create datamodule
data_module = WeatherDataModule(
datastore=datastore,
@@ -303,6 +304,7 @@ def main(input_args=None):
callbacks=[checkpoint_callback],
check_val_every_n_epoch=args.val_interval,
precision=args.precision,
+ devices=1,
)
# Only init once, on rank 0 only
@@ -310,7 +312,8 @@ def main(input_args=None):
utils.init_wandb_metrics(
logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
- wandb.save(args.datastore_config)
+ # TODO: should we save the datastore config here?
+ # wandb.save()
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 05607f8f..ceae3663 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -71,7 +71,7 @@ def __len__(self):
UserWarning,
)
# XXX: we should maybe check that the 2+ar_steps actually fits
- # in the elapsed_forecast_time dimension, should that be checked here?
+ # in the elapsed_forecast_duration dimension, should that be checked here?
return self.da_state.analysis_time.size
else:
# sample_len = 2 + ar_steps <-- 2 initial states + ar_steps target states
@@ -93,7 +93,7 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
da : xr.DataArray
The dataarray to sample from. This is expected to have a `time`
dimension if the datastore is providing analysis only data, and a
- `analysis_time` and `elapsed_forecast_time` dimensions if the
+ `analysis_time` and `elapsed_forecast_duration` dimensions if the
datastore is providing forecast data.
idx : int
The index of the time step to start the sample from.
@@ -103,19 +103,19 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
# selecting the time slice
if self.datastore.is_forecast:
# this implies that the data will have both `analysis_time` and
- # `elapsed_forecast_time` dimensions for forecasts we for now
+ # `elapsed_forecast_duration` dimensions for forecasts we for now
# simply select a analysis time and then the next ar_steps forecast
# times
da = da.isel(
analysis_time=idx,
- elapsed_forecast_time=slice(
+ elapsed_forecast_duration=slice(
n_timesteps_offset, n_steps + n_timesteps_offset
),
)
# create a new time dimension so that the produced sample has a
# `time` dimension, similarly to the analysis only data
- da["time"] = da.analysis_time + da.elapsed_forecast_time
- da = da.swap_dims({"elapsed_forecast_time": "time"})
+ da["time"] = da.analysis_time + da.elapsed_forecast_duration
+ da = da.swap_dims({"elapsed_forecast_duration": "time"})
else:
# only `time` dimension for analysis only data
da = da.isel(
diff --git a/pyproject.toml b/pyproject.toml
index 2681831d..075b0146 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,7 @@ dependencies = [
"torch-geometric==2.3.1",
"mllam-data-prep @ git+https://github.com/mllam/mllam-data-prep",
"parse>=1.20.2",
+ "dataclass-wizard>=0.22.3",
]
requires-python = ">=3.9"
diff --git a/tests/conftest.py b/tests/conftest.py
index 1c1cdd3e..9ff25a91 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -4,6 +4,7 @@
# Third-party
import pooch
+import yaml
# First-party
from neural_lam.datastore.mllam import MLLAMDatastore
@@ -32,7 +33,10 @@
def download_meps_example_reduced_dataset():
# Download and unzip test data into data/meps_example_reduced
- root_path = Path("tests/datastores_examples/npy")
+ root_path = Path("tests/datastore_configs/npy")
+ dataset_path = root_path / "meps_example_reduced"
+ will_download = not dataset_path.exists()
+
pooch.retrieve(
url=S3_FULL_PATH,
known_hash=TEST_DATA_KNOWN_HASH,
@@ -40,14 +44,23 @@ def download_meps_example_reduced_dataset():
path=root_path,
fname="meps_example_reduced.zip",
)
- return root_path / "meps_example_reduced"
+
+ if will_download:
+ # XXX: should update the dataset stored on S3 the change below
+ config_path = dataset_path / "data_config.yaml"
+ # rename the `projection.class` key to `projection.class_name` in the config
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ config["projection.class_name"] = config.pop("projection.class")
+ with open(config_path, "w") as f:
+ yaml.dump(config, f)
+
+ return dataset_path
DATASTORES_EXAMPLES = dict(
- multizarr=dict(
- config_path="tests/datastore_configs/multizarr/data_config.yaml"
- ),
- mllam=dict(config_path="tests/datastore_configs/mllam/example.danra.yaml"),
+ multizarr=dict(root_path="tests/datastore_configs/multizarr"),
+ mllam=dict(root_path="tests/datastore_configs/mllam"),
npyfiles=dict(root_path=download_meps_example_reduced_dataset()),
)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index abb41e92..bd378e98 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -150,6 +150,10 @@ def test_get_dataarray(datastore_name):
"elapsed_forecast_duration",
]
+ if datastore.is_ensemble and category == "state":
+ # assume that only state variables change with ensemble members
+ expected_dims.append("ensemble_member")
+
# XXX: for now we only have a single attribute to get the shape of
# the grid which uses the shape from the "state" category, maybe
# this should change?
diff --git a/tests/test_npy_forecast_dataset.py b/tests/test_npy_forecast_dataset.py
deleted file mode 100644
index 571565dd..00000000
--- a/tests/test_npy_forecast_dataset.py
+++ /dev/null
@@ -1,161 +0,0 @@
-# Standard library
-import os
-
-# Third-party
-import pooch
-import pytest
-
-# First-party
-from neural_lam.create_graph import create_graph as create_graph
-from neural_lam.datastore.npyfiles import NpyFilesDatastore
-from neural_lam.train_model import main as train_model
-from neural_lam.weather_dataset import WeatherDataset
-
-# Disable weights and biases to avoid unnecessary logging
-# and to avoid having to deal with authentication
-os.environ["WANDB_DISABLED"] = "true"
-
-# Initializing variables for the s3 client
-S3_BUCKET_NAME = "mllam-testdata"
-S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
-S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip"
-S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH])
-TEST_DATA_KNOWN_HASH = (
- "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7"
-)
-
-
-@pytest.fixture(scope="session")
-def ewc_testdata_path():
- # Download and unzip test data into data/meps_example_reduced
- pooch.retrieve(
- url=S3_FULL_PATH,
- known_hash=TEST_DATA_KNOWN_HASH,
- processor=pooch.Unzip(extract_dir=""),
- path="data",
- fname="meps_example_reduced.zip",
- )
-
- return "data/meps_example_reduced"
-
-
-def test_load_reduced_meps_dataset(ewc_testdata_path):
- datastore = NpyFilesDatastore(root_path=ewc_testdata_path)
- datastore.get_xy(category="state", stacked=True)
-
- datastore.get_dataarray(category="forcing", split="train").unstack(
- "grid_index"
- )
- datastore.get_dataarray(category="state", split="train").unstack(
- "grid_index"
- )
-
- dataset = WeatherDataset(datastore=datastore)
-
- var_names = datastore.config.values["dataset"]["var_names"]
- var_units = datastore.config.values["dataset"]["var_units"]
- var_longnames = datastore.config.values["dataset"]["var_longnames"]
-
- assert len(var_names) == len(var_longnames)
- assert len(var_names) == len(var_units)
-
- # in future the number of grid static features
- # will be provided by the Dataset class itself
- n_grid_static_features = 4
- # Hardcoded in model
- n_input_steps = 2
-
- n_forcing_features = datastore.config.values["dataset"][
- "num_forcing_features"
- ]
- n_state_features = len(var_names)
- n_prediction_timesteps = dataset.ar_steps
-
- nx, ny = datastore.config.values["grid_shape_state"]
- n_grid = nx * ny
-
- # check that the dataset is not empty
- assert len(dataset) > 0
-
- # get the first item
- item = dataset[0]
- init_states = item.init_states
- target_states = item.target_states
- forcing = item.forcing
-
- # check that the shapes of the tensors are correct
- assert init_states.shape == (n_input_steps, n_grid, n_state_features)
- assert target_states.shape == (
- n_prediction_timesteps,
- n_grid,
- n_state_features,
- )
- assert forcing.shape == (
- n_prediction_timesteps,
- n_grid,
- n_forcing_features,
- )
-
- ds_state_norm = datastore.get_normalization_dataarray(category="state")
-
- static_data = {
- "border_mask": datastore.boundary_mask.values,
- "grid_static_features": datastore.get_dataarray(
- category="static", split="train"
- ).values,
- "data_mean": ds_state_norm.state_mean.values,
- "data_std": ds_state_norm.state_std.values,
- "step_diff_mean": ds_state_norm.state_diff_mean.values,
- "step_diff_std": ds_state_norm.state_diff_std.values,
- }
-
- required_props = {
- "border_mask",
- "grid_static_features",
- "step_diff_mean",
- "step_diff_std",
- "data_mean",
- "data_std",
- "param_weights",
- }
-
- # check the sizes of the props
- assert static_data["border_mask"].shape == (n_grid,)
- assert static_data["grid_static_features"].shape == (
- n_grid,
- n_grid_static_features,
- )
- assert static_data["step_diff_mean"].shape == (n_state_features,)
- assert static_data["step_diff_std"].shape == (n_state_features,)
- assert static_data["data_mean"].shape == (n_state_features,)
- assert static_data["data_std"].shape == (n_state_features,)
- assert static_data["param_weights"].shape == (n_state_features,)
-
- assert set(static_data.keys()) == required_props
-
-
-def test_create_graph_reduced_meps_dataset():
- args = [
- "--graph=hierarchical",
- "--hierarchical=1",
- "--data_config=data/meps_example_reduced/data_config.yaml",
- "--levels=2",
- ]
- create_graph(args)
-
-
-def test_train_model_reduced_meps_dataset():
- args = [
- "--model=hi_lam",
- "--data_config=data/meps_example_reduced/data_config.yaml",
- "--n_workers=4",
- "--epochs=1",
- "--graph=hierarchical",
- "--hidden_dim=16",
- "--hidden_layers=1",
- "--processor_layers=1",
- "--ar_steps=1",
- "--eval=val",
- "--n_example_pred=0",
- ]
- train_model(args)
diff --git a/tests/test_training.py b/tests/test_training.py
index 3767fbc0..5e7f4095 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -80,3 +80,20 @@ class ModelArgs:
)
wandb.init()
trainer.fit(model=model, datamodule=data_module)
+
+
+# def test_train_model_reduced_meps_dataset():
+# args = [
+# "--model=hi_lam",
+# "--data_config=data/meps_example_reduced/data_config.yaml",
+# "--n_workers=4",
+# "--epochs=1",
+# "--graph=hierarchical",
+# "--hidden_dim=16",
+# "--hidden_layers=1",
+# "--processor_layers=1",
+# "--ar_steps=1",
+# "--eval=val",
+# "--n_example_pred=0",
+# ]
+# train_model(args)
From fe65a4d95ae6dcabd21f08b780537c471d061435 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 26 Jul 2024 10:43:45 +0200
Subject: [PATCH 151/273] cleanup for datastore examples
---
neural_lam/datastore/npyfiles/config.py | 14 +-------------
neural_lam/train_model.py | 8 +++-----
.../.gitignore | 0
.../mllam/.gitignore | 0
.../mllam/example.danra.yaml | 0
.../multizarr/data_config.yaml | 0
6 files changed, 4 insertions(+), 18 deletions(-)
rename tests/{datastore_configs => datastores_examples}/.gitignore (100%)
rename tests/{datastore_configs => datastores_examples}/mllam/.gitignore (100%)
rename tests/{datastore_configs => datastores_examples}/mllam/example.danra.yaml (100%)
rename tests/{datastore_configs => datastores_examples}/multizarr/data_config.yaml (100%)
diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py
index 545b4b8b..afb08c77 100644
--- a/neural_lam/datastore/npyfiles/config.py
+++ b/neural_lam/datastore/npyfiles/config.py
@@ -18,7 +18,7 @@ class Projection:
kwargs: A dictionary of keyword arguments specific to the projection type.
"""
- class_name: str # = field(metadata={'data_key': 'class'})
+ class_name: str
kwargs: Dict[str, Any]
@@ -56,15 +56,3 @@ class NpyDatastoreConfig(dataclass_wizard.YAMLWizard):
dataset: Dataset
grid_shape_state: List[int]
projection: Projection
-
-
-class NpyConfig:
- """Class for loading configuration files.
-
- This class loads a configuration file and provides a way to access
- its values as attributes.
- """
-
- def num_data_vars(self):
- """Return the number of data variables for a given key."""
- return len(self.dataset.var_names)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 39f0cbdf..ffd9bb67 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -44,16 +44,14 @@ def main(input_args=None):
description="Train or evaluate NeurWP models for LAM"
)
parser.add_argument(
- "--datastore-kind",
+ "datastore-kind",
type=str,
choices=["multizarr", "npyfiles", "mllam"],
- default="multizarr",
- help="Kind of datastore to use (default: multizarr)",
+ help="Kind of datastore to use",
)
parser.add_argument(
- "--datastore-path",
+ "datastore-path",
type=str,
- default="tests/datastore_configs/multizarr",
help="The root path for the datastore",
)
parser.add_argument(
diff --git a/tests/datastore_configs/.gitignore b/tests/datastores_examples/.gitignore
similarity index 100%
rename from tests/datastore_configs/.gitignore
rename to tests/datastores_examples/.gitignore
diff --git a/tests/datastore_configs/mllam/.gitignore b/tests/datastores_examples/mllam/.gitignore
similarity index 100%
rename from tests/datastore_configs/mllam/.gitignore
rename to tests/datastores_examples/mllam/.gitignore
diff --git a/tests/datastore_configs/mllam/example.danra.yaml b/tests/datastores_examples/mllam/example.danra.yaml
similarity index 100%
rename from tests/datastore_configs/mllam/example.danra.yaml
rename to tests/datastores_examples/mllam/example.danra.yaml
diff --git a/tests/datastore_configs/multizarr/data_config.yaml b/tests/datastores_examples/multizarr/data_config.yaml
similarity index 100%
rename from tests/datastore_configs/multizarr/data_config.yaml
rename to tests/datastores_examples/multizarr/data_config.yaml
From e533794d88142b5b5eeff9e1adc1a65294c6879c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 26 Jul 2024 10:02:52 +0000
Subject: [PATCH 152/273] training on ohm with danra!
---
neural_lam/datastore/mllam.py | 4 +---
neural_lam/models/ar_model.py | 1 +
neural_lam/train_model.py | 5 ++---
neural_lam/weather_dataset.py | 10 ++++++++++
.../mllam/{example.danra.yaml => data_config.yaml} | 0
5 files changed, 14 insertions(+), 6 deletions(-)
rename tests/datastores_examples/mllam/{example.danra.yaml => data_config.yaml} (100%)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index a83cf31c..7b060faf 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -37,9 +37,7 @@ def __init__(self, root_path, n_boundary_points=30, reuse_existing=True):
self._root_path = Path(root_path)
config_path = self._root_path / config_filename
self._config = mdp.Config.from_yaml_file(config_path)
- fp_ds = self._config_path.parent / self._config_path.name.replace(
- ".yaml", ".zarr"
- )
+ fp_ds = self._root_path / config_path.name.replace(".yaml", ".zarr")
if reuse_existing and fp_ds.exists():
self._ds = xr.open_zarr(fp_ds, consolidated=True)
else:
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index fec31e5b..cea723b0 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -275,6 +275,7 @@ def validation_step(self, batch, batch_idx):
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_to_log
+ if step < len(time_step_loss)
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index ffd9bb67..cf576008 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -44,13 +44,13 @@ def main(input_args=None):
description="Train or evaluate NeurWP models for LAM"
)
parser.add_argument(
- "datastore-kind",
+ "datastore_kind",
type=str,
choices=["multizarr", "npyfiles", "mllam"],
help="Kind of datastore to use",
)
parser.add_argument(
- "datastore-path",
+ "datastore_path",
type=str,
help="The root path for the datastore",
)
@@ -302,7 +302,6 @@ def main(input_args=None):
callbacks=[checkpoint_callback],
check_val_every_n_epoch=args.val_interval,
precision=args.precision,
- devices=1,
)
# Only init once, on rank 0 only
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index ceae3663..ed4856d3 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -39,6 +39,16 @@ def __init__(
)
self.forcing_window_size = forcing_window_size
+ # check that with the provided data-arrays and ar_steps that we have a
+ # non-zero amount of samples
+ if self.__len__() <= 0:
+ raise ValueError(
+ f"The provided datastore only provides {len(self.da_state.time)} "
+ f"time steps for `{split}` split, which is less than the "
+ f"required 2+ar_steps (2+{self.ar_steps}={2+self.ar_steps}) "
+ "for creating a sample with initial and target states."
+ )
+
# Set up for standardization
# TODO: This will become part of ar_model.py soon!
self.standardize = standardize
diff --git a/tests/datastores_examples/mllam/example.danra.yaml b/tests/datastores_examples/mllam/data_config.yaml
similarity index 100%
rename from tests/datastores_examples/mllam/example.danra.yaml
rename to tests/datastores_examples/mllam/data_config.yaml
From 640ac05bfa03e805d820e6afea114687502c6a6e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 5 Aug 2024 13:10:47 +0000
Subject: [PATCH 153/273] use mllam-data-prep v0.2.0
---
pyproject.toml | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 075b0146..d0d8c67f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,9 +24,9 @@ dependencies = [
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
- "mllam-data-prep @ git+https://github.com/mllam/mllam-data-prep",
"parse>=1.20.2",
"dataclass-wizard>=0.22.3",
+ "mllam-data-prep>=0.2.0",
]
requires-python = ">=3.9"
@@ -38,6 +38,8 @@ dev = [
"ipdb>=0.13.13",
"gpustat>=1.1.1",
"zarrdump>=0.4.1",
+ "dask[distributed]>=2024.7.1",
+ "bokeh!=3.0.*,>=2.4.2",
]
[tool.black]
From 0f16f133bd27cbf5ee723f94a77940a380840fcc Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 5 Aug 2024 13:16:57 +0000
Subject: [PATCH 154/273] remove py3.12 from pre-commit
---
.github/workflows/pre-commit.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index ad2b1a9c..71e28ad7 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python
From 724548e55e446268d2a28d496f50a06d7f767b2e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 8 Aug 2024 09:36:00 +0000
Subject: [PATCH 155/273] cleanup
---
neural_lam/create_graph.py | 6 +-
neural_lam/datastore/base.py | 12 +-
neural_lam/datastore/mllam.py | 131 +++++++++++++--
neural_lam/datastore/multizarr/__init__.py | 5 +
...orcings.py => create_datetime_forcings.py} | 69 +++++---
.../multizarr/create_normalization_stats.py | 87 ++++++----
neural_lam/datastore/multizarr/store.py | 157 ++++++++++--------
neural_lam/datastore/npyfiles/store.py | 70 ++++++--
neural_lam/train_model.py | 15 +-
neural_lam/weather_dataset.py | 77 ++++++---
tests/conftest.py | 74 ++++++---
.../.gitignore | 0
.../mllam/.gitignore | 0
.../mllam/danra.example.yaml} | 0
.../multizarr/data_config.yaml | 0
tests/test_datasets.py | 3 +-
tests/test_training.py | 2 +-
17 files changed, 502 insertions(+), 206 deletions(-)
rename neural_lam/datastore/multizarr/{create_auxiliary_forcings.py => create_datetime_forcings.py} (82%)
rename tests/{datastores_examples => datastore_examples}/.gitignore (100%)
rename tests/{datastores_examples => datastore_examples}/mllam/.gitignore (100%)
rename tests/{datastores_examples/mllam/data_config.yaml => datastore_examples/mllam/danra.example.yaml} (100%)
rename tests/{datastores_examples => datastore_examples}/multizarr/data_config.yaml (100%)
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index 6b062e3d..e5eb44a4 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -564,9 +564,9 @@ def cli(input_args=None):
help="kind of data store to use (default: multizarr)",
)
parser.add_argument(
- "datastore_path",
+ "datastore_config_path",
type=str,
- help="path to the data store",
+ help="path to the data store config",
)
parser.add_argument(
"--name",
@@ -596,7 +596,7 @@ def cli(input_args=None):
args = parser.parse_args(input_args)
DatastoreClass = DATASTORES[args.datastore]
- datastore = DatastoreClass(root_path=args.datastore_path)
+ datastore = DatastoreClass(config_path=args.datastore_config_path)
create_graph_from_datastore(
datastore=datastore,
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index c2c2d798..73658126 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -128,10 +128,14 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
pass
@abc.abstractmethod
- def get_dataarray(self, category: str, split: str) -> xr.DataArray:
+ def get_dataarray(
+ self, category: str, split: str
+ ) -> Union[xr.DataArray, None]:
"""Return the processed data (as a single `xr.DataArray`) for the given
category of data and test/train/val-split that covers all the data (in
- space and time) of a given category.
+ space and time) of a given category (state/forcing/static). A datastore
+ must be able to return for the "state" category, but "forcing" and
+ "static" are optional (in which case the method should return `None`).
The returned dataarray is expected to at minimum have dimensions of
`(grid_index, {category}_feature)` so that any spatial dimensions have
@@ -156,14 +160,14 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
Returns
-------
- xr.DataArray
+ xr.DataArray or None
The xarray DataArray object with processed dataset.
"""
pass
@property
@abc.abstractmethod
- def boundary_mask(self):
+ def boundary_mask(self) -> xr.DataArray:
"""Return the boundary mask for the dataset, with spatial dimensions
stacked. Where the value is 1, the grid point is a boundary point, and
where the value is 0, the grid point is not a boundary point.
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 7b060faf..ae2c5d53 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -1,4 +1,6 @@
# Standard library
+import shutil
+import warnings
from pathlib import Path
from typing import List
@@ -6,6 +8,7 @@
import cartopy.crs as ccrs
import mllam_data_prep as mdp
import xarray as xr
+from loguru import logger
from numpy import ndarray
# Local
@@ -15,11 +18,12 @@
class MLLAMDatastore(BaseCartesianDatastore):
"""Datastore class for the MLLAM dataset."""
- def __init__(self, root_path, n_boundary_points=30, reuse_existing=True):
+ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
"""Construct a new MLLAMDatastore from the configuration file at
`config_path`. A boundary mask is created with `n_boundary_points`
boundary points. If `reuse_existing` is True, the dataset is loaded
- from a zarr file if it exists, otherwise it is created from the
+ from a zarr file if it exists (unless the config has been modified
+ since the zarr was created), otherwise it is created from the
configuration file.
Parameters
@@ -31,16 +35,29 @@ def __init__(self, root_path, n_boundary_points=30, reuse_existing=True):
n_boundary_points : int
The number of boundary points to use in the boundary mask.
reuse_existing : bool
- Whether to reuse an existing dataset zarr file if it exists.
+ Whether to reuse an existing dataset zarr file if it exists and its
+ creation date is newer than the configuration file.
"""
- config_filename = "data_config.yaml"
- self._root_path = Path(root_path)
- config_path = self._root_path / config_filename
- self._config = mdp.Config.from_yaml_file(config_path)
- fp_ds = self._root_path / config_path.name.replace(".yaml", ".zarr")
+ self._config_path = Path(config_path)
+ self._root_path = self._config_path.parent
+ self._config = mdp.Config.from_yaml_file(self._config_path)
+ fp_ds = self._root_path / self._config_path.name.replace(
+ ".yaml", ".zarr"
+ )
+
+ self._ds = None
if reuse_existing and fp_ds.exists():
- self._ds = xr.open_zarr(fp_ds, consolidated=True)
- else:
+ # check that the zarr directory is newer than the config file
+ if fp_ds.stat().st_mtime > self._config_path.stat().st_mtime:
+ self._ds = xr.open_zarr(fp_ds, consolidated=True)
+ else:
+ logger.warning(
+ "config file has been modified since zarr was created. "
+ "recreating dataset."
+ )
+ shutil.rmtree(fp_ds)
+
+ if self._ds is None:
self._ds = mdp.create_dataset(config=self._config)
if reuse_existing:
self._ds.to_zarr(fp_ds)
@@ -48,23 +65,115 @@ def __init__(self, root_path, n_boundary_points=30, reuse_existing=True):
@property
def root_path(self) -> Path:
+ """The root path of the dataset.
+
+ Returns
+ -------
+ Path
+ The root path of the dataset.
+ """
return self._root_path
@property
def step_length(self) -> int:
+ """The length of the time steps in hours.
+
+ Returns
+ -------
+ int
+ The length of the time steps in hours.
+ """
da_dt = self._ds["time"].diff("time")
return (da_dt.dt.seconds[0] // 3600).item()
def get_vars_units(self, category: str) -> List[str]:
+ """Return the units of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The units of the variables in the given category.
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return []
return self._ds[f"{category}_feature_units"].values.tolist()
def get_vars_names(self, category: str) -> List[str]:
+ """Return the names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The names of the variables in the given category.
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return []
return self._ds[f"{category}_feature"].values.tolist()
def get_num_data_vars(self, category: str) -> int:
- return self._ds[f"{category}_feature"].count().item()
+ """Return the number of variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ int
+ The number of variables in the given category.
+ """
+ return len(self.get_vars_names(category))
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
+ """Return the processed data (as a single `xr.DataArray`) for the given
+ category of data and test/train/val-split that covers all the data (in
+ space and time) of a given category (state/forcing/static). "state" is
+ the only required category, for other categories, the method will
+ return `None` if the category is not found in the datastore.
+
+ The returned dataarray will at minimum have dimensions of `(grid_index,
+ {category}_feature)` so that any spatial dimensions have been stacked
+ into a single dimension and all variables and levels have been stacked
+ into a single feature dimension named by the `category` of data being
+ loaded.
+
+ For categories of data that have a time dimension (i.e. not static
+ data), the dataarray will additionally have `(analysis_time,
+ elapsed_forecast_duration)` dimensions if `is_forecast` is True, or
+ `(time)` if `is_forecast` is False.
+
+ If the data is ensemble data, the dataarray will have an additional
+ `ensemble_member` dimension.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ split : str
+ The time split to filter the dataset (train/val/test).
+
+ Returns
+ -------
+ xr.DataArray or None
+ The xarray DataArray object with processed dataset.
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return None
+
da_category = self._ds[category]
if "time" not in da_category.dims:
diff --git a/neural_lam/datastore/multizarr/__init__.py b/neural_lam/datastore/multizarr/__init__.py
index c1958905..c59f31f4 100644
--- a/neural_lam/datastore/multizarr/__init__.py
+++ b/neural_lam/datastore/multizarr/__init__.py
@@ -1,2 +1,7 @@
# Local
+from . import ( # noqa
+ create_boundary_mask,
+ create_datetime_forcings,
+ create_normalization_stats,
+)
from .store import MultiZarrDatastore # noqa
diff --git a/neural_lam/datastore/multizarr/create_auxiliary_forcings.py b/neural_lam/datastore/multizarr/create_datetime_forcings.py
similarity index 82%
rename from neural_lam/datastore/multizarr/create_auxiliary_forcings.py
rename to neural_lam/datastore/multizarr/create_datetime_forcings.py
index c4839be3..3907ca08 100644
--- a/neural_lam/datastore/multizarr/create_auxiliary_forcings.py
+++ b/neural_lam/datastore/multizarr/create_datetime_forcings.py
@@ -10,6 +10,8 @@
# First-party
from neural_lam.datastore.multizarr import MultiZarrDatastore
+DEFAULT_FILENAME = "datetime_forcings.zarr"
+
def get_seconds_in_year(year):
start_of_year = pd.Timestamp(f"{year}-01-01")
@@ -68,39 +70,32 @@ def calculate_datetime_forcing(da_time: xr.DataArray):
return datetime_forcing
-def main():
- """Main function for creating the datetime forcing and boundary mask."""
- parser = argparse.ArgumentParser(
- description="Create the datetime forcing for neural LAM.",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "data_config",
- type=str,
- help="Path to data config file",
- )
- parser.add_argument(
- "--zarr_path",
- type=str,
- default=None,
- help="Path to save the Zarr archive "
- "(default: same directory as the data-config)",
- )
- args = parser.parse_args()
+def create_datetime_forcing_zarr(
+ data_config_path: str,
+ zarr_path: str = None,
+ chunking: dict = {"time": 1},
+):
+ """Create the datetime forcing and save it to a Zarr archive.
- zarr_path = args.zarr_path
+ Parameters
+ ----------
+ zarr_path : str
+ The path to save the Zarr archive.
+ da_time : xr.DataArray
+ The time DataArray for which to create the datetime forcing.
+ chunking : dict, optional
+ The chunking to use when saving the Zarr archive.
+ """
if zarr_path is None:
- zarr_path = Path(args.data_config).parent / "datetime_forcings.zarr"
+ zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME
- datastore = MultiZarrDatastore(config_path=args.data_config)
+ datastore = MultiZarrDatastore(config_path=data_config_path)
da_state = datastore.get_dataarray(category="state", split="train")
da_datetime_forcing = calculate_datetime_forcing(
da_time=da_state.time
).expand_dims({"grid_index": da_state.grid_index})
- chunking = {"time": 1}
-
if "x" in da_state.coords and "y" in da_state.coords:
# copy the x and y coordinates to the datetime forcing
for aux_coord in ["x", "y"]:
@@ -121,5 +116,31 @@ def main():
print(f"Datetime forcing saved to {zarr_path}")
+def main():
+ """Main function for creating the datetime forcing and boundary mask."""
+ parser = argparse.ArgumentParser(
+ description="Create the datetime forcing for neural LAM.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "data_config",
+ type=str,
+ help="Path to data config file",
+ )
+ parser.add_argument(
+ "--zarr_path",
+ type=str,
+ default=None,
+ help="Path to save the Zarr archive "
+ "(default: same directory as the data-config)",
+ )
+ args = parser.parse_args()
+
+ create_datetime_forcing_zarr(
+ data_config_path=args.data_config,
+ zarr_path=args.zarr_path,
+ )
+
+
if __name__ == "__main__":
main()
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index 2298e191..46dcc7d7 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -1,5 +1,6 @@
# Standard library
import argparse
+from pathlib import Path
# Third-party
import xarray as xr
@@ -7,6 +8,8 @@
# First-party
from neural_lam.datastore.multizarr import MultiZarrDatastore
+DEFAULT_FILENAME = "normalization.zarr"
+
def compute_stats(da):
mean = da.mean(dim=("time", "grid_index"))
@@ -14,25 +17,26 @@ def compute_stats(da):
return mean, std
-def main():
- parser = argparse.ArgumentParser(
- description="Training arguments",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "data_config",
- type=str,
- help="Path to data config file",
- )
- parser.add_argument(
- "--zarr_path",
- type=str,
- default="normalization.zarr",
- help="Directory where data is stored",
- )
- args = parser.parse_args()
+def create_normalization_stats_zarr(
+ data_config_path: str,
+ zarr_path: str = None,
+):
+ """Compute mean and std.-dev. for state and forcing variables and save them
+ to a Zarr file.
- datastore = MultiZarrDatastore(config_path=args.data_config)
+ Parameters
+ ----------
+ data_config_path : str
+ Path to data config file.
+ zarr_path : str, optional
+ Path to save the normalization statistics to. If not provided, the
+ statistics are saved to the same directory as the data config file with
+ the name `normalization.zarr`.
+ """
+ if zarr_path is None:
+ zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME
+
+ datastore = MultiZarrDatastore(config_path=data_config_path)
da_state = datastore.get_dataarray(category="state", split="train")
da_forcing = datastore.get_dataarray(category="forcing", split="train")
@@ -50,28 +54,28 @@ def main():
for group in combined_stats:
vars_to_combine = group["vars"]
- means = da_forcing_mean.sel(variable=vars_to_combine)
- stds = da_forcing_std.sel(variable=vars_to_combine)
+ means = da_forcing_mean.sel(variable_name=vars_to_combine)
+ stds = da_forcing_std.sel(variable_name=vars_to_combine)
- combined_mean = means.mean(dim="variable")
- combined_std = (stds**2).mean(dim="variable") ** 0.5
+ combined_mean = means.mean(dim="variable_name")
+ combined_std = (stds**2).mean(dim="variable_name") ** 0.5
da_forcing_mean.loc[
- dict(variable=vars_to_combine)
+ dict(variable_name=vars_to_combine)
] = combined_mean
da_forcing_std.loc[
- dict(variable=vars_to_combine)
+ dict(variable_name=vars_to_combine)
] = combined_std
window = datastore._config["forcing"]["window"]
da_forcing_mean = xr.concat(
[da_forcing_mean] * window, dim="window"
- ).stack(forcing_variable=("variable", "window"))
+ ).stack(forcing_variable=("variable_name", "window"))
da_forcing_std = xr.concat(
[da_forcing_std] * window, dim="window"
- ).stack(forcing_variable=("variable", "window"))
- vars = da_forcing["variable"].values.tolist()
+ ).stack(forcing_variable=("variable_name", "window"))
+ vars = da_forcing["variable_name"].values.tolist()
window = datastore._config["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
@@ -99,14 +103,37 @@ def main():
}
)
.reset_index(["forcing_variable"])
- .drop_vars(["variable", "window"])
+ .drop_vars(["variable_name", "window"])
.assign_coords(forcing_variable=forcing_vars)
)
ds = xr.merge([ds, dsf])
- ds = ds.chunk({"variable": -1, "forcing_variable": -1})
+ ds = ds.chunk({"variable_name": -1, "forcing_variable": -1})
print("Saving dataset as Zarr...")
- ds.to_zarr(args.zarr_path, mode="w")
+ ds.to_zarr(zarr_path, mode="w")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Training arguments",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "data_config",
+ type=str,
+ help="Path to data config file",
+ )
+ parser.add_argument(
+ "--zarr_path",
+ type=str,
+ default="normalization.zarr",
+ help="Directory where data is stored",
+ )
+ args = parser.parse_args()
+
+ create_normalization_stats_zarr(
+ data_config_path=args.data_config, zarr_path=args.zarr_path
+ )
if __name__ == "__main__":
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 1f874d6e..d3b339ce 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -15,35 +15,71 @@
class MultiZarrDatastore(BaseCartesianDatastore):
- DIMS_TO_KEEP = {"time", "grid_index", "variable"}
+ DIMS_TO_KEEP = {"time", "grid_index", "variable_name"}
- def __init__(self, root_path):
- self._root_path = Path(root_path)
- config_path = self._root_path / "data_config.yaml"
+ def __init__(self, config_path):
+ """Create a multi-zarr datastore from the given configuration file. The
+ configuration file should be a YAML file, the format of which is should
+ be inferred from the example configuration file in
+ `tests/datastore_examples/multizarr/data_config.yml`.
+
+ Parameters
+ ----------
+ config_path : str
+ The path to the configuration file.
+ """
+ self._config_path = Path(config_path)
+ self._root_path = self._config_path.parent
with open(config_path, encoding="utf-8", mode="r") as file:
self._config = yaml.safe_load(file)
@property
def root_path(self):
+ """Return the root path of the datastore.
+
+ Returns
+ -------
+ str
+ The root path of the datastore.
+ """
return self._root_path
- def _normalize_path(self, path):
+ def _normalize_path(self, path) -> str:
+ """
+ Normalize the path of source-dataset defined in the configuration file.
+ This assumes that any paths that do not start with a protocol (e.g. `s3://`)
+ or are not absolute paths, are relative to the configuration file.
+
+ Parameters
+ ----------
+ path : str
+ The path to normalize.
+
+ Returns
+ -------
+ str
+ The normalized path.
+ """
# try to parse path to see if it defines a protocol, e.g. s3://
if "://" in path or path.startswith("/"):
pass
else:
# assume path is relative to config file
- path = os.path.join(os.path.dirname(self.config_path), path)
+ path = os.path.join(self._root_path, path)
return path
def open_zarrs(self, category):
"""Open the zarr dataset for the given category.
- Args:
- category (str): The category of the dataset (state/forcing/static).
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
- Returns:
- xr.Dataset: The xarray Dataset object.
+ Returns
+ -------
+ xr.Dataset
+ The xarray Dataset object.
"""
zarr_configs = self._config[category]["zarrs"]
@@ -178,7 +214,7 @@ def _convert_dataset_to_dataarray(self, dataset):
xr.DataArray: The xarray DataArray object.
"""
if isinstance(dataset, xr.Dataset):
- dataset = dataset.to_array()
+ dataset = dataset.to_array(dim="variable_name")
return dataset
def _filter_dimensions(self, dataset, transpose_array=True):
@@ -193,7 +229,7 @@ def _filter_dimensions(self, dataset, transpose_array=True):
OR xr.DataArray: The xarray DataArray object with filtered dimensions.
"""
dims_to_keep = self.DIMS_TO_KEEP
- dataset_dims = set(list(dataset.dims) + ["variable"])
+ dataset_dims = set(list(dataset.dims) + ["variable_name"])
min_req_dims = dims_to_keep.copy()
min_req_dims.discard("time")
if not min_req_dims.issubset(dataset_dims):
@@ -210,7 +246,7 @@ def _filter_dimensions(self, dataset, transpose_array=True):
dataset.attrs["category"], dataset=dataset
)
dataset = self._stack_grid(dataset)
- dataset_dims = set(list(dataset.dims) + ["variable"])
+ dataset_dims = set(list(dataset.dims) + ["variable_name"])
if min_req_dims.issubset(dataset_dims):
print(
"\033[92mSuccessfully updated dims and "
@@ -223,7 +259,7 @@ def _filter_dimensions(self, dataset, transpose_array=True):
)
return None
- dataset_dims = set(list(dataset.dims) + ["variable"])
+ dataset_dims = set(list(dataset.dims) + ["variable_name"])
dims_to_drop = dataset_dims - dims_to_keep
dataset = dataset.drop_dims(dims_to_drop)
if dims_to_drop:
@@ -241,13 +277,15 @@ def _filter_dimensions(self, dataset, transpose_array=True):
dataset = self._convert_dataset_to_dataarray(dataset)
if "time" in dataset.dims:
- dataset = dataset.transpose("time", "grid_index", "variable")
+ dataset = dataset.transpose(
+ "time", "grid_index", "variable_name"
+ )
else:
- dataset = dataset.transpose("grid_index", "variable")
+ dataset = dataset.transpose("grid_index", "variable_name")
dataset_vars = (
list(dataset.data_vars)
if isinstance(dataset, xr.Dataset)
- else dataset["variable"].values.tolist()
+ else dataset["variable_name"].values.tolist()
)
print( # noqa
@@ -345,26 +383,42 @@ def get_xy_extent(self, category):
return extent
@functools.lru_cache()
- def get_normalization_dataarray(self, category):
- """Load the normalization statistics for the dataset.
+ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
+ """Return the normalization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one-step differences of the state variables. The return
+ dataarray should at least have dimensions of `({category}_feature)`,
+ but can also include for example `grid_index` (if the normalisation is
+ done per grid point for example).
- Args:
- category (str): The category of the dataset (state/forcing/static).
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
- Returns:
- OR xr.Dataset: The normalization statistics for the dataset.
+ Returns
+ -------
+ xr.Dataset
+ The normalization dataarray for the given category, with variables
+ for the mean and standard deviation of the variables (and
+ differences for state variables).
"""
- combined_stats = self._load_and_merge_stats()
- if combined_stats is None:
+ ds_combined_stats = self._load_and_merge_stats()
+ if ds_combined_stats is None:
return None
- combined_stats = self._rename_data_vars(combined_stats)
+ ds_combined_stats = self._rename_data_vars(ds_combined_stats)
- stats = self._select_stats_by_category(combined_stats, category)
- if stats is None:
- return None
+ ops = ["mean", "std"]
+ stats_variables = [f"{category}_{op}" for op in ops]
+ if category == "state":
+ stats_variables += [f"state_diff_{op}" for op in ops]
- return stats
+ ds_stats = ds_combined_stats[stats_variables]
+
+ return ds_stats
def _load_and_merge_stats(self):
"""Load and merge the normalization statistics for the dataset.
@@ -422,7 +476,7 @@ def _select_stats_by_category(self, combined_stats, category):
"""
if category == "state":
stats = combined_stats.loc[
- dict(variable=self.get_vars_names(category=category))
+ dict(variable_name=self.get_vars_names(category=category))
]
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
return stats
@@ -432,9 +486,7 @@ def _select_stats_by_category(self, combined_stats, category):
)
if non_normalized_vars is None:
non_normalized_vars = []
- vars = self.vars_names(category)
- window = self["forcing"]["window"]
- forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
+ forcing_vars = self.vars_names(category)
normalized_vars = [
var for var in forcing_vars if var not in non_normalized_vars
]
@@ -541,7 +593,7 @@ def _rename_dataset_dims_and_vars(self, category, dataset=None):
dataset = self.open_zarrs(category)
elif isinstance(dataset, xr.DataArray):
convert = True
- dataset = dataset.to_dataset("variable")
+ dataset = dataset.to_dataset("variable_name")
dims_mapping = {}
zarr_configs = self._config[category]["zarrs"]
for zarr_config in zarr_configs:
@@ -579,34 +631,6 @@ def _apply_time_split(self, dataset, split="train"):
dataset.attrs["split"] = split
return dataset
- def apply_window(self, category, dataset=None):
- """Apply the forcing window to the forcing dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
- dataset (xr.Dataset): The xarray Dataset object.
-
- Returns:
- xr.Dataset: The xarray Dataset object with the window applied.
- """
- if dataset is None:
- dataset = self.open_zarrs(category)
- if isinstance(dataset, xr.Dataset):
- dataset = self._convert_dataset_to_dataarray(dataset)
- state = self.open_zarrs("state")
- state = self._apply_time_split(state, dataset.attrs["split"])
- state_time = state.time.values
- window = self._config[category]["window"]
- dataset = (
- dataset.sel(time=state_time, method="nearest")
- .pad(time=(window // 2, window // 2), mode="edge")
- .rolling(time=window, center=True)
- .construct("window")
- .stack(variable_window=("variable", "window"))
- )
- dataset = dataset.isel(time=slice(window // 2, -window // 2 + 1))
- return dataset
-
@property
def grid_shape_state(self):
"""Return the shape of the state grid.
@@ -637,13 +661,12 @@ def boundary_mask(self):
"grid_index"
)
- def get_dataarray(self, category, split="train", apply_windowing=True):
+ def get_dataarray(self, category, split="train"):
"""Process the dataset for the given category.
Args:
category (str): The category of the dataset (state/forcing/static).
split (str): The time split to filter the dataset (train/val/test).
- apply_windowing (bool): Whether to apply windowing to the forcing dataset.
Returns:
xr.DataArray: The xarray DataArray object with processed dataset.
@@ -656,9 +679,9 @@ def get_dataarray(self, category, split="train", apply_windowing=True):
dataset = self._rename_dataset_dims_and_vars(category, dataset)
dataset = self._filter_dimensions(dataset)
dataset = self._convert_dataset_to_dataarray(dataset)
- if "window" in self._config[category] and apply_windowing:
- dataset = self.apply_window(category, dataset)
if category == "static" and "time" in dataset.dims:
dataset = dataset.isel(time=0, drop=True)
+ dataset = dataset.rename(dict(variable_name=f"{category}_feature"))
+
return dataset
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 02365a46..674c368d 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -1,3 +1,5 @@
+"""Numpy-files based datastore to support the MEPS example dataset introduced
+in neural-lam v0.1.0."""
# Standard library
import functools
import re
@@ -134,20 +136,39 @@ class NpyFilesDatastore(BaseCartesianDatastore):
def __init__(
self,
- root_path,
+ config_path,
):
+ """Create a new NpyFilesDatastore using the configuration file at the
+ given path. The config file should be a YAML file and will be loaded
+ into an instance of the `NpyDatastoreConfig` dataclass.
+
+ Internally, the datastore uses dask.delayed to load the data from the
+ numpy files, so that the data isn't actually loaded until it's needed.
+
+ Parameters
+ ----------
+ config_path : str
+ The path to the configuration file for the datastore.
+ """
# XXX: This should really be in the config file, not hard-coded in this class
self._num_timesteps = 65
self._step_length = 3 # 3 hours
self._num_ensemble_members = 2
- self._root_path = Path(root_path)
- self.config = NpyDatastoreConfig.from_yaml_file(
- self.root_path / "data_config.yaml"
- )
+ self._config_path = Path(config_path)
+ self._root_path = self._config_path.parent
+ self.config = NpyDatastoreConfig.from_yaml_file(self._config_path)
@property
- def root_path(self):
+ def root_path(self) -> Path:
+ """The root path of the datastore on disk. This is the directory
+ relative to which graphs and other files can be stored.
+
+ Returns
+ -------
+ Path
+ The root path of the datastore
+ """
return self._root_path
def get_dataarray(self, category: str, split: str) -> DataArray:
@@ -403,7 +424,7 @@ def _get_single_timeseries_dataarray(
return da
- def _get_analysis_times(self, split):
+ def _get_analysis_times(self, split) -> List[np.datetime64]:
"""Get the analysis times for the given split by parsing the filenames
of all the files found for the given split.
@@ -529,15 +550,18 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
return arr
@property
- def step_length(self):
- return self._step_length
+ def step_length(self) -> int:
+ """The length of each time step in hours.
- @property
- def coords_projection(self):
- return self.config.coords_projection
+ Returns
+ -------
+ int
+ The length of each time step in hours.
+ """
+ return self._step_length
@property
- def grid_shape_state(self):
+ def grid_shape_state(self) -> CartesianGridShape:
"""The shape of the cartesian grid for the state variables.
Returns
@@ -549,7 +573,15 @@ def grid_shape_state(self):
return CartesianGridShape(x=nx, y=ny)
@property
- def boundary_mask(self):
+ def boundary_mask(self) -> xr.DataArray:
+ """The boundary mask for the dataset. This is a binary mask that is 1
+ where the grid cell is on the boundary of the domain, and 0 otherwise.
+
+ Returns
+ -------
+ xr.DataArray
+ The boundary mask for the dataset, with dimensions `[grid_index]`.
+ """
xs, ys = self.get_xy(category="state", stacked=False)
assert np.all(xs[:, 0] == xs[:, -1])
assert np.all(ys[0, :] == ys[-1, :])
@@ -627,8 +659,14 @@ def load_pickled_tensor(fn):
return ds_norm
@functools.cached_property
- def coords_projection(self):
- """Return the projection."""
+ def coords_projection(self) -> ccrs.Projection:
+ """The projection of the spatial coordinates.
+
+ Returns
+ -------
+ ccrs.Projection
+ The projection of the spatial coordinates.
+ """
proj_class_name = self.config.projection.class_name
ProjectionClass = getattr(ccrs, proj_class_name)
proj_params = self.config.projection.kwargs
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index cf576008..dffa7021 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -26,13 +26,13 @@
}
-def _init_datastore(datastore_kind, path):
+def _init_datastore(datastore_kind, config_path):
if datastore_kind == "multizarr":
- datastore = MultiZarrDatastore(root_path=path)
+ datastore = MultiZarrDatastore(config_path=config_path)
elif datastore_kind == "npyfiles":
- datastore = NpyFilesDatastore(root_path=path)
+ datastore = NpyFilesDatastore(config_path=config_path)
elif datastore_kind == "mllam":
- datastore = MLLAMDatastore(root_path=path)
+ datastore = MLLAMDatastore(config_path=config_path)
else:
raise ValueError(f"Unknown datastore kind: {datastore_kind}")
return datastore
@@ -50,9 +50,9 @@ def main(input_args=None):
help="Kind of datastore to use",
)
parser.add_argument(
- "datastore_path",
+ "datastore_config_path",
type=str,
- help="The root path for the datastore",
+ help="Path for the datastore config",
)
parser.add_argument(
"--model",
@@ -246,7 +246,8 @@ def main(input_args=None):
seed.seed_everything(args.seed)
# Create datastore
datastore = _init_datastore(
- datastore_kind=args.datastore_kind, path=args.datastore_path
+ datastore_kind=args.datastore_kind,
+ config_path=args.datastore_config_path,
)
# Create datamodule
data_module = WeatherDataModule(
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index ed4856d3..22604a68 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -178,7 +178,7 @@ def __getitem__(self, idx):
)
da_forcing = self.da_forcing
else:
- da_forcing = xr.DataArray()
+ da_forcing = None
# handle time sampling in a way that is compatible with both analysis
# and forecast data
@@ -186,29 +186,37 @@ def __getitem__(self, idx):
da=da_state, idx=idx, n_steps=2 + self.ar_steps
)
- das_forcing = []
- for n in range(self.forcing_window_size):
- da_ = self._sample_time(
- da=da_forcing,
- idx=idx,
- n_steps=self.ar_steps,
- n_timesteps_offset=n,
- )
- if n > 0:
- da_ = da_.drop_vars("time")
- das_forcing.append(da_)
- da_forcing_windowed = xr.concat(das_forcing, dim="window_sample")
+ if da_forcing is not None:
+ das_forcing = []
+ for n in range(self.forcing_window_size):
+ da_ = self._sample_time(
+ da=da_forcing,
+ idx=idx,
+ n_steps=self.ar_steps,
+ n_timesteps_offset=n,
+ )
+ if n > 0:
+ da_ = da_.drop_vars("time")
+ das_forcing.append(da_)
+ da_forcing_windowed = xr.concat(das_forcing, dim="window_sample")
+
+ # load the data into memory
+ da_state = da_state.load()
+ if da_forcing is not None:
+ da_forcing_windowed = da_forcing_windowed.load()
# ensure the dimensions are in the correct order
da_state = da_state.transpose("time", "grid_index", "state_feature")
- da_forcing_windowed = da_forcing_windowed.transpose(
- "time", "grid_index", "forcing_feature", "window_sample"
- )
+
+ if da_forcing is not None:
+ da_forcing_windowed = da_forcing_windowed.transpose(
+ "time", "grid_index", "forcing_feature", "window_sample"
+ )
da_init_states = da_state.isel(time=slice(None, 2))
da_target_states = da_state.isel(time=slice(2, None))
- batch_times = da_forcing_windowed.time.values.astype(float)
+ batch_times = da_target_states.time.values.astype(float)
if self.standardize:
da_init_states = (
@@ -223,17 +231,28 @@ def __getitem__(self, idx):
da_forcing_windowed - self.da_forcing_mean
) / self.da_forcing_std
- # stack the `forcing_feature` and `window_sample` dimensions into a
- # single `forcing_feature` dimension
- da_forcing_windowed = da_forcing_windowed.stack(
- forcing_feature_windowed=("forcing_feature", "window_sample")
- )
+ if self.da_forcing is not None:
+ # stack the `forcing_feature` and `window_sample` dimensions into a
+ # single `forcing_feature` dimension
+ da_forcing_windowed = da_forcing_windowed.stack(
+ forcing_feature_windowed=("forcing_feature", "window_sample")
+ )
init_states = torch.tensor(da_init_states.values, dtype=torch.float32)
target_states = torch.tensor(
da_target_states.values, dtype=torch.float32
)
- forcing = torch.tensor(da_forcing_windowed.values, dtype=torch.float32)
+
+ if self.da_forcing is None:
+ # create an empty forcing tensor
+ forcing = torch.empty(
+ (self.ar_steps, da_state.grid_index.size, 0),
+ dtype=torch.float32,
+ )
+ else:
+ forcing = torch.tensor(
+ da_forcing_windowed.values, dtype=torch.float32
+ )
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps, N_grid, d_features)
@@ -276,6 +295,12 @@ def __init__(
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
+ if num_workers > 0:
+ # default to spawn for now, as the default on linux "fork" hangs
+ # when using dask (which the npyfiles datastore uses)
+ self.multiprocessing_context = "spawn"
+ else:
+ self.multiprocessing_context = None
def setup(self, stage=None):
if stage == "fit" or stage is None:
@@ -310,6 +335,8 @@ def train_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
+ multiprocessing_context=self.multiprocessing_context,
+ persistent_workers=True,
)
def val_dataloader(self):
@@ -319,6 +346,8 @@ def val_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
+ multiprocessing_context=self.multiprocessing_context,
+ persistent_workers=True,
)
def test_dataloader(self):
@@ -328,4 +357,6 @@ def test_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
+ multiprocessing_context=self.multiprocessing_context,
+ persistent_workers=True,
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 9ff25a91..bea9e95f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,20 +7,20 @@
import yaml
# First-party
-from neural_lam.datastore.mllam import MLLAMDatastore
-from neural_lam.datastore.multizarr import MultiZarrDatastore
-from neural_lam.datastore.npyfiles import NpyFilesDatastore
+from neural_lam.datastore import mllam, multizarr, npyfiles
# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
os.environ["WANDB_DISABLED"] = "true"
DATASTORES = dict(
- multizarr=MultiZarrDatastore,
- mllam=MLLAMDatastore,
- npyfiles=NpyFilesDatastore,
+ multizarr=multizarr.MultiZarrDatastore,
+ mllam=mllam.MLLAMDatastore,
+ npyfiles=npyfiles.NpyFilesDatastore,
)
+DATASTORE_EXAMPLES_ROOT_PATH = Path("tests/datastore_examples")
+
# Initializing variables for the s3 client
S3_BUCKET_NAME = "mllam-testdata"
S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
@@ -33,9 +33,8 @@
def download_meps_example_reduced_dataset():
# Download and unzip test data into data/meps_example_reduced
- root_path = Path("tests/datastore_configs/npy")
+ root_path = DATASTORE_EXAMPLES_ROOT_PATH / "npy"
dataset_path = root_path / "meps_example_reduced"
- will_download = not dataset_path.exists()
pooch.retrieve(
url=S3_FULL_PATH,
@@ -45,23 +44,60 @@ def download_meps_example_reduced_dataset():
fname="meps_example_reduced.zip",
)
- if will_download:
- # XXX: should update the dataset stored on S3 the change below
- config_path = dataset_path / "data_config.yaml"
- # rename the `projection.class` key to `projection.class_name` in the config
- with open(config_path, "r") as f:
- config = yaml.safe_load(f)
- config["projection.class_name"] = config.pop("projection.class")
+ config_path = dataset_path / "data_config.yaml"
+
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+
+ if "class" in config["projection"]:
+ # XXX: should update the dataset stored on S3 with the change below
+ #
+ # rename the `projection.class` key to `projection.class_name` in the
+ # config this is because the `class` key is reserved for the class
+ # attribute of the object and so we can't use it to define a python
+ # dataclass
+ config["projection"]["class_name"] = config["projection"].pop("class")
+
with open(config_path, "w") as f:
yaml.dump(config, f)
- return dataset_path
+ return config_path
+
+
+def bootstrap_multizarr_example():
+ multizarr_path = DATASTORE_EXAMPLES_ROOT_PATH / "multizarr"
+ data_config_path = multizarr_path / "data_config.yaml"
+ # here assume that the data-config is referring the the default path
+ # for the "datetime forcings" dataset
+ datetime_forcing_zarr_path = (
+ data_config_path.parent
+ / multizarr.create_datetime_forcings.DEFAULT_FILENAME
+ )
+ if not datetime_forcing_zarr_path.exists():
+ multizarr.create_datetime_forcings.create_datetime_forcing_zarr(
+ data_config_path=data_config_path
+ )
+
+ normalized_forcing_zarr_path = (
+ data_config_path.parent
+ / multizarr.create_normalization_stats.DEFAULT_FILENAME
+ )
+ if not normalized_forcing_zarr_path.exists():
+ multizarr.create_normalization_stats.create_normalization_stats_zarr(
+ data_config_path=data_config_path
+ )
+
+ return data_config_path
DATASTORES_EXAMPLES = dict(
- multizarr=dict(root_path="tests/datastore_configs/multizarr"),
- mllam=dict(root_path="tests/datastore_configs/mllam"),
- npyfiles=dict(root_path=download_meps_example_reduced_dataset()),
+ multizarr=dict(config_path=bootstrap_multizarr_example()),
+ mllam=dict(
+ config_path=DATASTORE_EXAMPLES_ROOT_PATH
+ / "mllam"
+ / "danra.example.yaml"
+ ),
+ npyfiles=dict(config_path=download_meps_example_reduced_dataset()),
)
diff --git a/tests/datastores_examples/.gitignore b/tests/datastore_examples/.gitignore
similarity index 100%
rename from tests/datastores_examples/.gitignore
rename to tests/datastore_examples/.gitignore
diff --git a/tests/datastores_examples/mllam/.gitignore b/tests/datastore_examples/mllam/.gitignore
similarity index 100%
rename from tests/datastores_examples/mllam/.gitignore
rename to tests/datastore_examples/mllam/.gitignore
diff --git a/tests/datastores_examples/mllam/data_config.yaml b/tests/datastore_examples/mllam/danra.example.yaml
similarity index 100%
rename from tests/datastores_examples/mllam/data_config.yaml
rename to tests/datastore_examples/mllam/danra.example.yaml
diff --git a/tests/datastores_examples/multizarr/data_config.yaml b/tests/datastore_examples/multizarr/data_config.yaml
similarity index 100%
rename from tests/datastores_examples/multizarr/data_config.yaml
rename to tests/datastore_examples/multizarr/data_config.yaml
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 72518887..64548d02 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -70,8 +70,9 @@ def test_dataset_item(datastore_name):
pass
+@pytest.mark.parametrize("split", ["train", "val", "test"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
-def test_single_batch(datastore_name, split="train"):
+def test_single_batch(datastore_name, split):
"""Check that the `datastore.get_dataarray` method is implemented.
And that it returns an xarray DataArray with the correct dimensions.
diff --git a/tests/test_training.py b/tests/test_training.py
index 5e7f4095..44f863f0 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -29,8 +29,8 @@ def test_training(datastore_name):
trainer = pl.Trainer(
max_epochs=3,
deterministic=True,
- strategy="ddp",
accelerator=device_name,
+ devices=1,
log_every_n_steps=1,
)
From a1b20376d29d4273fee3bc7687ef441e817447e7 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 12 Aug 2024 14:10:45 +0000
Subject: [PATCH 156/273] all tests passing!
---
README.md | 74 +++++++++++++++----
neural_lam/datastore/base.py | 12 ++-
.../multizarr/create_boundary_mask.py | 63 +++++++++++-----
.../multizarr/create_datetime_forcings.py | 4 +-
.../multizarr/create_normalization_stats.py | 49 +++++-------
neural_lam/datastore/multizarr/store.py | 42 +++++------
neural_lam/weather_dataset.py | 2 +-
tests/conftest.py | 46 ++++++++++++
.../mllam/danra.example.yaml | 5 +-
tests/datastore_examples/multizarr/.gitignore | 2 +
.../multizarr/data_config.yaml | 8 +-
tests/test_datasets.py | 13 ++--
tests/test_datastores.py | 2 +
tests/test_training.py | 7 +-
14 files changed, 224 insertions(+), 105 deletions(-)
create mode 100644 tests/datastore_examples/multizarr/.gitignore
diff --git a/README.md b/README.md
index e7c2f53d..7ea730c7 100644
--- a/README.md
+++ b/README.md
@@ -45,20 +45,62 @@ Still, some restrictions are inevitable:
-## A note on the limited area setting
-Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
-There are still some parts of the code that is quite specific for the MEPS area use case.
-This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ).
-If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
-We would be happy to support such enhancements.
-See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
-
# Using Neural-LAM
-Below follows instructions on how to use Neural-LAM to train and evaluate models.
+Below follows instructions on how to use Neural-LAM to train and evaluate models. Once `neural-lam` has been installed the general process is:
+
+1. Run any pre-processing scripts to generate the necessary derived data that your chosen datastore requires
+2. Run graph-creation step
+3. Train the model
+
+## Data
+
+To enable flexibility in what input-data sources can be used with neural-lam,
+the input-data representation is split into two parts:
+
+1. a "datastore" (represented by instances of
+ [neural_lam.datastore.BaseDataStore](neural_lam/datastore/base.py)) which
+ takes care of loading a given category (state, forcing or static) and split
+ (train/val/test) of data from disk and returning it as a `xarray.DataArray`.
+ The returned data-array is expected to have the spatial coordinates
+ flattened into a single `grid_index` dimension and all variables and vertical
+ levels stacked into a feature dimension (named as `{category}_feature`) The
+ datastore also provides information about the number, names and units of
+ variables in the data, the boundary mask, normalisation values and grid
+ information.
+
+2. a `pytorch.Dataset`-derived class (called
+ `neural_lam.weather_dataset.WeatherDataset`) which takes care of sampling in
+ time to create individual samples for training, validation and testing. The
+ `WeatherDataset` class is also responsible for normalising the values and
+ returning `torch.Tensor`-objects.
+
+There are currently three different datastores implemented in the codebase:
+
+1. `neural_lam.datastore.NpyDataStore` which reads data from `.npy`-files in
+ the format introduced in neural-lam `v0.1.0`.
+
+2. `neural_lam.datastore.MultizarrDatastore` which can combines multiple zarr
+ files during train/val/test sampling, with the transformations to facilitate
+ this implemented within `neural_lam.datastore.MultizarrDatastore`.
+
+3. `neural_lam.datastore.MLLAMDatastore` which can combine multiple zarr
+ datasets either either as a preprocessing step or during sampling, but
+ offloads the implementation of the transformations the
+ [mllam-data-prep](https://github.com/mllam/mllam-data-prep) package.
+
+If neither of these options fit your need you can create your own datastore by
+subclassing the `neural_lam.datastore.BaseDataStore` class or
+`neural_lam.datastore.BaseCartesianDatastore` class (if your data is stored on
+a Cartesian grid) and implementing the abstract methods.
+
## Installation
-The dependencies in `neural-lam` is handled with [pdm](https://pdm.fming.dev/), but you can still install `neural-lam` directly with pip if you prefer. The benefits of using `pdm` are that [pyproject.toml](pyproject.toml) is automatically updated when you add/remove dependencies (with `pdm add ` or `pdm remove ` or `pdm remove
@@ -132,11 +179,12 @@ wandb off
```
## Train Models
-Models can be trained using `python -m neural_lam.train_model`.
+Models can be trained using `python -m neural_lam.train_model `.
Run `python neural_lam.train_model --help` for a full list of training options.
A few of the key ones are outlined below:
-* `--data_config`: Path to the data configuration file
+* ``: The kind of datastore that you are using (should be one of `npyfiles`, `multizarr` or `mllam`)
+* ``: Path to the data store configuration file
* `--model`: Which model to train
* `--graph`: Which graph to use with the model
* `--processor_layers`: Number of GNN layers to use in the processing part of the model
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 73658126..101a13bc 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -16,14 +16,20 @@ class BaseDatastore(abc.ABC):
access the data in a processed format that can be used for training and
evaluation of neural networks.
+ NOTE: All methods return either primitive types, `numpy.ndarray`,
+ `xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor`
+ objects. Conversion to `pytorch.Tensor` objects should be done in the
+ `weather_dataset.WeatherDataset` class (which inherits from
+ `torch.utils.data.Dataset` and uses the datastore to access the data).
+
# Forecast vs analysis data
- If the datastore should represent forecast rather than analysis data, then
+ If the datastore is used represent forecast rather than analysis data, then
the `is_forecast` attribute should be set to True, and returned data from
`get_dataarray` is assumed to have `analysis_time` and `forecast_time` dimensions
(rather than just `time`).
# Ensemble vs deterministic data
- If the datastore should represent ensemble data, then the `is_ensemble`
+ If the datastore is used to represent ensemble data, then the `is_ensemble`
attribute should be set to True, and returned data from `get_dataarray` is
assumed to have an `ensemble_member` dimension.
"""
@@ -108,7 +114,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
should contain a `{category}_mean` and `{category}_std` variable for
each variable in the category. For `category=="state"`, the dataarray
should also contain a `state_diff_mean` and `state_diff_std` variable
- for the one-step differences of the state variables. The return
+ for the one-step differences of the state variables. The returned
dataarray should at least have dimensions of `({category}_feature)`,
but can also include for example `grid_index` (if the normalisation is
done per grid point for example).
diff --git a/neural_lam/datastore/multizarr/create_boundary_mask.py b/neural_lam/datastore/multizarr/create_boundary_mask.py
index 038d88be..ae154941 100644
--- a/neural_lam/datastore/multizarr/create_boundary_mask.py
+++ b/neural_lam/datastore/multizarr/create_boundary_mask.py
@@ -1,49 +1,74 @@
# Standard library
from argparse import ArgumentParser
+from pathlib import Path
# Third-party
import numpy as np
import xarray as xr
-# First-party
-from neural_lam.datastore.multizarr import config
+# Local
+from . import config
+
+DEFAULT_FILENAME = "boundary_mask.zarr"
+
+
+def create_boundary_mask(data_config_path, zarr_path, n_boundary_cells):
+ """Create a mask for the boundaries of the grid.
+
+ Parameters
+ ----------
+ data_config_path : str
+ Data configuration.
+ zarr_path : str
+ Path to save the Zarr archive.
+ """
+ data_config_path = config.Config.from_file(str(data_config_path))
+ mask = np.zeros(list(data_config_path.grid_shape_state.values.values()))
+
+ # Set the n_boundary_cells grid-cells closest to each boundary to True
+ mask[:n_boundary_cells, :] = True # top boundary
+ mask[-n_boundary_cells:, :] = True # noqa bottom boundary
+ mask[:, :n_boundary_cells] = True # left boundary
+ mask[:, -n_boundary_cells:] = True # noqa right boundary
+
+ mask = xr.Dataset({"mask": (["y", "x"], mask)})
+
+ print(f"Saving mask to {zarr_path}...")
+ mask.to_zarr(zarr_path, mode="w")
def main():
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
- "--data_config",
+ "data_config",
type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
+ help="Path to data config file",
)
parser.add_argument(
"--zarr_path",
type=str,
- default="data/boundary_mask.zarr",
+ default=None,
help="Path to save the Zarr archive "
- "(default: same directory as data/boundary_mask.zarr)",
+ "(default: same directory as data config)",
)
parser.add_argument(
- "--boundaries",
+ "--n_boundary_cells",
type=int,
default=30,
help="Number of grid-cells to set to True along each boundary",
)
args = parser.parse_args()
- data_config = config.Config.from_file(args.data_config)
- mask = np.zeros(list(data_config.grid_shape_state.values.values()))
- # Set the args.boundaries grid-cells closest to each boundary to True
- mask[: args.boundaries, :] = True # top boundary
- mask[-args.boundaries :, :] = True # noqa bottom boundary
- mask[:, : args.boundaries] = True # left boundary
- mask[:, -args.boundaries :] = True # noqa right boundary
+ if args.zarr_path is None:
+ args.zarr_path = Path(args.data_config).parent / DEFAULT_FILENAME
+ else:
+ zarr_path = Path(args.zarr_path)
- mask = xr.Dataset({"mask": (["y", "x"], mask)})
-
- print(f"Saving mask to {args.zarr_path}...")
- mask.to_zarr(args.zarr_path, mode="w")
+ create_boundary_mask(
+ data_config_path=args.data_config,
+ zarr_path=zarr_path,
+ n_boundary_cells=args.n_boundary_cells,
+ )
if __name__ == "__main__":
diff --git a/neural_lam/datastore/multizarr/create_datetime_forcings.py b/neural_lam/datastore/multizarr/create_datetime_forcings.py
index 3907ca08..82a90147 100644
--- a/neural_lam/datastore/multizarr/create_datetime_forcings.py
+++ b/neural_lam/datastore/multizarr/create_datetime_forcings.py
@@ -7,8 +7,8 @@
import pandas as pd
import xarray as xr
-# First-party
-from neural_lam.datastore.multizarr import MultiZarrDatastore
+# Local
+from .store import MultiZarrDatastore
DEFAULT_FILENAME = "datetime_forcings.zarr"
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index 46dcc7d7..b4cf1be6 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -5,8 +5,8 @@
# Third-party
import xarray as xr
-# First-party
-from neural_lam.datastore.multizarr import MultiZarrDatastore
+# Local
+from .store import MultiZarrDatastore
DEFAULT_FILENAME = "normalization.zarr"
@@ -54,31 +54,20 @@ def create_normalization_stats_zarr(
for group in combined_stats:
vars_to_combine = group["vars"]
- means = da_forcing_mean.sel(variable_name=vars_to_combine)
- stds = da_forcing_std.sel(variable_name=vars_to_combine)
+ da_forcing_means = da_forcing_mean.sel(
+ forcing_feature=vars_to_combine
+ )
+ stds = da_forcing_std.sel(forcing_feature=vars_to_combine)
- combined_mean = means.mean(dim="variable_name")
- combined_std = (stds**2).mean(dim="variable_name") ** 0.5
+ combined_mean = da_forcing_means.mean(dim="forcing_feature")
+ combined_std = (stds**2).mean(dim="forcing_feature") ** 0.5
da_forcing_mean.loc[
- dict(variable_name=vars_to_combine)
+ dict(forcing_feature=vars_to_combine)
] = combined_mean
da_forcing_std.loc[
- dict(variable_name=vars_to_combine)
+ dict(forcing_feature=vars_to_combine)
] = combined_std
-
- window = datastore._config["forcing"]["window"]
-
- da_forcing_mean = xr.concat(
- [da_forcing_mean] * window, dim="window"
- ).stack(forcing_variable=("variable_name", "window"))
- da_forcing_std = xr.concat(
- [da_forcing_std] * window, dim="window"
- ).stack(forcing_variable=("variable_name", "window"))
- vars = da_forcing["variable_name"].values.tolist()
- window = datastore._config["forcing"]["window"]
- forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]
-
print(
"Computing mean and std.-dev. for one-step differences...", flush=True
)
@@ -94,21 +83,17 @@ def create_normalization_stats_zarr(
"state_diff_std": diff_std,
}
)
+
if da_forcing is not None:
- dsf = (
- xr.Dataset(
- {
- "forcing_mean": da_forcing_mean,
- "forcing_std": da_forcing_std,
- }
- )
- .reset_index(["forcing_variable"])
- .drop_vars(["variable_name", "window"])
- .assign_coords(forcing_variable=forcing_vars)
+ dsf = xr.Dataset(
+ {
+ "forcing_mean": da_forcing_mean,
+ "forcing_std": da_forcing_std,
+ }
)
ds = xr.merge([ds, dsf])
- ds = ds.chunk({"variable_name": -1, "forcing_variable": -1})
+ ds = ds.chunk({"state_feature": -1, "forcing_feature": -1})
print("Saving dataset as Zarr...")
ds.to_zarr(zarr_path, mode="w")
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index d3b339ce..1a3a2a89 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -364,24 +364,6 @@ def get_xy(self, category, stacked=True):
return xy
- def get_xy_extent(self, category):
- """Return the extent of the x, y coordinates. This should be a list of
- 4 floats with `[xmin, xmax, ymin, ymax]`
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
-
- Returns:
- list(float): The extent of the x, y coordinates.
- """
- x, y = self.get_xy(category, stacked=False)
- if self.projection.inverted:
- extent = [x.max(), x.min(), y.max(), y.min()]
- else:
- extent = [x.min(), x.max(), y.min(), y.max()]
-
- return extent
-
@functools.lru_cache()
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
"""Return the normalization dataarray for the given category. This
@@ -405,6 +387,22 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
for the mean and standard deviation of the variables (and
differences for state variables).
"""
+ # XXX: the multizarr code didn't include routines for computing the
+ # normalization of "static" features previously, we'll just hack
+ # something in here and assume they are already normalized
+ if category == "static":
+ da_mean = xr.DataArray(
+ np.zeros(self.get_num_data_vars(category)),
+ dims=("static_feature",),
+ coords={"static_feature": self.get_vars_names(category)},
+ )
+ da_std = xr.DataArray(
+ np.ones(self.get_num_data_vars(category)),
+ dims=("static_feature",),
+ coords={"static_feature": self.get_vars_names(category)},
+ )
+ return xr.Dataset(dict(static_mean=da_mean, static_std=da_std))
+
ds_combined_stats = self._load_and_merge_stats()
if ds_combined_stats is None:
return None
@@ -644,7 +642,7 @@ def grid_shape_state(self):
)
@property
- def boundary_mask(self):
+ def boundary_mask(self) -> xr.DataArray:
"""Load the boundary mask for the dataset, with spatial dimensions
stacked.
@@ -657,8 +655,10 @@ def boundary_mask(self):
self._config["boundary"]["mask"]["path"]
)
ds_boundary_mask = xr.open_zarr(boundary_mask_path)
- return ds_boundary_mask.mask.stack(grid_index=("y", "x")).reset_index(
- "grid_index"
+ return (
+ ds_boundary_mask.mask.stack(grid_index=("y", "x"))
+ .reset_index("grid_index")
+ .astype(int)
)
def get_dataarray(self, category, split="train"):
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 22604a68..5ba1d326 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -84,7 +84,7 @@ def __len__(self):
# in the elapsed_forecast_duration dimension, should that be checked here?
return self.da_state.analysis_time.size
else:
- # sample_len = 2 + ar_steps <-- 2 initial states + ar_steps target states
+ # sample_len = 2 + ar_steps (2 initial states + ar_steps target states)
# n_samples = len(self.da_state.time) - sample_len + 1
# = len(self.da_state.time) - 2 - ar_steps + 1
# = len(self.da_state.time) - ar_steps - 1
diff --git a/tests/conftest.py b/tests/conftest.py
index bea9e95f..1f4edd1a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -4,6 +4,7 @@
# Third-party
import pooch
+import xarray as xr
import yaml
# First-party
@@ -65,7 +66,40 @@ def download_meps_example_reduced_dataset():
def bootstrap_multizarr_example():
+ """Run the steps that are needed to prepare the input data for the
+ multizarr datastore example. This includes:
+
+ - Downloading the two zarr datasets (since training directly from S3 is
+ error-prone as the connection often breaks)
+ - Creating the datetime forcings zarr
+ - Creating the normalization stats zarr
+ - Creating the boundary mask zarr
+ """
multizarr_path = DATASTORE_EXAMPLES_ROOT_PATH / "multizarr"
+ n_boundary_cells = 10
+
+ data_urls = [
+ "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr",
+ "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr",
+ ]
+
+ for url in data_urls:
+ local_path = multizarr_path / "danra" / Path(url).name
+ if local_path.exists():
+ continue
+ print(f"Downloading {url} to {local_path}")
+ ds = xr.open_zarr(url)
+ chunk_dict = {dim: -1 for dim in ds.dims if dim != "time"}
+ chunk_dict["time"] = 20
+ ds = ds.chunk(chunk_dict)
+
+ for var in ds.variables:
+ if "chunks" in ds[var].encoding:
+ del ds[var].encoding["chunks"]
+
+ ds.to_zarr(local_path, mode="w")
+ print("DONE")
+
data_config_path = multizarr_path / "data_config.yaml"
# here assume that the data-config is referring the the default path
# for the "datetime forcings" dataset
@@ -87,6 +121,18 @@ def bootstrap_multizarr_example():
data_config_path=data_config_path
)
+ boundary_mask_path = (
+ data_config_path.parent
+ / multizarr.create_boundary_mask.DEFAULT_FILENAME
+ )
+
+ if not boundary_mask_path.exists():
+ multizarr.create_boundary_mask.create_boundary_mask(
+ data_config_path=data_config_path,
+ n_boundary_cells=n_boundary_cells,
+ zarr_path=boundary_mask_path,
+ )
+
return data_config_path
diff --git a/tests/datastore_examples/mllam/danra.example.yaml b/tests/datastore_examples/mllam/danra.example.yaml
index 5c2d02d7..73aa0dfa 100644
--- a/tests/datastore_examples/mllam/danra.example.yaml
+++ b/tests/datastore_examples/mllam/danra.example.yaml
@@ -59,9 +59,8 @@ inputs:
path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr
dims: [time, x, y]
variables:
- # shouldn't really be using sea-surface pressure as "forcing", but don't
- # have radiation variables in danra yet
- - pres_seasurface
+ # use surface incoming shortwave radiation as forcing
+ - swavr0m
dim_mapping:
time:
method: rename
diff --git a/tests/datastore_examples/multizarr/.gitignore b/tests/datastore_examples/multizarr/.gitignore
new file mode 100644
index 00000000..f2828f46
--- /dev/null
+++ b/tests/datastore_examples/multizarr/.gitignore
@@ -0,0 +1,2 @@
+*.zarr/
+graph/
diff --git a/tests/datastore_examples/multizarr/data_config.yaml b/tests/datastore_examples/multizarr/data_config.yaml
index 0b857761..5d5a4336 100644
--- a/tests/datastore_examples/multizarr/data_config.yaml
+++ b/tests/datastore_examples/multizarr/data_config.yaml
@@ -1,7 +1,7 @@
name: danra
state:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "danra/single_levels.zarr"
dims:
time: time
level: null
@@ -11,7 +11,7 @@ state:
lat_lon_names:
lon: lon
lat: lat
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr"
+ - path: "danra/height_levels.zarr"
dims:
time: time
level: altitude
@@ -41,7 +41,7 @@ state:
- 100
forcing:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "danra/single_levels.zarr"
dims:
time: time
level: null
@@ -82,7 +82,7 @@ forcing:
window: 3 # Number of time steps to use for forcing (odd)
static:
zarrs:
- - path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
+ - path: "danra/single_levels.zarr"
dims:
level: null
x: x
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 64548d02..7e73f787 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -27,7 +27,7 @@ def test_dataset_item(datastore_name):
datastore = init_datastore(datastore_name)
N_gridpoints = datastore.grid_shape_state.x * datastore.grid_shape_state.y
- N_pred_steps = 1
+ N_pred_steps = 4
forcing_window_size = 3
dataset = WeatherDataset(
datastore=datastore,
@@ -43,16 +43,19 @@ def test_dataset_item(datastore_name):
init_states, target_states, forcing, batch_times = item
# initial states
+ assert init_states.ndim == 3
assert init_states.shape[0] == 2 # two time steps go into the input
assert init_states.shape[1] == N_gridpoints
assert init_states.shape[2] == datastore.get_num_data_vars("state")
# output states
+ assert target_states.ndim == 3
assert target_states.shape[0] == N_pred_steps
assert target_states.shape[1] == N_gridpoints
assert target_states.shape[2] == datastore.get_num_data_vars("state")
# forcing
+ assert forcing.ndim == 3
assert forcing.shape[0] == N_pred_steps
assert forcing.shape[1] == N_gridpoints
assert (
@@ -61,13 +64,13 @@ def test_dataset_item(datastore_name):
)
# batch times
+ assert batch_times.ndim == 1
assert batch_times.shape[0] == N_pred_steps
- # try to run through the whole dataset to ensure slicing and stacking
+ # try to get the last item of the dataset to ensure slicing and stacking
# operations are working as expected and are consistent with the dataset
# length
- for item in iter(dataset):
- pass
+ dataset[len(dataset) - 1]
@pytest.mark.parametrize("split", ["train", "val", "test"])
@@ -118,7 +121,7 @@ class ModelArgs:
)
model_device = model.to(device_name)
- data_loader = DataLoader(dataset, batch_size=2)
+ data_loader = DataLoader(dataset, batch_size=5)
batch = next(iter(data_loader))
batch_device = [part.to(device_name) for part in batch]
model_device.common_step(batch_device)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index bd378e98..198d4460 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -126,6 +126,8 @@ def test_get_normalization_dataarray(datastore_name):
for op in ops:
var_name = f"{category}_{op}"
assert var_name in ds_stats.data_vars
+ da_val = ds_stats[var_name]
+ assert set(da_val.dims) == {f"{category}_feature"}
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
diff --git a/tests/test_training.py b/tests/test_training.py
index 44f863f0..ee532656 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -30,7 +30,10 @@ def test_training(datastore_name):
max_epochs=3,
deterministic=True,
accelerator=device_name,
- devices=1,
+ # XXX: `devices` has to be set to 2 otherwise
+ # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails
+ # because it expects to aggregate over multiple devices
+ devices=2,
log_every_n_steps=1,
)
@@ -68,7 +71,7 @@ class ModelArgs:
processor_layers = 4
mesh_aggr = "sum"
lr = 1.0e-3
- val_steps_to_log = [1]
+ val_steps_to_log = [1, 3]
metrics_watch = []
model_args = ModelArgs()
From e35958f4ee3f71f5cf7d2c0551f019874ed0862b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 12 Aug 2024 14:13:39 +0000
Subject: [PATCH 157/273] use mllam-data-prep v0.3.0
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index d0d8c67f..2e43f3bf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,7 +26,7 @@ dependencies = [
"torch-geometric==2.3.1",
"parse>=1.20.2",
"dataclass-wizard>=0.22.3",
- "mllam-data-prep>=0.2.0",
+ "mllam-data-prep[dask-distributed]>=0.3.0",
]
requires-python = ">=3.9"
From 658836a9c2654cc1f97418e17b70835569d404a8 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 13 Aug 2024 12:46:05 +0000
Subject: [PATCH 158/273] remove .DS_Store
---
neural_lam/.DS_Store | Bin 6148 -> 0 bytes
neural_lam/datasets/.DS_Store | Bin 6148 -> 0 bytes
2 files changed, 0 insertions(+), 0 deletions(-)
delete mode 100644 neural_lam/.DS_Store
delete mode 100644 neural_lam/datasets/.DS_Store
diff --git a/neural_lam/.DS_Store b/neural_lam/.DS_Store
deleted file mode 100644
index d0f319116c039464f072e86a8e9d64de0b1a5a9a..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 6148
zcmeHKO>5gg5S?}0+9H(tkU)+Ly&B@WA7FYBPI~AyD$OA&suYPhP1qA3#e?3
zm~#3~37ykqBU&5}kpbSj9qd<*9joc){*~~>G=vF>l|ZX;qr)Gji7T)^!Ah0bxKGxDN*W
zS=8OWkHuw8VL%vo&DW3`Dt9pi5Q0
z#ZWFCe(mEzkBvi@PRchQ%CD?^hobE2sIP4}snDTVVL%wT%>eg)kRIRvZ-3wa+aNK*
zfH3fXGN9Ur(cu{H=I_>xx8%E4Lq9=TIIeMcoq~bAiV@3K@gCF){F(>A&|~8e9*F!1
MSQ^9#1OJqPUyKWBpa1{>
diff --git a/neural_lam/datasets/.DS_Store b/neural_lam/datasets/.DS_Store
deleted file mode 100644
index f172ab58d31f03adddb2b8b1d35371f1d00616de..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 6148
zcmeHKJ8nWj3>*gvqBN8#_X@ee3Xv0VfgpisB9O?ZepSwuqhj~%HV}muB>vR5b
z|9aT(^5*MQWu<@=kOERb3P^z)6!6|ln>{2dN&zV#1x^b1_o2}pd*P56pALo?0f
Date: Tue, 13 Aug 2024 13:59:01 +0000
Subject: [PATCH 159/273] use tmate in gpu pdm cicd
---
.github/workflows/ci-pdm-install-and-test-gpu.yml | 3 +++
1 file changed, 3 insertions(+)
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
index f9060361..c605be8b 100644
--- a/.github/workflows/ci-pdm-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -36,6 +36,9 @@ jobs:
python -c "import torch; print(torch.__version__)"
python -c "import torch; assert not torch.__version__.endswith('+cpu')"
+ - name: Setup tmate session
+ uses: mxschmitt/action-tmate@v1
+
- name: Install package (including dev dependencies)
run: |
pdm install
From 3afe0e498275731eda0eaa23030998255a9e6760 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 13 Aug 2024 14:18:50 +0000
Subject: [PATCH 160/273] update pdm gpu cicd setup to pdm venv on nvme drive
---
.github/workflows/ci-pdm-install-and-test-gpu.yml | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
index c605be8b..cd533cfa 100644
--- a/.github/workflows/ci-pdm-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -24,17 +24,19 @@ jobs:
- name: Create venv
run: |
+ pdm config venv.in_project False
+ pdm config venv.location /opt/dlami/nvme/venv
pdm venv create --with-pip
pdm use --venv in-project
- name: Install torch (GPU CUDA 12.1)
run: |
- python -m pip install torch --index-url https://download.pytorch.org/whl/cu121
+ pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cu121
- name: Print and check torch version
run: |
- python -c "import torch; print(torch.__version__)"
- python -c "import torch; assert not torch.__version__.endswith('+cpu')"
+ pdm run python -c "import torch; print(torch.__version__)"
+ pdm run python -c "import torch; assert not torch.__version__.endswith('+cpu')"
- name: Setup tmate session
uses: mxschmitt/action-tmate@v1
From f3d028b3718023e1457f267d9fd99bd7c2c33d24 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 13 Aug 2024 14:32:10 +0000
Subject: [PATCH 161/273] don't try to use pdm venv in-project
---
.github/workflows/ci-pdm-install-and-test-gpu.yml | 1 -
1 file changed, 1 deletion(-)
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
index cd533cfa..3bd46fd9 100644
--- a/.github/workflows/ci-pdm-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -27,7 +27,6 @@ jobs:
pdm config venv.in_project False
pdm config venv.location /opt/dlami/nvme/venv
pdm venv create --with-pip
- pdm use --venv in-project
- name: Install torch (GPU CUDA 12.1)
run: |
From 2c35662ccb2cdc6e7a7650e4c364df31d65c8d41 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 13 Aug 2024 14:44:06 +0000
Subject: [PATCH 162/273] remove tmate
---
.github/workflows/ci-pdm-install-and-test-gpu.yml | 3 ---
1 file changed, 3 deletions(-)
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
index 3bd46fd9..94e740ce 100644
--- a/.github/workflows/ci-pdm-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -37,9 +37,6 @@ jobs:
pdm run python -c "import torch; print(torch.__version__)"
pdm run python -c "import torch; assert not torch.__version__.endswith('+cpu')"
- - name: Setup tmate session
- uses: mxschmitt/action-tmate@v1
-
- name: Install package (including dev dependencies)
run: |
pdm install
From 5f30255501e35ccc72a888bf6fdc4c5bdf9fc3ed Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 06:54:32 +0000
Subject: [PATCH 163/273] update README with install instructions
---
README.md | 41 ++++++++++++++++++++++++++++-------------
1 file changed, 28 insertions(+), 13 deletions(-)
diff --git a/README.md b/README.md
index 26d844f7..4c90dfea 100644
--- a/README.md
+++ b/README.md
@@ -57,21 +57,36 @@ See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://git
Below follows instructions on how to use Neural-LAM to train and evaluate models.
## Installation
-Follow the steps below to create the necessary python environment.
-1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement.
-2. Use python 3.9.
-3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system.
-4. Install required packages specified in `requirements.txt`.
-5. Install PyTorch Geometric version 2.2.0. This can be done by running
-```
-TORCH="2.0.1"
-CUDA="cu117"
+When installing `neural-lam` you have a choice of either installing with
+directly `pip` or using the `pdm` package manager.
+We recommend using `pdm` as it makes it easy to add/remove packages while
+keeping versions consistent (it automatically updates the `pyproject.toml`
+file), makes it easy to handle virtual environments and includes the
+development toolchain packages installation too.
+
+**regarding `torch` installation**: because `torch` creates different package
+variants for different CUDA versions and cpu-only support you will need to install
+`torch` separately if you don't want the most recent GPU variant that also
+expects the most recent version of CUDA on your system.
+
+We cover all the installation options in our [github actions ci/cd
+setup](.github/workflows/) which you can use as a reference.
+
+### Using `pdm`
+
+1. Clone this repository and navigate to the root directory.
+2. Install `pdm` if you don't have it installed on your system (either with `pip install pdm` or [following the install instructions](https://pdm-project.org/latest/#installation)). If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 5.
+3. Create a virtual environment for pdm to use with `pdm venv create --with-pip`.
+4. Install a specific version of `torch` with `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)).
+5. Install the dependencies with `pdm install`. If you will be developing `neural-lam` we recommend to install the development dependencies with `pdm install --dev`. By default `pdm` installs the `neural-lam` package in editable mode, so you can make changes to the code and see the effects immediately.
+
+### Using `pip`
+
+1. Clone this repository and navigate to the root directory. If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 3.
+2. Install a specific version of `torch` with `python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)).
+3. Install the dependencies with `python -m pip install .`. If you will be developing `neural-lam` we recommend to install in editable mode with `python -m pip install -e .` so you can make changes to the code and see the effects immediately. The development dependencies to install are listed in `pyproject.toml`.
-pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 torch-cluster==1.6.1\
- torch-geometric==2.3.1 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
-```
-You will have to adjust the `CUDA` variable to match the CUDA version on your system or to run on CPU. See the [installation webpage](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) for more information.
## Data
Datasets should be stored in a directory called `data`.
From b2b563162d485ca26e0344ff6e648e3222afd918 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 06:56:52 +0000
Subject: [PATCH 164/273] changelog
---
CHANGELOG.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index dfb186f7..5d2d3410 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- added github pull-request template to ease contribution and review process
[\#53](https://github.com/mllam/neural-lam/pull/53), @leifdenby
+- ci/cd setup for running both CPU and GPU-based testing both with pdm and pip based installs [\#37](https://github.com/mllam/neural-lam/pull/37), @khintz, @leifdenby
+
### Changed
Optional multi-core/GPU support for statistics calculation in `create_parameter_weights.py`
From c8ae8294cb35fa658c2b9131bd124edb1a0cbcf6 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 06:58:42 +0000
Subject: [PATCH 165/273] update ci/cd badges to include gpu + gpu
---
README.md | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 4c90dfea..81ec2766 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,6 @@
![Linting](https://github.com/mllam/neural-lam/actions/workflows/pre-commit.yml/badge.svg?branch=main)
-![Automatic tests](https://github.com/mllam/neural-lam/actions/workflows/run_tests.yml/badge.svg?branch=main)
+[![test (pdm install, gpu)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-gpu.yml/badge.svg)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-gpu.yml)
+[![test (pdm install, cpu)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-cpu.yml/badge.svg)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-cpu.yml)
From 0b72e9d1b9831804b56926c68bf2a3e3f1fbb491 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 07:51:09 +0000
Subject: [PATCH 166/273] add pyproject-flake8 to precommit config
---
.pre-commit-config.yaml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 815a92e1..7f16cbee 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -35,3 +35,4 @@ repos:
hooks:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
+ additional_dependencies: [pyproject-flake8]
From 190d1de713a658e717d24f7d8c65323becd22f7a Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 07:51:59 +0000
Subject: [PATCH 167/273] use Flake8-pyproject instead
---
.pre-commit-config.yaml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7f16cbee..dfbf8b60 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -35,4 +35,4 @@ repos:
hooks:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
- additional_dependencies: [pyproject-flake8]
+ additional_dependencies: [Flake8-pyproject]
From 791af0a8028bd9e60d2c620f383e80c2ae2227b4 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 07:58:09 +0000
Subject: [PATCH 168/273] update README
---
README.md | 10 ++--------
1 file changed, 2 insertions(+), 8 deletions(-)
diff --git a/README.md b/README.md
index 00562f22..ce8daf69 100644
--- a/README.md
+++ b/README.md
@@ -62,16 +62,10 @@ Follow the steps below to create the necessary python environment.
1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement.
2. Use python 3.9.
3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system.
-4. Install required packages specified in `requirements.txt`.
-5. Install PyTorch Geometric version 2.2.0. This can be done by running
+4. Install `neural-lam` with pip:
```
-TORCH="2.0.1"
-CUDA="cu117"
-
-pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 torch-cluster==1.6.1\
- torch-geometric==2.3.1 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
+pip install -e .
```
-You will have to adjust the `CUDA` variable to match the CUDA version on your system or to run on CPU. See the [installation webpage](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) for more information.
## Data
Datasets should be stored in a directory called `data`.
From 799d55e3abd8a7ba34507cf1e2d524be070de89f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 12:30:07 +0000
Subject: [PATCH 169/273] linting fixes
---
.pre-commit-config.yaml | 3 +-
neural_lam/create_graph.py | 62 ++-----
neural_lam/datastore/base.py | 155 +++++++++++++-----
neural_lam/datastore/mllam.py | 114 +++++++++----
.../multizarr/create_boundary_mask.py | 1 +
.../multizarr/create_datetime_forcings.py | 13 +-
.../multizarr/create_normalization_stats.py | 17 +-
neural_lam/datastore/multizarr/store.py | 116 ++++++++-----
neural_lam/datastore/npyfiles/config.py | 16 +-
neural_lam/datastore/npyfiles/store.py | 132 +++++++++------
neural_lam/interaction_net.py | 19 ++-
neural_lam/metrics.py | 40 ++---
neural_lam/models/ar_model.py | 84 ++++------
neural_lam/models/base_graph_model.py | 31 ++--
neural_lam/models/base_hi_graph_model.py | 63 +++----
neural_lam/models/graph_lam.py | 31 ++--
neural_lam/models/hi_lam.py | 22 ++-
neural_lam/models/hi_lam_parallel.py | 17 +-
neural_lam/train_model.py | 23 +--
neural_lam/utils.py | 24 +--
neural_lam/vis.py | 25 +--
neural_lam/weather_dataset.py | 64 ++++----
plot_graph.py | 8 +-
pyproject.toml | 11 +-
tests/conftest.py | 18 +-
tests/test_cli.py | 4 +-
tests/test_datasets.py | 13 +-
tests/test_datastores.py | 25 +--
tests/test_graph_creation.py | 11 +-
tests/test_training.py | 4 +-
30 files changed, 625 insertions(+), 541 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index fd40f4d7..91983d9b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -40,4 +40,5 @@ repos:
rev: v1.7.5
hooks:
- id: docformatter
- args: [--in-place, --recursive]
+ args: [--in-place, --recursive, --config, ./pyproject.toml]
+ additional_dependencies: [tomli]
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index e5eb44a4..6450f134 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -35,9 +35,7 @@ def plot_graph(graph, title=None):
# TODO: indicate direction of directed edges
# Move all to cpu and numpy, compute (in)-degrees
- degrees = (
- pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy()
- )
+ degrees = pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy()
edge_index = edge_index.cpu().numpy()
pos = pos.cpu().numpy()
@@ -82,9 +80,7 @@ def sort_nodes_internally(nx_graph):
def save_edges(graph, name, base_path):
- torch.save(
- graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt")
- )
+ torch.save(graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt"))
edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(
torch.float32
) # Save as float32
@@ -97,9 +93,7 @@ def save_edges_list(graphs, name, base_path):
os.path.join(base_path, f"{name}_edge_index.pt"),
)
edge_features = [
- torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(
- torch.float32
- )
+ torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(torch.float32)
for graph in graphs
] # Save as float32
torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt"))
@@ -130,11 +124,7 @@ def mk_2d_graph(xy, nx, ny):
# add diagonal edges
g.add_edges_from(
[((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
- + [
- ((x + 1, y), (x, y + 1))
- for x in range(nx - 1)
- for y in range(ny - 1)
- ]
+ + [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
)
# turn into directed graph
@@ -164,8 +154,7 @@ def create_graph(
hierarchical: bool,
create_plot: bool,
):
- """Create graph components from `xy` grid coordinates and store in
- `graph_dir_path`.
+ """Create graph components from `xy` grid coordinates and store in `graph_dir_path`.
Creates the following files for all graphs:
- g2m_edge_index.pt [2, N_g2m_edges]
@@ -225,6 +214,7 @@ def create_graph(
Returns
-------
None
+
"""
os.makedirs(graph_dir_path, exist_ok=True)
@@ -262,10 +252,7 @@ def create_graph(
if hierarchical:
# Relabel nodes of each level with level index first
- G = [
- prepend_node_index(graph, level_i)
- for level_i, graph in enumerate(G)
- ]
+ G = [prepend_node_index(graph, level_i) for level_i, graph in enumerate(G)]
num_nodes_level = np.array([len(g_level.nodes) for g_level in G])
# First node index in each level in the hierarchical graph
@@ -307,9 +294,7 @@ def create_graph(
# add edge from mesh to grid
G_down.add_edge(u, v)
d = np.sqrt(
- np.sum(
- (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2
- )
+ np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2)
)
G_down.edges[u, v]["len"] = d
G_down.edges[u, v]["vdiff"] = (
@@ -334,14 +319,10 @@ def create_graph(
down_graphs.append(pyg_down)
if create_plot:
- plot_graph(
- pyg_down, title=f"Down graph, {from_level} -> {to_level}"
- )
+ plot_graph(pyg_down, title=f"Down graph, {from_level} -> {to_level}")
plt.show()
- plot_graph(
- pyg_down, title=f"Up graph, {to_level} -> {from_level}"
- )
+ plot_graph(pyg_down, title=f"Up graph, {to_level} -> {from_level}")
plt.show()
# Save up and down edges
@@ -426,9 +407,7 @@ def create_graph(
vm = G_bottom_mesh.nodes
vm_xy = np.array([xy for _, xy in vm.data("pos")])
# distance between mesh nodes
- dm = np.sqrt(
- np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2)
- )
+ dm = np.sqrt(np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2))
# grid nodes
Ny, Nx = xy.shape[1:]
@@ -470,13 +449,9 @@ def create_graph(
u = vg_list[i]
# add edge from grid to mesh
G_g2m.add_edge(u, v)
- d = np.sqrt(
- np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2)
- )
+ d = np.sqrt(np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2))
G_g2m.edges[u, v]["len"] = d
- G_g2m.edges[u, v]["vdiff"] = (
- G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]
- )
+ G_g2m.edges[u, v]["vdiff"] = G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]
pyg_g2m = from_networkx(G_g2m)
@@ -505,13 +480,9 @@ def create_graph(
u = vm_list[i]
# add edge from mesh to grid
G_m2g.add_edge(u, v)
- d = np.sqrt(
- np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2)
- )
+ d = np.sqrt(np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2))
G_m2g.edges[u, v]["len"] = d
- G_m2g.edges[u, v]["vdiff"] = (
- G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]
- )
+ G_m2g.edges[u, v]["vdiff"] = G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]
# relabel nodes to integers (sorted)
G_m2g_int = networkx.convert_node_labels_to_integers(
@@ -578,8 +549,7 @@ def cli(input_args=None):
"--plot",
type=int,
default=0,
- help="If graphs should be plotted during generation "
- "(default: 0 (false))",
+ help="If graphs should be plotted during generation " "(default: 0 (false))",
)
parser.add_argument(
"--levels",
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 101a13bc..1b662fa4 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -11,10 +11,17 @@
class BaseDatastore(abc.ABC):
- """Base class for weather data used in the neural-lam package. A datastore
- defines the interface for accessing weather data by providing methods to
- access the data in a processed format that can be used for training and
- evaluation of neural networks.
+ """Base class for weather
+ data used in the neural-
+ lam package. A datastore
+ defines the interface for
+ accessing weather data by
+ providing methods to
+ access the data in a
+ processed format that can
+ be used for training and
+ evaluation of neural
+ networks.
NOTE: All methods return either primitive types, `numpy.ndarray`,
`xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor`
@@ -32,6 +39,7 @@ class BaseDatastore(abc.ABC):
If the datastore is used to represent ensemble data, then the `is_ensemble`
attribute should be set to True, and returned data from `get_dataarray` is
assumed to have an `ensemble_member` dimension.
+
"""
is_ensemble: bool = False
@@ -40,13 +48,14 @@ class BaseDatastore(abc.ABC):
@property
@abc.abstractmethod
def root_path(self) -> Path:
- """The root path to the datastore. It is relative to this that any
- derived files (for example the graph components) are stored.
+ """The root path to the datastore. It is relative to this that any derived files
+ (for example the graph components) are stored.
Returns
-------
pathlib.Path
The root path to the datastore.
+
"""
pass
@@ -57,6 +66,7 @@ def step_length(self) -> int:
Returns:
int: The step length in hours.
+
"""
pass
@@ -73,6 +83,7 @@ def get_vars_units(self, category: str) -> List[str]:
-------
List[str]
The units of the variables.
+
"""
pass
@@ -89,6 +100,7 @@ def get_vars_names(self, category: str) -> List[str]:
-------
List[str]
The names of the variables.
+
"""
pass
@@ -105,19 +117,39 @@ def get_num_data_vars(self, category: str) -> int:
-------
int
The number of data variables.
+
"""
pass
@abc.abstractmethod
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This
- should contain a `{category}_mean` and `{category}_std` variable for
- each variable in the category. For `category=="state"`, the dataarray
- should also contain a `state_diff_mean` and `state_diff_std` variable
- for the one-step differences of the state variables. The returned
- dataarray should at least have dimensions of `({category}_feature)`,
- but can also include for example `grid_index` (if the normalisation is
- done per grid point for example).
+ """Return the
+ normalization
+ dataarray for the
+ given category. This
+ should contain a
+ `{category}_mean` and
+ `{category}_std`
+ variable for each
+ variable in the
+ category. For
+ `category=="state"`,
+ the dataarray should
+ also contain a
+ `state_diff_mean` and
+ `state_diff_std`
+ variable for the one-
+ step differences of
+ the state variables.
+ The returned dataarray
+ should at least have
+ dimensions of `({categ
+ ory}_feature)`, but
+ can also include for
+ example `grid_index`
+ (if the normalisation
+ is done per grid point
+ for example).
Parameters
----------
@@ -130,18 +162,30 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
The normalization dataarray for the given category, with variables
for the mean and standard deviation of the variables (and
differences for state variables).
+
"""
pass
@abc.abstractmethod
- def get_dataarray(
- self, category: str, split: str
- ) -> Union[xr.DataArray, None]:
- """Return the processed data (as a single `xr.DataArray`) for the given
- category of data and test/train/val-split that covers all the data (in
- space and time) of a given category (state/forcing/static). A datastore
- must be able to return for the "state" category, but "forcing" and
- "static" are optional (in which case the method should return `None`).
+ def get_dataarray(self, category: str, split: str) -> Union[xr.DataArray, None]:
+ """Return the
+ processed data (as a
+ single `xr.DataArray`)
+ for the given category
+ of data and
+ test/train/val-split
+ that covers all the
+ data (in space and
+ time) of a given
+ category (state/forcin
+ g/static). A datastore
+ must be able to return
+ for the "state"
+ category, but
+ "forcing" and "static"
+ are optional (in which
+ case the method should
+ return `None`).
The returned dataarray is expected to at minimum have dimensions of
`(grid_index, {category}_feature)` so that any spatial dimensions have
@@ -168,20 +212,29 @@ def get_dataarray(
-------
xr.DataArray or None
The xarray DataArray object with processed dataset.
+
"""
pass
@property
@abc.abstractmethod
def boundary_mask(self) -> xr.DataArray:
- """Return the boundary mask for the dataset, with spatial dimensions
- stacked. Where the value is 1, the grid point is a boundary point, and
- where the value is 0, the grid point is not a boundary point.
+ """Return the boundary
+ mask for the dataset,
+ with spatial
+ dimensions stacked.
+ Where the value is 1,
+ the grid point is a
+ boundary point, and
+ where the value is 0,
+ the grid point is not
+ a boundary point.
Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions `('grid_index',)`.
+
"""
pass
@@ -195,12 +248,21 @@ class CartesianGridShape:
class BaseCartesianDatastore(BaseDatastore):
- """Base class for weather data stored on a Cartesian grid. In addition to
- the methods and attributes required for weather data in general (see
- `BaseDatastore`) for Cartesian gridded source data each `grid_index`
- coordinate value is assume to have an associated `x` and `y`-value so that
- the processed data-arrays can be reshaped back into into 2D xy-gridded
- arrays.
+ """Base class for weather
+ data stored on a Cartesian
+ grid. In addition to the
+ methods and attributes
+ required for weather data
+ in general (see
+ `BaseDatastore`) for
+ Cartesian gridded source
+ data each `grid_index`
+ coordinate value is assume
+ to have an associated `x`
+ and `y`-value so that the
+ processed data-arrays can
+ be reshaped back into into
+ 2D xy-gridded arrays.
In addition the following attributes and methods are required:
- `coords_projection` (property): Projection object for the coordinates.
@@ -208,6 +270,7 @@ class BaseCartesianDatastore(BaseDatastore):
- `get_xy_extent` (method): Return the extent of the x, y coordinates for a
given category of data.
- `get_xy` (method): Return the x, y coordinates of the dataset.
+
"""
CARTESIAN_COORDS = ["y", "x"]
@@ -223,6 +286,7 @@ def coords_projection(self) -> ccrs.Projection:
-------
cartopy.crs.Projection:
The projection object.
+
"""
pass
@@ -236,6 +300,7 @@ def grid_shape_state(self) -> CartesianGridShape:
CartesianGridShape:
The shape of the grid for the state variables, which has `x` and
`y` attributes.
+
"""
pass
@@ -257,13 +322,22 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
value of `stacked`:
- `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
+
"""
pass
def get_xy_extent(self, category: str) -> List[float]:
- """Return the extent of the x, y coordinates for a given category of
- data. The extent should be returned as a list of 4 floats with `[xmin,
- xmax, ymin, ymax]` which can then be used to set the extent of a plot.
+ """Return the extent
+ of the x, y
+ coordinates for a
+ given category of
+ data. The extent
+ should be returned as
+ a list of 4 floats
+ with `[xmin, xmax,
+ ymin, ymax]` which can
+ then be used to set
+ the extent of a plot.
Parameters
----------
@@ -274,6 +348,7 @@ def get_xy_extent(self, category: str) -> List[float]:
-------
List[float]
The extent of the x, y coordinates.
+
"""
xy = self.get_xy(category, stacked=False)
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
@@ -282,9 +357,8 @@ def get_xy_extent(self, category: str) -> List[float]:
def unstack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
) -> Union[xr.DataArray, xr.Dataset]:
- """Stack the spatial grid coordinates into separate `x` and `y`
- dimensions (the names can be set by the `CARTESIAN_COORDS` attribute)
- to create a 2D grid.
+ """Stack the spatial grid coordinates into separate `x` and `y` dimensions (the
+ names can be set by the `CARTESIAN_COORDS` attribute) to create a 2D grid.
Parameters
----------
@@ -295,6 +369,7 @@ def unstack_grid_coords(
-------
xr.DataArray or xr.Dataset
The dataarray or dataset with the grid coordinates unstacked.
+
"""
return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack(
"grid_index"
@@ -303,9 +378,8 @@ def unstack_grid_coords(
def stack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
) -> Union[xr.DataArray, xr.Dataset]:
- """Stack the spatial grid coordinated (by default `x` and `y`, but this
- can be set by the `CARTESIAN_COORDS` attribute) into a single
- `grid_index` dimension.
+ """Stack the spatial grid coordinated (by default `x` and `y`, but this can be
+ set by the `CARTESIAN_COORDS` attribute) into a single `grid_index` dimension.
Parameters
----------
@@ -316,5 +390,6 @@ def stack_grid_coords(
-------
xr.DataArray or xr.Dataset
The dataarray or dataset with the grid coordinates stacked.
+
"""
return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index ae2c5d53..0d011e5e 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -19,11 +19,23 @@ class MLLAMDatastore(BaseCartesianDatastore):
"""Datastore class for the MLLAM dataset."""
def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
- """Construct a new MLLAMDatastore from the configuration file at
- `config_path`. A boundary mask is created with `n_boundary_points`
- boundary points. If `reuse_existing` is True, the dataset is loaded
- from a zarr file if it exists (unless the config has been modified
- since the zarr was created), otherwise it is created from the
+ """Construct a new
+ MLLAMDatastore from
+ the configuration file
+ at `config_path`. A
+ boundary mask is
+ created with
+ `n_boundary_points`
+ boundary points. If
+ `reuse_existing` is
+ True, the dataset is
+ loaded from a zarr
+ file if it exists
+ (unless the config has
+ been modified since
+ the zarr was created),
+ otherwise it is
+ created from the
configuration file.
Parameters
@@ -37,13 +49,12 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
reuse_existing : bool
Whether to reuse an existing dataset zarr file if it exists and its
creation date is newer than the configuration file.
+
"""
self._config_path = Path(config_path)
self._root_path = self._config_path.parent
self._config = mdp.Config.from_yaml_file(self._config_path)
- fp_ds = self._root_path / self._config_path.name.replace(
- ".yaml", ".zarr"
- )
+ fp_ds = self._root_path / self._config_path.name.replace(".yaml", ".zarr")
self._ds = None
if reuse_existing and fp_ds.exists():
@@ -71,6 +82,7 @@ def root_path(self) -> Path:
-------
Path
The root path of the dataset.
+
"""
return self._root_path
@@ -82,6 +94,7 @@ def step_length(self) -> int:
-------
int
The length of the time steps in hours.
+
"""
da_dt = self._ds["time"].diff("time")
return (da_dt.dt.seconds[0] // 3600).item()
@@ -98,6 +111,7 @@ def get_vars_units(self, category: str) -> List[str]:
-------
List[str]
The units of the variables in the given category.
+
"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
@@ -116,6 +130,7 @@ def get_vars_names(self, category: str) -> List[str]:
-------
List[str]
The names of the variables in the given category.
+
"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
@@ -134,15 +149,29 @@ def get_num_data_vars(self, category: str) -> int:
-------
int
The number of variables in the given category.
+
"""
return len(self.get_vars_names(category))
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
- """Return the processed data (as a single `xr.DataArray`) for the given
- category of data and test/train/val-split that covers all the data (in
- space and time) of a given category (state/forcing/static). "state" is
- the only required category, for other categories, the method will
- return `None` if the category is not found in the datastore.
+ """Return the
+ processed data (as a
+ single `xr.DataArray`)
+ for the given category
+ of data and
+ test/train/val-split
+ that covers all the
+ data (in space and
+ time) of a given
+ category (state/forcin
+ g/static). "state" is
+ the only required
+ category, for other
+ categories, the method
+ will return `None` if
+ the category is not
+ found in the
+ datastore.
The returned dataarray will at minimum have dimensions of `(grid_index,
{category}_feature)` so that any spatial dimensions have been stacked
@@ -169,6 +198,7 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
-------
xr.DataArray or None
The xarray DataArray object with processed dataset.
+
"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
@@ -194,11 +224,24 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
return da_category.sel(time=slice(t_start, t_end))
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This
- should contain a `{category}_mean` and `{category}_std` variable for
- each variable in the category. For `category=="state"`, the dataarray
- should also contain a `state_diff_mean` and `state_diff_std` variable
- for the one-step differences of the state variables.
+ """Return the
+ normalization
+ dataarray for the
+ given category. This
+ should contain a
+ `{category}_mean` and
+ `{category}_std`
+ variable for each
+ variable in the
+ category. For
+ `category=="state"`,
+ the dataarray should
+ also contain a
+ `state_diff_mean` and
+ `state_diff_std`
+ variable for the one-
+ step differences of
+ the state variables.
Parameters
----------
@@ -211,6 +254,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
The normalization dataarray for the given category, with variables
for the mean and standard deviation of the variables (and
differences for state variables).
+
"""
ops = ["mean", "std"]
split = "train"
@@ -227,11 +271,23 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
@property
def boundary_mask(self) -> xr.DataArray:
- """Produce a 0/1 mask for the boundary points of the dataset, these
- will sit at the edges of the domain (in x/y extent) and will be used to
- mask out the boundary points from the loss function and to overwrite
- the boundary points from the prediction. For now this is created when
- the mask is requested, but in the future this could be saved to the
+ """Produce a 0/1 mask
+ for the boundary
+ points of the dataset,
+ these will sit at the
+ edges of the domain
+ (in x/y extent) and
+ will be used to mask
+ out the boundary
+ points from the loss
+ function and to
+ overwrite the boundary
+ points from the
+ prediction. For now
+ this is created when
+ the mask is requested,
+ but in the future this
+ could be saved to the
zarr file.
Returns
@@ -239,19 +295,16 @@ def boundary_mask(self) -> xr.DataArray:
xr.DataArray
A 0/1 mask for the boundary points of the dataset, where 1 is a
boundary point and 0 is not.
+
"""
ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
- da_state_variable = (
- ds_unstacked["state"].isel(time=0).isel(state_feature=0)
- )
+ da_state_variable = ds_unstacked["state"].isel(time=0).isel(state_feature=0)
da_domain_allzero = xr.zeros_like(da_state_variable)
ds_unstacked["boundary_mask"] = da_domain_allzero.isel(
x=slice(self._n_boundary_points, -self._n_boundary_points),
y=slice(self._n_boundary_points, -self._n_boundary_points),
)
- ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(
- 1
- ).astype(int)
+ ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(1).astype(int)
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)
@property
@@ -262,6 +315,7 @@ def coords_projection(self) -> ccrs.Projection:
-------
ccrs.Projection
The projection of the coordinates.
+
"""
# TODO: danra doesn't contain projection information yet, but the next
# version will for now we hardcode the projection
@@ -276,6 +330,7 @@ def grid_shape_state(self):
-------
CartesianGridShape
The shape of the cartesian grid for the state variables.
+
"""
ds_state = self.unstack_grid_coords(self._ds["state"])
da_x, da_y = ds_state.x, ds_state.y
@@ -299,6 +354,7 @@ def get_xy(self, category: str, stacked: bool) -> ndarray:
value of `stacked`:
- `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
+
"""
# assume variables are stored in dimensions [grid_index, ...]
ds_category = self.unstack_grid_coords(da_or_ds=self._ds[category])
diff --git a/neural_lam/datastore/multizarr/create_boundary_mask.py b/neural_lam/datastore/multizarr/create_boundary_mask.py
index ae154941..31966394 100644
--- a/neural_lam/datastore/multizarr/create_boundary_mask.py
+++ b/neural_lam/datastore/multizarr/create_boundary_mask.py
@@ -21,6 +21,7 @@ def create_boundary_mask(data_config_path, zarr_path, n_boundary_cells):
Data configuration.
zarr_path : str
Path to save the Zarr archive.
+
"""
data_config_path = config.Config.from_file(str(data_config_path))
mask = np.zeros(list(data_config_path.grid_shape_state.values.values()))
diff --git a/neural_lam/datastore/multizarr/create_datetime_forcings.py b/neural_lam/datastore/multizarr/create_datetime_forcings.py
index 82a90147..7b645cae 100644
--- a/neural_lam/datastore/multizarr/create_datetime_forcings.py
+++ b/neural_lam/datastore/multizarr/create_datetime_forcings.py
@@ -36,6 +36,7 @@ def calculate_datetime_forcing(da_time: xr.DataArray):
- hour_cos: The cosine of the hour of the day, normalized to [0, 1].
- year_sin: The sine of the time of year, normalized to [0, 1].
- year_cos: The cosine of the time of year, normalized to [0, 1].
+
"""
hours_of_day = xr.DataArray(da_time.dt.hour, dims=["time"])
seconds_into_year = xr.DataArray(
@@ -49,10 +50,7 @@ def calculate_datetime_forcing(da_time: xr.DataArray):
dims=["time"],
)
year_seconds = xr.DataArray(
- [
- get_seconds_in_year(pd.Timestamp(dt_obj).year)
- for dt_obj in da_time.values
- ],
+ [get_seconds_in_year(pd.Timestamp(dt_obj).year) for dt_obj in da_time.values],
dims=["time"],
)
hour_angle = (hours_of_day / 12) * np.pi
@@ -85,6 +83,7 @@ def create_datetime_forcing_zarr(
The time DataArray for which to create the datetime forcing.
chunking : dict, optional
The chunking to use when saving the Zarr archive.
+
"""
if zarr_path is None:
zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME
@@ -92,9 +91,9 @@ def create_datetime_forcing_zarr(
datastore = MultiZarrDatastore(config_path=data_config_path)
da_state = datastore.get_dataarray(category="state", split="train")
- da_datetime_forcing = calculate_datetime_forcing(
- da_time=da_state.time
- ).expand_dims({"grid_index": da_state.grid_index})
+ da_datetime_forcing = calculate_datetime_forcing(da_time=da_state.time).expand_dims(
+ {"grid_index": da_state.grid_index}
+ )
if "x" in da_state.coords and "y" in da_state.coords:
# copy the x and y coordinates to the datetime forcing
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index b4cf1be6..7a6df4d2 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -21,8 +21,8 @@ def create_normalization_stats_zarr(
data_config_path: str,
zarr_path: str = None,
):
- """Compute mean and std.-dev. for state and forcing variables and save them
- to a Zarr file.
+ """Compute mean and std.-dev. for state and forcing variables and save them to a
+ Zarr file.
Parameters
----------
@@ -32,6 +32,7 @@ def create_normalization_stats_zarr(
Path to save the normalization statistics to. If not provided, the
statistics are saved to the same directory as the data config file with
the name `normalization.zarr`.
+
"""
if zarr_path is None:
zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME
@@ -54,9 +55,7 @@ def create_normalization_stats_zarr(
for group in combined_stats:
vars_to_combine = group["vars"]
- da_forcing_means = da_forcing_mean.sel(
- forcing_feature=vars_to_combine
- )
+ da_forcing_means = da_forcing_mean.sel(forcing_feature=vars_to_combine)
stds = da_forcing_std.sel(forcing_feature=vars_to_combine)
combined_mean = da_forcing_means.mean(dim="forcing_feature")
@@ -65,12 +64,8 @@ def create_normalization_stats_zarr(
da_forcing_mean.loc[
dict(forcing_feature=vars_to_combine)
] = combined_mean
- da_forcing_std.loc[
- dict(forcing_feature=vars_to_combine)
- ] = combined_std
- print(
- "Computing mean and std.-dev. for one-step differences...", flush=True
- )
+ da_forcing_std.loc[dict(forcing_feature=vars_to_combine)] = combined_std
+ print("Computing mean and std.-dev. for one-step differences...", flush=True)
state_data_normalized = (da_state - da_state_mean) / da_state_std
state_data_diff_normalized = state_data_normalized.diff(dim="time")
diff_mean, diff_std = compute_stats(state_data_diff_normalized)
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 1a3a2a89..18af8457 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -18,15 +18,25 @@ class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable_name"}
def __init__(self, config_path):
- """Create a multi-zarr datastore from the given configuration file. The
- configuration file should be a YAML file, the format of which is should
- be inferred from the example configuration file in
- `tests/datastore_examples/multizarr/data_config.yml`.
+ """Create a multi-zarr
+ datastore from the
+ given configuration
+ file. The
+ configuration file
+ should be a YAML file,
+ the format of which is
+ should be inferred
+ from the example
+ configuration file in
+ `tests/datastore_examp
+ les/multizarr/data_con
+ fig.yml`.
Parameters
----------
config_path : str
The path to the configuration file.
+
"""
self._config_path = Path(config_path)
self._root_path = self._config_path.parent
@@ -41,6 +51,7 @@ def root_path(self):
-------
str
The root path of the datastore.
+
"""
return self._root_path
@@ -80,6 +91,7 @@ def open_zarrs(self, category):
-------
xr.Dataset
The xarray Dataset object.
+
"""
zarr_configs = self._config[category]["zarrs"]
@@ -104,6 +116,7 @@ def coords_projection(self):
Returns:
cartopy.crs.Projection: The projection object.
+
"""
proj_config = self._config["projection"]
proj_class_name = proj_config["class"]
@@ -117,6 +130,7 @@ def step_length(self):
Returns:
int: The step length in hours.
+
"""
dataset = self.open_zarrs("state")
time = dataset.time.isel(time=slice(0, 2)).values
@@ -133,6 +147,7 @@ def get_vars_names(self, category):
Returns:
list: The names of the variables in the dataset.
+
"""
surface_vars_names = self._config[category].get("surface_vars") or []
atmosphere_vars_names = [
@@ -151,6 +166,7 @@ def get_vars_units(self, category):
Returns:
list: The units of the variables in the dataset.
+
"""
surface_vars_units = self._config[category].get("surface_units") or []
atmosphere_vars_units = [
@@ -169,14 +185,13 @@ def get_num_data_vars(self, category):
Returns:
int: The number of data variables in the dataset.
+
"""
surface_vars = self._config[category].get("surface_vars", [])
atmosphere_vars = self._config[category].get("atmosphere_vars", [])
levels = self._config[category].get("levels", [])
- surface_vars_count = (
- len(surface_vars) if surface_vars is not None else 0
- )
+ surface_vars_count = len(surface_vars) if surface_vars is not None else 0
atmosphere_vars_count = (
len(atmosphere_vars) if atmosphere_vars is not None else 0
)
@@ -192,6 +207,7 @@ def _stack_grid(self, ds):
Returns:
xr.Dataset: The xarray Dataset object with stacked grid dimensions.
+
"""
if "grid_index" in ds.dims:
raise ValueError("Grid dimensions already stacked.")
@@ -212,6 +228,7 @@ def _convert_dataset_to_dataarray(self, dataset):
Returns:
xr.DataArray: The xarray DataArray object.
+
"""
if isinstance(dataset, xr.Dataset):
dataset = dataset.to_array(dim="variable_name")
@@ -227,6 +244,7 @@ def _filter_dimensions(self, dataset, transpose_array=True):
Returns:
xr.Dataset: The xarray Dataset object with filtered dimensions.
OR xr.DataArray: The xarray DataArray object with filtered dimensions.
+
"""
dims_to_keep = self.DIMS_TO_KEEP
dataset_dims = set(list(dataset.dims) + ["variable_name"])
@@ -277,9 +295,7 @@ def _filter_dimensions(self, dataset, transpose_array=True):
dataset = self._convert_dataset_to_dataarray(dataset)
if "time" in dataset.dims:
- dataset = dataset.transpose(
- "time", "grid_index", "variable_name"
- )
+ dataset = dataset.transpose("time", "grid_index", "variable_name")
else:
dataset = dataset.transpose("grid_index", "variable_name")
dataset_vars = (
@@ -304,6 +320,7 @@ def _reshape_grid_to_2d(self, dataset, grid_shape=None):
Returns:
xr.Dataset: The xarray Dataset object with reshaped grid dimensions.
+
"""
if grid_shape is None:
grid_shape = dict(self.grid_shape_state.values.items())
@@ -311,13 +328,9 @@ def _reshape_grid_to_2d(self, dataset, grid_shape=None):
x_coords = np.arange(x_dim)
y_coords = np.arange(y_dim)
- multi_index = pd.MultiIndex.from_product(
- [y_coords, x_coords], names=["y", "x"]
- )
+ multi_index = pd.MultiIndex.from_product([y_coords, x_coords], names=["y", "x"])
- mindex_coords = xr.Coordinates.from_pandas_multiindex(
- multi_index, "grid"
- )
+ mindex_coords = xr.Coordinates.from_pandas_multiindex(multi_index, "grid")
dataset = dataset.drop_vars(["grid", "x", "y"], errors="ignore")
dataset = dataset.assign_coords(mindex_coords)
reshaped_data = dataset.unstack("grid")
@@ -342,13 +355,12 @@ def get_xy(self, category, stacked=True):
value of `stacked`:
- `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
+
"""
dataset = self.open_zarrs(category)
xs, ys = dataset.x.values, dataset.y.values
- assert (
- xs.ndim == ys.ndim
- ), "x and y coordinates must have the same dimensions."
+ assert xs.ndim == ys.ndim, "x and y coordinates must have the same dimensions."
if xs.ndim == 1:
x, y = np.meshgrid(xs, ys)
@@ -366,14 +378,33 @@ def get_xy(self, category, stacked=True):
@functools.lru_cache()
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This
- should contain a `{category}_mean` and `{category}_std` variable for
- each variable in the category. For `category=="state"`, the dataarray
- should also contain a `state_diff_mean` and `state_diff_std` variable
- for the one-step differences of the state variables. The return
- dataarray should at least have dimensions of `({category}_feature)`,
- but can also include for example `grid_index` (if the normalisation is
- done per grid point for example).
+ """Return the
+ normalization
+ dataarray for the
+ given category. This
+ should contain a
+ `{category}_mean` and
+ `{category}_std`
+ variable for each
+ variable in the
+ category. For
+ `category=="state"`,
+ the dataarray should
+ also contain a
+ `state_diff_mean` and
+ `state_diff_std`
+ variable for the one-
+ step differences of
+ the state variables.
+ The return dataarray
+ should at least have
+ dimensions of `({categ
+ ory}_feature)`, but
+ can also include for
+ example `grid_index`
+ (if the normalisation
+ is done per grid point
+ for example).
Parameters
----------
@@ -386,6 +417,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
The normalization dataarray for the given category, with variables
for the mean and standard deviation of the variables (and
differences for state variables).
+
"""
# XXX: the multizarr code didn't include routines for computing the
# normalization of "static" features previously, we'll just hack
@@ -423,6 +455,7 @@ def _load_and_merge_stats(self):
Returns:
xr.Dataset: The merged normalization statistics for the dataset.
+
"""
combined_stats = None
for i, zarr_config in enumerate(
@@ -449,6 +482,7 @@ def _rename_data_vars(self, combined_stats):
Returns:
xr.Dataset: The combined normalization statistics with renamed data
variables.
+
"""
vars_mapping = {}
for zarr_config in self._config["utilities"]["normalization"]["zarrs"]:
@@ -471,6 +505,7 @@ def _select_stats_by_category(self, combined_stats, category):
Returns:
xr.Dataset: The normalization statistics for the dataset.
+
"""
if category == "state":
stats = combined_stats.loc[
@@ -479,9 +514,7 @@ def _select_stats_by_category(self, combined_stats, category):
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
return stats
elif category == "forcing":
- non_normalized_vars = (
- self.utilities.normalization.non_normalized_vars
- )
+ non_normalized_vars = self.utilities.normalization.non_normalized_vars
if non_normalized_vars is None:
non_normalized_vars = []
forcing_vars = self.vars_names(category)
@@ -517,6 +550,7 @@ def _extract_vars(self, category, ds=None):
Returns:
xr.Dataset: The xarray Dataset object with extracted variables.
+
"""
if ds is None:
ds = self.open_zarrs(category)
@@ -529,9 +563,7 @@ def _extract_vars(self, category, ds=None):
ds_atmosphere = None
if atmoshere_vars is not None:
- ds_atmosphere = self._extract_atmosphere_vars(
- category=category, ds=ds
- )
+ ds_atmosphere = self._extract_atmosphere_vars(category=category, ds=ds)
if ds_surface and ds_atmosphere:
return xr.merge([ds_surface, ds_atmosphere])
@@ -551,15 +583,11 @@ def _extract_atmosphere_vars(self, category, ds):
Returns:
xr.Dataset: The xarray Dataset object with atmosphere variables.
+
"""
- if (
- "level" not in list(ds.dims)
- and self._config[category]["atmosphere_vars"]
- ):
- ds = self._rename_dataset_dims_and_vars(
- ds.attrs["category"], dataset=ds
- )
+ if "level" not in list(ds.dims) and self._config[category]["atmosphere_vars"]:
+ ds = self._rename_dataset_dims_and_vars(ds.attrs["category"], dataset=ds)
data_arrays = [
ds[var].sel(level=level, drop=True).rename(f"{var}_{level}")
@@ -585,6 +613,7 @@ def _rename_dataset_dims_and_vars(self, category, dataset=None):
variables.
OR xr.DataArray: The xarray DataArray object with renamed
dimensions and variables.
+
"""
convert = False
if dataset is None:
@@ -620,6 +649,7 @@ def _apply_time_split(self, dataset, split="train"):
Returns:["window"]
xr.Dataset: The xarray Dataset object filtered by the time split.
+
"""
start, end = (
self._config["splits"][split]["start"],
@@ -635,6 +665,7 @@ def grid_shape_state(self):
Returns:
CartesianGridShape: The shape of the state grid.
+
"""
return CartesianGridShape(
x=self._config["grid_shape_state"]["x"],
@@ -643,13 +674,13 @@ def grid_shape_state(self):
@property
def boundary_mask(self) -> xr.DataArray:
- """Load the boundary mask for the dataset, with spatial dimensions
- stacked.
+ """Load the boundary mask for the dataset, with spatial dimensions stacked.
Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions `('grid_index',)`.
+
"""
boundary_mask_path = self._normalize_path(
self._config["boundary"]["mask"]["path"]
@@ -670,6 +701,7 @@ def get_dataarray(self, category, split="train"):
Returns:
xr.DataArray: The xarray DataArray object with processed dataset.
+
"""
dataset = self.open_zarrs(category)
dataset = self._extract_vars(category, dataset)
diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py
index afb08c77..5cdb22ea 100644
--- a/neural_lam/datastore/npyfiles/config.py
+++ b/neural_lam/datastore/npyfiles/config.py
@@ -8,14 +8,14 @@
@dataclass
class Projection:
- """Represents the projection information for a dataset, including the type
- of projection and its parameters. Capable of creating a cartopy.crs
- projection object.
+ """Represents the projection information for a dataset, including the type of
+ projection and its parameters. Capable of creating a cartopy.crs projection object.
Attributes:
class_name: The class name of the projection, this should be a valid
cartopy.crs class.
kwargs: A dictionary of keyword arguments specific to the projection type.
+
"""
class_name: str
@@ -24,8 +24,8 @@ class Projection:
@dataclass
class Dataset:
- """Contains information about the dataset, including variable names, units,
- and descriptions.
+ """Contains information about the dataset, including variable names, units, and
+ descriptions.
Attributes:
name: The name of the dataset.
@@ -33,6 +33,7 @@ class Dataset:
var_units: A list of units for each variable.
var_longnames: A list of long, descriptive names for each variable.
num_forcing_features: The number of forcing features in the dataset.
+
"""
name: str
@@ -44,13 +45,14 @@ class Dataset:
@dataclass
class NpyDatastoreConfig(dataclass_wizard.YAMLWizard):
- """Configuration for loading and processing a dataset, including dataset
- details, grid shape, and projection information.
+ """Configuration for loading and processing a dataset, including dataset details,
+ grid shape, and projection information.
Attributes:
dataset: An instance of Dataset containing details about the dataset.
grid_shape_state: A list representing the shape of the grid state.
projection: An instance of Projection containing projection details.
+
"""
dataset: Dataset
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 674c368d..630a8dd0 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -1,5 +1,5 @@
-"""Numpy-files based datastore to support the MEPS example dataset introduced
-in neural-lam v0.1.0."""
+"""Numpy-files based datastore to support the MEPS example dataset introduced in neural-
+lam v0.1.0."""
# Standard library
import functools
import re
@@ -138,9 +138,17 @@ def __init__(
self,
config_path,
):
- """Create a new NpyFilesDatastore using the configuration file at the
- given path. The config file should be a YAML file and will be loaded
- into an instance of the `NpyDatastoreConfig` dataclass.
+ """Create a new
+ NpyFilesDatastore
+ using the
+ configuration file at
+ the given path. The
+ config file should be
+ a YAML file and will
+ be loaded into an
+ instance of the
+ `NpyDatastoreConfig`
+ dataclass.
Internally, the datastore uses dask.delayed to load the data from the
numpy files, so that the data isn't actually loaded until it's needed.
@@ -149,6 +157,7 @@ def __init__(
----------
config_path : str
The path to the configuration file for the datastore.
+
"""
# XXX: This should really be in the config file, not hard-coded in this class
self._num_timesteps = 65
@@ -161,21 +170,32 @@ def __init__(
@property
def root_path(self) -> Path:
- """The root path of the datastore on disk. This is the directory
- relative to which graphs and other files can be stored.
+ """The root path of the datastore on disk. This is the directory relative to
+ which graphs and other files can be stored.
Returns
-------
Path
The root path of the datastore
+
"""
return self._root_path
def get_dataarray(self, category: str, split: str) -> DataArray:
- """Get the data array for the given category and split of data. If the
- category is 'state', the data array will be a concatenation of the data
- arrays for all ensemble members. The data will be loaded as a dask
- array, so that the data isn't actually loaded until it's needed.
+ """Get the data array
+ for the given category
+ and split of data. If
+ the category is
+ 'state', the data
+ array will be a
+ concatenation of the
+ data arrays for all
+ ensemble members. The
+ data will be loaded as
+ a dask array, so that
+ the data isn't
+ actually loaded until
+ it's needed.
Parameters
----------
@@ -193,6 +213,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
ensemble_member]`
forcing: `[elapsed_forecast_duration, analysis_time, grid_index, feature]`
static: `[grid_index, feature]`
+
"""
if category == "state":
das = []
@@ -211,9 +232,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# them separately
features = ["toa_downwelling_shortwave_flux", "column_water"]
das = [
- self._get_single_timeseries_dataarray(
- features=[feature], split=split
- )
+ self._get_single_timeseries_dataarray(features=[feature], split=split)
for feature in features
]
da = xr.concat(das, dim="feature")
@@ -225,9 +244,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# .chunk({"elapsed_forecast_duration": 1}) this time variable is turned
# into a dask array and so execution of the calculation is delayed
# until the feature values are actually used.
- da_forecast_time = (
- da.analysis_time + da.elapsed_forecast_duration
- ).chunk({"elapsed_forecast_duration": 1})
+ da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk(
+ {"elapsed_forecast_duration": 1}
+ )
da_datetime_forcing_features = self._calc_datetime_forcing_features(
da_time=da_forecast_time
)
@@ -248,9 +267,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
features=features, split=split
)
das.append(da)
- da = xr.concat(das, dim="feature").transpose(
- "grid_index", "feature"
- )
+ da = xr.concat(das, dim="feature").transpose("grid_index", "feature")
else:
raise NotImplementedError(category)
@@ -270,11 +287,21 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
def _get_single_timeseries_dataarray(
self, features: List[str], split: str, member: int = None
) -> DataArray:
- """Get the data array spanning the complete time series for a given set
- of features and split of data. For state features the `member` argument
- should be specified to select the ensemble member to load. The data
- will be loaded using dask.delayed, so that the data isn't actually
- loaded until it's needed.
+ """Get the data array
+ spanning the complete
+ time series for a
+ given set of features
+ and split of data. For
+ state features the
+ `member` argument
+ should be specified to
+ select the ensemble
+ member to load. The
+ data will be loaded
+ using dask.delayed, so
+ that the data isn't
+ actually loaded until
+ it's needed.
Parameters
----------
@@ -296,15 +323,12 @@ def _get_single_timeseries_dataarray(
The data array for the given category and split, with dimensions
`[elapsed_forecast_duration, analysis_time, grid_index, feature]` for
all categories of data
+
"""
assert split in ("train", "val", "test"), "Unknown dataset split"
- if member is not None and features != self.get_vars_names(
- category="state"
- ):
- raise ValueError(
- "Member can only be specified for the 'state' category"
- )
+ if member is not None and features != self.get_vars_names(category="state"):
+ raise ValueError("Member can only be specified for the 'state' category")
# XXX: we here assume that the grid shape is the same for all categories
grid_shape = self.grid_shape_state
@@ -387,9 +411,7 @@ def _get_single_timeseries_dataarray(
if features_vary_with_analysis_time:
filepaths = [
fp_samples
- / filename_format.format(
- analysis_time=analysis_time, **file_params
- )
+ / filename_format.format(analysis_time=analysis_time, **file_params)
for analysis_time in coords["analysis_time"]
]
else:
@@ -425,8 +447,8 @@ def _get_single_timeseries_dataarray(
return da
def _get_analysis_times(self, split) -> List[np.datetime64]:
- """Get the analysis times for the given split by parsing the filenames
- of all the files found for the given split.
+ """Get the analysis times for the given split by parsing the filenames of all
+ the files found for the given split.
Parameters
----------
@@ -437,6 +459,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]:
-------
List[dt.datetime]
The analysis times for the given split.
+
"""
pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT)
pattern = re.sub(r"{member_id:[^}]*}", "*", pattern)
@@ -449,9 +472,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]:
times.append(name_parts["analysis_time"])
if len(times) == 0:
- raise ValueError(
- f"No files found in {sample_dir} with pattern {pattern}"
- )
+ raise ValueError(f"No files found in {sample_dir} with pattern {pattern}")
return times
@@ -534,6 +555,7 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
value of `stacked`:
- `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
+
"""
# the array on disk has shape [2, N_x, N_y], but we want to return it
@@ -557,6 +579,7 @@ def step_length(self) -> int:
-------
int
The length of each time step in hours.
+
"""
return self._step_length
@@ -568,19 +591,21 @@ def grid_shape_state(self) -> CartesianGridShape:
-------
CartesianGridShape
The shape of the cartesian grid for the state variables.
+
"""
nx, ny = self.config.grid_shape_state
return CartesianGridShape(x=nx, y=ny)
@property
def boundary_mask(self) -> xr.DataArray:
- """The boundary mask for the dataset. This is a binary mask that is 1
- where the grid cell is on the boundary of the domain, and 0 otherwise.
+ """The boundary mask for the dataset. This is a binary mask that is 1 where the
+ grid cell is on the boundary of the domain, and 0 otherwise.
Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions `[grid_index]`.
+
"""
xs, ys = self.get_xy(category="state", stacked=False)
assert np.all(xs[:, 0] == xs[:, -1])
@@ -595,11 +620,24 @@ def boundary_mask(self) -> xr.DataArray:
return da_mask_stacked_xy
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This
- should contain a `{category}_mean` and `{category}_std` variable for
- each variable in the category. For `category=="state"`, the dataarray
- should also contain a `state_diff_mean` and `state_diff_std` variable
- for the one-step differences of the state variables.
+ """Return the
+ normalization
+ dataarray for the
+ given category. This
+ should contain a
+ `{category}_mean` and
+ `{category}_std`
+ variable for each
+ variable in the
+ category. For
+ `category=="state"`,
+ the dataarray should
+ also contain a
+ `state_diff_mean` and
+ `state_diff_std`
+ variable for the one-
+ step differences of
+ the state variables.
Parameters
----------
@@ -612,6 +650,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
The normalization dataarray for the given category, with variables
for the mean and standard deviation of the variables (and
differences for state variables).
+
"""
def load_pickled_tensor(fn):
@@ -666,6 +705,7 @@ def coords_projection(self) -> ccrs.Projection:
-------
ccrs.Projection
The projection of the spatial coordinates.
+
"""
proj_class_name = self.config.projection.class_name
ProjectionClass = getattr(ccrs, proj_class_name)
diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py
index 4ed3e3eb..5ad0fdca 100644
--- a/neural_lam/interaction_net.py
+++ b/neural_lam/interaction_net.py
@@ -11,6 +11,7 @@ class InteractionNet(pyg.nn.MessagePassing):
"""Implementation of a generic Interaction Network, from Battaglia et al.
(2016)
+
"""
# pylint: disable=arguments-differ
@@ -43,6 +44,7 @@ def __init__(
representation into and use separate MLPs for
(None = no chunking, same MLP)
aggr: Message aggregation method (sum/mean)
+
"""
assert aggr in ("sum", "mean"), f"Unknown aggregation method: {aggr}"
super().__init__(aggr=aggr)
@@ -55,9 +57,7 @@ def __init__(
edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0]
# Store number of receiver nodes according to edge_index
self.num_rec = edge_index[1].max() + 1
- edge_index[0] = (
- edge_index[0] + self.num_rec
- ) # Make sender indices after rec
+ edge_index[0] = edge_index[0] + self.num_rec # Make sender indices after rec
self.register_buffer("edge_index", edge_index, persistent=False)
# Create MLPs
@@ -83,8 +83,8 @@ def __init__(
self.update_edges = update_edges
def forward(self, send_rep, rec_rep, edge_rep):
- """Apply interaction network to update the representations of receiver
- nodes, and optionally the edge representations.
+ """Apply interaction network to update the representations of receiver nodes,
+ and optionally the edge representations.
send_rep: (N_send, d_h), vector representations of sender nodes
rec_rep: (N_rec, d_h), vector representations of receiver nodes
@@ -94,6 +94,7 @@ def forward(self, send_rep, rec_rep, edge_rep):
rec_rep: (N_rec, d_h), updated vector representations of receiver nodes
(optionally) edge_rep: (M, d_h), updated vector representations
of edges
+
"""
# Always concatenate to [rec_nodes, send_nodes] for propagation,
# but only aggregate to rec_nodes
@@ -130,8 +131,11 @@ def aggregate(self, inputs, index, ptr, dim_size):
class SplitMLPs(nn.Module):
"""Module that feeds chunks of input through different MLPs.
- Split up input along dim -2 using given chunk sizes and feeds each
- chunk through separate MLPs.
+ Split up input along dim
+ -2 using given chunk sizes
+ and feeds each chunk
+ through separate MLPs.
+
"""
def __init__(self, mlps, chunk_sizes):
@@ -150,6 +154,7 @@ def forward(self, x):
Returns:
joined_output: (..., N, d), concatenated results from the MLPs
+
"""
chunks = torch.split(x, self.chunk_sizes, dim=-2)
chunk_outputs = [
diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py
index 1ed4fb08..324440a8 100644
--- a/neural_lam/metrics.py
+++ b/neural_lam/metrics.py
@@ -9,11 +9,10 @@ def get_metric(metric_name):
Returns:
metric: function implementing the metric
+
"""
metric_name_lower = metric_name.lower()
- assert (
- metric_name_lower in DEFINED_METRICS
- ), f"Unknown metric: {metric_name}"
+ assert metric_name_lower in DEFINED_METRICS, f"Unknown metric: {metric_name}"
return DEFINED_METRICS[metric_name_lower]
@@ -31,22 +30,17 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
+
"""
# Only keep grid nodes in mask
if mask is not None:
- metric_entry_vals = metric_entry_vals[
- ..., mask, :
- ] # (..., N', d_state)
+ metric_entry_vals = metric_entry_vals[..., mask, :] # (..., N', d_state)
# Optionally reduce last two dimensions
if average_grid: # Reduce grid first
- metric_entry_vals = torch.mean(
- metric_entry_vals, dim=-2
- ) # (..., d_state)
+ metric_entry_vals = torch.mean(metric_entry_vals, dim=-2) # (..., d_state)
if sum_vars: # Reduce vars second
- metric_entry_vals = torch.sum(
- metric_entry_vals, dim=-1
- ) # (..., N) or (...,)
+ metric_entry_vals = torch.sum(metric_entry_vals, dim=-1) # (..., N) or (...,)
return metric_entry_vals
@@ -67,6 +61,7 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
+
"""
entry_mse = torch.nn.functional.mse_loss(
pred, target, reduction="none"
@@ -97,11 +92,10 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
+
"""
# Replace pred_std with constant ones
- return wmse(
- pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars
- )
+ return wmse(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars)
def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
@@ -120,6 +114,7 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
+
"""
entry_mae = torch.nn.functional.l1_loss(
pred, target, reduction="none"
@@ -150,11 +145,10 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
+
"""
# Replace pred_std with constant ones
- return wmae(
- pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars
- )
+ return wmae(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars)
def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
@@ -173,6 +167,7 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
+
"""
# Broadcast pred_std if shaped (d_state,), done internally in Normal class
dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state)
@@ -183,11 +178,9 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)
-def crps_gauss(
- pred, target, pred_std, mask=None, average_grid=True, sum_vars=True
-):
- """(Negative) Continuous Ranked Probability Score (CRPS) Closed-form
- expression based on Gaussian predictive distribution.
+def crps_gauss(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
+ """(Negative) Continuous Ranked Probability Score (CRPS) Closed-form expression
+ based on Gaussian predictive distribution.
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -202,6 +195,7 @@ def crps_gauss(
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
+
"""
std_normal = torch.distributions.Normal(
torch.zeros((), device=pred.device), torch.ones((), device=pred.device)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index cea723b0..eadd9445 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -17,14 +17,13 @@ class ARModel(pl.LightningModule):
"""Generic auto-regressive weather model.
Abstract class that can be extended.
+
"""
# pylint: disable=arguments-differ
# Disable to override args/kwargs from superclass
- def __init__(
- self, args, datastore: BaseDatastore, forcing_window_size: int
- ):
+ def __init__(self, args, datastore: BaseDatastore, forcing_window_size: int):
super().__init__()
self.save_hyperparameters(ignore=["datastore"])
self.args = args
@@ -33,17 +32,13 @@ def __init__(
split = "train"
num_state_vars = datastore.get_num_data_vars(category="state")
num_forcing_vars = datastore.get_num_data_vars(category="forcing")
- da_static_features = datastore.get_dataarray(
- category="static", split=split
- )
+ da_static_features = datastore.get_dataarray(category="static", split=split)
da_state_stats = datastore.get_normalization_dataarray(category="state")
da_boundary_mask = datastore.boundary_mask
# Load static features for grid/data, NB: self.predict_step assumes dimension
# order to be (grid_index, static_feature)
- arr_static = da_static_features.transpose(
- "grid_index", "static_feature"
- ).values
+ arr_static = da_static_features.transpose("grid_index", "static_feature").values
self.register_buffer(
"grid_static_features",
torch.tensor(arr_static, dtype=torch.float32),
@@ -136,9 +131,7 @@ def __init__(
self.spatial_loss_maps = []
def configure_optimizers(self):
- opt = torch.optim.AdamW(
- self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
- )
+ opt = torch.optim.AdamW(self.parameters(), lr=self.args.lr, betas=(0.9, 0.95))
return opt
@property
@@ -185,8 +178,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
# Overwrite border with true state
new_state = (
- self.boundary_mask * border_state
- + self.interior_mask * pred_state
+ self.boundary_mask * border_state + self.interior_mask * pred_state
)
prediction_list.append(new_state)
@@ -231,9 +223,7 @@ def training_step(self, batch):
# Compute loss
batch_loss = torch.mean(
- self.loss(
- prediction, target, pred_std, mask=self.interior_mask_bool
- )
+ self.loss(prediction, target, pred_std, mask=self.interior_mask_bool)
) # mean over unrolled times and batch
log_dict = {"train_loss": batch_loss}
@@ -248,12 +238,13 @@ def training_step(self, batch):
return batch_loss
def all_gather_cat(self, tensor_to_gather):
- """Gather tensors across all ranks, and concatenate across dim. 0
- (instead of stacking in new dim. 0)
+ """Gather tensors across all ranks, and concatenate across dim. 0 (instead of
+ stacking in new dim. 0)
tensor_to_gather: (d1, d2, ...), distributed over K ranks
returns: (K*d1, d2, ...)
+
"""
return self.all_gather(tensor_to_gather).flatten(0, 1)
@@ -264,9 +255,7 @@ def validation_step(self, batch, batch_idx):
prediction, target, pred_std, _ = self.common_step(batch)
time_step_loss = torch.mean(
- self.loss(
- prediction, target, pred_std, mask=self.interior_mask_bool
- ),
+ self.loss(prediction, target, pred_std, mask=self.interior_mask_bool),
dim=0,
) # (time_steps-1)
mean_loss = torch.mean(time_step_loss)
@@ -314,9 +303,7 @@ def test_step(self, batch, batch_idx):
# pred_steps, num_grid_nodes, d_f) or (d_f,)
time_step_loss = torch.mean(
- self.loss(
- prediction, target, pred_std, mask=self.interior_mask_bool
- ),
+ self.loss(prediction, target, pred_std, mask=self.interior_mask_bool),
dim=0,
) # (time_steps-1,)
mean_loss = torch.mean(time_step_loss)
@@ -368,19 +355,14 @@ def test_step(self, batch, batch_idx):
# (B, N_log, num_grid_nodes)
# Plot example predictions (on rank 0 only)
- if (
- self.trainer.is_global_zero
- and self.plotted_examples < self.n_example_pred
- ):
+ if self.trainer.is_global_zero and self.plotted_examples < self.n_example_pred:
# Need to plot more example predictions
n_additional_examples = min(
prediction.shape[0],
self.n_example_pred - self.plotted_examples,
)
- self.plot_examples(
- batch, n_additional_examples, prediction=prediction
- )
+ self.plot_examples(batch, n_additional_examples, prediction=prediction)
def plot_examples(self, batch, n_examples, prediction=None):
"""Plot the first n_examples forecasts from batch.
@@ -389,6 +371,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes,
d_f), existing prediction.
Generate if None.
+
"""
if prediction is None:
prediction, target, _, _ = self.common_step(batch)
@@ -457,16 +440,12 @@ def plot_examples(self, batch, n_examples, prediction=None):
)
}
)
- plt.close(
- "all"
- ) # Close all figs for this time step, saves memory
+ plt.close("all") # Close all figs for this time step, saves memory
# Save pred and target as .pt files
torch.save(
pred_slice.cpu(),
- os.path.join(
- wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"
- ),
+ os.path.join(wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"),
)
torch.save(
target_slice.cpu(),
@@ -476,14 +455,15 @@ def plot_examples(self, batch, n_examples, prediction=None):
)
def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
- """Put together a dict with everything to log for one metric. Also
- saves plots as pdf and csv if using test prefix.
+ """Put together a dict with everything to log for one metric. Also saves plots
+ as pdf and csv if using test prefix.
metric_tensor: (pred_steps, d_f), metric values per time and variable
prefix: string, prefix to use for logging metric_name: string, name of
the metric
Return: log_dict: dict with everything to log for given metric
+
"""
log_dict = {}
metric_fig = vis.plot_error_map(
@@ -496,9 +476,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
if prefix == "test":
# Save pdf
- metric_fig.savefig(
- os.path.join(wandb.run.dir, f"{full_log_name}.pdf")
- )
+ metric_fig.savefig(os.path.join(wandb.run.dir, f"{full_log_name}.pdf"))
# Save errors also as csv
np.savetxt(
os.path.join(wandb.run.dir, f"{full_log_name}.csv"),
@@ -522,12 +500,12 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
return log_dict
def aggregate_and_plot_metrics(self, metrics_dict, prefix):
- """Aggregate and create error map plots for all metrics in
- metrics_dict.
+ """Aggregate and create error map plots for all metrics in metrics_dict.
metrics_dict: dictionary with metric_names and list of tensors
with step-evals.
prefix: string, prefix to use for logging
+
"""
log_dict = {}
for metric_name, metric_val_list in metrics_dict.items():
@@ -548,9 +526,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
metric_rescaled = metric_tensor_averaged * self.state_std
# (pred_steps, d_f)
log_dict.update(
- self.create_metric_log_dict(
- metric_rescaled, prefix, metric_name
- )
+ self.create_metric_log_dict(metric_rescaled, prefix, metric_name)
)
if self.trainer.is_global_zero and not self.trainer.sanity_checking:
@@ -560,8 +536,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
def on_test_epoch_end(self):
"""Compute test metrics and make plots at the end of test epoch.
- Will gather stored tensors and perform plotting and logging on
- rank 0.
+ Will gather stored tensors and perform plotting and logging on rank 0.
+
"""
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")
@@ -582,9 +558,7 @@ def on_test_epoch_end(self):
self.data_config,
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
- for t_i, loss_map in zip(
- self.args.val_steps_to_log, mean_spatial_loss
- )
+ for t_i, loss_map in zip(self.args.val_steps_to_log, mean_spatial_loss)
]
# log all to same wandb key, sequentially
@@ -624,9 +598,7 @@ def on_load_checkpoint(self, checkpoint):
)
)
for old_key in replace_keys:
- new_key = old_key.replace(
- "g2m_gnn.grid_mlp", "encoding_grid_mlp"
- )
+ new_key = old_key.replace("g2m_gnn.grid_mlp", "encoding_grid_mlp")
loaded_state_dict[new_key] = loaded_state_dict[old_key]
del loaded_state_dict[old_key]
if not self.restore_opt:
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index a76fc518..158275dd 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -8,8 +8,8 @@
class BaseGraphModel(ARModel):
- """Base (abstract) class for graph-based models building on the encode-
- process- decode idea."""
+ """Base (abstract) class for graph-based models building on the encode- process-
+ decode idea."""
def __init__(self, args, datastore, forcing_window_size):
super().__init__(
@@ -20,9 +20,7 @@ def __init__(self, args, datastore, forcing_window_size):
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
# num_mesh_nodes indices,
graph_dir_path = datastore.root_path / "graph" / args.graph
- self.hierarchical, graph_ldict = utils.load_graph(
- graph_dir_path=graph_dir_path
- )
+ self.hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path)
for name, attr_value in graph_ldict.items():
# Make BufferLists module members and register tensors as buffers
if isinstance(attr_value, torch.Tensor):
@@ -44,9 +42,7 @@ def __init__(self, args, datastore, forcing_window_size):
# Define sub-models
# Feature embedders for grid
self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1)
- self.grid_embedder = utils.make_mlp(
- [self.grid_dim] + self.mlp_blueprint_end
- )
+ self.grid_embedder = utils.make_mlp([self.grid_dim] + self.mlp_blueprint_end)
self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end)
self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end)
@@ -72,27 +68,26 @@ def __init__(self, args, datastore, forcing_window_size):
# Output mapping (hidden_dim -> output_dim)
self.output_map = utils.make_mlp(
- [args.hidden_dim] * (args.hidden_layers + 1)
- + [self.grid_output_dim],
+ [args.hidden_dim] * (args.hidden_layers + 1) + [self.grid_output_dim],
layer_norm=False,
) # No layer norm on this one
def get_num_mesh(self):
- """Compute number of mesh nodes from loaded features, and number of
- mesh nodes that should be ignored in encoding/decoding."""
+ """Compute number of mesh nodes from loaded features, and number of mesh nodes
+ that should be ignored in encoding/decoding."""
raise NotImplementedError("get_num_mesh not implemented")
def embedd_mesh_nodes(self):
- """Embed static mesh features Returns tensor of shape (num_mesh_nodes,
- d_h)"""
+ """Embed static mesh features Returns tensor of shape (num_mesh_nodes, d_h)"""
raise NotImplementedError("embedd_mesh_nodes not implemented")
def process_step(self, mesh_rep):
- """Process step of embedd-process-decode framework Processes the
- representation on the mesh, possible in multiple steps.
+ """Process step of embedd-process-decode framework Processes the representation
+ on the mesh, possible in multiple steps.
mesh_rep: has shape (B, num_mesh_nodes, d_h)
Returns mesh_rep: (B, num_mesh_nodes, d_h)
+
"""
raise NotImplementedError("process_step not implemented")
@@ -147,9 +142,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
) # (B, num_grid_nodes, d_h)
# Map to output dimension, only for grid
- net_output = self.output_map(
- grid_rep
- ) # (B, num_grid_nodes, d_grid_out)
+ net_output = self.output_map(grid_rep) # (B, num_grid_nodes, d_grid_out)
if self.output_std:
pred_delta_mean, pred_std_raw = net_output.chunk(
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index 14827f25..8bfc2c3e 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -96,28 +96,34 @@ def __init__(self, args):
)
def get_num_mesh(self):
- """Compute number of mesh nodes from loaded features, and number of
- mesh nodes that should be ignored in encoding/decoding."""
+ """Compute number of mesh nodes from loaded features, and number of mesh nodes
+ that should be ignored in encoding/decoding."""
num_mesh_nodes = sum(
node_feat.shape[0] for node_feat in self.mesh_static_features
)
- num_mesh_nodes_ignore = (
- num_mesh_nodes - self.mesh_static_features[0].shape[0]
- )
+ num_mesh_nodes_ignore = num_mesh_nodes - self.mesh_static_features[0].shape[0]
return num_mesh_nodes, num_mesh_nodes_ignore
def embedd_mesh_nodes(self):
- """Embed static mesh features This embeds only bottom level, rest is
- done at beginning of processing step Returns tensor of shape
- (num_mesh_nodes[0], d_h)"""
+ """Embed static mesh
+ features This embeds
+ only bottom level,
+ rest is done at
+ beginning of
+ processing step
+ Returns tensor of
+ shape
+ (num_mesh_nodes[0],
+ d_h)"""
return self.mesh_embedders[0](self.mesh_static_features[0])
def process_step(self, mesh_rep):
- """Process step of embedd-process-decode framework Processes the
- representation on the mesh, possible in multiple steps.
+ """Process step of embedd-process-decode framework Processes the representation
+ on the mesh, possible in multiple steps.
mesh_rep: has shape (B, num_mesh_nodes, d_h)
Returns mesh_rep: (B, num_mesh_nodes, d_h)
+
"""
batch_size = mesh_rep.shape[0]
@@ -136,21 +142,15 @@ def process_step(self, mesh_rep):
# Embed edges, expand with batch dimension
mesh_same_rep = [
self.expand_to_batch(emb(edge_feat), batch_size)
- for emb, edge_feat in zip(
- self.mesh_same_embedders, self.m2m_features
- )
+ for emb, edge_feat in zip(self.mesh_same_embedders, self.m2m_features)
]
mesh_up_rep = [
self.expand_to_batch(emb(edge_feat), batch_size)
- for emb, edge_feat in zip(
- self.mesh_up_embedders, self.mesh_up_features
- )
+ for emb, edge_feat in zip(self.mesh_up_embedders, self.mesh_up_features)
]
mesh_down_rep = [
self.expand_to_batch(emb(edge_feat), batch_size)
- for emb, edge_feat in zip(
- self.mesh_down_embedders, self.mesh_down_features
- )
+ for emb, edge_feat in zip(self.mesh_down_embedders, self.mesh_down_features)
]
# - MESH INIT. -
@@ -160,20 +160,14 @@ def process_step(self, mesh_rep):
send_node_rep = mesh_rep_levels[
level_l - 1
] # (B, num_mesh_nodes[l-1], d_h)
- rec_node_rep = mesh_rep_levels[
- level_l
- ] # (B, num_mesh_nodes[l], d_h)
+ rec_node_rep = mesh_rep_levels[level_l] # (B, num_mesh_nodes[l], d_h)
edge_rep = mesh_up_rep[level_l - 1]
# Apply GNN
- new_node_rep, new_edge_rep = gnn(
- send_node_rep, rec_node_rep, edge_rep
- )
+ new_node_rep, new_edge_rep = gnn(send_node_rep, rec_node_rep, edge_rep)
# Update node and edge vectors in lists
- mesh_rep_levels[
- level_l
- ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
+ mesh_rep_levels[level_l] = new_node_rep # (B, num_mesh_nodes[l], d_h)
mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h)
# - PROCESSOR -
@@ -190,18 +184,14 @@ def process_step(self, mesh_rep):
send_node_rep = mesh_rep_levels[
level_l + 1
] # (B, num_mesh_nodes[l+1], d_h)
- rec_node_rep = mesh_rep_levels[
- level_l
- ] # (B, num_mesh_nodes[l], d_h)
+ rec_node_rep = mesh_rep_levels[level_l] # (B, num_mesh_nodes[l], d_h)
edge_rep = mesh_down_rep[level_l]
# Apply GNN
new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep)
# Update node and edge vectors in lists
- mesh_rep_levels[
- level_l
- ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
+ mesh_rep_levels[level_l] = new_node_rep # (B, num_mesh_nodes[l], d_h)
# Return only bottom level representation
return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h)
@@ -209,8 +199,8 @@ def process_step(self, mesh_rep):
def hi_processor_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
):
- """Internal processor step of hierarchical graph models. Between mesh
- init and read out.
+ """Internal processor step of hierarchical graph models. Between mesh init and
+ read out.
Each input is list with representations, each with shape
@@ -220,5 +210,6 @@ def hi_processor_step(
mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
+
"""
raise NotImplementedError("hi_process_step not implemented")
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index 723c4678..55befd02 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -8,20 +8,18 @@
class GraphLAM(BaseGraphModel):
- """Full graph-based LAM model that can be used with different (non-
- hierarchical )graphs.
+ """Full graph-based LAM model that can be used with different (non- hierarchical
+ )graphs.
+
+ Mainly based on GraphCast, but the model from Keisler (2022) is almost identical.
+ Used for GC-LAM and L1-LAM in Oskarsson et al. (2023).
- Mainly based on GraphCast, but the model from Keisler (2022) is
- almost identical. Used for GC-LAM and L1-LAM in Oskarsson et al.
- (2023).
"""
def __init__(self, args, datastore, forcing_window_size):
super().__init__(args, datastore, forcing_window_size)
- assert (
- not self.hierarchical
- ), "GraphLAM does not use a hierarchical mesh graph"
+ assert not self.hierarchical, "GraphLAM does not use a hierarchical mesh graph"
# grid_dim from data + static + batch_static
mesh_dim = self.mesh_static_features.shape[1]
@@ -56,8 +54,8 @@ def __init__(self, args, datastore, forcing_window_size):
)
def get_num_mesh(self):
- """Compute number of mesh nodes from loaded features, and number of
- mesh nodes that should be ignored in encoding/decoding."""
+ """Compute number of mesh nodes from loaded features, and number of mesh nodes
+ that should be ignored in encoding/decoding."""
return self.mesh_static_features.shape[0], 0
def embedd_mesh_nodes(self):
@@ -65,20 +63,17 @@ def embedd_mesh_nodes(self):
return self.mesh_embedder(self.mesh_static_features) # (N_mesh, d_h)
def process_step(self, mesh_rep):
- """Process step of embedd-process-decode framework Processes the
- representation on the mesh, possible in multiple steps.
+ """Process step of embedd-process-decode framework Processes the representation
+ on the mesh, possible in multiple steps.
mesh_rep: has shape (B, N_mesh, d_h)
Returns mesh_rep: (B, N_mesh, d_h)
+
"""
# Embed m2m here first
batch_size = mesh_rep.shape[0]
m2m_emb = self.m2m_embedder(self.m2m_features) # (M_mesh, d_h)
- m2m_emb_expanded = self.expand_to_batch(
- m2m_emb, batch_size
- ) # (B, M_mesh, d_h)
+ m2m_emb_expanded = self.expand_to_batch(m2m_emb, batch_size) # (B, M_mesh, d_h)
- mesh_rep, _ = self.processor(
- mesh_rep, m2m_emb_expanded
- ) # (B, N_mesh, d_h)
+ mesh_rep, _ = self.processor(mesh_rep, m2m_emb_expanded) # (B, N_mesh, d_h)
return mesh_rep
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index a7d55ba0..95300185 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -7,10 +7,11 @@
class HiLAM(BaseHiGraphModel):
- """Hierarchical graph model with message passing that goes sequentially
- down and up the hierarchy during processing.
+ """Hierarchical graph model with message passing that goes sequentially down and up
+ the hierarchy during processing.
The Hi-LAM model from Oskarsson et al. (2023)
+
"""
def __init__(self, args):
@@ -79,8 +80,8 @@ def mesh_down_step(
down_gnns,
same_gnns,
):
- """Run down-part of vertical processing, sequentially alternating
- between processing using down edges and same-level edges."""
+ """Run down-part of vertical processing, sequentially alternating between
+ processing using down edges and same-level edges."""
# Run same level processing on level L
mesh_rep_levels[-1], mesh_same_rep[-1] = same_gnns[-1](
mesh_rep_levels[-1], mesh_rep_levels[-1], mesh_same_rep[-1]
@@ -93,9 +94,7 @@ def mesh_down_step(
reversed(same_gnns[:-1]),
):
# Extract representations
- send_node_rep = mesh_rep_levels[
- level_l + 1
- ] # (B, N_mesh[l+1], d_h)
+ send_node_rep = mesh_rep_levels[level_l + 1] # (B, N_mesh[l+1], d_h)
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
down_edge_rep = mesh_down_rep[level_l]
same_edge_rep = mesh_same_rep[level_l]
@@ -129,9 +128,7 @@ def mesh_up_step(
zip(up_gnns, same_gnns[1:]), start=1
):
# Extract representations
- send_node_rep = mesh_rep_levels[
- level_l - 1
- ] # (B, N_mesh[l-1], d_h)
+ send_node_rep = mesh_rep_levels[level_l - 1] # (B, N_mesh[l-1], d_h)
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
up_edge_rep = mesh_up_rep[level_l - 1]
same_edge_rep = mesh_same_rep[level_l]
@@ -153,8 +150,8 @@ def mesh_up_step(
def hi_processor_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
):
- """Internal processor step of hierarchical graph models. Between mesh
- init and read out.
+ """Internal processor step of hierarchical graph models. Between mesh init and
+ read out.
Each input is list with representations, each with shape
@@ -164,6 +161,7 @@ def hi_processor_step(
mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
+
"""
for down_gnns, down_same_gnns, up_gnns, up_same_gnns in zip(
self.mesh_down_gnns,
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index fe8152b3..26357281 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -8,11 +8,11 @@
class HiLAMParallel(BaseHiGraphModel):
- """Version of HiLAM where all message passing in the hierarchical mesh (up,
- down, inter-level) is ran in parallel.
+ """Version of HiLAM where all message passing in the hierarchical mesh (up, down,
+ inter-level) is ran in parallel.
+
+ This is a somewhat simpler alternative to the sequential message passing of Hi-LAM.
- This is a somewhat simpler alternative to the sequential message
- passing of Hi-LAM.
"""
def __init__(self, args):
@@ -52,8 +52,8 @@ def __init__(self, args):
def hi_processor_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
):
- """Internal processor step of hierarchical graph models. Between mesh
- init and read out.
+ """Internal processor step of hierarchical graph models. Between mesh init and
+ read out.
Each input is list with representations, each with shape
@@ -63,6 +63,7 @@ def hi_processor_step(
mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
+
"""
# First join all node and edge representations to single tensors
@@ -75,9 +76,7 @@ def hi_processor_step(
mesh_rep, mesh_edge_rep = self.processor(mesh_rep, mesh_edge_rep)
# Split up again for read-out step
- mesh_rep_levels = list(
- torch.split(mesh_rep, self.level_mesh_sizes, dim=1)
- )
+ mesh_rep_levels = list(torch.split(mesh_rep, self.level_mesh_sizes, dim=1))
mesh_edge_rep_sections = torch.split(
mesh_edge_rep, self.edge_split_sections, dim=1
)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 4a69f1aa..4f011b76 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -38,9 +38,7 @@ def _init_datastore(datastore_kind, config_path):
def main(input_args=None):
"""Main function for training and evaluating models."""
- parser = ArgumentParser(
- description="Train or evaluate NeurWP models for LAM"
- )
+ parser = ArgumentParser(description="Train or evaluate NeurWP models for LAM")
parser.add_argument(
"datastore_kind",
type=str,
@@ -85,8 +83,7 @@ def main(input_args=None):
"--restore_opt",
type=int,
default=0,
- help="If optimizer state should be restored with model "
- "(default: 0 (false))",
+ help="If optimizer state should be restored with model " "(default: 0 (false))",
)
parser.add_argument(
"--precision",
@@ -100,8 +97,7 @@ def main(input_args=None):
"--graph",
type=str,
default="multiscale",
- help="Graph to load and use in graph-based model "
- "(default: multiscale)",
+ help="Graph to load and use in graph-based model " "(default: multiscale)",
)
parser.add_argument(
"--hidden_dim",
@@ -149,8 +145,7 @@ def main(input_args=None):
"--control_only",
type=int,
default=0,
- help="Train only on control member of ensemble data "
- "(default: 0 (False))",
+ help="Train only on control member of ensemble data " "(default: 0 (False))",
)
parser.add_argument(
"--loss",
@@ -165,8 +160,7 @@ def main(input_args=None):
"--val_interval",
type=int,
default=1,
- help="Number of epochs training between each validation run "
- "(default: 1)",
+ help="Number of epochs training between each validation run " "(default: 1)",
)
# Evaluation options
@@ -187,8 +181,7 @@ def main(input_args=None):
"--n_example_pred",
type=int,
default=1,
- help="Number of example predictions to plot during evaluation "
- "(default: 1)",
+ help="Number of example predictions to plot during evaluation " "(default: 1)",
)
# Logger Settings
@@ -261,9 +254,7 @@ def main(input_args=None):
# Instantiate model + trainer
if torch.cuda.is_available():
device_name = "cuda"
- torch.set_float32_matmul_precision(
- "high"
- ) # Allows using Tensor Cores on A100s
+ torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s
else:
device_name = "cpu"
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 79de3193..2ebe7b4d 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -9,11 +9,12 @@
class BufferList(nn.Module):
- """A list of torch buffer tensors that sit together as a Module with no
- parameters and only buffers.
+ """A list of torch buffer tensors that sit together as a Module with no parameters
+ and only buffers.
This should be replaced by a native torch BufferList once implemented.
See: https://github.com/pytorch/pytorch/issues/37386
+
"""
def __init__(self, buffer_tensors, persistent=True):
@@ -74,6 +75,7 @@ def load_graph(graph_dir_path, device="cpu"):
- mesh_up_features
- mesh_down_features
- mesh_static_features
+
"""
def loads_file(fn):
@@ -112,9 +114,7 @@ def loads_file(fn):
) # List of (N_mesh[l], d_mesh_static)
# Some checks for consistency
- assert (
- len(m2m_features) == n_levels
- ), "Inconsistent number of levels in mesh"
+ assert len(m2m_features) == n_levels, "Inconsistent number of levels in mesh"
assert (
len(mesh_static_features) == n_levels
), "Inconsistent number of levels in mesh"
@@ -137,23 +137,15 @@ def loads_file(fn):
# Rescale
mesh_up_features = BufferList(
- [
- edge_features / longest_edge
- for edge_features in mesh_up_features
- ],
+ [edge_features / longest_edge for edge_features in mesh_up_features],
persistent=False,
)
mesh_down_features = BufferList(
- [
- edge_features / longest_edge
- for edge_features in mesh_down_features
- ],
+ [edge_features / longest_edge for edge_features in mesh_down_features],
persistent=False,
)
- mesh_static_features = BufferList(
- mesh_static_features, persistent=False
- )
+ mesh_static_features = BufferList(mesh_static_features, persistent=False)
else:
# Extract single mesh level
m2m_edge_index = m2m_edge_index[0]
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 98e066c4..e5c970c4 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -53,9 +53,7 @@ def plot_error_map(
ax.set_yticks(np.arange(d_f))
var_names = datastore.get_vars_names(category="state")
var_units = datastore.get_vars_units(category="state")
- y_ticklabels = [
- f"{name} ({unit})" for name, unit in zip(var_names, var_units)
- ]
+ y_ticklabels = [f"{name} ({unit})" for name, unit in zip(var_names, var_units)]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)
if title:
@@ -76,6 +74,7 @@ def plot_prediction(
"""Plot example prediction and grond truth.
Each has shape (N_grid,)
+
"""
# Get common scale for values
if vrange is None:
@@ -89,9 +88,7 @@ def plot_prediction(
# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
- pixel_alpha = (
- mask_reshaped.clamp(0.7, 1).cpu().numpy()
- ) # Faded border region
+ pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region
fig, axes = plt.subplots(
1,
@@ -104,9 +101,7 @@ def plot_prediction(
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
data_grid = (
- data.reshape(list(datastore.grid_shape_state.values.values()))
- .cpu()
- .numpy()
+ data.reshape(list(datastore.grid_shape_state.values.values())).cpu().numpy()
)
im = ax.imshow(
data_grid,
@@ -143,12 +138,8 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
extent = data_config.get_xy_extent("state")
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(
- list(data_config.grid_shape_state.values.values())
- )
- pixel_alpha = (
- mask_reshaped.clamp(0.7, 1).cpu().numpy()
- ) # Faded border region
+ mask_reshaped = obs_mask.reshape(list(data_config.grid_shape_state.values.values()))
+ pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region
fig, ax = plt.subplots(
figsize=(5, 4.8),
@@ -157,9 +148,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
ax.coastlines() # Add coastline outlines
error_grid = (
- error.reshape(list(data_config.grid_shape_state.values.values()))
- .cpu()
- .numpy()
+ error.reshape(list(data_config.grid_shape_state.values.values())).cpu().numpy()
)
im = ax.imshow(
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 5ba1d326..a8213922 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -15,6 +15,7 @@ class WeatherDataset(torch.utils.data.Dataset):
"""Dataset class for weather data.
This class loads and processes weather data from a given datastore.
+
"""
def __init__(
@@ -31,9 +32,7 @@ def __init__(
self.ar_steps = ar_steps
self.datastore = datastore
- self.da_state = self.datastore.get_dataarray(
- category="state", split=self.split
- )
+ self.da_state = self.datastore.get_dataarray(category="state", split=self.split)
self.da_forcing = self.datastore.get_dataarray(
category="forcing", split=self.split
)
@@ -61,10 +60,8 @@ def __init__(
self.da_state_std = self.ds_state_stats.state_std
if self.da_forcing is not None:
- self.ds_forcing_stats = (
- self.datastore.get_normalization_dataarray(
- category="forcing"
- )
+ self.ds_forcing_stats = self.datastore.get_normalization_dataarray(
+ category="forcing"
)
self.da_forcing_mean = self.ds_forcing_stats.forcing_mean
self.da_forcing_std = self.ds_forcing_stats.forcing_std
@@ -91,11 +88,23 @@ def __len__(self):
return len(self.da_state.time) - self.ar_steps - 1
def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
- """Produce a time slice of the given dataarray `da` (state or forcing)
- starting at `idx` and with `n_steps` steps. The `n_timesteps_offset`
- parameter is used to offset the start of the sample, for example to
- exclude the first two steps when sampling the forcing data (and to
- produce the windowing samples of forcing data by increasing the offset
+ """Produce a time
+ slice of the given
+ dataarray `da` (state
+ or forcing) starting
+ at `idx` and with
+ `n_steps` steps. The
+ `n_timesteps_offset`
+ parameter is used to
+ offset the start of
+ the sample, for
+ example to exclude the
+ first two steps when
+ sampling the forcing
+ data (and to produce
+ the windowing samples
+ of forcing data by
+ increasing the offset
for each window).
Parameters
@@ -109,6 +118,7 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
The index of the time step to start the sample from.
n_steps : int
The number of time steps to include in the sample.
+
"""
# selecting the time slice
if self.datastore.is_forecast:
@@ -129,15 +139,13 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
else:
# only `time` dimension for analysis only data
da = da.isel(
- time=slice(
- idx + n_timesteps_offset, idx + n_steps + n_timesteps_offset
- )
+ time=slice(idx + n_timesteps_offset, idx + n_steps + n_timesteps_offset)
)
return da
def __getitem__(self, idx):
- """Return a single training sample, which consists of the initial
- states, target states, forcing and batch times.
+ """Return a single training sample, which consists of the initial states, target
+ states, forcing and batch times.
The implementation currently uses xarray.DataArray objects for the
normalisation so that we can make us of xarray's broadcasting
@@ -158,6 +166,7 @@ def __getitem__(self, idx):
A training sample object containing the initial states, target
states, forcing and batch times. The batch times are the times of
the target steps.
+
"""
# handling ensemble data
if self.datastore.is_ensemble:
@@ -182,9 +191,7 @@ def __getitem__(self, idx):
# handle time sampling in a way that is compatible with both analysis
# and forecast data
- da_state = self._sample_time(
- da=da_state, idx=idx, n_steps=2 + self.ar_steps
- )
+ da_state = self._sample_time(da=da_state, idx=idx, n_steps=2 + self.ar_steps)
if da_forcing is not None:
das_forcing = []
@@ -219,9 +226,7 @@ def __getitem__(self, idx):
batch_times = da_target_states.time.values.astype(float)
if self.standardize:
- da_init_states = (
- da_init_states - self.da_state_mean
- ) / self.da_state_std
+ da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std
da_target_states = (
da_target_states - self.da_state_mean
) / self.da_state_std
@@ -239,9 +244,7 @@ def __getitem__(self, idx):
)
init_states = torch.tensor(da_init_states.values, dtype=torch.float32)
- target_states = torch.tensor(
- da_target_states.values, dtype=torch.float32
- )
+ target_states = torch.tensor(da_target_states.values, dtype=torch.float32)
if self.da_forcing is None:
# create an empty forcing tensor
@@ -250,9 +253,7 @@ def __getitem__(self, idx):
dtype=torch.float32,
)
else:
- forcing = torch.tensor(
- da_forcing_windowed.values, dtype=torch.float32
- )
+ forcing = torch.tensor(da_forcing_windowed.values, dtype=torch.float32)
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps, N_grid, d_features)
@@ -264,8 +265,9 @@ def __getitem__(self, idx):
def __iter__(self):
"""Convenience method to iterate over the dataset.
- This isn't used by pytorch DataLoader which itself implements an
- iterator that uses Dataset.__getitem__ and Dataset.__len__.
+ This isn't used by pytorch DataLoader which itself implements an iterator that
+ uses Dataset.__getitem__ and Dataset.__len__.
+
"""
for i in range(len(self)):
yield self[i]
diff --git a/plot_graph.py b/plot_graph.py
index b7b710bf..e84bb627 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -64,9 +64,7 @@ def main():
# Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
- grid_pos = np.concatenate(
- (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1
- )
+ grid_pos = np.concatenate((grid_pos, np.expand_dims(z_grid, axis=1)), axis=1)
# List of edges to plot, (edge_index, color, line_width, label)
edge_plot_list = [
@@ -118,9 +116,7 @@ def main():
z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees
mesh_node_size = mesh_degrees / 2
- mesh_pos = np.concatenate(
- (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1
- )
+ mesh_pos = np.concatenate((mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1)
edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M"))
diff --git a/pyproject.toml b/pyproject.toml
index e661ff46..1c86119c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,9 +45,6 @@ dev = [
[tool.setuptools]
py-modules = ["neural_lam"]
-[tool.black]
-line-length = 80
-
[tool.isort]
default_section = "THIRDPARTY"
profile = "black"
@@ -70,7 +67,7 @@ known_first_party = [
]
[tool.flake8]
-max-line-length = 80
+max-line-length = 88
ignore = [
"E203", # Allow whitespace before ':' (https://github.com/PyCQA/pycodestyle/issues/373)
"I002", # Don't check for isort configuration
@@ -114,3 +111,9 @@ min-similarity-lines=10
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
+
+
+[tool.docformatter]
+recursive = true
+blank = true
+black = true
diff --git a/tests/conftest.py b/tests/conftest.py
index 1f4edd1a..c8afc109 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -66,14 +66,15 @@ def download_meps_example_reduced_dataset():
def bootstrap_multizarr_example():
- """Run the steps that are needed to prepare the input data for the
- multizarr datastore example. This includes:
+ """Run the steps that are needed to prepare the input data for the multizarr
+ datastore example. This includes:
- Downloading the two zarr datasets (since training directly from S3 is
error-prone as the connection often breaks)
- Creating the datetime forcings zarr
- Creating the normalization stats zarr
- Creating the boundary mask zarr
+
"""
multizarr_path = DATASTORE_EXAMPLES_ROOT_PATH / "multizarr"
n_boundary_cells = 10
@@ -104,8 +105,7 @@ def bootstrap_multizarr_example():
# here assume that the data-config is referring the the default path
# for the "datetime forcings" dataset
datetime_forcing_zarr_path = (
- data_config_path.parent
- / multizarr.create_datetime_forcings.DEFAULT_FILENAME
+ data_config_path.parent / multizarr.create_datetime_forcings.DEFAULT_FILENAME
)
if not datetime_forcing_zarr_path.exists():
multizarr.create_datetime_forcings.create_datetime_forcing_zarr(
@@ -113,8 +113,7 @@ def bootstrap_multizarr_example():
)
normalized_forcing_zarr_path = (
- data_config_path.parent
- / multizarr.create_normalization_stats.DEFAULT_FILENAME
+ data_config_path.parent / multizarr.create_normalization_stats.DEFAULT_FILENAME
)
if not normalized_forcing_zarr_path.exists():
multizarr.create_normalization_stats.create_normalization_stats_zarr(
@@ -122,8 +121,7 @@ def bootstrap_multizarr_example():
)
boundary_mask_path = (
- data_config_path.parent
- / multizarr.create_boundary_mask.DEFAULT_FILENAME
+ data_config_path.parent / multizarr.create_boundary_mask.DEFAULT_FILENAME
)
if not boundary_mask_path.exists():
@@ -139,9 +137,7 @@ def bootstrap_multizarr_example():
DATASTORES_EXAMPLES = dict(
multizarr=dict(config_path=bootstrap_multizarr_example()),
mllam=dict(
- config_path=DATASTORE_EXAMPLES_ROOT_PATH
- / "mllam"
- / "danra.example.yaml"
+ config_path=DATASTORE_EXAMPLES_ROOT_PATH / "mllam" / "danra.example.yaml"
),
npyfiles=dict(config_path=download_meps_example_reduced_dataset()),
)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 19ca1ed8..0dbd04a1 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -5,8 +5,8 @@
def test_import():
- """This test just ensures that each cli entry-point can be imported for
- now, eventually we should test their execution too."""
+ """This test just ensures that each cli entry-point can be imported for now,
+ eventually we should test their execution too."""
assert neural_lam is not None
assert neural_lam.create_graph is not None
assert neural_lam.train_model is not None
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 7e73f787..8ae9d917 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -15,7 +15,7 @@
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_dataset_item(datastore_name):
- """Check that the `datastore.get_dataarray` method is implemented.
+ """Check that the `datasto re.get_dataarray` method is implemented.
Validate the shapes of the tensors match between the different
components of the training sample.
@@ -23,6 +23,7 @@ def test_dataset_item(datastore_name):
init_states: (2, N_grid, d_features)
target_states: (ar_steps, N_grid, d_features)
forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,)
+
"""
datastore = init_datastore(datastore_name)
N_gridpoints = datastore.grid_shape_state.x * datastore.grid_shape_state.y
@@ -59,8 +60,7 @@ def test_dataset_item(datastore_name):
assert forcing.shape[0] == N_pred_steps
assert forcing.shape[1] == N_gridpoints
assert (
- forcing.shape[2]
- == datastore.get_num_data_vars("forcing") * forcing_window_size
+ forcing.shape[2] == datastore.get_num_data_vars("forcing") * forcing_window_size
)
# batch times
@@ -76,15 +76,14 @@ def test_dataset_item(datastore_name):
@pytest.mark.parametrize("split", ["train", "val", "test"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_single_batch(datastore_name, split):
- """Check that the `datastore.get_dataarray` method is implemented.
+ """Check that the `datasto re.get_dataarray` method is implemented.
And that it returns an xarray DataArray with the correct dimensions.
+
"""
datastore = init_datastore(datastore_name)
- device_name = ( # noqa
- torch.device("cuda") if torch.cuda.is_available() else "cpu"
- )
+ device_name = torch.device("cuda") if torch.cuda.is_available() else "cpu" # noqa
graph_name = "1level"
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 198d4460..319c5a7c 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -23,6 +23,7 @@
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
+
"""
# Standard library
@@ -57,9 +58,9 @@ def test_step_length(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_datastore_grid_xy(datastore_name):
- """Use the `datastore.get_xy` method to get the x, y coordinates of the
- dataset and check that the shape is correct against the
- `datastore.grid_shape_state` property."""
+ """Use the `datastore.get_xy` method to get the x, y coordinates of the dataset and
+ check that the shape is correct against the `da tastore.grid_shape_state`
+ property."""
datastore = init_datastore(datastore_name)
# check the shapes of the xy grid
@@ -87,6 +88,7 @@ def test_get_vars(datastore_name):
are consistent (as in the number of variables are the same) and that the
return types of each are correct.
+
"""
datastore = init_datastore(datastore_name)
@@ -103,7 +105,7 @@ def test_get_vars(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_normalization_dataarray(datastore_name):
- """Check that the `datastore.get_normalization_dataarray` method is
+ """Check that the `datasto re.get_normalization_dataa rray` method is
implemented."""
datastore = init_datastore(datastore_name)
@@ -132,9 +134,10 @@ def test_get_normalization_dataarray(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_dataarray(datastore_name):
- """Check that the `datastore.get_dataarray` method is implemented.
+ """Check that the `datasto re.get_dataarray` method is implemented.
And that it returns an xarray DataArray with the correct dimensions.
+
"""
datastore = init_datastore(datastore_name)
@@ -176,8 +179,8 @@ def test_get_dataarray(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_boundary_mask(datastore_name):
- """Check that the `datastore.boundary_mask` property is implemented and
- that the returned object is an xarray DataArray with the correct shape."""
+ """Check that the `datastore.boundary_mask` property is implemented and that the
+ returned object is an xarray DataArray with the correct shape."""
datastore = init_datastore(datastore_name)
da_mask = datastore.boundary_mask
@@ -195,8 +198,8 @@ def test_boundary_mask(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_xy_extent(datastore_name):
- """Check that the `datastore.get_xy_extent` method is implemented and that
- the returned object is a tuple of the correct length."""
+ """Check that the `datastore.get_xy_extent` method is implemented and that the
+ returned object is a tuple of the correct length."""
datastore = init_datastore(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
@@ -247,7 +250,7 @@ def test_get_xy(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_projection(datastore_name):
- """Check that the `datastore.coords_projection` property is implemented."""
+ """Check that the `datasto re.coords_projection` property is implemented."""
datastore = init_datastore(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
@@ -258,7 +261,7 @@ def test_get_projection(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def get_grid_shape_state(datastore_name):
- """Check that the `datastore.grid_shape_state` property is implemented."""
+ """Check that the `datasto re.grid_shape_state` property is implemented."""
datastore = init_datastore(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py
index 6384c46f..652e3dce 100644
--- a/tests/test_graph_creation.py
+++ b/tests/test_graph_creation.py
@@ -14,9 +14,10 @@
@pytest.mark.parametrize("graph_name", ["1level", "multiscale", "hierarchical"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_graph_creation(datastore_name, graph_name):
- """Check that the `create_graph_from_datastore` function is implemented.
+ """Check that the `create_ graph_from_datastore` function is implemented.
And that the graph is created in the correct location.
+
"""
datastore = init_datastore(datastore_name)
if graph_name == "hierarchical":
@@ -80,9 +81,7 @@ def test_graph_creation(datastore_name, graph_name):
assert isinstance(result, torch.Tensor)
if file_id.endswith("_index"):
- assert (
- result.shape[0] == 2
- ) # adjacency matrix uses two rows
+ assert result.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert result.shape[1] == d_features
@@ -91,9 +90,7 @@ def test_graph_creation(datastore_name, graph_name):
if not hierarchical:
assert len(result) == 1
else:
- if file_id.startswith("mesh_up") or file_id.startswith(
- "mesh_down"
- ):
+ if file_id.startswith("mesh_up") or file_id.startswith("mesh_down"):
assert len(result) == n_max_levels - 1
else:
assert len(result) == n_max_levels
diff --git a/tests/test_training.py b/tests/test_training.py
index ee532656..94b36980 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -20,9 +20,7 @@ def test_training(datastore_name):
if torch.cuda.is_available():
device_name = "cuda"
- torch.set_float32_matmul_precision(
- "high"
- ) # Allows using Tensor Cores on A100s
+ torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s
else:
device_name = "cpu"
From 57bbb81364f117df3281ff2368dc7b5e906764d9 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 13:56:57 +0000
Subject: [PATCH 170/273] train only 1 epoch in cicd and print to stdout
---
.github/workflows/ci-pdm-install-and-test-cpu.yml | 2 +-
.github/workflows/ci-pdm-install-and-test-gpu.yml | 2 +-
.github/workflows/ci-pip-install-and-test-cpu.yml | 2 +-
.github/workflows/ci-pip-install-and-test-gpu.yml | 2 +-
tests/test_training.py | 2 +-
5 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/.github/workflows/ci-pdm-install-and-test-cpu.yml b/.github/workflows/ci-pdm-install-and-test-cpu.yml
index 7d31f867..a1734a2c 100644
--- a/.github/workflows/ci-pdm-install-and-test-cpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-cpu.yml
@@ -47,7 +47,7 @@ jobs:
- name: Run tests
run: |
- pdm run pytest
+ pdm run pytest -s tests/
- name: Save cache data
uses: actions/cache/save@v4
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
index 94e740ce..628c082c 100644
--- a/.github/workflows/ci-pdm-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -52,7 +52,7 @@ jobs:
- name: Run tests
run: |
- pdm run pytest
+ pdm run pytest -s tests/
- name: Save cache data
uses: actions/cache/save@v4
diff --git a/.github/workflows/ci-pip-install-and-test-cpu.yml b/.github/workflows/ci-pip-install-and-test-cpu.yml
index c94e70c2..f7741a35 100644
--- a/.github/workflows/ci-pip-install-and-test-cpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-cpu.yml
@@ -40,7 +40,7 @@ jobs:
- name: Run tests
run: |
- python -m pytest
+ python -m pytest -s tests/
- name: Save cache data
uses: actions/cache/save@v4
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index 4dfc98c8..8cd516f4 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -42,7 +42,7 @@ jobs:
- name: Run tests
run: |
- python -m pytest
+ python -m pytest -s tests/
- name: Save cache data
uses: actions/cache/save@v4
diff --git a/tests/test_training.py b/tests/test_training.py
index 94b36980..19d48e3a 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -25,7 +25,7 @@ def test_training(datastore_name):
device_name = "cpu"
trainer = pl.Trainer(
- max_epochs=3,
+ max_epochs=1,
deterministic=True,
accelerator=device_name,
# XXX: `devices` has to be set to 2 otherwise
From a955ceecf1a7b49b81a2b4f381cafde9727733eb Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 14 Aug 2024 14:28:09 +0000
Subject: [PATCH 171/273] log datastore config
---
neural_lam/datastore/base.py | 14 ++++++++
neural_lam/datastore/mllam.py | 12 +++++++
neural_lam/datastore/multizarr/store.py | 12 +++++++
neural_lam/datastore/npyfiles/store.py | 46 +++++++++++--------------
neural_lam/train_model.py | 4 ++-
tests/test_datastores.py | 14 ++++++++
6 files changed, 75 insertions(+), 27 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 1b662fa4..480476fe 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -1,5 +1,6 @@
# Standard library
import abc
+import collections
import dataclasses
from pathlib import Path
from typing import List, Union
@@ -59,6 +60,19 @@ def root_path(self) -> Path:
"""
pass
+ @property
+ @abc.abstractmethod
+ def config(self) -> collections.abc.Mapping:
+ """The configuration of the datastore.
+
+ Returns
+ -------
+ collections.abc.Mapping
+ The configuration of the datastore, any dict like object can be returned.
+
+ """
+ pass
+
@property
@abc.abstractmethod
def step_length(self) -> int:
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 0d011e5e..d22f041a 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -86,6 +86,18 @@ def root_path(self) -> Path:
"""
return self._root_path
+ @property
+ def config(self) -> mdp.Config:
+ """The configuration of the dataset.
+
+ Returns
+ -------
+ mdp.Config
+ The configuration of the dataset.
+
+ """
+ return self._config
+
@property
def step_length(self) -> int:
"""The length of the time steps in hours.
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 18af8457..23b33fe2 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -55,6 +55,18 @@ def root_path(self):
"""
return self._root_path
+ @property
+ def config(self) -> dict:
+ """Return the configuration dictionary.
+
+ Returns
+ -------
+ dict
+ The configuration dictionary.
+
+ """
+ return self._config
+
def _normalize_path(self, path) -> str:
"""
Normalize the path of source-dataset defined in the configuration file.
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 630a8dd0..cff20043 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -138,17 +138,9 @@ def __init__(
self,
config_path,
):
- """Create a new
- NpyFilesDatastore
- using the
- configuration file at
- the given path. The
- config file should be
- a YAML file and will
- be loaded into an
- instance of the
- `NpyDatastoreConfig`
- dataclass.
+ """Create a new NpyFilesDatastore using the configuration file at the given
+ path. The config file should be a YAML file and will be loaded into an instance
+ of the `NpyDatastoreConfig` dataclass.
Internally, the datastore uses dask.delayed to load the data from the
numpy files, so that the data isn't actually loaded until it's needed.
@@ -166,7 +158,7 @@ def __init__(
self._config_path = Path(config_path)
self._root_path = self._config_path.parent
- self.config = NpyDatastoreConfig.from_yaml_file(self._config_path)
+ self._config = NpyDatastoreConfig.from_yaml_file(self._config_path)
@property
def root_path(self) -> Path:
@@ -181,21 +173,23 @@ def root_path(self) -> Path:
"""
return self._root_path
+ @property
+ def config(self) -> NpyDatastoreConfig:
+ """The configuration for the datastore.
+
+ Returns
+ -------
+ NpyDatastoreConfig
+ The configuration for the datastore.
+
+ """
+ return self._config
+
def get_dataarray(self, category: str, split: str) -> DataArray:
- """Get the data array
- for the given category
- and split of data. If
- the category is
- 'state', the data
- array will be a
- concatenation of the
- data arrays for all
- ensemble members. The
- data will be loaded as
- a dask array, so that
- the data isn't
- actually loaded until
- it's needed.
+ """Get the data array for the given category and split of data. If the category
+ is 'state', the data array will be a concatenation of the data arrays for all
+ ensemble members. The data will be loaded as a dask array, so that the data
+ isn't actually loaded until it's needed.
Parameters
----------
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 4f011b76..e819c403 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -280,7 +280,9 @@ def main(input_args=None):
save_last=True,
)
logger = pl.loggers.WandbLogger(
- project=args.wandb_project, name=run_name, config=args
+ project=args.wandb_project,
+ name=run_name,
+ config=dict(training=vars(args), datastore=datastore._config),
)
trainer = pl.Trainer(
max_epochs=args.epochs,
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 319c5a7c..512bc5a0 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -16,6 +16,7 @@
`xr.DataArray`) for the given category and test/train/val-split.
- [x] `boundary_mask` (property): Return the boundary mask for the dataset,
with spatial dimensions stacked.
+- [x] `config` (property): Return the configuration of the datastore.
In addition BaseCartesianDatastore must have the following methods and attributes:
- [x] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
@@ -27,6 +28,8 @@
"""
# Standard library
+import collections
+import dataclasses
from pathlib import Path
# Third-party
@@ -47,6 +50,17 @@ def test_root_path(datastore_name):
assert isinstance(datastore.root_path, Path)
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_config(datastore_name):
+ """Check that the `datastore.config` property is implemented."""
+ datastore = init_datastore(datastore_name)
+ # check the config is a mapping or a dataclass
+ config = datastore.config
+ assert isinstance(config, collections.abc.Mapping) or dataclasses.is_dataclass(
+ config
+ )
+
+
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_step_length(datastore_name):
"""Check that the `datastore.step_length` property is implemented."""
From 0a79c74cb1d6bf3e29b999188c0c6e4d9f95078b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 15 Aug 2024 13:25:25 +0000
Subject: [PATCH 172/273] cleanup doctrings
---
neural_lam/datastore/base.py | 14 +---
neural_lam/datastore/mllam.py | 95 +++++++------------------
neural_lam/datastore/multizarr/store.py | 51 +++----------
neural_lam/datastore/npyfiles/store.py | 40 +++--------
4 files changed, 46 insertions(+), 154 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 480476fe..b19cbf23 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -12,17 +12,9 @@
class BaseDatastore(abc.ABC):
- """Base class for weather
- data used in the neural-
- lam package. A datastore
- defines the interface for
- accessing weather data by
- providing methods to
- access the data in a
- processed format that can
- be used for training and
- evaluation of neural
- networks.
+ """Base class for weather data used in the neural- lam package. A datastore defines
+ the interface for accessing weather data by providing methods to access the data in
+ a processed format that can be used for training and evaluation of neural networks.
NOTE: All methods return either primitive types, `numpy.ndarray`,
`xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor`
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index d22f041a..5e44837a 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -19,24 +19,11 @@ class MLLAMDatastore(BaseCartesianDatastore):
"""Datastore class for the MLLAM dataset."""
def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
- """Construct a new
- MLLAMDatastore from
- the configuration file
- at `config_path`. A
- boundary mask is
- created with
- `n_boundary_points`
- boundary points. If
- `reuse_existing` is
- True, the dataset is
- loaded from a zarr
- file if it exists
- (unless the config has
- been modified since
- the zarr was created),
- otherwise it is
- created from the
- configuration file.
+ """Construct a new MLLAMDatastore from the configuration file at `config_path`.
+ A boundary mask is created with `n_boundary_points` boundary points. If
+ `reuse_existing` is True, the dataset is loaded from a zarr file if it exists
+ (unless the config has been modified since the zarr was created), otherwise it
+ is created from the configuration file.
Parameters
----------
@@ -74,6 +61,11 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points
+ print("Training with the following features:")
+ for category in ["state", "forcing", "static"]:
+ if len(self.get_vars_names(category)) > 0:
+ print(f"{category}: {' '.join(self.get_vars_names(category))}")
+
@property
def root_path(self) -> Path:
"""The root path of the dataset.
@@ -166,24 +158,11 @@ def get_num_data_vars(self, category: str) -> int:
return len(self.get_vars_names(category))
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
- """Return the
- processed data (as a
- single `xr.DataArray`)
- for the given category
- of data and
- test/train/val-split
- that covers all the
- data (in space and
- time) of a given
- category (state/forcin
- g/static). "state" is
- the only required
- category, for other
- categories, the method
- will return `None` if
- the category is not
- found in the
- datastore.
+ """Return the processed data (as a single `xr.DataArray`) for the given category
+ of data and test/train/val-split that covers all the data (in space and time) of
+ a given category (state/forcin g/static). "state" is the only required category,
+ for other categories, the method will return `None` if the category is not found
+ in the datastore.
The returned dataarray will at minimum have dimensions of `(grid_index,
{category}_feature)` so that any spatial dimensions have been stacked
@@ -236,23 +215,10 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
return da_category.sel(time=slice(t_start, t_end))
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the
- normalization
- dataarray for the
- given category. This
- should contain a
- `{category}_mean` and
- `{category}_std`
- variable for each
- variable in the
- category. For
- `category=="state"`,
- the dataarray should
- also contain a
- `state_diff_mean` and
- `state_diff_std`
- variable for the one-
- step differences of
+ """Return the normalization dataarray for the given category. This should
+ contain a `{category}_mean` and `{category}_std` variable for each variable in
+ the category. For `category=="state"`, the dataarray should also contain a
+ `state_diff_mean` and `state_diff_std` variable for the one- step differences of
the state variables.
Parameters
@@ -283,24 +249,11 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
@property
def boundary_mask(self) -> xr.DataArray:
- """Produce a 0/1 mask
- for the boundary
- points of the dataset,
- these will sit at the
- edges of the domain
- (in x/y extent) and
- will be used to mask
- out the boundary
- points from the loss
- function and to
- overwrite the boundary
- points from the
- prediction. For now
- this is created when
- the mask is requested,
- but in the future this
- could be saved to the
- zarr file.
+ """Produce a 0/1 mask for the boundary points of the dataset, these will sit at
+ the edges of the domain (in x/y extent) and will be used to mask out the
+ boundary points from the loss function and to overwrite the boundary points from
+ the prediction. For now this is created when the mask is requested, but in the
+ future this could be saved to the zarr file.
Returns
-------
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 23b33fe2..ebcc65e8 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -18,19 +18,10 @@ class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable_name"}
def __init__(self, config_path):
- """Create a multi-zarr
- datastore from the
- given configuration
- file. The
- configuration file
- should be a YAML file,
- the format of which is
- should be inferred
- from the example
- configuration file in
- `tests/datastore_examp
- les/multizarr/data_con
- fig.yml`.
+ """Create a multi-zarr datastore from the given configuration file. The
+ configuration file should be a YAML file, the format of which is should be
+ inferred from the example configuration file in `tests/datastore_examp
+ les/multizarr/data_con fig.yml`.
Parameters
----------
@@ -390,33 +381,13 @@ def get_xy(self, category, stacked=True):
@functools.lru_cache()
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the
- normalization
- dataarray for the
- given category. This
- should contain a
- `{category}_mean` and
- `{category}_std`
- variable for each
- variable in the
- category. For
- `category=="state"`,
- the dataarray should
- also contain a
- `state_diff_mean` and
- `state_diff_std`
- variable for the one-
- step differences of
- the state variables.
- The return dataarray
- should at least have
- dimensions of `({categ
- ory}_feature)`, but
- can also include for
- example `grid_index`
- (if the normalisation
- is done per grid point
- for example).
+ """Return the normalization dataarray for the given category. This should
+ contain a `{category}_mean` and `{category}_std` variable for each variable in
+ the category. For `category=="state"`, the dataarray should also contain a
+ `state_diff_mean` and `state_diff_std` variable for the one- step differences of
+ the state variables. The return dataarray should at least have dimensions of
+ `({categ ory}_feature)`, but can also include for example `grid_index` (if the
+ normalisation is done per grid point for example).
Parameters
----------
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index cff20043..ff43a626 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -281,21 +281,10 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
def _get_single_timeseries_dataarray(
self, features: List[str], split: str, member: int = None
) -> DataArray:
- """Get the data array
- spanning the complete
- time series for a
- given set of features
- and split of data. For
- state features the
- `member` argument
- should be specified to
- select the ensemble
- member to load. The
- data will be loaded
- using dask.delayed, so
- that the data isn't
- actually loaded until
- it's needed.
+ """Get the data array spanning the complete time series for a given set of
+ features and split of data. For state features the `member` argument should be
+ specified to select the ensemble member to load. The data will be loaded using
+ dask.delayed, so that the data isn't actually loaded until it's needed.
Parameters
----------
@@ -614,23 +603,10 @@ def boundary_mask(self) -> xr.DataArray:
return da_mask_stacked_xy
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the
- normalization
- dataarray for the
- given category. This
- should contain a
- `{category}_mean` and
- `{category}_std`
- variable for each
- variable in the
- category. For
- `category=="state"`,
- the dataarray should
- also contain a
- `state_diff_mean` and
- `state_diff_std`
- variable for the one-
- step differences of
+ """Return the normalization dataarray for the given category. This should
+ contain a `{category}_mean` and `{category}_std` variable for each variable in
+ the category. For `category=="state"`, the dataarray should also contain a
+ `state_diff_mean` and `state_diff_std` variable for the one- step differences of
the state variables.
Parameters
From 342229806307110d5f245dba7f711463f760e0d1 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 19 Aug 2024 15:54:01 +0200
Subject: [PATCH 173/273] update changelog
---
CHANGELOG.md | 3 +++
1 file changed, 3 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ac7b5226..4ae6c8be 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -101,6 +101,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
anywhere once the package has been installed).
[\#32](https://github.com/mllam/neural-lam/pull/32), @leifdenby
+- move from `requirements.txt` to `pyproject.toml` for defining package dependencies.
+ [\#37](https://github.com/mllam/neural-lam/pull/37), @leifdenby
+
## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0)
First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication
From 689ef693781bede4e1bc58f3a186214db280d307 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 09:16:46 +0200
Subject: [PATCH 174/273] move dev deps optional dependencies group
---
pyproject.toml | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index fc38e773..d66c0087 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,13 +27,12 @@ dependencies = [
]
requires-python = ">=3.9"
-[tool.pdm.dev-dependencies]
+[project.optional-dependencies]
dev = [
- "pre-commit>=2.15.0",
- "pytest>=8.2.1",
- "pooch>=1.8.1",
+ "pre-commit>=3.8.0",
+ "pytest>=8.3.2",
+ "pooch>=1.8.2",
]
-
[tool.setuptools]
py-modules = ["neural_lam"]
@@ -103,6 +102,8 @@ allow-any-import-level="neural_lam"
[tool.pylint.SIMILARITIES]
min-similarity-lines=10
+
+[tool.pdm]
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
From 9a0d538a683aa104b4ddea76fd295d22e390b63b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 09:23:44 +0200
Subject: [PATCH 175/273] update cicd tests to install dev deps
---
.github/workflows/ci-pdm-install-and-test-cpu.yml | 3 +--
.github/workflows/ci-pdm-install-and-test-gpu.yml | 3 +--
.github/workflows/ci-pip-install-and-test-cpu.yml | 6 +-----
.github/workflows/ci-pip-install-and-test-gpu.yml | 3 +--
4 files changed, 4 insertions(+), 11 deletions(-)
diff --git a/.github/workflows/ci-pdm-install-and-test-cpu.yml b/.github/workflows/ci-pdm-install-and-test-cpu.yml
index 7d31f867..c5da88cc 100644
--- a/.github/workflows/ci-pdm-install-and-test-cpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-cpu.yml
@@ -29,8 +29,7 @@ jobs:
- name: Install package (including dev dependencies)
run: |
- pdm install
- pdm install --dev
+ pdm install --group :all
- name: Print and check torch version
run: |
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
index 94e740ce..9ab4f379 100644
--- a/.github/workflows/ci-pdm-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -39,8 +39,7 @@ jobs:
- name: Install package (including dev dependencies)
run: |
- pdm install
- pdm install --dev
+ pdm install --group :all
- name: Load cache data
uses: actions/cache/restore@v4
diff --git a/.github/workflows/ci-pip-install-and-test-cpu.yml b/.github/workflows/ci-pip-install-and-test-cpu.yml
index c94e70c2..81e402c5 100644
--- a/.github/workflows/ci-pip-install-and-test-cpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-cpu.yml
@@ -19,11 +19,7 @@ jobs:
- name: Install package (including dev dependencies)
run: |
- python -m pip install .
- # pip can't install from "dev" pdm group in pyproject.toml, should we put these requirements
- # for running tests in a separate group? Using "dev" ensures that the these requirements aren't
- # included in build packages
- python -m pip install pytest pooch
+ python -m pip install ".[dev]"
- name: Print and check torch version
run: |
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index 4dfc98c8..ce68946a 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -24,8 +24,7 @@ jobs:
- name: Install package (including dev dependencies)
run: |
- python -m pip install .
- python -m pip install pytest pooch
+ python -m pip install ".[dev]"
- name: Print and check torch version
run: |
From bddfcaf28625e7ff106b7be834d5228cfc458759 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 09:49:01 +0200
Subject: [PATCH 176/273] update readme with new dev deps group
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 409ebff2..b0705377 100644
--- a/README.md
+++ b/README.md
@@ -80,13 +80,13 @@ setup](.github/workflows/) which you can use as a reference.
2. Install `pdm` if you don't have it installed on your system (either with `pip install pdm` or [following the install instructions](https://pdm-project.org/latest/#installation)). If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 5.
3. Create a virtual environment for pdm to use with `pdm venv create --with-pip`.
4. Install a specific version of `torch` with `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)).
-5. Install the dependencies with `pdm install`. If you will be developing `neural-lam` we recommend to install the development dependencies with `pdm install --dev`. By default `pdm` installs the `neural-lam` package in editable mode, so you can make changes to the code and see the effects immediately.
+5. Install the dependencies with `pdm install` (by default this in include the). If you will be developing `neural-lam` we recommend to install the development dependencies with `pdm install --group dev`. By default `pdm` installs the `neural-lam` package in editable mode, so you can make changes to the code and see the effects immediately.
### Using `pip`
1. Clone this repository and navigate to the root directory. If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 3.
2. Install a specific version of `torch` with `python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)).
-3. Install the dependencies with `python -m pip install .`. If you will be developing `neural-lam` we recommend to install in editable mode with `python -m pip install -e .` so you can make changes to the code and see the effects immediately. The development dependencies to install are listed in `pyproject.toml`.
+3. Install the dependencies with `python -m pip install .`. If you will be developing `neural-lam` we recommend to install in editable mode and install the development dependencies with `python -m pip install -e ".[dev]"` so you can make changes to the code and see the effects immediately.
## Data
From b96cfdcd7d6dc3be11280de3aee3c0831e71cdf7 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 10:17:53 +0200
Subject: [PATCH 177/273] quote the skip step the install readme
---
README.md | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index b0705377..7dc6c7ab 100644
--- a/README.md
+++ b/README.md
@@ -77,14 +77,16 @@ setup](.github/workflows/) which you can use as a reference.
### Using `pdm`
1. Clone this repository and navigate to the root directory.
-2. Install `pdm` if you don't have it installed on your system (either with `pip install pdm` or [following the install instructions](https://pdm-project.org/latest/#installation)). If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 5.
+2. Install `pdm` if you don't have it installed on your system (either with `pip install pdm` or [following the install instructions](https://pdm-project.org/latest/#installation)).
+> If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 5.
3. Create a virtual environment for pdm to use with `pdm venv create --with-pip`.
4. Install a specific version of `torch` with `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)).
5. Install the dependencies with `pdm install` (by default this in include the). If you will be developing `neural-lam` we recommend to install the development dependencies with `pdm install --group dev`. By default `pdm` installs the `neural-lam` package in editable mode, so you can make changes to the code and see the effects immediately.
### Using `pip`
-1. Clone this repository and navigate to the root directory. If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 3.
+1. Clone this repository and navigate to the root directory.
+> If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 3.
2. Install a specific version of `torch` with `python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)).
3. Install the dependencies with `python -m pip install .`. If you will be developing `neural-lam` we recommend to install in editable mode and install the development dependencies with `python -m pip install -e ".[dev]"` so you can make changes to the code and see the effects immediately.
From 2600dee50da6757dc68ef089d35185c7304d4087 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 10:34:17 +0200
Subject: [PATCH 178/273] remove unused files
---
neural_lam/datasets/.DS_Store | Bin 6148 -> 0 bytes
neural_lam/datasets/__init__.py | 0
2 files changed, 0 insertions(+), 0 deletions(-)
delete mode 100644 neural_lam/datasets/.DS_Store
delete mode 100644 neural_lam/datasets/__init__.py
diff --git a/neural_lam/datasets/.DS_Store b/neural_lam/datasets/.DS_Store
deleted file mode 100644
index f172ab58d31f03adddb2b8b1d35371f1d00616de..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 6148
zcmeHKJ8nWj3>*gvqBN8#_X@ee3Xv0VfgpisB9O?ZepSwuqhj~%HV}muB>vR5b
z|9aT(^5*MQWu<@=kOERb3P^z)6!6|ln>{2dN&zV#1x^b1_o2}pd*P56pALo?0f
Date: Tue, 20 Aug 2024 11:01:56 +0200
Subject: [PATCH 179/273] revert to line length of 80
---
.pre-commit-config.yaml | 6 --
neural_lam/create_graph.py | 58 +++++++++++++----
neural_lam/datastore/base.py | 4 +-
neural_lam/datastore/mllam.py | 12 +++-
.../multizarr/create_datetime_forcings.py | 11 ++--
.../multizarr/create_normalization_stats.py | 12 +++-
neural_lam/datastore/multizarr/store.py | 37 ++++++++---
neural_lam/datastore/npyfiles/store.py | 30 ++++++---
neural_lam/interaction_net.py | 4 +-
neural_lam/metrics.py | 28 ++++++--
neural_lam/models/ar_model.py | 64 ++++++++++++++-----
neural_lam/models/base_graph_model.py | 15 +++--
neural_lam/models/base_hi_graph_model.py | 36 ++++++++---
neural_lam/models/graph_lam.py | 12 +++-
neural_lam/models/hi_lam.py | 8 ++-
neural_lam/models/hi_lam_parallel.py | 4 +-
neural_lam/train_model.py | 23 +++++--
neural_lam/utils.py | 18 ++++--
neural_lam/vis.py | 24 +++++--
neural_lam/weather_dataset.py | 30 ++++++---
plot_graph.py | 8 ++-
pyproject.toml | 5 +-
tests/conftest.py | 13 ++--
tests/test_datasets.py | 7 +-
tests/test_datastores.py | 6 +-
tests/test_graph_creation.py | 8 ++-
tests/test_training.py | 4 +-
27 files changed, 355 insertions(+), 132 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 91983d9b..dfbf8b60 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -36,9 +36,3 @@ repos:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
additional_dependencies: [Flake8-pyproject]
- - repo: https://github.com/myint/docformatter
- rev: v1.7.5
- hooks:
- - id: docformatter
- args: [--in-place, --recursive, --config, ./pyproject.toml]
- additional_dependencies: [tomli]
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index 6450f134..3126543b 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -35,7 +35,9 @@ def plot_graph(graph, title=None):
# TODO: indicate direction of directed edges
# Move all to cpu and numpy, compute (in)-degrees
- degrees = pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy()
+ degrees = (
+ pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy()
+ )
edge_index = edge_index.cpu().numpy()
pos = pos.cpu().numpy()
@@ -80,7 +82,9 @@ def sort_nodes_internally(nx_graph):
def save_edges(graph, name, base_path):
- torch.save(graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt"))
+ torch.save(
+ graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt")
+ )
edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(
torch.float32
) # Save as float32
@@ -93,7 +97,9 @@ def save_edges_list(graphs, name, base_path):
os.path.join(base_path, f"{name}_edge_index.pt"),
)
edge_features = [
- torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(torch.float32)
+ torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(
+ torch.float32
+ )
for graph in graphs
] # Save as float32
torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt"))
@@ -124,7 +130,11 @@ def mk_2d_graph(xy, nx, ny):
# add diagonal edges
g.add_edges_from(
[((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
- + [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
+ + [
+ ((x + 1, y), (x, y + 1))
+ for x in range(nx - 1)
+ for y in range(ny - 1)
+ ]
)
# turn into directed graph
@@ -252,7 +262,10 @@ def create_graph(
if hierarchical:
# Relabel nodes of each level with level index first
- G = [prepend_node_index(graph, level_i) for level_i, graph in enumerate(G)]
+ G = [
+ prepend_node_index(graph, level_i)
+ for level_i, graph in enumerate(G)
+ ]
num_nodes_level = np.array([len(g_level.nodes) for g_level in G])
# First node index in each level in the hierarchical graph
@@ -294,7 +307,9 @@ def create_graph(
# add edge from mesh to grid
G_down.add_edge(u, v)
d = np.sqrt(
- np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2)
+ np.sum(
+ (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2
+ )
)
G_down.edges[u, v]["len"] = d
G_down.edges[u, v]["vdiff"] = (
@@ -319,10 +334,14 @@ def create_graph(
down_graphs.append(pyg_down)
if create_plot:
- plot_graph(pyg_down, title=f"Down graph, {from_level} -> {to_level}")
+ plot_graph(
+ pyg_down, title=f"Down graph, {from_level} -> {to_level}"
+ )
plt.show()
- plot_graph(pyg_down, title=f"Up graph, {to_level} -> {from_level}")
+ plot_graph(
+ pyg_down, title=f"Up graph, {to_level} -> {from_level}"
+ )
plt.show()
# Save up and down edges
@@ -407,7 +426,9 @@ def create_graph(
vm = G_bottom_mesh.nodes
vm_xy = np.array([xy for _, xy in vm.data("pos")])
# distance between mesh nodes
- dm = np.sqrt(np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2))
+ dm = np.sqrt(
+ np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2)
+ )
# grid nodes
Ny, Nx = xy.shape[1:]
@@ -449,9 +470,13 @@ def create_graph(
u = vg_list[i]
# add edge from grid to mesh
G_g2m.add_edge(u, v)
- d = np.sqrt(np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2))
+ d = np.sqrt(
+ np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2)
+ )
G_g2m.edges[u, v]["len"] = d
- G_g2m.edges[u, v]["vdiff"] = G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]
+ G_g2m.edges[u, v]["vdiff"] = (
+ G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]
+ )
pyg_g2m = from_networkx(G_g2m)
@@ -480,9 +505,13 @@ def create_graph(
u = vm_list[i]
# add edge from mesh to grid
G_m2g.add_edge(u, v)
- d = np.sqrt(np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2))
+ d = np.sqrt(
+ np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2)
+ )
G_m2g.edges[u, v]["len"] = d
- G_m2g.edges[u, v]["vdiff"] = G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]
+ G_m2g.edges[u, v]["vdiff"] = (
+ G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]
+ )
# relabel nodes to integers (sorted)
G_m2g_int = networkx.convert_node_labels_to_integers(
@@ -549,7 +578,8 @@ def cli(input_args=None):
"--plot",
type=int,
default=0,
- help="If graphs should be plotted during generation " "(default: 0 (false))",
+ help="If graphs should be plotted during generation "
+ "(default: 0 (false))",
)
parser.add_argument(
"--levels",
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index b19cbf23..bbbb62dd 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -173,7 +173,9 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
pass
@abc.abstractmethod
- def get_dataarray(self, category: str, split: str) -> Union[xr.DataArray, None]:
+ def get_dataarray(
+ self, category: str, split: str
+ ) -> Union[xr.DataArray, None]:
"""Return the
processed data (as a
single `xr.DataArray`)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 5e44837a..36abe3de 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -41,7 +41,9 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._config_path = Path(config_path)
self._root_path = self._config_path.parent
self._config = mdp.Config.from_yaml_file(self._config_path)
- fp_ds = self._root_path / self._config_path.name.replace(".yaml", ".zarr")
+ fp_ds = self._root_path / self._config_path.name.replace(
+ ".yaml", ".zarr"
+ )
self._ds = None
if reuse_existing and fp_ds.exists():
@@ -263,13 +265,17 @@ def boundary_mask(self) -> xr.DataArray:
"""
ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
- da_state_variable = ds_unstacked["state"].isel(time=0).isel(state_feature=0)
+ da_state_variable = (
+ ds_unstacked["state"].isel(time=0).isel(state_feature=0)
+ )
da_domain_allzero = xr.zeros_like(da_state_variable)
ds_unstacked["boundary_mask"] = da_domain_allzero.isel(
x=slice(self._n_boundary_points, -self._n_boundary_points),
y=slice(self._n_boundary_points, -self._n_boundary_points),
)
- ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(1).astype(int)
+ ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(
+ 1
+ ).astype(int)
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)
@property
diff --git a/neural_lam/datastore/multizarr/create_datetime_forcings.py b/neural_lam/datastore/multizarr/create_datetime_forcings.py
index 7b645cae..d728faaf 100644
--- a/neural_lam/datastore/multizarr/create_datetime_forcings.py
+++ b/neural_lam/datastore/multizarr/create_datetime_forcings.py
@@ -50,7 +50,10 @@ def calculate_datetime_forcing(da_time: xr.DataArray):
dims=["time"],
)
year_seconds = xr.DataArray(
- [get_seconds_in_year(pd.Timestamp(dt_obj).year) for dt_obj in da_time.values],
+ [
+ get_seconds_in_year(pd.Timestamp(dt_obj).year)
+ for dt_obj in da_time.values
+ ],
dims=["time"],
)
hour_angle = (hours_of_day / 12) * np.pi
@@ -91,9 +94,9 @@ def create_datetime_forcing_zarr(
datastore = MultiZarrDatastore(config_path=data_config_path)
da_state = datastore.get_dataarray(category="state", split="train")
- da_datetime_forcing = calculate_datetime_forcing(da_time=da_state.time).expand_dims(
- {"grid_index": da_state.grid_index}
- )
+ da_datetime_forcing = calculate_datetime_forcing(
+ da_time=da_state.time
+ ).expand_dims({"grid_index": da_state.grid_index})
if "x" in da_state.coords and "y" in da_state.coords:
# copy the x and y coordinates to the datetime forcing
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index 7a6df4d2..11da134b 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -55,7 +55,9 @@ def create_normalization_stats_zarr(
for group in combined_stats:
vars_to_combine = group["vars"]
- da_forcing_means = da_forcing_mean.sel(forcing_feature=vars_to_combine)
+ da_forcing_means = da_forcing_mean.sel(
+ forcing_feature=vars_to_combine
+ )
stds = da_forcing_std.sel(forcing_feature=vars_to_combine)
combined_mean = da_forcing_means.mean(dim="forcing_feature")
@@ -64,8 +66,12 @@ def create_normalization_stats_zarr(
da_forcing_mean.loc[
dict(forcing_feature=vars_to_combine)
] = combined_mean
- da_forcing_std.loc[dict(forcing_feature=vars_to_combine)] = combined_std
- print("Computing mean and std.-dev. for one-step differences...", flush=True)
+ da_forcing_std.loc[
+ dict(forcing_feature=vars_to_combine)
+ ] = combined_std
+ print(
+ "Computing mean and std.-dev. for one-step differences...", flush=True
+ )
state_data_normalized = (da_state - da_state_mean) / da_state_std
state_data_diff_normalized = state_data_normalized.diff(dim="time")
diff_mean, diff_std = compute_stats(state_data_diff_normalized)
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index ebcc65e8..3fc714db 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -194,7 +194,9 @@ def get_num_data_vars(self, category):
atmosphere_vars = self._config[category].get("atmosphere_vars", [])
levels = self._config[category].get("levels", [])
- surface_vars_count = len(surface_vars) if surface_vars is not None else 0
+ surface_vars_count = (
+ len(surface_vars) if surface_vars is not None else 0
+ )
atmosphere_vars_count = (
len(atmosphere_vars) if atmosphere_vars is not None else 0
)
@@ -298,7 +300,9 @@ def _filter_dimensions(self, dataset, transpose_array=True):
dataset = self._convert_dataset_to_dataarray(dataset)
if "time" in dataset.dims:
- dataset = dataset.transpose("time", "grid_index", "variable_name")
+ dataset = dataset.transpose(
+ "time", "grid_index", "variable_name"
+ )
else:
dataset = dataset.transpose("grid_index", "variable_name")
dataset_vars = (
@@ -331,9 +335,13 @@ def _reshape_grid_to_2d(self, dataset, grid_shape=None):
x_coords = np.arange(x_dim)
y_coords = np.arange(y_dim)
- multi_index = pd.MultiIndex.from_product([y_coords, x_coords], names=["y", "x"])
+ multi_index = pd.MultiIndex.from_product(
+ [y_coords, x_coords], names=["y", "x"]
+ )
- mindex_coords = xr.Coordinates.from_pandas_multiindex(multi_index, "grid")
+ mindex_coords = xr.Coordinates.from_pandas_multiindex(
+ multi_index, "grid"
+ )
dataset = dataset.drop_vars(["grid", "x", "y"], errors="ignore")
dataset = dataset.assign_coords(mindex_coords)
reshaped_data = dataset.unstack("grid")
@@ -363,7 +371,9 @@ def get_xy(self, category, stacked=True):
dataset = self.open_zarrs(category)
xs, ys = dataset.x.values, dataset.y.values
- assert xs.ndim == ys.ndim, "x and y coordinates must have the same dimensions."
+ assert (
+ xs.ndim == ys.ndim
+ ), "x and y coordinates must have the same dimensions."
if xs.ndim == 1:
x, y = np.meshgrid(xs, ys)
@@ -497,7 +507,9 @@ def _select_stats_by_category(self, combined_stats, category):
stats = stats.drop_vars(["forcing_mean", "forcing_std"])
return stats
elif category == "forcing":
- non_normalized_vars = self.utilities.normalization.non_normalized_vars
+ non_normalized_vars = (
+ self.utilities.normalization.non_normalized_vars
+ )
if non_normalized_vars is None:
non_normalized_vars = []
forcing_vars = self.vars_names(category)
@@ -546,7 +558,9 @@ def _extract_vars(self, category, ds=None):
ds_atmosphere = None
if atmoshere_vars is not None:
- ds_atmosphere = self._extract_atmosphere_vars(category=category, ds=ds)
+ ds_atmosphere = self._extract_atmosphere_vars(
+ category=category, ds=ds
+ )
if ds_surface and ds_atmosphere:
return xr.merge([ds_surface, ds_atmosphere])
@@ -569,8 +583,13 @@ def _extract_atmosphere_vars(self, category, ds):
"""
- if "level" not in list(ds.dims) and self._config[category]["atmosphere_vars"]:
- ds = self._rename_dataset_dims_and_vars(ds.attrs["category"], dataset=ds)
+ if (
+ "level" not in list(ds.dims)
+ and self._config[category]["atmosphere_vars"]
+ ):
+ ds = self._rename_dataset_dims_and_vars(
+ ds.attrs["category"], dataset=ds
+ )
data_arrays = [
ds[var].sel(level=level, drop=True).rename(f"{var}_{level}")
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index ff43a626..923983c2 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -226,7 +226,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# them separately
features = ["toa_downwelling_shortwave_flux", "column_water"]
das = [
- self._get_single_timeseries_dataarray(features=[feature], split=split)
+ self._get_single_timeseries_dataarray(
+ features=[feature], split=split
+ )
for feature in features
]
da = xr.concat(das, dim="feature")
@@ -238,9 +240,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# .chunk({"elapsed_forecast_duration": 1}) this time variable is turned
# into a dask array and so execution of the calculation is delayed
# until the feature values are actually used.
- da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk(
- {"elapsed_forecast_duration": 1}
- )
+ da_forecast_time = (
+ da.analysis_time + da.elapsed_forecast_duration
+ ).chunk({"elapsed_forecast_duration": 1})
da_datetime_forcing_features = self._calc_datetime_forcing_features(
da_time=da_forecast_time
)
@@ -261,7 +263,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
features=features, split=split
)
das.append(da)
- da = xr.concat(das, dim="feature").transpose("grid_index", "feature")
+ da = xr.concat(das, dim="feature").transpose(
+ "grid_index", "feature"
+ )
else:
raise NotImplementedError(category)
@@ -310,8 +314,12 @@ def _get_single_timeseries_dataarray(
"""
assert split in ("train", "val", "test"), "Unknown dataset split"
- if member is not None and features != self.get_vars_names(category="state"):
- raise ValueError("Member can only be specified for the 'state' category")
+ if member is not None and features != self.get_vars_names(
+ category="state"
+ ):
+ raise ValueError(
+ "Member can only be specified for the 'state' category"
+ )
# XXX: we here assume that the grid shape is the same for all categories
grid_shape = self.grid_shape_state
@@ -394,7 +402,9 @@ def _get_single_timeseries_dataarray(
if features_vary_with_analysis_time:
filepaths = [
fp_samples
- / filename_format.format(analysis_time=analysis_time, **file_params)
+ / filename_format.format(
+ analysis_time=analysis_time, **file_params
+ )
for analysis_time in coords["analysis_time"]
]
else:
@@ -455,7 +465,9 @@ def _get_analysis_times(self, split) -> List[np.datetime64]:
times.append(name_parts["analysis_time"])
if len(times) == 0:
- raise ValueError(f"No files found in {sample_dir} with pattern {pattern}")
+ raise ValueError(
+ f"No files found in {sample_dir} with pattern {pattern}"
+ )
return times
diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py
index 5ad0fdca..2dd0a8a2 100644
--- a/neural_lam/interaction_net.py
+++ b/neural_lam/interaction_net.py
@@ -57,7 +57,9 @@ def __init__(
edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0]
# Store number of receiver nodes according to edge_index
self.num_rec = edge_index[1].max() + 1
- edge_index[0] = edge_index[0] + self.num_rec # Make sender indices after rec
+ edge_index[0] = (
+ edge_index[0] + self.num_rec
+ ) # Make sender indices after rec
self.register_buffer("edge_index", edge_index, persistent=False)
# Create MLPs
diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py
index 324440a8..6beaa4f7 100644
--- a/neural_lam/metrics.py
+++ b/neural_lam/metrics.py
@@ -12,7 +12,9 @@ def get_metric(metric_name):
"""
metric_name_lower = metric_name.lower()
- assert metric_name_lower in DEFINED_METRICS, f"Unknown metric: {metric_name}"
+ assert (
+ metric_name_lower in DEFINED_METRICS
+ ), f"Unknown metric: {metric_name}"
return DEFINED_METRICS[metric_name_lower]
@@ -34,13 +36,19 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
"""
# Only keep grid nodes in mask
if mask is not None:
- metric_entry_vals = metric_entry_vals[..., mask, :] # (..., N', d_state)
+ metric_entry_vals = metric_entry_vals[
+ ..., mask, :
+ ] # (..., N', d_state)
# Optionally reduce last two dimensions
if average_grid: # Reduce grid first
- metric_entry_vals = torch.mean(metric_entry_vals, dim=-2) # (..., d_state)
+ metric_entry_vals = torch.mean(
+ metric_entry_vals, dim=-2
+ ) # (..., d_state)
if sum_vars: # Reduce vars second
- metric_entry_vals = torch.sum(metric_entry_vals, dim=-1) # (..., N) or (...,)
+ metric_entry_vals = torch.sum(
+ metric_entry_vals, dim=-1
+ ) # (..., N) or (...,)
return metric_entry_vals
@@ -95,7 +103,9 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
# Replace pred_std with constant ones
- return wmse(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars)
+ return wmse(
+ pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars
+ )
def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
@@ -148,7 +158,9 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
# Replace pred_std with constant ones
- return wmae(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars)
+ return wmae(
+ pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars
+ )
def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
@@ -178,7 +190,9 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)
-def crps_gauss(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
+def crps_gauss(
+ pred, target, pred_std, mask=None, average_grid=True, sum_vars=True
+):
"""(Negative) Continuous Ranked Probability Score (CRPS) Closed-form expression
based on Gaussian predictive distribution.
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index eadd9445..708cad54 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -23,7 +23,9 @@ class ARModel(pl.LightningModule):
# pylint: disable=arguments-differ
# Disable to override args/kwargs from superclass
- def __init__(self, args, datastore: BaseDatastore, forcing_window_size: int):
+ def __init__(
+ self, args, datastore: BaseDatastore, forcing_window_size: int
+ ):
super().__init__()
self.save_hyperparameters(ignore=["datastore"])
self.args = args
@@ -32,13 +34,17 @@ def __init__(self, args, datastore: BaseDatastore, forcing_window_size: int):
split = "train"
num_state_vars = datastore.get_num_data_vars(category="state")
num_forcing_vars = datastore.get_num_data_vars(category="forcing")
- da_static_features = datastore.get_dataarray(category="static", split=split)
+ da_static_features = datastore.get_dataarray(
+ category="static", split=split
+ )
da_state_stats = datastore.get_normalization_dataarray(category="state")
da_boundary_mask = datastore.boundary_mask
# Load static features for grid/data, NB: self.predict_step assumes dimension
# order to be (grid_index, static_feature)
- arr_static = da_static_features.transpose("grid_index", "static_feature").values
+ arr_static = da_static_features.transpose(
+ "grid_index", "static_feature"
+ ).values
self.register_buffer(
"grid_static_features",
torch.tensor(arr_static, dtype=torch.float32),
@@ -131,7 +137,9 @@ def __init__(self, args, datastore: BaseDatastore, forcing_window_size: int):
self.spatial_loss_maps = []
def configure_optimizers(self):
- opt = torch.optim.AdamW(self.parameters(), lr=self.args.lr, betas=(0.9, 0.95))
+ opt = torch.optim.AdamW(
+ self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
+ )
return opt
@property
@@ -178,7 +186,8 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
# Overwrite border with true state
new_state = (
- self.boundary_mask * border_state + self.interior_mask * pred_state
+ self.boundary_mask * border_state
+ + self.interior_mask * pred_state
)
prediction_list.append(new_state)
@@ -223,7 +232,9 @@ def training_step(self, batch):
# Compute loss
batch_loss = torch.mean(
- self.loss(prediction, target, pred_std, mask=self.interior_mask_bool)
+ self.loss(
+ prediction, target, pred_std, mask=self.interior_mask_bool
+ )
) # mean over unrolled times and batch
log_dict = {"train_loss": batch_loss}
@@ -255,7 +266,9 @@ def validation_step(self, batch, batch_idx):
prediction, target, pred_std, _ = self.common_step(batch)
time_step_loss = torch.mean(
- self.loss(prediction, target, pred_std, mask=self.interior_mask_bool),
+ self.loss(
+ prediction, target, pred_std, mask=self.interior_mask_bool
+ ),
dim=0,
) # (time_steps-1)
mean_loss = torch.mean(time_step_loss)
@@ -303,7 +316,9 @@ def test_step(self, batch, batch_idx):
# pred_steps, num_grid_nodes, d_f) or (d_f,)
time_step_loss = torch.mean(
- self.loss(prediction, target, pred_std, mask=self.interior_mask_bool),
+ self.loss(
+ prediction, target, pred_std, mask=self.interior_mask_bool
+ ),
dim=0,
) # (time_steps-1,)
mean_loss = torch.mean(time_step_loss)
@@ -355,14 +370,19 @@ def test_step(self, batch, batch_idx):
# (B, N_log, num_grid_nodes)
# Plot example predictions (on rank 0 only)
- if self.trainer.is_global_zero and self.plotted_examples < self.n_example_pred:
+ if (
+ self.trainer.is_global_zero
+ and self.plotted_examples < self.n_example_pred
+ ):
# Need to plot more example predictions
n_additional_examples = min(
prediction.shape[0],
self.n_example_pred - self.plotted_examples,
)
- self.plot_examples(batch, n_additional_examples, prediction=prediction)
+ self.plot_examples(
+ batch, n_additional_examples, prediction=prediction
+ )
def plot_examples(self, batch, n_examples, prediction=None):
"""Plot the first n_examples forecasts from batch.
@@ -440,12 +460,16 @@ def plot_examples(self, batch, n_examples, prediction=None):
)
}
)
- plt.close("all") # Close all figs for this time step, saves memory
+ plt.close(
+ "all"
+ ) # Close all figs for this time step, saves memory
# Save pred and target as .pt files
torch.save(
pred_slice.cpu(),
- os.path.join(wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"),
+ os.path.join(
+ wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"
+ ),
)
torch.save(
target_slice.cpu(),
@@ -476,7 +500,9 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
if prefix == "test":
# Save pdf
- metric_fig.savefig(os.path.join(wandb.run.dir, f"{full_log_name}.pdf"))
+ metric_fig.savefig(
+ os.path.join(wandb.run.dir, f"{full_log_name}.pdf")
+ )
# Save errors also as csv
np.savetxt(
os.path.join(wandb.run.dir, f"{full_log_name}.csv"),
@@ -526,7 +552,9 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
metric_rescaled = metric_tensor_averaged * self.state_std
# (pred_steps, d_f)
log_dict.update(
- self.create_metric_log_dict(metric_rescaled, prefix, metric_name)
+ self.create_metric_log_dict(
+ metric_rescaled, prefix, metric_name
+ )
)
if self.trainer.is_global_zero and not self.trainer.sanity_checking:
@@ -558,7 +586,9 @@ def on_test_epoch_end(self):
self.data_config,
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
- for t_i, loss_map in zip(self.args.val_steps_to_log, mean_spatial_loss)
+ for t_i, loss_map in zip(
+ self.args.val_steps_to_log, mean_spatial_loss
+ )
]
# log all to same wandb key, sequentially
@@ -598,7 +628,9 @@ def on_load_checkpoint(self, checkpoint):
)
)
for old_key in replace_keys:
- new_key = old_key.replace("g2m_gnn.grid_mlp", "encoding_grid_mlp")
+ new_key = old_key.replace(
+ "g2m_gnn.grid_mlp", "encoding_grid_mlp"
+ )
loaded_state_dict[new_key] = loaded_state_dict[old_key]
del loaded_state_dict[old_key]
if not self.restore_opt:
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 158275dd..decddc7f 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -20,7 +20,9 @@ def __init__(self, args, datastore, forcing_window_size):
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
# num_mesh_nodes indices,
graph_dir_path = datastore.root_path / "graph" / args.graph
- self.hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path)
+ self.hierarchical, graph_ldict = utils.load_graph(
+ graph_dir_path=graph_dir_path
+ )
for name, attr_value in graph_ldict.items():
# Make BufferLists module members and register tensors as buffers
if isinstance(attr_value, torch.Tensor):
@@ -42,7 +44,9 @@ def __init__(self, args, datastore, forcing_window_size):
# Define sub-models
# Feature embedders for grid
self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1)
- self.grid_embedder = utils.make_mlp([self.grid_dim] + self.mlp_blueprint_end)
+ self.grid_embedder = utils.make_mlp(
+ [self.grid_dim] + self.mlp_blueprint_end
+ )
self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end)
self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end)
@@ -68,7 +72,8 @@ def __init__(self, args, datastore, forcing_window_size):
# Output mapping (hidden_dim -> output_dim)
self.output_map = utils.make_mlp(
- [args.hidden_dim] * (args.hidden_layers + 1) + [self.grid_output_dim],
+ [args.hidden_dim] * (args.hidden_layers + 1)
+ + [self.grid_output_dim],
layer_norm=False,
) # No layer norm on this one
@@ -142,7 +147,9 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
) # (B, num_grid_nodes, d_h)
# Map to output dimension, only for grid
- net_output = self.output_map(grid_rep) # (B, num_grid_nodes, d_grid_out)
+ net_output = self.output_map(
+ grid_rep
+ ) # (B, num_grid_nodes, d_grid_out)
if self.output_std:
pred_delta_mean, pred_std_raw = net_output.chunk(
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index 8bfc2c3e..d98af155 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -101,7 +101,9 @@ def get_num_mesh(self):
num_mesh_nodes = sum(
node_feat.shape[0] for node_feat in self.mesh_static_features
)
- num_mesh_nodes_ignore = num_mesh_nodes - self.mesh_static_features[0].shape[0]
+ num_mesh_nodes_ignore = (
+ num_mesh_nodes - self.mesh_static_features[0].shape[0]
+ )
return num_mesh_nodes, num_mesh_nodes_ignore
def embedd_mesh_nodes(self):
@@ -142,15 +144,21 @@ def process_step(self, mesh_rep):
# Embed edges, expand with batch dimension
mesh_same_rep = [
self.expand_to_batch(emb(edge_feat), batch_size)
- for emb, edge_feat in zip(self.mesh_same_embedders, self.m2m_features)
+ for emb, edge_feat in zip(
+ self.mesh_same_embedders, self.m2m_features
+ )
]
mesh_up_rep = [
self.expand_to_batch(emb(edge_feat), batch_size)
- for emb, edge_feat in zip(self.mesh_up_embedders, self.mesh_up_features)
+ for emb, edge_feat in zip(
+ self.mesh_up_embedders, self.mesh_up_features
+ )
]
mesh_down_rep = [
self.expand_to_batch(emb(edge_feat), batch_size)
- for emb, edge_feat in zip(self.mesh_down_embedders, self.mesh_down_features)
+ for emb, edge_feat in zip(
+ self.mesh_down_embedders, self.mesh_down_features
+ )
]
# - MESH INIT. -
@@ -160,14 +168,20 @@ def process_step(self, mesh_rep):
send_node_rep = mesh_rep_levels[
level_l - 1
] # (B, num_mesh_nodes[l-1], d_h)
- rec_node_rep = mesh_rep_levels[level_l] # (B, num_mesh_nodes[l], d_h)
+ rec_node_rep = mesh_rep_levels[
+ level_l
+ ] # (B, num_mesh_nodes[l], d_h)
edge_rep = mesh_up_rep[level_l - 1]
# Apply GNN
- new_node_rep, new_edge_rep = gnn(send_node_rep, rec_node_rep, edge_rep)
+ new_node_rep, new_edge_rep = gnn(
+ send_node_rep, rec_node_rep, edge_rep
+ )
# Update node and edge vectors in lists
- mesh_rep_levels[level_l] = new_node_rep # (B, num_mesh_nodes[l], d_h)
+ mesh_rep_levels[
+ level_l
+ ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h)
# - PROCESSOR -
@@ -184,14 +198,18 @@ def process_step(self, mesh_rep):
send_node_rep = mesh_rep_levels[
level_l + 1
] # (B, num_mesh_nodes[l+1], d_h)
- rec_node_rep = mesh_rep_levels[level_l] # (B, num_mesh_nodes[l], d_h)
+ rec_node_rep = mesh_rep_levels[
+ level_l
+ ] # (B, num_mesh_nodes[l], d_h)
edge_rep = mesh_down_rep[level_l]
# Apply GNN
new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep)
# Update node and edge vectors in lists
- mesh_rep_levels[level_l] = new_node_rep # (B, num_mesh_nodes[l], d_h)
+ mesh_rep_levels[
+ level_l
+ ] = new_node_rep # (B, num_mesh_nodes[l], d_h)
# Return only bottom level representation
return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h)
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index 55befd02..b995ecc9 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -19,7 +19,9 @@ class GraphLAM(BaseGraphModel):
def __init__(self, args, datastore, forcing_window_size):
super().__init__(args, datastore, forcing_window_size)
- assert not self.hierarchical, "GraphLAM does not use a hierarchical mesh graph"
+ assert (
+ not self.hierarchical
+ ), "GraphLAM does not use a hierarchical mesh graph"
# grid_dim from data + static + batch_static
mesh_dim = self.mesh_static_features.shape[1]
@@ -73,7 +75,11 @@ def process_step(self, mesh_rep):
# Embed m2m here first
batch_size = mesh_rep.shape[0]
m2m_emb = self.m2m_embedder(self.m2m_features) # (M_mesh, d_h)
- m2m_emb_expanded = self.expand_to_batch(m2m_emb, batch_size) # (B, M_mesh, d_h)
+ m2m_emb_expanded = self.expand_to_batch(
+ m2m_emb, batch_size
+ ) # (B, M_mesh, d_h)
- mesh_rep, _ = self.processor(mesh_rep, m2m_emb_expanded) # (B, N_mesh, d_h)
+ mesh_rep, _ = self.processor(
+ mesh_rep, m2m_emb_expanded
+ ) # (B, N_mesh, d_h)
return mesh_rep
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index 95300185..0b05b687 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -94,7 +94,9 @@ def mesh_down_step(
reversed(same_gnns[:-1]),
):
# Extract representations
- send_node_rep = mesh_rep_levels[level_l + 1] # (B, N_mesh[l+1], d_h)
+ send_node_rep = mesh_rep_levels[
+ level_l + 1
+ ] # (B, N_mesh[l+1], d_h)
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
down_edge_rep = mesh_down_rep[level_l]
same_edge_rep = mesh_same_rep[level_l]
@@ -128,7 +130,9 @@ def mesh_up_step(
zip(up_gnns, same_gnns[1:]), start=1
):
# Extract representations
- send_node_rep = mesh_rep_levels[level_l - 1] # (B, N_mesh[l-1], d_h)
+ send_node_rep = mesh_rep_levels[
+ level_l - 1
+ ] # (B, N_mesh[l-1], d_h)
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
up_edge_rep = mesh_up_rep[level_l - 1]
same_edge_rep = mesh_same_rep[level_l]
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index 26357281..8655746c 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -76,7 +76,9 @@ def hi_processor_step(
mesh_rep, mesh_edge_rep = self.processor(mesh_rep, mesh_edge_rep)
# Split up again for read-out step
- mesh_rep_levels = list(torch.split(mesh_rep, self.level_mesh_sizes, dim=1))
+ mesh_rep_levels = list(
+ torch.split(mesh_rep, self.level_mesh_sizes, dim=1)
+ )
mesh_edge_rep_sections = torch.split(
mesh_edge_rep, self.edge_split_sections, dim=1
)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index e819c403..3e37ebff 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -38,7 +38,9 @@ def _init_datastore(datastore_kind, config_path):
def main(input_args=None):
"""Main function for training and evaluating models."""
- parser = ArgumentParser(description="Train or evaluate NeurWP models for LAM")
+ parser = ArgumentParser(
+ description="Train or evaluate NeurWP models for LAM"
+ )
parser.add_argument(
"datastore_kind",
type=str,
@@ -83,7 +85,8 @@ def main(input_args=None):
"--restore_opt",
type=int,
default=0,
- help="If optimizer state should be restored with model " "(default: 0 (false))",
+ help="If optimizer state should be restored with model "
+ "(default: 0 (false))",
)
parser.add_argument(
"--precision",
@@ -97,7 +100,8 @@ def main(input_args=None):
"--graph",
type=str,
default="multiscale",
- help="Graph to load and use in graph-based model " "(default: multiscale)",
+ help="Graph to load and use in graph-based model "
+ "(default: multiscale)",
)
parser.add_argument(
"--hidden_dim",
@@ -145,7 +149,8 @@ def main(input_args=None):
"--control_only",
type=int,
default=0,
- help="Train only on control member of ensemble data " "(default: 0 (False))",
+ help="Train only on control member of ensemble data "
+ "(default: 0 (False))",
)
parser.add_argument(
"--loss",
@@ -160,7 +165,8 @@ def main(input_args=None):
"--val_interval",
type=int,
default=1,
- help="Number of epochs training between each validation run " "(default: 1)",
+ help="Number of epochs training between each validation run "
+ "(default: 1)",
)
# Evaluation options
@@ -181,7 +187,8 @@ def main(input_args=None):
"--n_example_pred",
type=int,
default=1,
- help="Number of example predictions to plot during evaluation " "(default: 1)",
+ help="Number of example predictions to plot during evaluation "
+ "(default: 1)",
)
# Logger Settings
@@ -254,7 +261,9 @@ def main(input_args=None):
# Instantiate model + trainer
if torch.cuda.is_available():
device_name = "cuda"
- torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s
+ torch.set_float32_matmul_precision(
+ "high"
+ ) # Allows using Tensor Cores on A100s
else:
device_name = "cpu"
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 2ebe7b4d..7d2bd228 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -114,7 +114,9 @@ def loads_file(fn):
) # List of (N_mesh[l], d_mesh_static)
# Some checks for consistency
- assert len(m2m_features) == n_levels, "Inconsistent number of levels in mesh"
+ assert (
+ len(m2m_features) == n_levels
+ ), "Inconsistent number of levels in mesh"
assert (
len(mesh_static_features) == n_levels
), "Inconsistent number of levels in mesh"
@@ -137,15 +139,23 @@ def loads_file(fn):
# Rescale
mesh_up_features = BufferList(
- [edge_features / longest_edge for edge_features in mesh_up_features],
+ [
+ edge_features / longest_edge
+ for edge_features in mesh_up_features
+ ],
persistent=False,
)
mesh_down_features = BufferList(
- [edge_features / longest_edge for edge_features in mesh_down_features],
+ [
+ edge_features / longest_edge
+ for edge_features in mesh_down_features
+ ],
persistent=False,
)
- mesh_static_features = BufferList(mesh_static_features, persistent=False)
+ mesh_static_features = BufferList(
+ mesh_static_features, persistent=False
+ )
else:
# Extract single mesh level
m2m_edge_index = m2m_edge_index[0]
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index e5c970c4..d0b1a428 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -53,7 +53,9 @@ def plot_error_map(
ax.set_yticks(np.arange(d_f))
var_names = datastore.get_vars_names(category="state")
var_units = datastore.get_vars_units(category="state")
- y_ticklabels = [f"{name} ({unit})" for name, unit in zip(var_names, var_units)]
+ y_ticklabels = [
+ f"{name} ({unit})" for name, unit in zip(var_names, var_units)
+ ]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)
if title:
@@ -88,7 +90,9 @@ def plot_prediction(
# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
- pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region
+ pixel_alpha = (
+ mask_reshaped.clamp(0.7, 1).cpu().numpy()
+ ) # Faded border region
fig, axes = plt.subplots(
1,
@@ -101,7 +105,9 @@ def plot_prediction(
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
data_grid = (
- data.reshape(list(datastore.grid_shape_state.values.values())).cpu().numpy()
+ data.reshape(list(datastore.grid_shape_state.values.values()))
+ .cpu()
+ .numpy()
)
im = ax.imshow(
data_grid,
@@ -138,8 +144,12 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
extent = data_config.get_xy_extent("state")
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(list(data_config.grid_shape_state.values.values()))
- pixel_alpha = mask_reshaped.clamp(0.7, 1).cpu().numpy() # Faded border region
+ mask_reshaped = obs_mask.reshape(
+ list(data_config.grid_shape_state.values.values())
+ )
+ pixel_alpha = (
+ mask_reshaped.clamp(0.7, 1).cpu().numpy()
+ ) # Faded border region
fig, ax = plt.subplots(
figsize=(5, 4.8),
@@ -148,7 +158,9 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
ax.coastlines() # Add coastline outlines
error_grid = (
- error.reshape(list(data_config.grid_shape_state.values.values())).cpu().numpy()
+ error.reshape(list(data_config.grid_shape_state.values.values()))
+ .cpu()
+ .numpy()
)
im = ax.imshow(
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index a8213922..9b567afc 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -32,7 +32,9 @@ def __init__(
self.ar_steps = ar_steps
self.datastore = datastore
- self.da_state = self.datastore.get_dataarray(category="state", split=self.split)
+ self.da_state = self.datastore.get_dataarray(
+ category="state", split=self.split
+ )
self.da_forcing = self.datastore.get_dataarray(
category="forcing", split=self.split
)
@@ -60,8 +62,10 @@ def __init__(
self.da_state_std = self.ds_state_stats.state_std
if self.da_forcing is not None:
- self.ds_forcing_stats = self.datastore.get_normalization_dataarray(
- category="forcing"
+ self.ds_forcing_stats = (
+ self.datastore.get_normalization_dataarray(
+ category="forcing"
+ )
)
self.da_forcing_mean = self.ds_forcing_stats.forcing_mean
self.da_forcing_std = self.ds_forcing_stats.forcing_std
@@ -139,7 +143,9 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
else:
# only `time` dimension for analysis only data
da = da.isel(
- time=slice(idx + n_timesteps_offset, idx + n_steps + n_timesteps_offset)
+ time=slice(
+ idx + n_timesteps_offset, idx + n_steps + n_timesteps_offset
+ )
)
return da
@@ -191,7 +197,9 @@ def __getitem__(self, idx):
# handle time sampling in a way that is compatible with both analysis
# and forecast data
- da_state = self._sample_time(da=da_state, idx=idx, n_steps=2 + self.ar_steps)
+ da_state = self._sample_time(
+ da=da_state, idx=idx, n_steps=2 + self.ar_steps
+ )
if da_forcing is not None:
das_forcing = []
@@ -226,7 +234,9 @@ def __getitem__(self, idx):
batch_times = da_target_states.time.values.astype(float)
if self.standardize:
- da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std
+ da_init_states = (
+ da_init_states - self.da_state_mean
+ ) / self.da_state_std
da_target_states = (
da_target_states - self.da_state_mean
) / self.da_state_std
@@ -244,7 +254,9 @@ def __getitem__(self, idx):
)
init_states = torch.tensor(da_init_states.values, dtype=torch.float32)
- target_states = torch.tensor(da_target_states.values, dtype=torch.float32)
+ target_states = torch.tensor(
+ da_target_states.values, dtype=torch.float32
+ )
if self.da_forcing is None:
# create an empty forcing tensor
@@ -253,7 +265,9 @@ def __getitem__(self, idx):
dtype=torch.float32,
)
else:
- forcing = torch.tensor(da_forcing_windowed.values, dtype=torch.float32)
+ forcing = torch.tensor(
+ da_forcing_windowed.values, dtype=torch.float32
+ )
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps, N_grid, d_features)
diff --git a/plot_graph.py b/plot_graph.py
index e84bb627..b7b710bf 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -64,7 +64,9 @@ def main():
# Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
- grid_pos = np.concatenate((grid_pos, np.expand_dims(z_grid, axis=1)), axis=1)
+ grid_pos = np.concatenate(
+ (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1
+ )
# List of edges to plot, (edge_index, color, line_width, label)
edge_plot_list = [
@@ -116,7 +118,9 @@ def main():
z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees
mesh_node_size = mesh_degrees / 2
- mesh_pos = np.concatenate((mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1)
+ mesh_pos = np.concatenate(
+ (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1
+ )
edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M"))
diff --git a/pyproject.toml b/pyproject.toml
index 1c86119c..49f6732f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,6 +45,9 @@ dev = [
[tool.setuptools]
py-modules = ["neural_lam"]
+[tool.black]
+line-length = 80
+
[tool.isort]
default_section = "THIRDPARTY"
profile = "black"
@@ -67,7 +70,7 @@ known_first_party = [
]
[tool.flake8]
-max-line-length = 88
+max-line-length = 80
ignore = [
"E203", # Allow whitespace before ':' (https://github.com/PyCQA/pycodestyle/issues/373)
"I002", # Don't check for isort configuration
diff --git a/tests/conftest.py b/tests/conftest.py
index c8afc109..f0d1c2f5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -105,7 +105,8 @@ def bootstrap_multizarr_example():
# here assume that the data-config is referring the the default path
# for the "datetime forcings" dataset
datetime_forcing_zarr_path = (
- data_config_path.parent / multizarr.create_datetime_forcings.DEFAULT_FILENAME
+ data_config_path.parent
+ / multizarr.create_datetime_forcings.DEFAULT_FILENAME
)
if not datetime_forcing_zarr_path.exists():
multizarr.create_datetime_forcings.create_datetime_forcing_zarr(
@@ -113,7 +114,8 @@ def bootstrap_multizarr_example():
)
normalized_forcing_zarr_path = (
- data_config_path.parent / multizarr.create_normalization_stats.DEFAULT_FILENAME
+ data_config_path.parent
+ / multizarr.create_normalization_stats.DEFAULT_FILENAME
)
if not normalized_forcing_zarr_path.exists():
multizarr.create_normalization_stats.create_normalization_stats_zarr(
@@ -121,7 +123,8 @@ def bootstrap_multizarr_example():
)
boundary_mask_path = (
- data_config_path.parent / multizarr.create_boundary_mask.DEFAULT_FILENAME
+ data_config_path.parent
+ / multizarr.create_boundary_mask.DEFAULT_FILENAME
)
if not boundary_mask_path.exists():
@@ -137,7 +140,9 @@ def bootstrap_multizarr_example():
DATASTORES_EXAMPLES = dict(
multizarr=dict(config_path=bootstrap_multizarr_example()),
mllam=dict(
- config_path=DATASTORE_EXAMPLES_ROOT_PATH / "mllam" / "danra.example.yaml"
+ config_path=DATASTORE_EXAMPLES_ROOT_PATH
+ / "mllam"
+ / "danra.example.yaml"
),
npyfiles=dict(config_path=download_meps_example_reduced_dataset()),
)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 8ae9d917..06deeaa4 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -60,7 +60,8 @@ def test_dataset_item(datastore_name):
assert forcing.shape[0] == N_pred_steps
assert forcing.shape[1] == N_gridpoints
assert (
- forcing.shape[2] == datastore.get_num_data_vars("forcing") * forcing_window_size
+ forcing.shape[2]
+ == datastore.get_num_data_vars("forcing") * forcing_window_size
)
# batch times
@@ -83,7 +84,9 @@ def test_single_batch(datastore_name, split):
"""
datastore = init_datastore(datastore_name)
- device_name = torch.device("cuda") if torch.cuda.is_available() else "cpu" # noqa
+ device_name = (
+ torch.device("cuda") if torch.cuda.is_available() else "cpu"
+ ) # noqa
graph_name = "1level"
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 512bc5a0..0955b8cc 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -56,9 +56,9 @@ def test_config(datastore_name):
datastore = init_datastore(datastore_name)
# check the config is a mapping or a dataclass
config = datastore.config
- assert isinstance(config, collections.abc.Mapping) or dataclasses.is_dataclass(
- config
- )
+ assert isinstance(
+ config, collections.abc.Mapping
+ ) or dataclasses.is_dataclass(config)
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py
index 652e3dce..3a36109f 100644
--- a/tests/test_graph_creation.py
+++ b/tests/test_graph_creation.py
@@ -81,7 +81,9 @@ def test_graph_creation(datastore_name, graph_name):
assert isinstance(result, torch.Tensor)
if file_id.endswith("_index"):
- assert result.shape[0] == 2 # adjacency matrix uses two rows
+ assert (
+ result.shape[0] == 2
+ ) # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert result.shape[1] == d_features
@@ -90,7 +92,9 @@ def test_graph_creation(datastore_name, graph_name):
if not hierarchical:
assert len(result) == 1
else:
- if file_id.startswith("mesh_up") or file_id.startswith("mesh_down"):
+ if file_id.startswith("mesh_up") or file_id.startswith(
+ "mesh_down"
+ ):
assert len(result) == n_max_levels - 1
else:
assert len(result) == n_max_levels
diff --git a/tests/test_training.py b/tests/test_training.py
index 19d48e3a..33dd8203 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -20,7 +20,9 @@ def test_training(datastore_name):
if torch.cuda.is_available():
device_name = "cuda"
- torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s
+ torch.set_float32_matmul_precision(
+ "high"
+ ) # Allows using Tensor Cores on A100s
else:
device_name = "cpu"
From 46b37f85624832186887399fb96ecc9c7d8b7087 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 13:09:01 +0200
Subject: [PATCH 180/273] revert docstring formatting changes
---
neural_lam/interaction_net.py | 36 ++++++------
neural_lam/metrics.py | 34 +++++------
neural_lam/models/ar_model.py | 72 ++++++++++++++----------
neural_lam/models/base_graph_model.py | 23 +++++---
neural_lam/models/base_hi_graph_model.py | 36 ++++++------
neural_lam/models/graph_lam.py | 28 +++++----
neural_lam/models/hi_lam.py | 43 ++++++++------
neural_lam/models/hi_lam_parallel.py | 15 ++---
neural_lam/utils.py | 16 ++++--
neural_lam/vis.py | 8 ++-
neural_lam/weather_dataset.py | 55 ++++++++----------
11 files changed, 201 insertions(+), 165 deletions(-)
diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py
index 2dd0a8a2..2f45b03f 100644
--- a/neural_lam/interaction_net.py
+++ b/neural_lam/interaction_net.py
@@ -8,10 +8,9 @@
class InteractionNet(pyg.nn.MessagePassing):
- """Implementation of a generic Interaction Network, from Battaglia et al.
-
- (2016)
-
+ """
+ Implementation of a generic Interaction Network,
+ from Battaglia et al. (2016)
"""
# pylint: disable=arguments-differ
@@ -28,7 +27,8 @@ def __init__(
aggr_chunk_sizes=None,
aggr="sum",
):
- """Create a new InteractionNet.
+ """
+ Create a new InteractionNet
edge_index: (2,M), Edges in pyg format
input_dim: Dimensionality of input representations,
@@ -44,7 +44,6 @@ def __init__(
representation into and use separate MLPs for
(None = no chunking, same MLP)
aggr: Message aggregation method (sum/mean)
-
"""
assert aggr in ("sum", "mean"), f"Unknown aggregation method: {aggr}"
super().__init__(aggr=aggr)
@@ -85,8 +84,9 @@ def __init__(
self.update_edges = update_edges
def forward(self, send_rep, rec_rep, edge_rep):
- """Apply interaction network to update the representations of receiver nodes,
- and optionally the edge representations.
+ """
+ Apply interaction network to update the representations of receiver
+ nodes, and optionally the edge representations.
send_rep: (N_send, d_h), vector representations of sender nodes
rec_rep: (N_rec, d_h), vector representations of receiver nodes
@@ -96,7 +96,6 @@ def forward(self, send_rep, rec_rep, edge_rep):
rec_rep: (N_rec, d_h), updated vector representations of receiver nodes
(optionally) edge_rep: (M, d_h), updated vector representations
of edges
-
"""
# Always concatenate to [rec_nodes, send_nodes] for propagation,
# but only aggregate to rec_nodes
@@ -116,7 +115,9 @@ def forward(self, send_rep, rec_rep, edge_rep):
return rec_rep
def message(self, x_j, x_i, edge_attr):
- """Compute messages from node j to node i."""
+ """
+ Compute messages from node j to node i.
+ """
return self.edge_mlp(torch.cat((edge_attr, x_j, x_i), dim=-1))
# pylint: disable-next=signature-differs
@@ -131,13 +132,10 @@ def aggregate(self, inputs, index, ptr, dim_size):
class SplitMLPs(nn.Module):
- """Module that feeds chunks of input through different MLPs.
-
- Split up input along dim
- -2 using given chunk sizes
- and feeds each chunk
- through separate MLPs.
-
+ """
+ Module that feeds chunks of input through different MLPs.
+ Split up input along dim -2 using given chunk sizes and feeds
+ each chunk through separate MLPs.
"""
def __init__(self, mlps, chunk_sizes):
@@ -150,13 +148,13 @@ def __init__(self, mlps, chunk_sizes):
self.chunk_sizes = chunk_sizes
def forward(self, x):
- """Chunk up input and feed through MLPs.
+ """
+ Chunk up input and feed through MLPs
x: (..., N, d), where N = sum(chunk_sizes)
Returns:
joined_output: (..., N, d), concatenated results from the MLPs
-
"""
chunks = torch.split(x, self.chunk_sizes, dim=-2)
chunk_outputs = [
diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py
index 6beaa4f7..7db2cca6 100644
--- a/neural_lam/metrics.py
+++ b/neural_lam/metrics.py
@@ -3,13 +3,13 @@
def get_metric(metric_name):
- """Get a defined metric with given name.
+ """
+ Get a defined metric with given name
metric_name: str, name of the metric
Returns:
metric: function implementing the metric
-
"""
metric_name_lower = metric_name.lower()
assert (
@@ -19,7 +19,8 @@ def get_metric(metric_name):
def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
- """Masks and (optionally) reduces entry-wise metric values.
+ """
+ Masks and (optionally) reduces entry-wise metric values
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -32,7 +33,6 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
-
"""
# Only keep grid nodes in mask
if mask is not None:
@@ -54,7 +54,8 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """Weighted Mean Squared Error.
+ """
+ Weighted Mean Squared Error
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -69,7 +70,6 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
-
"""
entry_mse = torch.nn.functional.mse_loss(
pred, target, reduction="none"
@@ -85,7 +85,8 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """(Unweighted) Mean Squared Error.
+ """
+ (Unweighted) Mean Squared Error
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -100,7 +101,6 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
-
"""
# Replace pred_std with constant ones
return wmse(
@@ -109,7 +109,8 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """Weighted Mean Absolute Error.
+ """
+ Weighted Mean Absolute Error
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -124,7 +125,6 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
-
"""
entry_mae = torch.nn.functional.l1_loss(
pred, target, reduction="none"
@@ -140,7 +140,8 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """(Unweighted) Mean Absolute Error.
+ """
+ (Unweighted) Mean Absolute Error
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -155,7 +156,6 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
-
"""
# Replace pred_std with constant ones
return wmae(
@@ -164,7 +164,8 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
- """Negative Log Likelihood loss, for isotropic Gaussian likelihood.
+ """
+ Negative Log Likelihood loss, for isotropic Gaussian likelihood
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -179,7 +180,6 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
-
"""
# Broadcast pred_std if shaped (d_state,), done internally in Normal class
dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state)
@@ -193,8 +193,9 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def crps_gauss(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True
):
- """(Negative) Continuous Ranked Probability Score (CRPS) Closed-form expression
- based on Gaussian predictive distribution.
+ """
+ (Negative) Continuous Ranked Probability Score (CRPS)
+ Closed-form expression based on Gaussian predictive distribution
(...,) is any number of batch dimensions, potentially different
but broadcastable
@@ -209,7 +210,6 @@ def crps_gauss(
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state),
depending on reduction arguments.
-
"""
std_normal = torch.distributions.Normal(
torch.zeros((), device=pred.device), torch.ones((), device=pred.device)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 708cad54..a0a7880c 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -14,10 +14,9 @@
class ARModel(pl.LightningModule):
- """Generic auto-regressive weather model.
-
+ """
+ Generic auto-regressive weather model.
Abstract class that can be extended.
-
"""
# pylint: disable=arguments-differ
@@ -40,8 +39,8 @@ def __init__(
da_state_stats = datastore.get_normalization_dataarray(category="state")
da_boundary_mask = datastore.boundary_mask
- # Load static features for grid/data, NB: self.predict_step assumes dimension
- # order to be (grid_index, static_feature)
+ # Load static features for grid/data, NB: self.predict_step assumes
+ # dimension order to be (grid_index, static_feature)
arr_static = da_static_features.transpose(
"grid_index", "static_feature"
).values
@@ -144,12 +143,16 @@ def configure_optimizers(self):
@property
def interior_mask_bool(self):
- """Get the interior mask as a boolean (N,) mask."""
+ """
+ Get the interior mask as a boolean (N,) mask.
+ """
return self.interior_mask[:, 0].to(torch.bool)
@staticmethod
def expand_to_batch(x, batch_size):
- """Expand tensor with initial batch dimension."""
+ """
+ Expand tensor with initial batch dimension
+ """
return x.unsqueeze(0).expand(batch_size, -1, -1)
def predict_step(self, prev_state, prev_prev_state, forcing):
@@ -211,11 +214,13 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
return prediction, pred_std
def common_step(self, batch):
- """Predict on single batch batch consists of: init_states: (B, 2,
+ """
+ Predict on single batch batch consists of: init_states: (B, 2,
num_grid_nodes, d_features) target_states: (B, pred_steps,
num_grid_nodes, d_features) forcing_features: (B, pred_steps,
- num_grid_nodes, d_forcing), where index 0 corresponds to index 1 of
- init_states."""
+ num_grid_nodes, d_forcing),
+ where index 0 corresponds to index 1 of init_states
+ """
(init_states, target_states, forcing_features, batch_times) = batch
prediction, pred_std = self.unroll_prediction(
@@ -227,7 +232,9 @@ def common_step(self, batch):
return prediction, target_states, pred_std, batch_times
def training_step(self, batch):
- """Train on single batch."""
+ """
+ Train on single batch
+ """
prediction, target, pred_std, _ = self.common_step(batch)
# Compute loss
@@ -249,20 +256,22 @@ def training_step(self, batch):
return batch_loss
def all_gather_cat(self, tensor_to_gather):
- """Gather tensors across all ranks, and concatenate across dim. 0 (instead of
- stacking in new dim. 0)
+ """
+ Gather tensors across all ranks, and concatenate across dim. 0 (instead
+ of stacking in new dim. 0)
tensor_to_gather: (d1, d2, ...), distributed over K ranks
returns: (K*d1, d2, ...)
-
"""
return self.all_gather(tensor_to_gather).flatten(0, 1)
# newer lightning versions requires batch_idx argument, even if unused
# pylint: disable-next=unused-argument
def validation_step(self, batch, batch_idx):
- """Run validation on single batch."""
+ """
+ Run validation on single batch
+ """
prediction, target, pred_std, _ = self.common_step(batch)
time_step_loss = torch.mean(
@@ -299,7 +308,9 @@ def validation_step(self, batch, batch_idx):
self.val_metrics["mse"].append(entry_mses)
def on_validation_epoch_end(self):
- """Compute val metrics at the end of val epoch."""
+ """
+ Compute val metrics at the end of val epoch
+ """
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.val_metrics, prefix="val")
@@ -309,7 +320,9 @@ def on_validation_epoch_end(self):
# pylint: disable-next=unused-argument
def test_step(self, batch, batch_idx):
- """Run test on single batch."""
+ """
+ Run test on single batch
+ """
# TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
@@ -385,13 +398,13 @@ def test_step(self, batch, batch_idx):
)
def plot_examples(self, batch, n_examples, prediction=None):
- """Plot the first n_examples forecasts from batch.
+ """
+ Plot the first n_examples forecasts from batch
batch: batch with data to plot corresponding forecasts for n_examples:
number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes,
d_f), existing prediction.
Generate if None.
-
"""
if prediction is None:
prediction, target, _, _ = self.common_step(batch)
@@ -479,15 +492,15 @@ def plot_examples(self, batch, n_examples, prediction=None):
)
def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
- """Put together a dict with everything to log for one metric. Also saves plots
- as pdf and csv if using test prefix.
+ """
+ Put together a dict with everything to log for one metric. Also saves
+ plots as pdf and csv if using test prefix.
metric_tensor: (pred_steps, d_f), metric values per time and variable
prefix: string, prefix to use for logging metric_name: string, name of
the metric
Return: log_dict: dict with everything to log for given metric
-
"""
log_dict = {}
metric_fig = vis.plot_error_map(
@@ -526,12 +539,12 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
return log_dict
def aggregate_and_plot_metrics(self, metrics_dict, prefix):
- """Aggregate and create error map plots for all metrics in metrics_dict.
+ """
+ Aggregate and create error map plots for all metrics in metrics_dict
metrics_dict: dictionary with metric_names and list of tensors
with step-evals.
prefix: string, prefix to use for logging
-
"""
log_dict = {}
for metric_name, metric_val_list in metrics_dict.items():
@@ -562,10 +575,9 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
plt.close("all") # Close all figs
def on_test_epoch_end(self):
- """Compute test metrics and make plots at the end of test epoch.
-
- Will gather stored tensors and perform plotting and logging on rank 0.
-
+ """
+ Compute test metrics and make plots at the end of test epoch. Will
+ gather stored tensors and perform plotting and logging on rank 0.
"""
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")
@@ -615,7 +627,9 @@ def on_test_epoch_end(self):
self.spatial_loss_maps.clear()
def on_load_checkpoint(self, checkpoint):
- """Perform any changes to state dict before loading checkpoint."""
+ """
+ Perform any changes to state dict before loading checkpoint
+ """
loaded_state_dict = checkpoint["state_dict"]
# Fix for loading older models after IneractionNet refactoring, where
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index decddc7f..16897e4f 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -8,8 +8,10 @@
class BaseGraphModel(ARModel):
- """Base (abstract) class for graph-based models building on the encode- process-
- decode idea."""
+ """
+ Base (abstract) class for graph-based models building on
+ the encode-process-decode idea.
+ """
def __init__(self, args, datastore, forcing_window_size):
super().__init__(
@@ -78,21 +80,26 @@ def __init__(self, args, datastore, forcing_window_size):
) # No layer norm on this one
def get_num_mesh(self):
- """Compute number of mesh nodes from loaded features, and number of mesh nodes
- that should be ignored in encoding/decoding."""
+ """
+ Compute number of mesh nodes from loaded features,
+ and number of mesh nodes that should be ignored in encoding/decoding
+ """
raise NotImplementedError("get_num_mesh not implemented")
def embedd_mesh_nodes(self):
- """Embed static mesh features Returns tensor of shape (num_mesh_nodes, d_h)"""
+ """
+ Embed static mesh features
+ Returns tensor of shape (num_mesh_nodes, d_h)
+ """
raise NotImplementedError("embedd_mesh_nodes not implemented")
def process_step(self, mesh_rep):
- """Process step of embedd-process-decode framework Processes the representation
- on the mesh, possible in multiple steps.
+ """
+ Process step of embedd-process-decode framework
+ Processes the representation on the mesh, possible in multiple steps
mesh_rep: has shape (B, num_mesh_nodes, d_h)
Returns mesh_rep: (B, num_mesh_nodes, d_h)
-
"""
raise NotImplementedError("process_step not implemented")
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index d98af155..a2ebcc1b 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -8,7 +8,9 @@
class BaseHiGraphModel(BaseGraphModel):
- """Base class for hierarchical graph models."""
+ """
+ Base class for hierarchical graph models.
+ """
def __init__(self, args):
super().__init__(args)
@@ -96,8 +98,10 @@ def __init__(self, args):
)
def get_num_mesh(self):
- """Compute number of mesh nodes from loaded features, and number of mesh nodes
- that should be ignored in encoding/decoding."""
+ """
+ Compute number of mesh nodes from loaded features,
+ and number of mesh nodes that should be ignored in encoding/decoding
+ """
num_mesh_nodes = sum(
node_feat.shape[0] for node_feat in self.mesh_static_features
)
@@ -107,25 +111,21 @@ def get_num_mesh(self):
return num_mesh_nodes, num_mesh_nodes_ignore
def embedd_mesh_nodes(self):
- """Embed static mesh
- features This embeds
- only bottom level,
- rest is done at
- beginning of
+ """
+ Embed static mesh features
+ This embeds only bottom level, rest is done at beginning of
processing step
- Returns tensor of
- shape
- (num_mesh_nodes[0],
- d_h)"""
+ Returns tensor of shape (num_mesh_nodes[0], d_h)
+ """
return self.mesh_embedders[0](self.mesh_static_features[0])
def process_step(self, mesh_rep):
- """Process step of embedd-process-decode framework Processes the representation
- on the mesh, possible in multiple steps.
+ """
+ Process step of embedd-process-decode framework
+ Processes the representation on the mesh, possible in multiple steps
mesh_rep: has shape (B, num_mesh_nodes, d_h)
Returns mesh_rep: (B, num_mesh_nodes, d_h)
-
"""
batch_size = mesh_rep.shape[0]
@@ -217,8 +217,9 @@ def process_step(self, mesh_rep):
def hi_processor_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
):
- """Internal processor step of hierarchical graph models. Between mesh init and
- read out.
+ """
+ Internal processor step of hierarchical graph models.
+ Between mesh init and read out.
Each input is list with representations, each with shape
@@ -228,6 +229,5 @@ def hi_processor_step(
mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
-
"""
raise NotImplementedError("hi_process_step not implemented")
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index b995ecc9..a4c726b1 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -8,12 +8,11 @@
class GraphLAM(BaseGraphModel):
- """Full graph-based LAM model that can be used with different (non- hierarchical
- )graphs.
-
- Mainly based on GraphCast, but the model from Keisler (2022) is almost identical.
- Used for GC-LAM and L1-LAM in Oskarsson et al. (2023).
-
+ """
+ Full graph-based LAM model that can be used with different
+ (non-hierarchical )graphs. Mainly based on GraphCast, but the model from
+ Keisler (2022) is almost identical. Used for GC-LAM and L1-LAM in
+ Oskarsson et al. (2023).
"""
def __init__(self, args, datastore, forcing_window_size):
@@ -56,21 +55,26 @@ def __init__(self, args, datastore, forcing_window_size):
)
def get_num_mesh(self):
- """Compute number of mesh nodes from loaded features, and number of mesh nodes
- that should be ignored in encoding/decoding."""
+ """
+ Compute number of mesh nodes from loaded features,
+ and number of mesh nodes that should be ignored in encoding/decoding
+ """
return self.mesh_static_features.shape[0], 0
def embedd_mesh_nodes(self):
- """Embed static mesh features Returns tensor of shape (N_mesh, d_h)"""
+ """
+ Embed static mesh features
+ Returns tensor of shape (N_mesh, d_h)
+ """
return self.mesh_embedder(self.mesh_static_features) # (N_mesh, d_h)
def process_step(self, mesh_rep):
- """Process step of embedd-process-decode framework Processes the representation
- on the mesh, possible in multiple steps.
+ """
+ Process step of embedd-process-decode framework
+ Processes the representation on the mesh, possible in multiple steps
mesh_rep: has shape (B, N_mesh, d_h)
Returns mesh_rep: (B, N_mesh, d_h)
-
"""
# Embed m2m here first
batch_size = mesh_rep.shape[0]
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index 0b05b687..3d6905c7 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -1,17 +1,16 @@
# Third-party
from torch import nn
-# Local
-from ..interaction_net import InteractionNet
-from .base_hi_graph_model import BaseHiGraphModel
+# First-party
+from neural_lam.interaction_net import InteractionNet
+from neural_lam.models.base_hi_graph_model import BaseHiGraphModel
class HiLAM(BaseHiGraphModel):
- """Hierarchical graph model with message passing that goes sequentially down and up
- the hierarchy during processing.
-
+ """
+ Hierarchical graph model with message passing that goes sequentially down
+ and up the hierarchy during processing.
The Hi-LAM model from Oskarsson et al. (2023)
-
"""
def __init__(self, args):
@@ -34,7 +33,9 @@ def __init__(self, args):
) # Nested lists (proc_steps, num_levels)
def make_same_gnns(self, args):
- """Make intra-level GNNs."""
+ """
+ Make intra-level GNNs.
+ """
return nn.ModuleList(
[
InteractionNet(
@@ -47,7 +48,9 @@ def make_same_gnns(self, args):
)
def make_up_gnns(self, args):
- """Make GNNs for processing steps up through the hierarchy."""
+ """
+ Make GNNs for processing steps up through the hierarchy.
+ """
return nn.ModuleList(
[
InteractionNet(
@@ -60,7 +63,9 @@ def make_up_gnns(self, args):
)
def make_down_gnns(self, args):
- """Make GNNs for processing steps down through the hierarchy."""
+ """
+ Make GNNs for processing steps down through the hierarchy.
+ """
return nn.ModuleList(
[
InteractionNet(
@@ -80,8 +85,10 @@ def mesh_down_step(
down_gnns,
same_gnns,
):
- """Run down-part of vertical processing, sequentially alternating between
- processing using down edges and same-level edges."""
+ """
+ Run down-part of vertical processing, sequentially alternating between
+ processing using down edges and same-level edges.
+ """
# Run same level processing on level L
mesh_rep_levels[-1], mesh_same_rep[-1] = same_gnns[-1](
mesh_rep_levels[-1], mesh_rep_levels[-1], mesh_same_rep[-1]
@@ -117,8 +124,10 @@ def mesh_down_step(
def mesh_up_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, same_gnns
):
- """Run up-part of vertical processing, sequentially alternating between
- processing using up edges and same-level edges."""
+ """
+ Run up-part of vertical processing, sequentially alternating between
+ processing using up edges and same-level edges.
+ """
# Run same level processing on level 0
mesh_rep_levels[0], mesh_same_rep[0] = same_gnns[0](
@@ -154,8 +163,9 @@ def mesh_up_step(
def hi_processor_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
):
- """Internal processor step of hierarchical graph models. Between mesh init and
- read out.
+ """
+ Internal processor step of hierarchical graph models.
+ Between mesh init and read out.
Each input is list with representations, each with shape
@@ -165,7 +175,6 @@ def hi_processor_step(
mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
-
"""
for down_gnns, down_same_gnns, up_gnns, up_same_gnns in zip(
self.mesh_down_gnns,
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index 8655746c..80181ec0 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -8,11 +8,12 @@
class HiLAMParallel(BaseHiGraphModel):
- """Version of HiLAM where all message passing in the hierarchical mesh (up, down,
- inter-level) is ran in parallel.
-
- This is a somewhat simpler alternative to the sequential message passing of Hi-LAM.
+ """
+ Version of HiLAM where all message passing in the hierarchical mesh (up,
+ down, inter-level) is ran in parallel.
+ This is a somewhat simpler alternative to the sequential message passing
+ of Hi-LAM.
"""
def __init__(self, args):
@@ -52,8 +53,9 @@ def __init__(self, args):
def hi_processor_step(
self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
):
- """Internal processor step of hierarchical graph models. Between mesh init and
- read out.
+ """
+ Internal processor step of hierarchical graph models.
+ Between mesh init and read out.
Each input is list with representations, each with shape
@@ -63,7 +65,6 @@ def hi_processor_step(
mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
-
"""
# First join all node and edge representations to single tensors
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 7d2bd228..0b4c39a4 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -9,12 +9,12 @@
class BufferList(nn.Module):
- """A list of torch buffer tensors that sit together as a Module with no parameters
- and only buffers.
+ """
+ A list of torch buffer tensors that sit together as a Module with no
+ parameters and only buffers.
This should be replaced by a native torch BufferList once implemented.
See: https://github.com/pytorch/pytorch/issues/37386
-
"""
def __init__(self, buffer_tensors, persistent=True):
@@ -211,8 +211,10 @@ def make_mlp(blueprint, layer_norm=True):
def fractional_plot_bundle(fraction):
- """Get the tueplots bundle, but with figure width as a fraction of the page
- width."""
+ """
+ Get the tueplots bundle, but with figure width as a fraction of
+ the page width.
+ """
# If latex is not available, some visualizations might not render
# correctly, but will at least not raise an error. Alternatively, use
# unicode raised numbers.
@@ -228,7 +230,9 @@ def fractional_plot_bundle(fraction):
def init_wandb_metrics(wandb_logger, val_steps):
- """Set up wandb metrics to track."""
+ """
+ Set up wandb metrics to track
+ """
experiment = wandb_logger.experiment
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index d0b1a428..542b6ab7 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -73,7 +73,8 @@ def plot_prediction(
title=None,
vrange=None,
):
- """Plot example prediction and grond truth.
+ """
+ Plot example prediction and grond truth.
Each has shape (N_grid,)
@@ -133,7 +134,10 @@ def plot_prediction(
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
- """Plot errors over spatial map Error and obs_mask has shape (N_grid,)"""
+ """
+ Plot errors over spatial map
+ Error and obs_mask has shape (N_grid,)
+ """
# Get common scale for values
if vrange is None:
vmin = error.min().cpu().item()
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 9b567afc..449021be 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -44,10 +44,11 @@ def __init__(
# non-zero amount of samples
if self.__len__() <= 0:
raise ValueError(
- f"The provided datastore only provides {len(self.da_state.time)} "
- f"time steps for `{split}` split, which is less than the "
- f"required 2+ar_steps (2+{self.ar_steps}={2+self.ar_steps}) "
- "for creating a sample with initial and target states."
+ "The provided datastore only provides "
+ f"{len(self.da_state.time)} time steps for `{split}` split, "
+ f"which is less than the required 2+ar_steps "
+ f"(2+{self.ar_steps}={2+self.ar_steps}) for creating a sample "
+ "with initial and target states."
)
# Set up for standardization
@@ -81,34 +82,25 @@ def __len__(self):
f"({self.da_state.ensemble_member.size})",
UserWarning,
)
- # XXX: we should maybe check that the 2+ar_steps actually fits
- # in the elapsed_forecast_duration dimension, should that be checked here?
+ # XXX: we should maybe check that the 2+ar_steps actually fits in
+ # the elapsed_forecast_duration dimension, should that be checked
+ # here?
return self.da_state.analysis_time.size
else:
- # sample_len = 2 + ar_steps (2 initial states + ar_steps target states)
+ # sample_len = 2 + ar_steps
+ # (2 initial states + ar_steps target states)
# n_samples = len(self.da_state.time) - sample_len + 1
# = len(self.da_state.time) - 2 - ar_steps + 1
# = len(self.da_state.time) - ar_steps - 1
return len(self.da_state.time) - self.ar_steps - 1
def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
- """Produce a time
- slice of the given
- dataarray `da` (state
- or forcing) starting
- at `idx` and with
- `n_steps` steps. The
- `n_timesteps_offset`
- parameter is used to
- offset the start of
- the sample, for
- example to exclude the
- first two steps when
- sampling the forcing
- data (and to produce
- the windowing samples
- of forcing data by
- increasing the offset
+ """
+ Produce a time slice of the given dataarray `da` (state or forcing)
+ starting at `idx` and with `n_steps` steps. The `n_timesteps_offset`
+ parameter is used to offset the start of the sample, for example to
+ exclude the first two steps when sampling the forcing data (and to
+ produce the windowing samples of forcing data by increasing the offset
for each window).
Parameters
@@ -150,8 +142,9 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
return da
def __getitem__(self, idx):
- """Return a single training sample, which consists of the initial states, target
- states, forcing and batch times.
+ """
+ Return a single training sample, which consists of the initial states,
+ target states, forcing and batch times.
The implementation currently uses xarray.DataArray objects for the
normalisation so that we can make us of xarray's broadcasting
@@ -176,7 +169,8 @@ def __getitem__(self, idx):
"""
# handling ensemble data
if self.datastore.is_ensemble:
- # for the now the strategy is to simply select a random ensemble member
+ # for the now the strategy is to simply select a random ensemble
+ # member
# XXX: this could be changed to include all ensemble members by
# splitting `idx` into two parts, one for the analysis time and one
# for the ensemble member and then increasing self.__len__ to
@@ -277,10 +271,11 @@ def __getitem__(self, idx):
return init_states, target_states, forcing, batch_times
def __iter__(self):
- """Convenience method to iterate over the dataset.
+ """
+ Convenience method to iterate over the dataset.
- This isn't used by pytorch DataLoader which itself implements an iterator that
- uses Dataset.__getitem__ and Dataset.__len__.
+ This isn't used by pytorch DataLoader which itself implements an
+ iterator that uses Dataset.__getitem__ and Dataset.__len__.
"""
for i in range(len(self)):
From 3cd0f8b13aba933481d9159d709c46a2d983f49e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 13:13:36 +0200
Subject: [PATCH 181/273] pin numpy to <2.0.0
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index d66c0087..4770c19f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ authors = [
# PEP 621 project metadata
# See https://www.python.org/dev/peps/pep-0621/
dependencies = [
- "numpy>=1.24.2",
+ "numpy<2.0.0,>=1.24.2",
"wandb>=0.13.10",
"scipy>=1.10.0",
"pytorch-lightning>=2.0.3",
From 1f661c6e008d69b7c0daf0fff8a2135b62b0848d Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 20 Aug 2024 15:45:20 +0200
Subject: [PATCH 182/273] fix flake8 linting errors
---
neural_lam/datastore/base.py | 130 +++++++-----------
neural_lam/datastore/mllam.py | 53 +++----
.../multizarr/create_normalization_stats.py | 5 +-
neural_lam/datastore/multizarr/store.py | 53 ++++---
neural_lam/datastore/npyfiles/config.py | 16 ++-
neural_lam/datastore/npyfiles/store.py | 111 ++++++++-------
neural_lam/weather_dataset.py | 4 +-
tests/conftest.py | 5 +-
tests/test_datastores.py | 27 ++--
9 files changed, 202 insertions(+), 202 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index bbbb62dd..e046bc02 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -12,9 +12,11 @@
class BaseDatastore(abc.ABC):
- """Base class for weather data used in the neural- lam package. A datastore defines
- the interface for accessing weather data by providing methods to access the data in
- a processed format that can be used for training and evaluation of neural networks.
+ """
+ Base class for weather data used in the neural- lam package. A datastore
+ defines the interface for accessing weather data by providing methods to
+ access the data in a processed format that can be used for training and
+ evaluation of neural networks.
NOTE: All methods return either primitive types, `numpy.ndarray`,
`xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor`
@@ -25,8 +27,8 @@ class BaseDatastore(abc.ABC):
# Forecast vs analysis data
If the datastore is used represent forecast rather than analysis data, then
the `is_forecast` attribute should be set to True, and returned data from
- `get_dataarray` is assumed to have `analysis_time` and `forecast_time` dimensions
- (rather than just `time`).
+ `get_dataarray` is assumed to have `analysis_time` and `forecast_time`
+ dimensions (rather than just `time`).
# Ensemble vs deterministic data
If the datastore is used to represent ensemble data, then the `is_ensemble`
@@ -41,8 +43,9 @@ class BaseDatastore(abc.ABC):
@property
@abc.abstractmethod
def root_path(self) -> Path:
- """The root path to the datastore. It is relative to this that any derived files
- (for example the graph components) are stored.
+ """
+ The root path to the datastore. It is relative to this that any derived
+ files (for example the graph components) are stored.
Returns
-------
@@ -60,7 +63,8 @@ def config(self) -> collections.abc.Mapping:
Returns
-------
collections.abc.Mapping
- The configuration of the datastore, any dict like object can be returned.
+ The configuration of the datastore, any dict like object can be
+ returned.
"""
pass
@@ -129,33 +133,15 @@ def get_num_data_vars(self, category: str) -> int:
@abc.abstractmethod
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the
- normalization
- dataarray for the
- given category. This
- should contain a
- `{category}_mean` and
- `{category}_std`
- variable for each
- variable in the
- category. For
- `category=="state"`,
- the dataarray should
- also contain a
- `state_diff_mean` and
- `state_diff_std`
- variable for the one-
- step differences of
- the state variables.
- The returned dataarray
- should at least have
- dimensions of `({categ
- ory}_feature)`, but
- can also include for
- example `grid_index`
- (if the normalisation
- is done per grid point
- for example).
+ """
+ Return the normalization dataarray for the given category. This should
+ contain a `{category}_mean` and `{category}_std` variable for each
+ variable in the category. For `category=="state"`, the dataarray should
+ also contain a `state_diff_mean` and `state_diff_std` variable for the
+ one- step differences of the state variables. The returned dataarray
+ should at least have dimensions of `({categ ory}_feature)`, but can
+ also include for example `grid_index` (if the normalisation is done per
+ grid point for example).
Parameters
----------
@@ -176,23 +162,12 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
def get_dataarray(
self, category: str, split: str
) -> Union[xr.DataArray, None]:
- """Return the
- processed data (as a
- single `xr.DataArray`)
- for the given category
- of data and
- test/train/val-split
- that covers all the
- data (in space and
- time) of a given
- category (state/forcin
- g/static). A datastore
- must be able to return
- for the "state"
- category, but
- "forcing" and "static"
- are optional (in which
- case the method should
+ """
+ Return the processed data (as a single `xr.DataArray`) for the given
+ category of data and test/train/val-split that covers all the data (in
+ space and time) of a given category (state/forcin g/static). A
+ datastore must be able to return for the "state" category, but
+ "forcing" and "static" are optional (in which case the method should
return `None`).
The returned dataarray is expected to at minimum have dimensions of
@@ -227,21 +202,16 @@ def get_dataarray(
@property
@abc.abstractmethod
def boundary_mask(self) -> xr.DataArray:
- """Return the boundary
- mask for the dataset,
- with spatial
- dimensions stacked.
- Where the value is 1,
- the grid point is a
- boundary point, and
- where the value is 0,
- the grid point is not
- a boundary point.
+ """
+ Return the boundary mask for the dataset, with spatial dimensions
+ stacked. Where the value is 1, the grid point is a boundary point, and
+ where the value is 0, the grid point is not a boundary point.
Returns
-------
xr.DataArray
- The boundary mask for the dataset, with dimensions `('grid_index',)`.
+ The boundary mask for the dataset, with dimensions
+ `('grid_index',)`.
"""
pass
@@ -326,26 +296,20 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
Returns
-------
np.ndarray
- The x, y coordinates of the dataset, returned differently based on the
- value of `stacked`:
- - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ The x, y coordinates of the dataset, returned differently based on
+ the value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where
+ n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
"""
pass
def get_xy_extent(self, category: str) -> List[float]:
- """Return the extent
- of the x, y
- coordinates for a
- given category of
- data. The extent
- should be returned as
- a list of 4 floats
- with `[xmin, xmax,
- ymin, ymax]` which can
- then be used to set
- the extent of a plot.
+ """
+ Return the extent of the x, y coordinates for a given category of data.
+ The extent should be returned as a list of 4 floats with `[xmin, xmax,
+ ymin, ymax]` which can then be used to set the extent of a plot.
Parameters
----------
@@ -365,8 +329,10 @@ def get_xy_extent(self, category: str) -> List[float]:
def unstack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
) -> Union[xr.DataArray, xr.Dataset]:
- """Stack the spatial grid coordinates into separate `x` and `y` dimensions (the
- names can be set by the `CARTESIAN_COORDS` attribute) to create a 2D grid.
+ """
+ Stack the spatial grid coordinates into separate `x` and `y` dimensions
+ (the names can be set by the `CARTESIAN_COORDS` attribute) to create a
+ 2D grid.
Parameters
----------
@@ -386,8 +352,10 @@ def unstack_grid_coords(
def stack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
) -> Union[xr.DataArray, xr.Dataset]:
- """Stack the spatial grid coordinated (by default `x` and `y`, but this can be
- set by the `CARTESIAN_COORDS` attribute) into a single `grid_index` dimension.
+ """
+ Stack the spatial grid coordinated (by default `x` and `y`, but this
+ can be set by the `CARTESIAN_COORDS` attribute) into a single
+ `grid_index` dimension.
Parameters
----------
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 36abe3de..fcb06030 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -19,11 +19,13 @@ class MLLAMDatastore(BaseCartesianDatastore):
"""Datastore class for the MLLAM dataset."""
def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
- """Construct a new MLLAMDatastore from the configuration file at `config_path`.
- A boundary mask is created with `n_boundary_points` boundary points. If
- `reuse_existing` is True, the dataset is loaded from a zarr file if it exists
- (unless the config has been modified since the zarr was created), otherwise it
- is created from the configuration file.
+ """
+ Construct a new MLLAMDatastore from the configuration file at
+ `config_path`. A boundary mask is created with `n_boundary_points`
+ boundary points. If `reuse_existing` is True, the dataset is loaded
+ from a zarr file if it exists (unless the config has been modified
+ since the zarr was created), otherwise it is created from the
+ configuration file.
Parameters
----------
@@ -160,11 +162,12 @@ def get_num_data_vars(self, category: str) -> int:
return len(self.get_vars_names(category))
def get_dataarray(self, category: str, split: str) -> xr.DataArray:
- """Return the processed data (as a single `xr.DataArray`) for the given category
- of data and test/train/val-split that covers all the data (in space and time) of
- a given category (state/forcin g/static). "state" is the only required category,
- for other categories, the method will return `None` if the category is not found
- in the datastore.
+ """
+ Return the processed data (as a single `xr.DataArray`) for the given
+ category of data and test/train/val-split that covers all the data (in
+ space and time) of a given category (state/forcin g/static). "state" is
+ the only required category, for other categories, the method will
+ return `None` if the category is not found in the datastore.
The returned dataarray will at minimum have dimensions of `(grid_index,
{category}_feature)` so that any spatial dimensions have been stacked
@@ -217,11 +220,12 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
return da_category.sel(time=slice(t_start, t_end))
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This should
- contain a `{category}_mean` and `{category}_std` variable for each variable in
- the category. For `category=="state"`, the dataarray should also contain a
- `state_diff_mean` and `state_diff_std` variable for the one- step differences of
- the state variables.
+ """
+ Return the normalization dataarray for the given category. This should
+ contain a `{category}_mean` and `{category}_std` variable for each
+ variable in the category. For `category=="state"`, the dataarray should
+ also contain a `state_diff_mean` and `state_diff_std` variable for the
+ one- step differences of the state variables.
Parameters
----------
@@ -251,11 +255,13 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
@property
def boundary_mask(self) -> xr.DataArray:
- """Produce a 0/1 mask for the boundary points of the dataset, these will sit at
- the edges of the domain (in x/y extent) and will be used to mask out the
- boundary points from the loss function and to overwrite the boundary points from
- the prediction. For now this is created when the mask is requested, but in the
- future this could be saved to the zarr file.
+ """
+ Produce a 0/1 mask for the boundary points of the dataset, these will
+ sit at the edges of the domain (in x/y extent) and will be used to mask
+ out the boundary points from the loss function and to overwrite the
+ boundary points from the prediction. For now this is created when the
+ mask is requested, but in the future this could be saved to the zarr
+ file.
Returns
-------
@@ -321,9 +327,10 @@ def get_xy(self, category: str, stacked: bool) -> ndarray:
Returns
-------
np.ndarray
- The x, y coordinates of the dataset, returned differently based on the
- value of `stacked`:
- - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ The x, y coordinates of the dataset, returned differently based on
+ the value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where
+ n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
"""
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index 11da134b..83dc2581 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -21,8 +21,9 @@ def create_normalization_stats_zarr(
data_config_path: str,
zarr_path: str = None,
):
- """Compute mean and std.-dev. for state and forcing variables and save them to a
- Zarr file.
+ """
+ Compute mean and std.-dev. for state and forcing variables and save them to
+ a Zarr file.
Parameters
----------
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
index 3fc714db..baf3be9e 100644
--- a/neural_lam/datastore/multizarr/store.py
+++ b/neural_lam/datastore/multizarr/store.py
@@ -18,10 +18,11 @@ class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable_name"}
def __init__(self, config_path):
- """Create a multi-zarr datastore from the given configuration file. The
- configuration file should be a YAML file, the format of which is should be
- inferred from the example configuration file in `tests/datastore_examp
- les/multizarr/data_con fig.yml`.
+ """
+ Create a multi-zarr datastore from the given configuration file. The
+ configuration file should be a YAML file, the format of which is should
+ be inferred from the example configuration file in
+ `tests/datastore_examp les/multizarr/data_config.yml`.
Parameters
----------
@@ -61,8 +62,9 @@ def config(self) -> dict:
def _normalize_path(self, path) -> str:
"""
Normalize the path of source-dataset defined in the configuration file.
- This assumes that any paths that do not start with a protocol (e.g. `s3://`)
- or are not absolute paths, are relative to the configuration file.
+ This assumes that any paths that do not start with a protocol (e.g.
+ `s3://`) or are not absolute paths, are relative to the configuration
+ file.
Parameters
----------
@@ -83,7 +85,8 @@ def _normalize_path(self, path) -> str:
return path
def open_zarrs(self, category):
- """Open the zarr dataset for the given category.
+ """
+ Open the zarr dataset for the given category.
Parameters
----------
@@ -113,7 +116,8 @@ def open_zarrs(self, category):
@functools.cached_property
def coords_projection(self):
- """Return the projection object for the coordinates.
+ """
+ Return the projection object for the coordinates.
The projection object is used to plot the coordinates on a map.
@@ -248,7 +252,8 @@ def _filter_dimensions(self, dataset, transpose_array=True):
Returns:
xr.Dataset: The xarray Dataset object with filtered dimensions.
- OR xr.DataArray: The xarray DataArray object with filtered dimensions.
+ OR xr.DataArray: The xarray DataArray object with filtered
+ dimensions.
"""
dims_to_keep = self.DIMS_TO_KEEP
@@ -362,9 +367,10 @@ def get_xy(self, category, stacked=True):
Returns
-------
np.ndarray
- The x, y coordinates of the dataset, returned differently based on the
- value of `stacked`:
- - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ The x, y coordinates of the dataset, returned differently based on
+ the value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where
+ n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
"""
@@ -391,13 +397,15 @@ def get_xy(self, category, stacked=True):
@functools.lru_cache()
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This should
- contain a `{category}_mean` and `{category}_std` variable for each variable in
- the category. For `category=="state"`, the dataarray should also contain a
- `state_diff_mean` and `state_diff_std` variable for the one- step differences of
- the state variables. The return dataarray should at least have dimensions of
- `({categ ory}_feature)`, but can also include for example `grid_index` (if the
- normalisation is done per grid point for example).
+ """
+ Return the normalization dataarray for the given category. This should
+ contain a `{category}_mean` and `{category}_std` variable for each
+ variable in the category. For `category=="state"`, the dataarray should
+ also contain a `state_diff_mean` and `state_diff_std` variable for the
+ one- step differences of the state variables. The return dataarray
+ should at least have dimensions of `({category}_feature)`, but can
+ also include for example `grid_index` (if the normalisation is done per
+ grid point for example).
Parameters
----------
@@ -676,12 +684,15 @@ def grid_shape_state(self):
@property
def boundary_mask(self) -> xr.DataArray:
- """Load the boundary mask for the dataset, with spatial dimensions stacked.
+ """
+ Load the boundary mask for the dataset, with spatial dimensions
+ stacked.
Returns
-------
xr.DataArray
- The boundary mask for the dataset, with dimensions `('grid_index',)`.
+ The boundary mask for the dataset, with dimensions
+ `('grid_index',)`.
"""
boundary_mask_path = self._normalize_path(
diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfiles/config.py
index 5cdb22ea..b483ac67 100644
--- a/neural_lam/datastore/npyfiles/config.py
+++ b/neural_lam/datastore/npyfiles/config.py
@@ -8,13 +8,15 @@
@dataclass
class Projection:
- """Represents the projection information for a dataset, including the type of
- projection and its parameters. Capable of creating a cartopy.crs projection object.
+ """Represents the projection information for a dataset, including the type
+ of projection and its parameters. Capable of creating a cartopy.crs
+ projection object.
Attributes:
class_name: The class name of the projection, this should be a valid
cartopy.crs class.
- kwargs: A dictionary of keyword arguments specific to the projection type.
+ kwargs: A dictionary of keyword arguments specific to the projection
+ type.
"""
@@ -24,8 +26,8 @@ class Projection:
@dataclass
class Dataset:
- """Contains information about the dataset, including variable names, units, and
- descriptions.
+ """Contains information about the dataset, including variable names, units,
+ and descriptions.
Attributes:
name: The name of the dataset.
@@ -45,8 +47,8 @@ class Dataset:
@dataclass
class NpyDatastoreConfig(dataclass_wizard.YAMLWizard):
- """Configuration for loading and processing a dataset, including dataset details,
- grid shape, and projection information.
+ """Configuration for loading and processing a dataset, including dataset
+ details, grid shape, and projection information.
Attributes:
dataset: An instance of Dataset containing details about the dataset.
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 923983c2..71a1cb1f 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -1,5 +1,7 @@
-"""Numpy-files based datastore to support the MEPS example dataset introduced in neural-
-lam v0.1.0."""
+"""
+Numpy-files based datastore to support the MEPS example dataset introduced in
+neural-lam v0.1.0.
+"""
# Standard library
import functools
import re
@@ -39,7 +41,8 @@ class NpyFilesDatastore(BaseCartesianDatastore):
__doc__ = f"""
Represents a dataset stored as numpy files on disk. The dataset is assumed
to be stored in a directory structure where each sample is stored in a
- separate file. The file-name format is assumed to be '{STATE_FILENAME_FORMAT}'
+ separate file. The file-name format is assumed to be
+ '{STATE_FILENAME_FORMAT}'
The MEPS dataset is organised into three splits: train, val, and test. Each
split has a set of files which are:
@@ -138,9 +141,10 @@ def __init__(
self,
config_path,
):
- """Create a new NpyFilesDatastore using the configuration file at the given
- path. The config file should be a YAML file and will be loaded into an instance
- of the `NpyDatastoreConfig` dataclass.
+ """
+ Create a new NpyFilesDatastore using the configuration file at the
+ given path. The config file should be a YAML file and will be loaded
+ into an instance of the `NpyDatastoreConfig` dataclass.
Internally, the datastore uses dask.delayed to load the data from the
numpy files, so that the data isn't actually loaded until it's needed.
@@ -151,7 +155,8 @@ def __init__(
The path to the configuration file for the datastore.
"""
- # XXX: This should really be in the config file, not hard-coded in this class
+ # XXX: This should really be in the config file, not hard-coded in this
+ # class
self._num_timesteps = 65
self._step_length = 3 # 3 hours
self._num_ensemble_members = 2
@@ -162,8 +167,9 @@ def __init__(
@property
def root_path(self) -> Path:
- """The root path of the datastore on disk. This is the directory relative to
- which graphs and other files can be stored.
+ """
+ The root path of the datastore on disk. This is the directory relative
+ to which graphs and other files can be stored.
Returns
-------
@@ -186,26 +192,30 @@ def config(self) -> NpyDatastoreConfig:
return self._config
def get_dataarray(self, category: str, split: str) -> DataArray:
- """Get the data array for the given category and split of data. If the category
- is 'state', the data array will be a concatenation of the data arrays for all
- ensemble members. The data will be loaded as a dask array, so that the data
- isn't actually loaded until it's needed.
+ """
+ Get the data array for the given category and split of data. If the
+ category is 'state', the data array will be a concatenation of the data
+ arrays for all ensemble members. The data will be loaded as a dask
+ array, so that the data isn't actually loaded until it's needed.
Parameters
----------
category : str
- The category of the data to load. One of 'state', 'forcing', or 'static'.
+ The category of the data to load. One of 'state', 'forcing', or
+ 'static'.
split : str
- The dataset split to load the data for. One of 'train', 'val', or 'test'.
+ The dataset split to load the data for. One of 'train', 'val', or
+ 'test'.
Returns
-------
xr.DataArray
The data array for the given category and split, with dimensions
per category:
- state: `[elapsed_forecast_duration, analysis_time, grid_index, feature,
- ensemble_member]`
- forcing: `[elapsed_forecast_duration, analysis_time, grid_index, feature]`
+ state: `[elapsed_forecast_duration, analysis_time, grid_index,
+ feature, ensemble_member]`
+ forcing: `[elapsed_forecast_duration, analysis_time, grid_index,
+ feature]`
static: `[grid_index, feature]`
"""
@@ -235,11 +245,12 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# add datetime forcing as a feature
# to do this we create a forecast time variable which has the
- # dimensions of (analysis_time, elapsed_forecast_duration) with values
- # that are the actual forecast time of each time step. By calling
- # .chunk({"elapsed_forecast_duration": 1}) this time variable is turned
- # into a dask array and so execution of the calculation is delayed
- # until the feature values are actually used.
+ # dimensions of (analysis_time, elapsed_forecast_duration) with
+ # values that are the actual forecast time of each time step. By
+ # calling .chunk({"elapsed_forecast_duration": 1}) this time
+ # variable is turned into a dask array and so execution of the
+ # calculation is delayed until the feature values are actually
+ # used.
da_forecast_time = (
da.analysis_time + da.elapsed_forecast_duration
).chunk({"elapsed_forecast_duration": 1})
@@ -285,10 +296,12 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
def _get_single_timeseries_dataarray(
self, features: List[str], split: str, member: int = None
) -> DataArray:
- """Get the data array spanning the complete time series for a given set of
- features and split of data. For state features the `member` argument should be
- specified to select the ensemble member to load. The data will be loaded using
- dask.delayed, so that the data isn't actually loaded until it's needed.
+ """
+ Get the data array spanning the complete time series for a given set of
+ features and split of data. For state features the `member` argument
+ should be specified to select the ensemble member to load. The data
+ will be loaded using dask.delayed, so that the data isn't actually
+ loaded until it's needed.
Parameters
----------
@@ -300,16 +313,18 @@ def _get_single_timeseries_dataarray(
'static' category this should be the list of static features to
load.
split : str
- The dataset split to load the data for. One of 'train', 'val', or 'test'.
+ The dataset split to load the data for. One of 'train', 'val', or
+ 'test'.
member : int, optional
- The ensemble member to load. Only applicable for the 'state' category.
+ The ensemble member to load. Only applicable for the 'state'
+ category.
Returns
-------
xr.DataArray
The data array for the given category and split, with dimensions
- `[elapsed_forecast_duration, analysis_time, grid_index, feature]` for
- all categories of data
+ `[elapsed_forecast_duration, analysis_time, grid_index, feature]`
+ for all categories of data
"""
assert split in ("train", "val", "test"), "Unknown dataset split"
@@ -355,7 +370,8 @@ def _get_single_timeseries_dataarray(
file_dims = ["y", "x", "feature"]
add_feature_dim = True
features_vary_with_analysis_time = False
- # XXX: border_mask is the same for all splits, and so saved in static/
+ # XXX: border_mask is the same for all splits, and so saved in
+ # static/
fp_samples = self.root_path / "static"
elif features == ["x", "y"]:
filename_format = "nwp_xy.npy"
@@ -426,12 +442,6 @@ def _get_single_timeseries_dataarray(
else:
arr_all = arrays[0]
- # if features == ["column_water"]:
- # # for column water, we need to repeat the array for each forecast time
- # # first insert a new axis for the forecast time
- # arr_all = np.expand_dims(arr_all, 1)
- # # and then repeat
- # arr_all = dask.array.repeat(arr_all, self._num_timesteps, axis=1)
da = xr.DataArray(arr_all, dims=dims, coords=coords)
# stack the [x, y] dimensions into a `grid_index` dimension
@@ -440,8 +450,8 @@ def _get_single_timeseries_dataarray(
return da
def _get_analysis_times(self, split) -> List[np.datetime64]:
- """Get the analysis times for the given split by parsing the filenames of all
- the files found for the given split.
+ """Get the analysis times for the given split by parsing the filenames
+ of all the files found for the given split.
Parameters
----------
@@ -546,9 +556,10 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
Returns
-------
np.ndarray
- The x, y coordinates of the dataset, returned differently based on the
- value of `stacked`:
- - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
+ The x, y coordinates of the dataset, returned differently based on
+ the value of `stacked`:
+ - `stacked==True`: shape `(2, n_grid_points)` where
+ n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
"""
@@ -593,8 +604,8 @@ def grid_shape_state(self) -> CartesianGridShape:
@property
def boundary_mask(self) -> xr.DataArray:
- """The boundary mask for the dataset. This is a binary mask that is 1 where the
- grid cell is on the boundary of the domain, and 0 otherwise.
+ """The boundary mask for the dataset. This is a binary mask that is 1
+ where the grid cell is on the boundary of the domain, and 0 otherwise.
Returns
-------
@@ -615,11 +626,11 @@ def boundary_mask(self) -> xr.DataArray:
return da_mask_stacked_xy
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This should
- contain a `{category}_mean` and `{category}_std` variable for each variable in
- the category. For `category=="state"`, the dataarray should also contain a
- `state_diff_mean` and `state_diff_std` variable for the one- step differences of
- the state variables.
+ """Return the normalization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one- step differences of the state variables.
Parameters
----------
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 449021be..2eff2489 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -47,8 +47,8 @@ def __init__(
"The provided datastore only provides "
f"{len(self.da_state.time)} time steps for `{split}` split, "
f"which is less than the required 2+ar_steps "
- f"(2+{self.ar_steps}={2+self.ar_steps}) for creating a sample "
- "with initial and target states."
+ f"(2+{self.ar_steps}={2 + self.ar_steps}) for creating a "
+ "sample with initial and target states."
)
# Set up for standardization
diff --git a/tests/conftest.py b/tests/conftest.py
index f0d1c2f5..fdbcb627 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -79,9 +79,10 @@ def bootstrap_multizarr_example():
multizarr_path = DATASTORE_EXAMPLES_ROOT_PATH / "multizarr"
n_boundary_cells = 10
+ base_url = "https://mllam-test-data.s3.eu-north-1.amazonaws.com/"
data_urls = [
- "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr",
- "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr",
+ base_url + "single_levels.zarr",
+ base_url + "height_levels.zarr",
]
for url in data_urls:
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 0955b8cc..f87a26e9 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -6,8 +6,10 @@
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
-- [x] `get_vars_units` (method): Get the units of the variables in the given category.
-- [x] `get_vars_names` (method): Get the names of the variables in the given category.
+- [x] `get_vars_units` (method): Get the units of the variables in the given
+ category.
+- [x] `get_vars_names` (method): Get the names of the variables in the given
+ category.
- [x] `get_num_data_vars` (method): Get the number of data variables in the
given category.
- [x] `get_normalization_dataarray` (method): Return the normalization
@@ -18,7 +20,8 @@
with spatial dimensions stacked.
- [x] `config` (property): Return the configuration of the datastore.
-In addition BaseCartesianDatastore must have the following methods and attributes:
+In addition BaseCartesianDatastore must have the following methods and
+attributes:
- [x] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
given category of data.
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
@@ -72,9 +75,9 @@ def test_step_length(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_datastore_grid_xy(datastore_name):
- """Use the `datastore.get_xy` method to get the x, y coordinates of the dataset and
- check that the shape is correct against the `da tastore.grid_shape_state`
- property."""
+ """Use the `datastore.get_xy` method to get the x, y coordinates of the
+ dataset and check that the shape is correct against the `da
+ tastore.grid_shape_state` property."""
datastore = init_datastore(datastore_name)
# check the shapes of the xy grid
@@ -82,10 +85,6 @@ def test_datastore_grid_xy(datastore_name):
nx, ny = grid_shape.x, grid_shape.y
for stacked in [True, False]:
xy = datastore.get_xy("static", stacked=stacked)
- """
- - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y.
- - `stacked==False`: shape `(2, N_y, N_x)`
- """
if stacked:
assert xy.shape == (2, nx * ny)
else:
@@ -193,8 +192,8 @@ def test_get_dataarray(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_boundary_mask(datastore_name):
- """Check that the `datastore.boundary_mask` property is implemented and that the
- returned object is an xarray DataArray with the correct shape."""
+ """Check that the `datastore.boundary_mask` property is implemented and
+ that the returned object is an xarray DataArray with the correct shape."""
datastore = init_datastore(datastore_name)
da_mask = datastore.boundary_mask
@@ -212,8 +211,8 @@ def test_boundary_mask(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_xy_extent(datastore_name):
- """Check that the `datastore.get_xy_extent` method is implemented and that the
- returned object is a tuple of the correct length."""
+ """Check that the `datastore.get_xy_extent` method is implemented and that
+ the returned object is a tuple of the correct length."""
datastore = init_datastore(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
From 4838872a1bc10778396951f01df4f1ff622dc222 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 18:04:25 +0100
Subject: [PATCH 183/273] Update neural_lam/weather_dataset.py
Clarify that for forecast datasets the training sample always starts with the first timestep, rather than creating samples from all lead times (as it was before)
Co-authored-by: Joel Oskarsson
---
neural_lam/weather_dataset.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 2eff2489..1739507d 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -119,9 +119,10 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
# selecting the time slice
if self.datastore.is_forecast:
# this implies that the data will have both `analysis_time` and
- # `elapsed_forecast_duration` dimensions for forecasts we for now
- # simply select a analysis time and then the next ar_steps forecast
- # times
+ # `elapsed_forecast_duration` dimensions for forecasts. We for now
+ # simply select a analysis time and the first `n_steps` forecast
+ # times (given no offset). Note that this means that we get one sample
+ # per forecast, always starting at forecast time 2.
da = da.isel(
analysis_time=idx,
elapsed_forecast_duration=slice(
From b59e7e5478f84d2dc59d1874f9eb6e87c8e6051a Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 18:07:54 +0100
Subject: [PATCH 184/273] Update
neural_lam/datastore/multizarr/create_normalization_stats.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/multizarr/create_normalization_stats.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
index 83dc2581..e4bbd353 100644
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ b/neural_lam/datastore/multizarr/create_normalization_stats.py
@@ -102,7 +102,7 @@ def create_normalization_stats_zarr(
def main():
parser = argparse.ArgumentParser(
- description="Training arguments",
+ description="Create standardization statistics",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
From 75b1fe7ad977c4f04ff203b435e1b45a062d53bd Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 19:51:00 +0100
Subject: [PATCH 185/273] Update neural_lam/datastore/npyfiles/store.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/npyfiles/store.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 71a1cb1f..2edab0cd 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -510,7 +510,7 @@ def get_vars_units(self, category: str) -> torch.List[str]:
elif category == "forcing":
return [
"W/m^2",
- "kg/m^2",
+ "1",
"1",
"1",
"1",
From 7e736cb2c6407e470a49e5962d2b4e59c7c14a2b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 19:51:14 +0100
Subject: [PATCH 186/273] Update neural_lam/datastore/npyfiles/store.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/npyfiles/store.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 2edab0cd..a36096ad 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -27,7 +27,7 @@
TOA_SW_DOWN_FLUX_FILENAME_FORMAT = (
"nwp_toa_downwelling_shortwave_flux_{analysis_time:%Y%m%d%H}.npy"
)
-COLUMN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"
+OPEN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"
def _load_np(fp, add_feature_dim):
From 613a7e29a0806e5c20484432db926412a89707df Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 19:52:22 +0100
Subject: [PATCH 187/273] Update neural_lam/datastore/npyfiles/store.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/npyfiles/store.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index a36096ad..6f06746e 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -661,8 +661,8 @@ def load_pickled_tensor(fn):
flux_mean, flux_std = flux_stats
# manually add hour sin/cos and day-of-year sin/cos stats for now
# the mean/std for column_water is hardcoded for now
- mean_values = np.array([flux_mean, 0.34033957, 0.0, 0.0, 0.0, 0.0])
- std_values = np.array([flux_std, 0.4661307, 1.0, 1.0, 1.0, 1.0])
+ mean_values = np.array([flux_mean, 0.0, 0.0, 0.0, 0.0, 0.0])
+ std_values = np.array([flux_std, 1.0, 1.0, 1.0, 1.0, 1.0])
elif category == "static":
ds_static = self.get_dataarray(category="static", split="train")
From 65e199bc8a8f83c067ac455841b87fba080a866c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 19:52:38 +0100
Subject: [PATCH 188/273] Update tests/test_training.py
Co-authored-by: Joel Oskarsson
---
tests/test_training.py | 17 -----------------
1 file changed, 17 deletions(-)
diff --git a/tests/test_training.py b/tests/test_training.py
index 33dd8203..09dab0fa 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -83,20 +83,3 @@ class ModelArgs:
)
wandb.init()
trainer.fit(model=model, datamodule=data_module)
-
-
-# def test_train_model_reduced_meps_dataset():
-# args = [
-# "--model=hi_lam",
-# "--data_config=data/meps_example_reduced/data_config.yaml",
-# "--n_workers=4",
-# "--epochs=1",
-# "--graph=hierarchical",
-# "--hidden_dim=16",
-# "--hidden_layers=1",
-# "--processor_layers=1",
-# "--ar_steps=1",
-# "--eval=val",
-# "--n_example_pred=0",
-# ]
-# train_model(args)
From 4435e26249ba6244e536bd7565307cb921ae2f45 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 19:59:19 +0100
Subject: [PATCH 189/273] Update tests/test_datasets.py
Co-authored-by: Joel Oskarsson
---
tests/test_datasets.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 06deeaa4..b9c74ea6 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -15,7 +15,7 @@
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_dataset_item(datastore_name):
- """Check that the `datasto re.get_dataarray` method is implemented.
+ """Check that the `datastore.get_dataarray` method is implemented.
Validate the shapes of the tensors match between the different
components of the training sample.
From 469340826cf74abc8f86378aedb7eab6ed492e99 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Sun, 8 Sep 2024 20:14:19 +0100
Subject: [PATCH 190/273] Update README.md
Co-authored-by: Joel Oskarsson
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index c7b1266a..ec94e2b3 100644
--- a/README.md
+++ b/README.md
@@ -77,8 +77,8 @@ the input-data representation is split into two parts:
There are currently three different datastores implemented in the codebase:
-1. `neural_lam.datastore.NpyDataStore` which reads data from `.npy`-files in
- the format introduced in neural-lam `v0.1.0`.
+1. `neural_lam.datastore.NpyDataStore` which reads MEPS data from `.npy`-files in
+ the format introduced in neural-lam `v0.1.0`. Note that this datastore is specific to the format of the MEPS dataset, but can act as an example for how to create similar numpy-based datastores.
2. `neural_lam.datastore.MultizarrDatastore` which can combines multiple zarr
files during train/val/test sampling, with the transformations to facilitate
From 2dfed2c2527a7b48a524c8e308a272f016018518 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 16:22:09 +0200
Subject: [PATCH 191/273] update README
---
README.md | 20 ++++++++++++++++----
1 file changed, 16 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index c7b1266a..baea9a68 100644
--- a/README.md
+++ b/README.md
@@ -142,16 +142,28 @@ It should thus be useful to make sure that your python environment is set up cor
## Pre-processing
+There are two main steps in the pre-processing pipeline: creating the graph and creating additional features/normalisation/boundary-masks.
+
+The amount of pre-processing required will depend on what kind of datastore you will be using for training.
+
+### Additional inputs
+
+#### MultiZarr Datastore
+
+* `python -m neural_lam.create_boundary_mask`
+* `python -m neural_lam.create_datetime_forcings`
+* `python -m neural_lam.create_norm`
+
+#### NpyFiles Datastore
+
+#### MLLAM Datastore
+
An overview of how the different pre-processing steps, training and files depend on each other is given in this figure:
In order to start training models at least three pre-processing steps have to be run:
-* `python -m neural_lam.create_mesh`
-* `python -m neural_lam.create_grid_features`
-* `python -m neural_lam.create_parameter_weights`
-
### Create graph
Run `python -m neural_lam.create_mesh` with suitable options to generate the graph you want to use (see `python neural_lam.create_mesh --help` for a list of options).
The graphs used for the different models in the [paper](https://arxiv.org/abs/2309.17370) can be created as:
From 66c663f011d13058c637fec256a9063a927d564e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 16:41:58 +0200
Subject: [PATCH 192/273] column_water -> open_water_fraction
---
neural_lam/datastore/npyfiles/store.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 6f06746e..cbc24cf5 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -56,8 +56,8 @@ class NpyFilesDatastore(BaseCartesianDatastore):
The top-of-atmosphere downwelling shortwave flux at `time`. The
dimensions of the array are `[forecast_timestep, y, x]`.
- - `{COLUMN_WATER_FILENAME_FORMAT}`:
- The column water at `time`. The dimensions of the array are
+ - `{OPEN_WATER_FILENAME_FORMAT}`:
+ The open water fraction at `time`. The dimensions of the array are
`[y, x]`.
@@ -234,7 +234,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
elif category == "forcing":
# the forcing features are in separate files, so we need to load
# them separately
- features = ["toa_downwelling_shortwave_flux", "column_water"]
+ features = ["toa_downwelling_shortwave_flux", "open_water_fraction"]
das = [
self._get_single_timeseries_dataarray(
features=[feature], split=split
@@ -353,8 +353,8 @@ def _get_single_timeseries_dataarray(
filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
add_feature_dim = True
- elif features == ["column_water"]:
- filename_format = COLUMN_WATER_FILENAME_FORMAT
+ elif features == ["open_water_fraction"]:
+ filename_format = OPEN_WATER_FILENAME_FORMAT
file_dims = ["y", "x", "feature"]
add_feature_dim = True
elif features == ["surface_geopotential"]:
@@ -529,7 +529,7 @@ def get_vars_names(self, category: str) -> torch.List[str]:
# the config
return [
"toa_downwelling_shortwave_flux",
- "column_water",
+ "open_water_fraction",
"sin_hour",
"cos_hour",
"sin_year",
@@ -660,7 +660,7 @@ def load_pickled_tensor(fn):
flux_stats = load_pickled_tensor("flux_stats.pt") # (2,)
flux_mean, flux_std = flux_stats
# manually add hour sin/cos and day-of-year sin/cos stats for now
- # the mean/std for column_water is hardcoded for now
+ # the mean/std for open_water_fraction is hardcoded for now
mean_values = np.array([flux_mean, 0.0, 0.0, 0.0, 0.0, 0.0])
std_values = np.array([flux_std, 1.0, 1.0, 1.0, 1.0, 1.0])
From 11a79781c86cbbb8a75e1c0a9c739c762bcfde60 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 16:43:14 +0200
Subject: [PATCH 193/273] fix linting
---
neural_lam/weather_dataset.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 1739507d..cd7b0e59 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -121,8 +121,8 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
# this implies that the data will have both `analysis_time` and
# `elapsed_forecast_duration` dimensions for forecasts. We for now
# simply select a analysis time and the first `n_steps` forecast
- # times (given no offset). Note that this means that we get one sample
- # per forecast, always starting at forecast time 2.
+ # times (given no offset). Note that this means that we get one
+ # sample per forecast, always starting at forecast time 2.
da = da.isel(
analysis_time=idx,
elapsed_forecast_duration=slice(
From a41c314f548d269a09ba5759beb820ff2c3bc212 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 16:59:33 +0200
Subject: [PATCH 194/273] static data same for all splits
---
neural_lam/datastore/base.py | 3 ++-
tests/test_datastores.py | 11 ++++++++++-
2 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index e046bc02..423e147e 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -168,7 +168,8 @@ def get_dataarray(
space and time) of a given category (state/forcin g/static). A
datastore must be able to return for the "state" category, but
"forcing" and "static" are optional (in which case the method should
- return `None`).
+ return `None`). For the "static" category the `split` is allowed to be
+ `None` because the static data is the same for all splits.
The returned dataarray is expected to at minimum have dimensions of
`(grid_index, {category}_feature)` so that any spatial dimensions have
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index f87a26e9..a2f35427 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -157,7 +157,16 @@ def test_get_dataarray(datastore_name):
for category in ["state", "forcing", "static"]:
n_features = {}
- for split in ["train", "val", "test"]:
+ if category in ["state", "forcing"]:
+ splits = ["train", "val", "test"]
+ elif category == "static":
+ # static data should be the same for all splits, so split
+ # should be allowed to be None
+ splits = ["train", "val", "test", None]
+ else:
+ raise NotImplementedError(category)
+
+ for split in splits:
expected_dims = ["grid_index", f"{category}_feature"]
if category != "static":
if not datastore.is_forecast:
From 6f1efd657e76fa1290b33d671c2910cf42602e46 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 17:07:06 +0200
Subject: [PATCH 195/273] forcing_window_size from args
---
neural_lam/models/ar_model.py | 5 ++++-
neural_lam/models/base_graph_model.py | 6 ++----
neural_lam/models/graph_lam.py | 4 ++--
3 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index a0a7880c..203b20c5 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -23,7 +23,9 @@ class ARModel(pl.LightningModule):
# Disable to override args/kwargs from superclass
def __init__(
- self, args, datastore: BaseDatastore, forcing_window_size: int
+ self,
+ args,
+ datastore: BaseDatastore,
):
super().__init__()
self.save_hyperparameters(ignore=["datastore"])
@@ -38,6 +40,7 @@ def __init__(
)
da_state_stats = datastore.get_normalization_dataarray(category="state")
da_boundary_mask = datastore.boundary_mask
+ forcing_window_size = args.forcing_window_size
# Load static features for grid/data, NB: self.predict_step assumes
# dimension order to be (grid_index, static_feature)
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 16897e4f..b9dce90f 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -13,10 +13,8 @@ class BaseGraphModel(ARModel):
the encode-process-decode idea.
"""
- def __init__(self, args, datastore, forcing_window_size):
- super().__init__(
- args, datastore=datastore, forcing_window_size=forcing_window_size
- )
+ def __init__(self, args, datastore):
+ super().__init__(args, datastore=datastore)
# Load graph with static features
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index a4c726b1..9288e539 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -15,8 +15,8 @@ class GraphLAM(BaseGraphModel):
Oskarsson et al. (2023).
"""
- def __init__(self, args, datastore, forcing_window_size):
- super().__init__(args, datastore, forcing_window_size)
+ def __init__(self, args, datastore):
+ super().__init__(args, datastore)
assert (
not self.hierarchical
From bacb9ec3b05ea0847ffc63f964ef02eb650e5c99 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 17:07:46 +0200
Subject: [PATCH 196/273] Update neural_lam/datastore/base.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 423e147e..6859a82d 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -25,7 +25,7 @@ class BaseDatastore(abc.ABC):
`torch.utils.data.Dataset` and uses the datastore to access the data).
# Forecast vs analysis data
- If the datastore is used represent forecast rather than analysis data, then
+ If the datastore is used to represent forecast rather than analysis data, then
the `is_forecast` attribute should be set to True, and returned data from
`get_dataarray` is assumed to have `analysis_time` and `forecast_time`
dimensions (rather than just `time`).
From 4a9db4eb5d64a958ccf6b804e143b16c5d4ea347 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 17:14:58 +0200
Subject: [PATCH 197/273] only use first ensemble member in datastores
---
neural_lam/weather_dataset.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index cd7b0e59..b9ac8f09 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -2,7 +2,6 @@
import warnings
# Third-party
-import numpy as np
import pytorch_lightning as pl
import torch
import xarray as xr
@@ -170,13 +169,17 @@ def __getitem__(self, idx):
"""
# handling ensemble data
if self.datastore.is_ensemble:
- # for the now the strategy is to simply select a random ensemble
+ # for the now the strategy is to only include the first ensemble
# member
# XXX: this could be changed to include all ensemble members by
# splitting `idx` into two parts, one for the analysis time and one
# for the ensemble member and then increasing self.__len__ to
# include all ensemble members
- i_ensemble = np.random.randint(self.da_state.ensemble_member.size)
+ warnings.warn(
+ "only use of ensemble member 0 (the first member) is "
+ "implemented for ensemble data"
+ )
+ i_ensemble = 0
da_state = self.da_state.isel(ensemble_member=i_ensemble)
else:
da_state = self.da_state
From bcaa919993ceb7bf922dec91680dcb1a16ea57c8 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 17:16:18 +0200
Subject: [PATCH 198/273] Update neural_lam/datastore/base.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 6859a82d..1be54ec2 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -87,7 +87,7 @@ def get_vars_units(self, category: str) -> List[str]:
Parameters
----------
category : str
- The category of the variables.
+ The category of the variables (state/forcing/static).
Returns
-------
From 90bc5948e41377d89fac696be2cb15a56b2aabb7 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 17:16:34 +0200
Subject: [PATCH 199/273] Update neural_lam/datastore/base.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 1be54ec2..66c18670 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -104,7 +104,7 @@ def get_vars_names(self, category: str) -> List[str]:
Parameters
----------
category : str
- The category of the variables.
+ The category of the variables (state/forcing/static).
Returns
-------
From 5bda935e8934761a1b307399e4a5772aae332c3e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 17:19:18 +0200
Subject: [PATCH 200/273] Update neural_lam/datastore/base.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 66c18670..b8c7afa8 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -121,7 +121,7 @@ def get_num_data_vars(self, category: str) -> int:
Parameters
----------
category : str
- The category of the variables.
+ The category of the variables (state/forcing/static).
Returns
-------
From 8e7931df500d8b74d8bf5b0dec00e5da87d8ecab Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 17:38:29 +0200
Subject: [PATCH 201/273] remove all multizarr functionality
---
neural_lam/create_graph.py | 5 +-
neural_lam/datastore/multizarr/__init__.py | 7 -
neural_lam/datastore/multizarr/config.py | 43 -
.../multizarr/create_boundary_mask.py | 76 --
.../multizarr/create_datetime_forcings.py | 148 ----
.../multizarr/create_normalization_stats.py | 127 ---
neural_lam/datastore/multizarr/store.py | 732 ------------------
neural_lam/train_model.py | 7 +-
tests/conftest.py | 78 +-
9 files changed, 4 insertions(+), 1219 deletions(-)
delete mode 100644 neural_lam/datastore/multizarr/__init__.py
delete mode 100644 neural_lam/datastore/multizarr/config.py
delete mode 100644 neural_lam/datastore/multizarr/create_boundary_mask.py
delete mode 100644 neural_lam/datastore/multizarr/create_datetime_forcings.py
delete mode 100644 neural_lam/datastore/multizarr/create_normalization_stats.py
delete mode 100644 neural_lam/datastore/multizarr/store.py
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index ba910987..4ce0811b 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -15,7 +15,6 @@
# Local
from .datastore.base import BaseCartesianDatastore
from .datastore.mllam import MLLAMDatastore
-from .datastore.multizarr import MultiZarrDatastore
from .datastore.npyfiles import NpyFilesDatastore
@@ -534,7 +533,6 @@ def create_graph(
DATASTORES = dict(
- multizarr=MultiZarrDatastore,
mllam=MLLAMDatastore,
npyfiles=NpyFilesDatastore,
)
@@ -562,9 +560,8 @@ def cli(input_args=None):
parser.add_argument(
"datastore",
type=str,
- default="multizarr",
choices=DATASTORES.keys(),
- help="kind of data store to use (default: multizarr)",
+ help="kind of data store to use",
)
parser.add_argument(
"datastore_config_path",
diff --git a/neural_lam/datastore/multizarr/__init__.py b/neural_lam/datastore/multizarr/__init__.py
deleted file mode 100644
index c59f31f4..00000000
--- a/neural_lam/datastore/multizarr/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Local
-from . import ( # noqa
- create_boundary_mask,
- create_datetime_forcings,
- create_normalization_stats,
-)
-from .store import MultiZarrDatastore # noqa
diff --git a/neural_lam/datastore/multizarr/config.py b/neural_lam/datastore/multizarr/config.py
deleted file mode 100644
index 1f0a1def..00000000
--- a/neural_lam/datastore/multizarr/config.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Standard library
-from pathlib import Path
-
-# Third-party
-import yaml
-
-
-class Config:
- """Class to load and access the configuration file."""
-
- def __init__(self, values):
- self.values = values
-
- @classmethod
- def from_file(cls, filepath):
- """Load the configuration file from the given path."""
- if filepath.endswith(".yaml"):
- with open(filepath, encoding="utf-8", mode="r") as file:
- return cls(values=yaml.safe_load(file))
- else:
- raise NotImplementedError(Path(filepath).suffix)
-
- def __getattr__(self, name):
- """Recursively access the values in the configuration."""
- keys = name.split(".")
- value = self.values
- for key in keys:
- try:
- value = value[key]
- except KeyError:
- raise AttributeError(f"Key '{key}' not found in {value}")
- if isinstance(value, dict):
- return Config(values=value)
- return value
-
- def __getitem__(self, key):
- value = self.values[key]
- if isinstance(value, dict):
- return Config(values=value)
- return value
-
- def __contains__(self, key):
- return key in self.values
diff --git a/neural_lam/datastore/multizarr/create_boundary_mask.py b/neural_lam/datastore/multizarr/create_boundary_mask.py
deleted file mode 100644
index 31966394..00000000
--- a/neural_lam/datastore/multizarr/create_boundary_mask.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Standard library
-from argparse import ArgumentParser
-from pathlib import Path
-
-# Third-party
-import numpy as np
-import xarray as xr
-
-# Local
-from . import config
-
-DEFAULT_FILENAME = "boundary_mask.zarr"
-
-
-def create_boundary_mask(data_config_path, zarr_path, n_boundary_cells):
- """Create a mask for the boundaries of the grid.
-
- Parameters
- ----------
- data_config_path : str
- Data configuration.
- zarr_path : str
- Path to save the Zarr archive.
-
- """
- data_config_path = config.Config.from_file(str(data_config_path))
- mask = np.zeros(list(data_config_path.grid_shape_state.values.values()))
-
- # Set the n_boundary_cells grid-cells closest to each boundary to True
- mask[:n_boundary_cells, :] = True # top boundary
- mask[-n_boundary_cells:, :] = True # noqa bottom boundary
- mask[:, :n_boundary_cells] = True # left boundary
- mask[:, -n_boundary_cells:] = True # noqa right boundary
-
- mask = xr.Dataset({"mask": (["y", "x"], mask)})
-
- print(f"Saving mask to {zarr_path}...")
- mask.to_zarr(zarr_path, mode="w")
-
-
-def main():
- parser = ArgumentParser(description="Training arguments")
- parser.add_argument(
- "data_config",
- type=str,
- help="Path to data config file",
- )
- parser.add_argument(
- "--zarr_path",
- type=str,
- default=None,
- help="Path to save the Zarr archive "
- "(default: same directory as data config)",
- )
- parser.add_argument(
- "--n_boundary_cells",
- type=int,
- default=30,
- help="Number of grid-cells to set to True along each boundary",
- )
- args = parser.parse_args()
-
- if args.zarr_path is None:
- args.zarr_path = Path(args.data_config).parent / DEFAULT_FILENAME
- else:
- zarr_path = Path(args.zarr_path)
-
- create_boundary_mask(
- data_config_path=args.data_config,
- zarr_path=zarr_path,
- n_boundary_cells=args.n_boundary_cells,
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/neural_lam/datastore/multizarr/create_datetime_forcings.py b/neural_lam/datastore/multizarr/create_datetime_forcings.py
deleted file mode 100644
index d728faaf..00000000
--- a/neural_lam/datastore/multizarr/create_datetime_forcings.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Standard library
-import argparse
-from pathlib import Path
-
-# Third-party
-import numpy as np
-import pandas as pd
-import xarray as xr
-
-# Local
-from .store import MultiZarrDatastore
-
-DEFAULT_FILENAME = "datetime_forcings.zarr"
-
-
-def get_seconds_in_year(year):
- start_of_year = pd.Timestamp(f"{year}-01-01")
- start_of_next_year = pd.Timestamp(f"{year + 1}-01-01")
- return (start_of_next_year - start_of_year).total_seconds()
-
-
-def calculate_datetime_forcing(da_time: xr.DataArray):
- """Compute the datetime forcing for a given set of timesteps, assuming that
- timesteps is a DataArray with a type of `np.datetime64`.
-
- Parameters
- ----------
- timesteps : xr.DataArray
- The timesteps for which to compute the datetime forcing.
-
- Returns
- -------
- xr.Dataset
- The datetime forcing, with the following variables:
- - hour_sin: The sine of the hour of the day, normalized to [0, 1].
- - hour_cos: The cosine of the hour of the day, normalized to [0, 1].
- - year_sin: The sine of the time of year, normalized to [0, 1].
- - year_cos: The cosine of the time of year, normalized to [0, 1].
-
- """
- hours_of_day = xr.DataArray(da_time.dt.hour, dims=["time"])
- seconds_into_year = xr.DataArray(
- [
- (
- pd.Timestamp(dt_obj)
- - pd.Timestamp(f"{pd.Timestamp(dt_obj).year}-01-01")
- ).total_seconds()
- for dt_obj in da_time.values
- ],
- dims=["time"],
- )
- year_seconds = xr.DataArray(
- [
- get_seconds_in_year(pd.Timestamp(dt_obj).year)
- for dt_obj in da_time.values
- ],
- dims=["time"],
- )
- hour_angle = (hours_of_day / 12) * np.pi
- year_angle = (seconds_into_year / year_seconds) * 2 * np.pi
- datetime_forcing = xr.Dataset(
- {
- "hour_sin": np.sin(hour_angle),
- "hour_cos": np.cos(hour_angle),
- "year_sin": np.sin(year_angle),
- "year_cos": np.cos(year_angle),
- },
- coords={"time": da_time},
- )
- datetime_forcing = (datetime_forcing + 1) / 2
- return datetime_forcing
-
-
-def create_datetime_forcing_zarr(
- data_config_path: str,
- zarr_path: str = None,
- chunking: dict = {"time": 1},
-):
- """Create the datetime forcing and save it to a Zarr archive.
-
- Parameters
- ----------
- zarr_path : str
- The path to save the Zarr archive.
- da_time : xr.DataArray
- The time DataArray for which to create the datetime forcing.
- chunking : dict, optional
- The chunking to use when saving the Zarr archive.
-
- """
- if zarr_path is None:
- zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME
-
- datastore = MultiZarrDatastore(config_path=data_config_path)
- da_state = datastore.get_dataarray(category="state", split="train")
-
- da_datetime_forcing = calculate_datetime_forcing(
- da_time=da_state.time
- ).expand_dims({"grid_index": da_state.grid_index})
-
- if "x" in da_state.coords and "y" in da_state.coords:
- # copy the x and y coordinates to the datetime forcing
- for aux_coord in ["x", "y"]:
- da_datetime_forcing.coords[aux_coord] = da_state[aux_coord]
-
- da_datetime_forcing = da_datetime_forcing.set_index(
- grid_index=("y", "x")
- ).unstack("grid_index")
- chunking["x"] = -1
- chunking["y"] = -1
- else:
- chunking["grid_index"] = -1
-
- da_datetime_forcing = da_datetime_forcing.chunk(chunking)
-
- da_datetime_forcing.to_zarr(zarr_path, mode="w")
- print(da_datetime_forcing)
- print(f"Datetime forcing saved to {zarr_path}")
-
-
-def main():
- """Main function for creating the datetime forcing and boundary mask."""
- parser = argparse.ArgumentParser(
- description="Create the datetime forcing for neural LAM.",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "data_config",
- type=str,
- help="Path to data config file",
- )
- parser.add_argument(
- "--zarr_path",
- type=str,
- default=None,
- help="Path to save the Zarr archive "
- "(default: same directory as the data-config)",
- )
- args = parser.parse_args()
-
- create_datetime_forcing_zarr(
- data_config_path=args.data_config,
- zarr_path=args.zarr_path,
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/neural_lam/datastore/multizarr/create_normalization_stats.py b/neural_lam/datastore/multizarr/create_normalization_stats.py
deleted file mode 100644
index e4bbd353..00000000
--- a/neural_lam/datastore/multizarr/create_normalization_stats.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Standard library
-import argparse
-from pathlib import Path
-
-# Third-party
-import xarray as xr
-
-# Local
-from .store import MultiZarrDatastore
-
-DEFAULT_FILENAME = "normalization.zarr"
-
-
-def compute_stats(da):
- mean = da.mean(dim=("time", "grid_index"))
- std = da.std(dim=("time", "grid_index"))
- return mean, std
-
-
-def create_normalization_stats_zarr(
- data_config_path: str,
- zarr_path: str = None,
-):
- """
- Compute mean and std.-dev. for state and forcing variables and save them to
- a Zarr file.
-
- Parameters
- ----------
- data_config_path : str
- Path to data config file.
- zarr_path : str, optional
- Path to save the normalization statistics to. If not provided, the
- statistics are saved to the same directory as the data config file with
- the name `normalization.zarr`.
-
- """
- if zarr_path is None:
- zarr_path = Path(data_config_path).parent / DEFAULT_FILENAME
-
- datastore = MultiZarrDatastore(config_path=data_config_path)
-
- da_state = datastore.get_dataarray(category="state", split="train")
- da_forcing = datastore.get_dataarray(category="forcing", split="train")
-
- print("Computing mean and std.-dev. for parameters...", flush=True)
- da_state_mean, da_state_std = compute_stats(da_state)
-
- if da_forcing is not None:
- da_forcing_mean, da_forcing_std = compute_stats(da_forcing)
- combined_stats = datastore._config["utilities"]["normalization"][
- "combined_stats"
- ]
-
- if combined_stats is not None:
- for group in combined_stats:
- vars_to_combine = group["vars"]
-
- da_forcing_means = da_forcing_mean.sel(
- forcing_feature=vars_to_combine
- )
- stds = da_forcing_std.sel(forcing_feature=vars_to_combine)
-
- combined_mean = da_forcing_means.mean(dim="forcing_feature")
- combined_std = (stds**2).mean(dim="forcing_feature") ** 0.5
-
- da_forcing_mean.loc[
- dict(forcing_feature=vars_to_combine)
- ] = combined_mean
- da_forcing_std.loc[
- dict(forcing_feature=vars_to_combine)
- ] = combined_std
- print(
- "Computing mean and std.-dev. for one-step differences...", flush=True
- )
- state_data_normalized = (da_state - da_state_mean) / da_state_std
- state_data_diff_normalized = state_data_normalized.diff(dim="time")
- diff_mean, diff_std = compute_stats(state_data_diff_normalized)
-
- ds = xr.Dataset(
- {
- "state_mean": da_state_mean,
- "state_std": da_state_std,
- "state_diff_mean": diff_mean,
- "state_diff_std": diff_std,
- }
- )
-
- if da_forcing is not None:
- dsf = xr.Dataset(
- {
- "forcing_mean": da_forcing_mean,
- "forcing_std": da_forcing_std,
- }
- )
- ds = xr.merge([ds, dsf])
-
- ds = ds.chunk({"state_feature": -1, "forcing_feature": -1})
- print("Saving dataset as Zarr...")
- ds.to_zarr(zarr_path, mode="w")
-
-
-def main():
- parser = argparse.ArgumentParser(
- description="Create standardization statistics",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "data_config",
- type=str,
- help="Path to data config file",
- )
- parser.add_argument(
- "--zarr_path",
- type=str,
- default="normalization.zarr",
- help="Directory where data is stored",
- )
- args = parser.parse_args()
-
- create_normalization_stats_zarr(
- data_config_path=args.data_config, zarr_path=args.zarr_path
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py
deleted file mode 100644
index baf3be9e..00000000
--- a/neural_lam/datastore/multizarr/store.py
+++ /dev/null
@@ -1,732 +0,0 @@
-# Standard library
-import functools
-import os
-from pathlib import Path
-
-# Third-party
-import cartopy.crs as ccrs
-import numpy as np
-import pandas as pd
-import xarray as xr
-import yaml
-
-# Local
-from ..base import BaseCartesianDatastore, CartesianGridShape
-
-
-class MultiZarrDatastore(BaseCartesianDatastore):
- DIMS_TO_KEEP = {"time", "grid_index", "variable_name"}
-
- def __init__(self, config_path):
- """
- Create a multi-zarr datastore from the given configuration file. The
- configuration file should be a YAML file, the format of which is should
- be inferred from the example configuration file in
- `tests/datastore_examp les/multizarr/data_config.yml`.
-
- Parameters
- ----------
- config_path : str
- The path to the configuration file.
-
- """
- self._config_path = Path(config_path)
- self._root_path = self._config_path.parent
- with open(config_path, encoding="utf-8", mode="r") as file:
- self._config = yaml.safe_load(file)
-
- @property
- def root_path(self):
- """Return the root path of the datastore.
-
- Returns
- -------
- str
- The root path of the datastore.
-
- """
- return self._root_path
-
- @property
- def config(self) -> dict:
- """Return the configuration dictionary.
-
- Returns
- -------
- dict
- The configuration dictionary.
-
- """
- return self._config
-
- def _normalize_path(self, path) -> str:
- """
- Normalize the path of source-dataset defined in the configuration file.
- This assumes that any paths that do not start with a protocol (e.g.
- `s3://`) or are not absolute paths, are relative to the configuration
- file.
-
- Parameters
- ----------
- path : str
- The path to normalize.
-
- Returns
- -------
- str
- The normalized path.
- """
- # try to parse path to see if it defines a protocol, e.g. s3://
- if "://" in path or path.startswith("/"):
- pass
- else:
- # assume path is relative to config file
- path = os.path.join(self._root_path, path)
- return path
-
- def open_zarrs(self, category):
- """
- Open the zarr dataset for the given category.
-
- Parameters
- ----------
- category : str
- The category of the dataset (state/forcing/static).
-
- Returns
- -------
- xr.Dataset
- The xarray Dataset object.
-
- """
- zarr_configs = self._config[category]["zarrs"]
-
- datasets = []
- for config in zarr_configs:
- dataset_path = self._normalize_path(config["path"])
-
- try:
- dataset = xr.open_zarr(dataset_path, consolidated=True)
- except Exception as e:
- raise Exception("Error opening dataset:", dataset_path) from e
- datasets.append(dataset)
- merged_dataset = xr.merge(datasets)
- merged_dataset.attrs["category"] = category
- return merged_dataset
-
- @functools.cached_property
- def coords_projection(self):
- """
- Return the projection object for the coordinates.
-
- The projection object is used to plot the coordinates on a map.
-
- Returns:
- cartopy.crs.Projection: The projection object.
-
- """
- proj_config = self._config["projection"]
- proj_class_name = proj_config["class"]
- proj_class = getattr(ccrs, proj_class_name)
- proj_params = proj_config.get("kwargs", {})
- return proj_class(**proj_params)
-
- @functools.cached_property
- def step_length(self):
- """Return the step length of the dataset in hours.
-
- Returns:
- int: The step length in hours.
-
- """
- dataset = self.open_zarrs("state")
- time = dataset.time.isel(time=slice(0, 2)).values
- step_length_ns = time[1] - time[0]
- step_length_hours = step_length_ns / np.timedelta64(1, "h")
- return int(step_length_hours)
-
- @functools.lru_cache()
- def get_vars_names(self, category):
- """Return the names of the variables in the dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
-
- Returns:
- list: The names of the variables in the dataset.
-
- """
- surface_vars_names = self._config[category].get("surface_vars") or []
- atmosphere_vars_names = [
- f"{var}_{level}"
- for var in (self._config[category].get("atmosphere_vars") or [])
- for level in (self._config[category].get("levels") or [])
- ]
- return surface_vars_names + atmosphere_vars_names
-
- @functools.lru_cache()
- def get_vars_units(self, category):
- """Return the units of the variables in the dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
-
- Returns:
- list: The units of the variables in the dataset.
-
- """
- surface_vars_units = self._config[category].get("surface_units") or []
- atmosphere_vars_units = [
- unit
- for unit in (self._config[category].get("atmosphere_units") or [])
- for _ in (self._config[category].get("levels") or [])
- ]
- return surface_vars_units + atmosphere_vars_units
-
- @functools.lru_cache()
- def get_num_data_vars(self, category):
- """Return the number of data variables in the dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
-
- Returns:
- int: The number of data variables in the dataset.
-
- """
- surface_vars = self._config[category].get("surface_vars", [])
- atmosphere_vars = self._config[category].get("atmosphere_vars", [])
- levels = self._config[category].get("levels", [])
-
- surface_vars_count = (
- len(surface_vars) if surface_vars is not None else 0
- )
- atmosphere_vars_count = (
- len(atmosphere_vars) if atmosphere_vars is not None else 0
- )
- levels_count = len(levels) if levels is not None else 0
-
- return surface_vars_count + atmosphere_vars_count * levels_count
-
- def _stack_grid(self, ds):
- """Stack the grid dimensions of the dataset.
-
- Args:
- ds (xr.Dataset): The xarray Dataset object.
-
- Returns:
- xr.Dataset: The xarray Dataset object with stacked grid dimensions.
-
- """
- if "grid_index" in ds.dims:
- raise ValueError("Grid dimensions already stacked.")
- else:
- if "x" not in ds.dims or "y" not in ds.dims:
- self._rename_dataset_dims_and_vars(dataset=ds)
- ds = ds.stack(grid_index=("y", "x")).reset_index("grid_index")
- # reset the grid_index coordinates to have integer values, otherwise
- # the serialisation to zarr will fail
- ds["grid_index"] = np.arange(len(ds["grid_index"]))
- return ds
-
- def _convert_dataset_to_dataarray(self, dataset):
- """Convert the Dataset to a Dataarray.
-
- Args:
- dataset (xr.Dataset): The xarray Dataset object.
-
- Returns:
- xr.DataArray: The xarray DataArray object.
-
- """
- if isinstance(dataset, xr.Dataset):
- dataset = dataset.to_array(dim="variable_name")
- return dataset
-
- def _filter_dimensions(self, dataset, transpose_array=True):
- """Drop the dimensions and filter the data_vars of the dataset.
-
- Args:
- dataset (xr.Dataset): The xarray Dataset object.
- transpose_array (bool): Whether to transpose the array.
-
- Returns:
- xr.Dataset: The xarray Dataset object with filtered dimensions.
- OR xr.DataArray: The xarray DataArray object with filtered
- dimensions.
-
- """
- dims_to_keep = self.DIMS_TO_KEEP
- dataset_dims = set(list(dataset.dims) + ["variable_name"])
- min_req_dims = dims_to_keep.copy()
- min_req_dims.discard("time")
- if not min_req_dims.issubset(dataset_dims):
- missing_dims = min_req_dims - dataset_dims
- print(
- f"\033[91mMissing required dimensions in dataset: "
- f"{missing_dims}\033[0m"
- )
- print(
- "\033[91mAttempting to update dims and "
- "vars based on zarr config...\033[0m"
- )
- dataset = self._rename_dataset_dims_and_vars(
- dataset.attrs["category"], dataset=dataset
- )
- dataset = self._stack_grid(dataset)
- dataset_dims = set(list(dataset.dims) + ["variable_name"])
- if min_req_dims.issubset(dataset_dims):
- print(
- "\033[92mSuccessfully updated dims and "
- "vars based on zarr config.\033[0m"
- )
- else:
- print(
- "\033[91mFailed to update dims and "
- "vars based on zarr config.\033[0m"
- )
- return None
-
- dataset_dims = set(list(dataset.dims) + ["variable_name"])
- dims_to_drop = dataset_dims - dims_to_keep
- dataset = dataset.drop_dims(dims_to_drop)
- if dims_to_drop:
- print(
- "\033[91mDropped dimensions: --",
- dims_to_drop,
- "-- from dataset.\033[0m",
- )
- print(
- "\033[91mAny data vars dependent "
- "on these variables were dropped!\033[0m"
- )
-
- if transpose_array:
- dataset = self._convert_dataset_to_dataarray(dataset)
-
- if "time" in dataset.dims:
- dataset = dataset.transpose(
- "time", "grid_index", "variable_name"
- )
- else:
- dataset = dataset.transpose("grid_index", "variable_name")
- dataset_vars = (
- list(dataset.data_vars)
- if isinstance(dataset, xr.Dataset)
- else dataset["variable_name"].values.tolist()
- )
-
- print( # noqa
- f"\033[94mYour {dataset.attrs['category']} xr.Dataarray has the "
- f"following variables: {dataset_vars} \033[0m",
- )
-
- return dataset
-
- def _reshape_grid_to_2d(self, dataset, grid_shape=None):
- """Reshape the grid to 2D for stacked data without multi-index.
-
- Args:
- dataset (xr.Dataset): The xarray Dataset object.
- grid_shape (dict): The shape of the grid.
-
- Returns:
- xr.Dataset: The xarray Dataset object with reshaped grid dimensions.
-
- """
- if grid_shape is None:
- grid_shape = dict(self.grid_shape_state.values.items())
- x_dim, y_dim = (grid_shape["x"], grid_shape["y"])
-
- x_coords = np.arange(x_dim)
- y_coords = np.arange(y_dim)
- multi_index = pd.MultiIndex.from_product(
- [y_coords, x_coords], names=["y", "x"]
- )
-
- mindex_coords = xr.Coordinates.from_pandas_multiindex(
- multi_index, "grid"
- )
- dataset = dataset.drop_vars(["grid", "x", "y"], errors="ignore")
- dataset = dataset.assign_coords(mindex_coords)
- reshaped_data = dataset.unstack("grid")
-
- return reshaped_data
-
- @functools.lru_cache()
- def get_xy(self, category, stacked=True):
- """Return the x, y coordinates of the dataset.
-
- Parameters
- ----------
- category : str
- The category of the dataset (state/forcing/static).
- stacked : bool
- Whether to stack the x, y coordinates.
-
- Returns
- -------
- np.ndarray
- The x, y coordinates of the dataset, returned differently based on
- the value of `stacked`:
- - `stacked==True`: shape `(2, n_grid_points)` where
- n_grid_points=N_x*N_y.
- - `stacked==False`: shape `(2, N_y, N_x)`
-
- """
- dataset = self.open_zarrs(category)
- xs, ys = dataset.x.values, dataset.y.values
-
- assert (
- xs.ndim == ys.ndim
- ), "x and y coordinates must have the same dimensions."
-
- if xs.ndim == 1:
- x, y = np.meshgrid(xs, ys)
- elif x.ndim == 2:
- x, y = xs, ys
- else:
- raise ValueError("Invalid dimensions for x, y coordinates.")
-
- xy = np.stack((x, y), axis=0) # (2, N_y, N_x)
-
- if stacked:
- xy = xy.reshape(2, -1) # (2, n_grid_points)
-
- return xy
-
- @functools.lru_cache()
- def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """
- Return the normalization dataarray for the given category. This should
- contain a `{category}_mean` and `{category}_std` variable for each
- variable in the category. For `category=="state"`, the dataarray should
- also contain a `state_diff_mean` and `state_diff_std` variable for the
- one- step differences of the state variables. The return dataarray
- should at least have dimensions of `({category}_feature)`, but can
- also include for example `grid_index` (if the normalisation is done per
- grid point for example).
-
- Parameters
- ----------
- category : str
- The category of the dataset (state/forcing/static).
-
- Returns
- -------
- xr.Dataset
- The normalization dataarray for the given category, with variables
- for the mean and standard deviation of the variables (and
- differences for state variables).
-
- """
- # XXX: the multizarr code didn't include routines for computing the
- # normalization of "static" features previously, we'll just hack
- # something in here and assume they are already normalized
- if category == "static":
- da_mean = xr.DataArray(
- np.zeros(self.get_num_data_vars(category)),
- dims=("static_feature",),
- coords={"static_feature": self.get_vars_names(category)},
- )
- da_std = xr.DataArray(
- np.ones(self.get_num_data_vars(category)),
- dims=("static_feature",),
- coords={"static_feature": self.get_vars_names(category)},
- )
- return xr.Dataset(dict(static_mean=da_mean, static_std=da_std))
-
- ds_combined_stats = self._load_and_merge_stats()
- if ds_combined_stats is None:
- return None
-
- ds_combined_stats = self._rename_data_vars(ds_combined_stats)
-
- ops = ["mean", "std"]
- stats_variables = [f"{category}_{op}" for op in ops]
- if category == "state":
- stats_variables += [f"state_diff_{op}" for op in ops]
-
- ds_stats = ds_combined_stats[stats_variables]
-
- return ds_stats
-
- def _load_and_merge_stats(self):
- """Load and merge the normalization statistics for the dataset.
-
- Returns:
- xr.Dataset: The merged normalization statistics for the dataset.
-
- """
- combined_stats = None
- for i, zarr_config in enumerate(
- self._config["utilities"]["normalization"]["zarrs"]
- ):
- stats_path = self._normalize_path(zarr_config["path"])
- if not os.path.exists(stats_path):
- raise FileNotFoundError(
- f"Normalization statistics not found at path: {stats_path}"
- )
- stats = xr.open_zarr(stats_path, consolidated=True)
- if i == 0:
- combined_stats = stats
- else:
- combined_stats = xr.merge([stats, combined_stats])
- return combined_stats
-
- def _rename_data_vars(self, combined_stats):
- """Rename the data variables of the normalization statistics.
-
- Args:
- combined_stats (xr.Dataset): The combined normalization statistics.
-
- Returns:
- xr.Dataset: The combined normalization statistics with renamed data
- variables.
-
- """
- vars_mapping = {}
- for zarr_config in self._config["utilities"]["normalization"]["zarrs"]:
- vars_mapping.update(zarr_config["stats_vars"])
-
- return combined_stats.rename_vars(
- {
- v: k
- for k, v in vars_mapping.items()
- if v in list(combined_stats.data_vars)
- }
- )
-
- def _select_stats_by_category(self, combined_stats, category):
- """Select the normalization statistics for the given category.
-
- Args:
- combined_stats (xr.Dataset): The combined normalization statistics.
- category (str): The category of the dataset (state/forcing/static).
-
- Returns:
- xr.Dataset: The normalization statistics for the dataset.
-
- """
- if category == "state":
- stats = combined_stats.loc[
- dict(variable_name=self.get_vars_names(category=category))
- ]
- stats = stats.drop_vars(["forcing_mean", "forcing_std"])
- return stats
- elif category == "forcing":
- non_normalized_vars = (
- self.utilities.normalization.non_normalized_vars
- )
- if non_normalized_vars is None:
- non_normalized_vars = []
- forcing_vars = self.vars_names(category)
- normalized_vars = [
- var for var in forcing_vars if var not in non_normalized_vars
- ]
- non_normalized_vars = [
- var for var in forcing_vars if var in non_normalized_vars
- ]
- stats_normalized = combined_stats.loc[
- dict(forcing_variable=normalized_vars)
- ]
- if non_normalized_vars:
- stats_non_normalized = combined_stats.loc[
- dict(forcing_variable=non_normalized_vars)
- ]
- stats = xr.merge([stats_normalized, stats_non_normalized])
- else:
- stats = stats_normalized
- stats_normalized = stats_normalized[["forcing_mean", "forcing_std"]]
-
- return stats
- else:
- print(f"Invalid category: {category}")
- return None
-
- def _extract_vars(self, category, ds=None):
- """Extract (select) the data variables from the dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
- dataset (xr.Dataset): The xarray Dataset object.
-
- Returns:
- xr.Dataset: The xarray Dataset object with extracted variables.
-
- """
- if ds is None:
- ds = self.open_zarrs(category)
- surface_vars = self._config[category].get("surface_vars")
- atmoshere_vars = self._config[category].get("atmosphere_vars")
-
- ds_surface = None
- if surface_vars is not None:
- ds_surface = ds[surface_vars]
-
- ds_atmosphere = None
- if atmoshere_vars is not None:
- ds_atmosphere = self._extract_atmosphere_vars(
- category=category, ds=ds
- )
-
- if ds_surface and ds_atmosphere:
- return xr.merge([ds_surface, ds_atmosphere])
- elif ds_surface:
- return ds_surface
- elif ds_atmosphere:
- return ds_atmosphere
- else:
- raise ValueError(f"No variables found in dataset {category}")
-
- def _extract_atmosphere_vars(self, category, ds):
- """Extract the atmosphere variables from the dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
- ds (xr.Dataset): The xarray Dataset object.
-
- Returns:
- xr.Dataset: The xarray Dataset object with atmosphere variables.
-
- """
-
- if (
- "level" not in list(ds.dims)
- and self._config[category]["atmosphere_vars"]
- ):
- ds = self._rename_dataset_dims_and_vars(
- ds.attrs["category"], dataset=ds
- )
-
- data_arrays = [
- ds[var].sel(level=level, drop=True).rename(f"{var}_{level}")
- for var in self._config[category]["atmosphere_vars"]
- for level in self._config[category]["levels"]
- ]
-
- if self._config[category]["atmosphere_vars"]:
- return xr.merge(data_arrays)
- else:
- return xr.Dataset()
-
- def _rename_dataset_dims_and_vars(self, category, dataset=None):
- """Rename the dimensions and variables of the dataset.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
- dataset (xr.Dataset): The xarray Dataset object. OR xr.DataArray:
- The xarray DataArray object.
-
- Returns:
- xr.Dataset: The xarray Dataset object with renamed dimensions and
- variables.
- OR xr.DataArray: The xarray DataArray object with renamed
- dimensions and variables.
-
- """
- convert = False
- if dataset is None:
- dataset = self.open_zarrs(category)
- elif isinstance(dataset, xr.DataArray):
- convert = True
- dataset = dataset.to_dataset("variable_name")
- dims_mapping = {}
- zarr_configs = self._config[category]["zarrs"]
- for zarr_config in zarr_configs:
- dims_mapping.update(zarr_config["dims"])
-
- dataset = dataset.rename_dims(
- {
- v: k
- for k, v in dims_mapping.items()
- if k not in dataset.dims and v in dataset.dims
- }
- )
- dataset = dataset.rename_vars(
- {v: k for k, v in dims_mapping.items() if v in dataset.coords}
- )
- if convert:
- dataset = dataset.to_array()
- return dataset
-
- def _apply_time_split(self, dataset, split="train"):
- """Filter the dataset by the time split.
-
- Args:
- dataset (xr.Dataset): The xarray Dataset object.
- split (str): The time split to filter the dataset.
-
- Returns:["window"]
- xr.Dataset: The xarray Dataset object filtered by the time split.
-
- """
- start, end = (
- self._config["splits"][split]["start"],
- self._config["splits"][split]["end"],
- )
- dataset = dataset.sel(time=slice(start, end))
- dataset.attrs["split"] = split
- return dataset
-
- @property
- def grid_shape_state(self):
- """Return the shape of the state grid.
-
- Returns:
- CartesianGridShape: The shape of the state grid.
-
- """
- return CartesianGridShape(
- x=self._config["grid_shape_state"]["x"],
- y=self._config["grid_shape_state"]["y"],
- )
-
- @property
- def boundary_mask(self) -> xr.DataArray:
- """
- Load the boundary mask for the dataset, with spatial dimensions
- stacked.
-
- Returns
- -------
- xr.DataArray
- The boundary mask for the dataset, with dimensions
- `('grid_index',)`.
-
- """
- boundary_mask_path = self._normalize_path(
- self._config["boundary"]["mask"]["path"]
- )
- ds_boundary_mask = xr.open_zarr(boundary_mask_path)
- return (
- ds_boundary_mask.mask.stack(grid_index=("y", "x"))
- .reset_index("grid_index")
- .astype(int)
- )
-
- def get_dataarray(self, category, split="train"):
- """Process the dataset for the given category.
-
- Args:
- category (str): The category of the dataset (state/forcing/static).
- split (str): The time split to filter the dataset (train/val/test).
-
- Returns:
- xr.DataArray: The xarray DataArray object with processed dataset.
-
- """
- dataset = self.open_zarrs(category)
- dataset = self._extract_vars(category, dataset)
- if category != "static":
- dataset = self._apply_time_split(dataset, split)
- dataset = self._stack_grid(dataset)
- dataset = self._rename_dataset_dims_and_vars(category, dataset)
- dataset = self._filter_dimensions(dataset)
- dataset = self._convert_dataset_to_dataarray(dataset)
- if category == "static" and "time" in dataset.dims:
- dataset = dataset.isel(time=0, drop=True)
-
- dataset = dataset.rename(dict(variable_name=f"{category}_feature"))
-
- return dataset
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index b817ad69..c5132a47 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -12,7 +12,6 @@
# Local
from . import utils
from .datastore.mllam import MLLAMDatastore
-from .datastore.multizarr import MultiZarrDatastore
from .datastore.npyfiles import NpyFilesDatastore
from .models import GraphLAM, HiLAM, HiLAMParallel
from .weather_dataset import WeatherDataModule
@@ -25,9 +24,7 @@
def _init_datastore(datastore_kind, config_path):
- if datastore_kind == "multizarr":
- datastore = MultiZarrDatastore(config_path=config_path)
- elif datastore_kind == "npyfiles":
+ if datastore_kind == "npyfiles":
datastore = NpyFilesDatastore(config_path=config_path)
elif datastore_kind == "mllam":
datastore = MLLAMDatastore(config_path=config_path)
@@ -44,7 +41,7 @@ def main(input_args=None):
parser.add_argument(
"datastore_kind",
type=str,
- choices=["multizarr", "npyfiles", "mllam"],
+ choices=["npyfiles", "mllam"],
help="Kind of datastore to use",
)
parser.add_argument(
diff --git a/tests/conftest.py b/tests/conftest.py
index fdbcb627..39d104ed 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -4,18 +4,16 @@
# Third-party
import pooch
-import xarray as xr
import yaml
# First-party
-from neural_lam.datastore import mllam, multizarr, npyfiles
+from neural_lam.datastore import mllam, npyfiles
# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
os.environ["WANDB_DISABLED"] = "true"
DATASTORES = dict(
- multizarr=multizarr.MultiZarrDatastore,
mllam=mllam.MLLAMDatastore,
npyfiles=npyfiles.NpyFilesDatastore,
)
@@ -65,81 +63,7 @@ def download_meps_example_reduced_dataset():
return config_path
-def bootstrap_multizarr_example():
- """Run the steps that are needed to prepare the input data for the multizarr
- datastore example. This includes:
-
- - Downloading the two zarr datasets (since training directly from S3 is
- error-prone as the connection often breaks)
- - Creating the datetime forcings zarr
- - Creating the normalization stats zarr
- - Creating the boundary mask zarr
-
- """
- multizarr_path = DATASTORE_EXAMPLES_ROOT_PATH / "multizarr"
- n_boundary_cells = 10
-
- base_url = "https://mllam-test-data.s3.eu-north-1.amazonaws.com/"
- data_urls = [
- base_url + "single_levels.zarr",
- base_url + "height_levels.zarr",
- ]
-
- for url in data_urls:
- local_path = multizarr_path / "danra" / Path(url).name
- if local_path.exists():
- continue
- print(f"Downloading {url} to {local_path}")
- ds = xr.open_zarr(url)
- chunk_dict = {dim: -1 for dim in ds.dims if dim != "time"}
- chunk_dict["time"] = 20
- ds = ds.chunk(chunk_dict)
-
- for var in ds.variables:
- if "chunks" in ds[var].encoding:
- del ds[var].encoding["chunks"]
-
- ds.to_zarr(local_path, mode="w")
- print("DONE")
-
- data_config_path = multizarr_path / "data_config.yaml"
- # here assume that the data-config is referring the the default path
- # for the "datetime forcings" dataset
- datetime_forcing_zarr_path = (
- data_config_path.parent
- / multizarr.create_datetime_forcings.DEFAULT_FILENAME
- )
- if not datetime_forcing_zarr_path.exists():
- multizarr.create_datetime_forcings.create_datetime_forcing_zarr(
- data_config_path=data_config_path
- )
-
- normalized_forcing_zarr_path = (
- data_config_path.parent
- / multizarr.create_normalization_stats.DEFAULT_FILENAME
- )
- if not normalized_forcing_zarr_path.exists():
- multizarr.create_normalization_stats.create_normalization_stats_zarr(
- data_config_path=data_config_path
- )
-
- boundary_mask_path = (
- data_config_path.parent
- / multizarr.create_boundary_mask.DEFAULT_FILENAME
- )
-
- if not boundary_mask_path.exists():
- multizarr.create_boundary_mask.create_boundary_mask(
- data_config_path=data_config_path,
- n_boundary_cells=n_boundary_cells,
- zarr_path=boundary_mask_path,
- )
-
- return data_config_path
-
-
DATASTORES_EXAMPLES = dict(
- multizarr=dict(config_path=bootstrap_multizarr_example()),
mllam=dict(
config_path=DATASTORE_EXAMPLES_ROOT_PATH
/ "mllam"
From 6998683ff5ce198b3eed5fe702ef7d84bbbf458f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 18:44:19 +0200
Subject: [PATCH 202/273] cleanup and test fixes for recent changes
---
neural_lam/datastore/__init__.py | 19 +++++++++++++++++++
neural_lam/datastore/init.py | 0
neural_lam/datastore/npyfiles/store.py | 18 +++++++++++++++---
neural_lam/train_model.py | 15 ++-------------
tests/conftest.py | 25 ++++++++++---------------
tests/test_datasets.py | 8 ++++----
tests/test_datastores.py | 26 +++++++++++++-------------
tests/test_graph_creation.py | 4 ++--
tests/test_training.py | 6 +++---
9 files changed, 68 insertions(+), 53 deletions(-)
create mode 100644 neural_lam/datastore/init.py
diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py
index ef20291a..479d31a9 100644
--- a/neural_lam/datastore/__init__.py
+++ b/neural_lam/datastore/__init__.py
@@ -1,2 +1,21 @@
# Local
from .mllam import MLLAMDatastore # noqa
+from .npyfiles import NpyFilesDatastore # noqa
+
+DATASTORES = dict(
+ mllam=MLLAMDatastore,
+ npyfiles=NpyFilesDatastore,
+)
+
+
+def init_datastore(datastore_kind, config_path):
+ DatastoreClass = DATASTORES.get(datastore_kind)
+
+ if DatastoreClass is None:
+ raise NotImplementedError(
+ f"Datastore kind {datastore_kind} is not implemented"
+ )
+
+ datastore = DatastoreClass(config_path=config_path)
+
+ return datastore
diff --git a/neural_lam/datastore/init.py b/neural_lam/datastore/init.py
new file mode 100644
index 00000000..e69de29b
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index cbc24cf5..fa2f152f 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -327,7 +327,18 @@ def _get_single_timeseries_dataarray(
for all categories of data
"""
- assert split in ("train", "val", "test"), "Unknown dataset split"
+ if (
+ set(features).difference(self.get_vars_names(category="static"))
+ == set()
+ ):
+ assert split in (
+ "train",
+ "val",
+ "test",
+ None,
+ ), "Unknown dataset split"
+ else:
+ assert split in ("train", "val", "test"), "Unknown dataset split"
if member is not None and features != self.get_vars_names(
category="state"
@@ -339,8 +350,6 @@ def _get_single_timeseries_dataarray(
# XXX: we here assume that the grid shape is the same for all categories
grid_shape = self.grid_shape_state
- fp_samples = self.root_path / "samples" / split
-
file_params = {}
add_feature_dim = False
features_vary_with_analysis_time = True
@@ -349,14 +358,17 @@ def _get_single_timeseries_dataarray(
file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
# only select one member for now
file_params["member_id"] = member
+ fp_samples = self.root_path / "samples" / split
elif features == ["toa_downwelling_shortwave_flux"]:
filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
add_feature_dim = True
+ fp_samples = self.root_path / "samples" / split
elif features == ["open_water_fraction"]:
filename_format = OPEN_WATER_FILENAME_FORMAT
file_dims = ["y", "x", "feature"]
add_feature_dim = True
+ fp_samples = self.root_path / "samples" / split
elif features == ["surface_geopotential"]:
filename_format = "surface_geopotential.npy"
file_dims = ["y", "x", "feature"]
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index c5132a47..890f80fe 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -11,8 +11,7 @@
# Local
from . import utils
-from .datastore.mllam import MLLAMDatastore
-from .datastore.npyfiles import NpyFilesDatastore
+from .datastore import init_datastore
from .models import GraphLAM, HiLAM, HiLAMParallel
from .weather_dataset import WeatherDataModule
@@ -23,16 +22,6 @@
}
-def _init_datastore(datastore_kind, config_path):
- if datastore_kind == "npyfiles":
- datastore = NpyFilesDatastore(config_path=config_path)
- elif datastore_kind == "mllam":
- datastore = MLLAMDatastore(config_path=config_path)
- else:
- raise ValueError(f"Unknown datastore kind: {datastore_kind}")
- return datastore
-
-
def main(input_args=None):
"""Main function for training and evaluating models."""
parser = ArgumentParser(
@@ -238,7 +227,7 @@ def main(input_args=None):
seed.seed_everything(args.seed)
# Create datastore
- datastore = _init_datastore(
+ datastore = init_datastore(
datastore_kind=args.datastore_kind,
config_path=args.datastore_config_path,
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 39d104ed..f5679c66 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,17 +7,12 @@
import yaml
# First-party
-from neural_lam.datastore import mllam, npyfiles
+from neural_lam.datastore import init_datastore
# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
os.environ["WANDB_DISABLED"] = "true"
-DATASTORES = dict(
- mllam=mllam.MLLAMDatastore,
- npyfiles=npyfiles.NpyFilesDatastore,
-)
-
DATASTORE_EXAMPLES_ROOT_PATH = Path("tests/datastore_examples")
# Initializing variables for the s3 client
@@ -64,15 +59,15 @@ def download_meps_example_reduced_dataset():
DATASTORES_EXAMPLES = dict(
- mllam=dict(
- config_path=DATASTORE_EXAMPLES_ROOT_PATH
- / "mllam"
- / "danra.example.yaml"
- ),
- npyfiles=dict(config_path=download_meps_example_reduced_dataset()),
+ mllam=(DATASTORE_EXAMPLES_ROOT_PATH / "mllam" / "danra.example.yaml"),
+ npyfiles=download_meps_example_reduced_dataset(),
)
-def init_datastore(datastore_name):
- DatastoreClass = DATASTORES[datastore_name]
- return DatastoreClass(**DATASTORES_EXAMPLES[datastore_name])
+def init_datastore_example(datastore_kind):
+ datastore = init_datastore(
+ datastore_kind=datastore_kind,
+ config_path=DATASTORES_EXAMPLES[datastore_kind],
+ )
+
+ return datastore
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index b9c74ea6..a556c9f5 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -4,7 +4,7 @@
# Third-party
import pytest
import torch
-from test_datastores import DATASTORES, init_datastore
+from conftest import DATASTORES, init_datastore_example
from torch.utils.data import DataLoader
# First-party
@@ -25,7 +25,7 @@ def test_dataset_item(datastore_name):
forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,)
"""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
N_gridpoints = datastore.grid_shape_state.x * datastore.grid_shape_state.y
N_pred_steps = 4
@@ -82,7 +82,7 @@ def test_single_batch(datastore_name, split):
And that it returns an xarray DataArray with the correct dimensions.
"""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
device_name = (
torch.device("cuda") if torch.cuda.is_available() else "cpu"
@@ -102,6 +102,7 @@ class ModelArgs:
hidden_layers = 1
processor_layers = 4
mesh_aggr = "sum"
+ forcing_window_size = 3
args = ModelArgs()
@@ -118,7 +119,6 @@ class ModelArgs:
model = GraphLAM( # noqa
args=args,
- forcing_window_size=dataset.forcing_window_size,
datastore=datastore,
)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index a2f35427..ce2f9585 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -40,7 +40,7 @@
import numpy as np
import pytest
import xarray as xr
-from conftest import DATASTORES, init_datastore
+from conftest import DATASTORES, init_datastore_example
# First-party
from neural_lam.datastore.base import BaseCartesianDatastore
@@ -49,14 +49,14 @@
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_root_path(datastore_name):
"""Check that the `datastore.root_path` property is implemented."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
assert isinstance(datastore.root_path, Path)
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_config(datastore_name):
"""Check that the `datastore.config` property is implemented."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
# check the config is a mapping or a dataclass
config = datastore.config
assert isinstance(
@@ -67,7 +67,7 @@ def test_config(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_step_length(datastore_name):
"""Check that the `datastore.step_length` property is implemented."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
step_length = datastore.step_length
assert isinstance(step_length, int)
assert step_length > 0
@@ -78,7 +78,7 @@ def test_datastore_grid_xy(datastore_name):
"""Use the `datastore.get_xy` method to get the x, y coordinates of the
dataset and check that the shape is correct against the `da
tastore.grid_shape_state` property."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
# check the shapes of the xy grid
grid_shape = datastore.grid_shape_state
@@ -103,7 +103,7 @@ def test_get_vars(datastore_name):
return types of each are correct.
"""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
for category in ["state", "forcing", "static"]:
units = datastore.get_vars_units(category)
@@ -120,7 +120,7 @@ def test_get_vars(datastore_name):
def test_get_normalization_dataarray(datastore_name):
"""Check that the `datasto re.get_normalization_dataa rray` method is
implemented."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
for category in ["state", "forcing", "static"]:
ds_stats = datastore.get_normalization_dataarray(category=category)
@@ -153,7 +153,7 @@ def test_get_dataarray(datastore_name):
"""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
for category in ["state", "forcing", "static"]:
n_features = {}
@@ -203,7 +203,7 @@ def test_get_dataarray(datastore_name):
def test_boundary_mask(datastore_name):
"""Check that the `datastore.boundary_mask` property is implemented and
that the returned object is an xarray DataArray with the correct shape."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
da_mask = datastore.boundary_mask
assert isinstance(da_mask, xr.DataArray)
@@ -222,7 +222,7 @@ def test_boundary_mask(datastore_name):
def test_get_xy_extent(datastore_name):
"""Check that the `datastore.get_xy_extent` method is implemented and that
the returned object is a tuple of the correct length."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
@@ -244,7 +244,7 @@ def test_get_xy_extent(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_xy(datastore_name):
"""Check that the `datastore.get_xy` method is implemented."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
@@ -273,7 +273,7 @@ def test_get_xy(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_projection(datastore_name):
"""Check that the `datasto re.coords_projection` property is implemented."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
@@ -284,7 +284,7 @@ def test_get_projection(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def get_grid_shape_state(datastore_name):
"""Check that the `datasto re.grid_shape_state` property is implemented."""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py
index 3a36109f..47ddc4fe 100644
--- a/tests/test_graph_creation.py
+++ b/tests/test_graph_creation.py
@@ -5,7 +5,7 @@
# Third-party
import pytest
import torch
-from test_datastores import DATASTORES, init_datastore
+from conftest import DATASTORES, init_datastore_example
# First-party
from neural_lam.create_graph import create_graph_from_datastore
@@ -19,7 +19,7 @@ def test_graph_creation(datastore_name, graph_name):
And that the graph is created in the correct location.
"""
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
if graph_name == "hierarchical":
hierarchical = True
n_max_levels = 3
diff --git a/tests/test_training.py b/tests/test_training.py
index 09dab0fa..1f1969a9 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -6,7 +6,7 @@
import pytorch_lightning as pl
import torch
import wandb
-from test_datastores import DATASTORES, init_datastore
+from conftest import DATASTORES, init_datastore_example
# First-party
from neural_lam.create_graph import create_graph_from_datastore
@@ -16,7 +16,7 @@
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_training(datastore_name):
- datastore = init_datastore(datastore_name)
+ datastore = init_datastore_example(datastore_name)
if torch.cuda.is_available():
device_name = "cuda"
@@ -73,12 +73,12 @@ class ModelArgs:
lr = 1.0e-3
val_steps_to_log = [1, 3]
metrics_watch = []
+ forcing_window_size = 3
model_args = ModelArgs()
model = GraphLAM( # noqa
args=model_args,
- forcing_window_size=data_module.forcing_window_size,
datastore=datastore,
)
wandb.init()
From 735d324d3b53f005861fe63add276d8d129c9432 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 18:46:25 +0200
Subject: [PATCH 203/273] fix linting
---
neural_lam/datastore/base.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index b8c7afa8..7bd72005 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -25,9 +25,9 @@ class BaseDatastore(abc.ABC):
`torch.utils.data.Dataset` and uses the datastore to access the data).
# Forecast vs analysis data
- If the datastore is used to represent forecast rather than analysis data, then
- the `is_forecast` attribute should be set to True, and returned data from
- `get_dataarray` is assumed to have `analysis_time` and `forecast_time`
+ If the datastore is used to represent forecast rather than analysis data,
+ then the `is_forecast` attribute should be set to True, and returned data
+ from `get_dataarray` is assumed to have `analysis_time` and `forecast_time`
dimensions (rather than just `time`).
# Ensemble vs deterministic data
From 5f2d919cdfb581dbf5d8392659d72d40217a804f Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 18:55:20 +0200
Subject: [PATCH 204/273] remove multizar example files
---
tests/datastore_examples/multizarr/.gitignore | 2 -
.../multizarr/data_config.yaml | 168 ------------------
2 files changed, 170 deletions(-)
delete mode 100644 tests/datastore_examples/multizarr/.gitignore
delete mode 100644 tests/datastore_examples/multizarr/data_config.yaml
diff --git a/tests/datastore_examples/multizarr/.gitignore b/tests/datastore_examples/multizarr/.gitignore
deleted file mode 100644
index f2828f46..00000000
--- a/tests/datastore_examples/multizarr/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-*.zarr/
-graph/
diff --git a/tests/datastore_examples/multizarr/data_config.yaml b/tests/datastore_examples/multizarr/data_config.yaml
deleted file mode 100644
index 5d5a4336..00000000
--- a/tests/datastore_examples/multizarr/data_config.yaml
+++ /dev/null
@@ -1,168 +0,0 @@
-name: danra
-state:
- zarrs:
- - path: "danra/single_levels.zarr"
- dims:
- time: time
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- - path: "danra/height_levels.zarr"
- dims:
- time: time
- level: altitude
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- surface_vars:
- - u10m
- - v10m
- - t2m
- surface_units:
- - m/s
- - m/s
- - K
- atmosphere_vars:
- - u
- - v
- - t
- atmosphere_units:
- - m/s
- - m/s
- - K
- levels:
- - 100
-forcing:
- zarrs:
- - path: "danra/single_levels.zarr"
- dims:
- time: time
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- - path: "datetime_forcings.zarr"
- dims:
- time: time
- level: null
- x: x
- y: y
- grid: null
- surface_vars:
- - cape_column # just as a technical test
- - icei0m
- - vis0m
- - xhail0m
- - hour_cos
- - hour_sin
- - year_cos
- - year_sin
- surface_units:
- - J/kg
- - kg/m^2 # just as a technical test :)
- - m
- - m
- - ""
- - ""
- - ""
- - ""
- atmosphere_vars: null
- atmosphere_units: null
- levels: null
- window: 3 # Number of time steps to use for forcing (odd)
-static:
- zarrs:
- - path: "danra/single_levels.zarr"
- dims:
- level: null
- x: x
- y: y
- grid: null
- lat_lon_names:
- lon: lon
- lat: lat
- surface_vars:
- - pres0m # just as a technical test
- surface_units:
- - Pa
- atmosphere_vars: null
- atmosphere_units: null
- levels: null
-boundary:
- zarrs: # This is not used currently, but soon ERA% boundaries will be used
- - path: "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
- dims:
- time: time
- level: level
- x: longitude
- y: latitude
- grid: null
- lat_lon_names:
- lon: longitude
- lat: latitude
- mask:
- path: "boundary_mask.zarr"
- dims:
- x: x
- y: y
- surface_vars:
- - t2m
- surface_units:
- - K
- atmosphere_vars: null
- atmosphere_units: null
- levels: null
- window: 3
-utilities:
- normalization:
- zarrs:
- - path: "normalization.zarr"
- stats_vars:
- state_mean: state_mean
- state_std: state_std
- forcing_mean: forcing_mean
- forcing_std: forcing_std
- diff_mean: diff_mean
- diff_std: diff_std
- combined_stats:
- - vars:
- - icei0m
- - vis0m
- - vars:
- - cape_column
- - xhail0m
- non_normalized_vars:
- - hour_cos
- - hour_sin
- - year_cos
- - year_sin
-grid_shape_state:
- y: 589
- x: 789
-splits:
- train:
- start: 1990-09-01T00
- end: 1990-09-11T00
- val:
- start: 1990-09-11T03
- end: 1990-09-13T09
- test:
- start: 1990-09-11T03
- end: 1990-09-13T09
-projection:
- class: LambertConformal # Name of class in cartopy.crs
- kwargs:
- central_longitude: 25
- central_latitude: 56.4
- standard_parallels: [50.4, 61.6]
- inverted: false # Whether the projection is inverted
From 5263d2ced70b5cca93f042dd9497e662756671b7 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 18:55:38 +0200
Subject: [PATCH 205/273] normalization -> standardization
---
neural_lam/datastore/base.py | 19 ++++++++++---------
neural_lam/datastore/mllam.py | 16 ++++++++--------
neural_lam/datastore/npyfiles/store.py | 8 ++++----
neural_lam/models/ar_model.py | 4 +++-
neural_lam/weather_dataset.py | 15 ++++++++-------
5 files changed, 33 insertions(+), 29 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 7bd72005..e0e0d667 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -132,15 +132,16 @@ def get_num_data_vars(self, category: str) -> int:
pass
@abc.abstractmethod
- def get_normalization_dataarray(self, category: str) -> xr.Dataset:
+ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
- Return the normalization dataarray for the given category. This should
- contain a `{category}_mean` and `{category}_std` variable for each
- variable in the category. For `category=="state"`, the dataarray should
- also contain a `state_diff_mean` and `state_diff_std` variable for the
- one- step differences of the state variables. The returned dataarray
- should at least have dimensions of `({categ ory}_feature)`, but can
- also include for example `grid_index` (if the normalisation is done per
+ Return the standardization (i.e. scaling to mean of 0.0 and standard
+ deviation of 1.0) dataarray for the given category. This should contain
+ a `{category}_mean` and `{category}_std` variable for each variable in
+ the category. For `category=="state"`, the dataarray should also
+ contain a `state_diff_mean` and `state_diff_std` variable for the one-
+ step differences of the state variables. The returned dataarray should
+ at least have dimensions of `({categ ory}_feature)`, but can also
+ include for example `grid_index` (if the standardization is done per
grid point for example).
Parameters
@@ -151,7 +152,7 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
Returns
-------
xr.Dataset
- The normalization dataarray for the given category, with variables
+ The standardization dataarray for the given category, with variables
for the mean and standard deviation of the variables (and
differences for state variables).
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index fcb06030..15886b9e 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -219,13 +219,13 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
)
return da_category.sel(time=slice(t_start, t_end))
- def get_normalization_dataarray(self, category: str) -> xr.Dataset:
+ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
- Return the normalization dataarray for the given category. This should
- contain a `{category}_mean` and `{category}_std` variable for each
- variable in the category. For `category=="state"`, the dataarray should
- also contain a `state_diff_mean` and `state_diff_std` variable for the
- one- step differences of the state variables.
+ Return the standardization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one- step differences of the state variables.
Parameters
----------
@@ -235,8 +235,8 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
Returns
-------
xr.Dataset
- The normalization dataarray for the given category, with variables
- for the mean and standard deviation of the variables (and
+ The standardization dataarray for the given category, with
+ variables for the mean and standard deviation of the variables (and
differences for state variables).
"""
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index fa2f152f..6b2e72f4 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -637,8 +637,8 @@ def boundary_mask(self) -> xr.DataArray:
da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int)
return da_mask_stacked_xy
- def get_normalization_dataarray(self, category: str) -> xr.Dataset:
- """Return the normalization dataarray for the given category. This
+ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
+ """Return the standardization dataarray for the given category. This
should contain a `{category}_mean` and `{category}_std` variable for
each variable in the category. For `category=="state"`, the dataarray
should also contain a `state_diff_mean` and `state_diff_std` variable
@@ -652,8 +652,8 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:
Returns
-------
xr.Dataset
- The normalization dataarray for the given category, with variables
- for the mean and standard deviation of the variables (and
+ The standardization dataarray for the given category, with
+ variables for the mean and standard deviation of the variables (and
differences for state variables).
"""
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 203b20c5..ec1649a4 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -38,7 +38,9 @@ def __init__(
da_static_features = datastore.get_dataarray(
category="static", split=split
)
- da_state_stats = datastore.get_normalization_dataarray(category="state")
+ da_state_stats = datastore.get_standardization_dataarray(
+ category="state"
+ )
da_boundary_mask = datastore.boundary_mask
forcing_window_size = args.forcing_window_size
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index b9ac8f09..72fa5d54 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -54,7 +54,7 @@ def __init__(
# TODO: This will become part of ar_model.py soon!
self.standardize = standardize
if standardize:
- self.ds_state_stats = self.datastore.get_normalization_dataarray(
+ self.ds_state_stats = self.datastore.get_standardization_dataarray(
category="state"
)
@@ -63,7 +63,7 @@ def __init__(
if self.da_forcing is not None:
self.ds_forcing_stats = (
- self.datastore.get_normalization_dataarray(
+ self.datastore.get_standardization_dataarray(
category="forcing"
)
)
@@ -147,11 +147,12 @@ def __getitem__(self, idx):
target states, forcing and batch times.
The implementation currently uses xarray.DataArray objects for the
- normalisation so that we can make us of xarray's broadcasting
- capabilities. This makes it possible to normalise with both global
- means, but also for example where a grid-point mean has been computed.
- This code will have to be replace if normalisation is to be done on the
- GPU to handle different shapes of the normalisation.
+ standardization (scaling to mean 0.0 and standard deviation of 1.0) so
+ that we can make us of xarray's broadcasting capabilities. This makes
+ it possible to standardization with both global means, but also for
+ example where a grid-point mean has been computed. This code will have
+ to be replace if standardization is to be done on the GPU to handle
+ different shapes of the standardization.
Parameters
----------
From ba1bec33c0fb1ab9401cab92d9f1e0776ff7a2d3 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 18:58:58 +0200
Subject: [PATCH 206/273] fix import for tests
---
tests/test_datasets.py | 3 ++-
tests/test_datastores.py | 3 ++-
tests/test_graph_creation.py | 3 ++-
tests/test_training.py | 3 ++-
4 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index a556c9f5..ad03a880 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -4,11 +4,12 @@
# Third-party
import pytest
import torch
-from conftest import DATASTORES, init_datastore_example
+from conftest import init_datastore_example
from torch.utils.data import DataLoader
# First-party
from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.datastore import DATASTORES
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataset
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index ce2f9585..dea99e96 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -40,9 +40,10 @@
import numpy as np
import pytest
import xarray as xr
-from conftest import DATASTORES, init_datastore_example
+from conftest import init_datastore_example
# First-party
+from neural_lam.datastore import DATASTORES
from neural_lam.datastore.base import BaseCartesianDatastore
diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py
index 47ddc4fe..f11ee7d7 100644
--- a/tests/test_graph_creation.py
+++ b/tests/test_graph_creation.py
@@ -5,10 +5,11 @@
# Third-party
import pytest
import torch
-from conftest import DATASTORES, init_datastore_example
+from conftest import init_datastore_example
# First-party
from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.datastore import DATASTORES
@pytest.mark.parametrize("graph_name", ["1level", "multiscale", "hierarchical"])
diff --git a/tests/test_training.py b/tests/test_training.py
index 1f1969a9..23b32191 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -6,10 +6,11 @@
import pytorch_lightning as pl
import torch
import wandb
-from conftest import DATASTORES, init_datastore_example
+from conftest import init_datastore_example
# First-party
from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.datastore import DATASTORES
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataModule
From d04d15e4f8d19b236dd0a0791f9ac239c734071c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 10 Sep 2024 18:59:59 +0200
Subject: [PATCH 207/273] Update neural_lam/datastore/base.py
Co-authored-by: Joel Oskarsson
---
neural_lam/datastore/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index e0e0d667..3a943c18 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -166,7 +166,7 @@ def get_dataarray(
"""
Return the processed data (as a single `xr.DataArray`) for the given
category of data and test/train/val-split that covers all the data (in
- space and time) of a given category (state/forcin g/static). A
+ space and time) of a given category (state/forcing/static). A
datastore must be able to return for the "state" category, but
"forcing" and "static" are optional (in which case the method should
return `None`). For the "static" category the `split` is allowed to be
From 743d7a1fdee785da8a890617967f3c16cb6d87b4 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 12 Sep 2024 11:21:07 +0200
Subject: [PATCH 208/273] fix coord issues and add datastore example plotting
cli
---
neural_lam/create_graph.py | 9 +-
neural_lam/datastore/base.py | 42 ++---
neural_lam/datastore/mllam.py | 50 +++++-
neural_lam/datastore/npyfiles/store.py | 48 +++---
neural_lam/datastore/plot_example.py | 150 ++++++++++++++++++
neural_lam/weather_dataset.py | 207 +++++++++++++++++++++----
tests/test_datasets.py | 78 +++++++++-
tests/test_datastores.py | 23 ++-
8 files changed, 515 insertions(+), 92 deletions(-)
create mode 100644 neural_lam/datastore/plot_example.py
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index 4ce0811b..0b267f67 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -13,9 +13,8 @@
from torch_geometric.utils.convert import from_networkx
# Local
+from .datastore import DATASTORES
from .datastore.base import BaseCartesianDatastore
-from .datastore.mllam import MLLAMDatastore
-from .datastore.npyfiles import NpyFilesDatastore
def plot_graph(graph, title=None):
@@ -532,12 +531,6 @@ def create_graph(
save_edges(pyg_m2g, "m2g", graph_dir_path)
-DATASTORES = dict(
- mllam=MLLAMDatastore,
- npyfiles=NpyFilesDatastore,
-)
-
-
def create_graph_from_datastore(
datastore: BaseCartesianDatastore,
output_root_path: str,
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 3a943c18..8e6d6e8d 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -9,6 +9,7 @@
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
+from pandas.core.indexes.multi import MultiIndex
class BaseDatastore(abc.ABC):
@@ -228,21 +229,13 @@ class CartesianGridShape:
class BaseCartesianDatastore(BaseDatastore):
- """Base class for weather
- data stored on a Cartesian
- grid. In addition to the
- methods and attributes
- required for weather data
- in general (see
- `BaseDatastore`) for
- Cartesian gridded source
- data each `grid_index`
- coordinate value is assume
- to have an associated `x`
- and `y`-value so that the
- processed data-arrays can
- be reshaped back into into
- 2D xy-gridded arrays.
+ """
+ Base class for weather data stored on a Cartesian grid. In addition to the
+ methods and attributes required for weather data in general (see
+ `BaseDatastore`) for Cartesian gridded source data each `grid_index`
+ coordinate value is assume to have an associated `x` and `y`-value so that
+ the processed data-arrays can be reshaped back into into 2D xy-gridded
+ arrays.
In addition the following attributes and methods are required:
- `coords_projection` (property): Projection object for the coordinates.
@@ -253,7 +246,7 @@ class BaseCartesianDatastore(BaseDatastore):
"""
- CARTESIAN_COORDS = ["y", "x"]
+ CARTESIAN_COORDS = ["x", "y"]
@property
@abc.abstractmethod
@@ -347,9 +340,20 @@ def unstack_grid_coords(
The dataarray or dataset with the grid coordinates unstacked.
"""
- return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack(
- "grid_index"
- )
+ # check whether `grid_index` is a multi-index
+ if not isinstance(da_or_ds.indexes.get("grid_index"), MultiIndex):
+ da_or_ds = da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS)
+
+ da_or_ds_unstacked = da_or_ds.unstack("grid_index")
+
+ # ensure that the x, y dimensions are in the correct order
+ dims = da_or_ds_unstacked.dims
+ xy_dim_order = [d for d in dims if d in self.CARTESIAN_COORDS]
+
+ if xy_dim_order != self.CARTESIAN_COORDS:
+ da_or_ds_unstacked = da_or_ds_unstacked.transpose("y", "x")
+
+ return da_or_ds_unstacked
def stack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py
index 15886b9e..b0867d02 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mllam.py
@@ -70,6 +70,19 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
if len(self.get_vars_names(category)) > 0:
print(f"{category}: {' '.join(self.get_vars_names(category))}")
+ # find out the dimension order for the stacking to grid-index
+ dim_order = None
+ for input_dataset in self._config.inputs.values():
+ dim_order_ = input_dataset.dim_mapping["grid_index"].dims
+ if dim_order is None:
+ dim_order = dim_order_
+ else:
+ assert (
+ dim_order == dim_order_
+ ), "all inputs must have the same dimension order"
+
+ self.CARTESIAN_COORDS = dim_order
+
@property
def root_path(self) -> Path:
"""The root path of the dataset.
@@ -202,6 +215,14 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
da_category = self._ds[category]
+ # set units on x y coordinates if missing
+ for coord in ["x", "y"]:
+ if "units" not in da_category[coord].attrs:
+ da_category[coord].attrs["units"] = "m"
+
+ # set multi-index for grid-index
+ da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS)
+
if "time" not in da_category.dims:
return da_category
else:
@@ -294,10 +315,26 @@ def coords_projection(self) -> ccrs.Projection:
The projection of the coordinates.
"""
- # TODO: danra doesn't contain projection information yet, but the next
- # version will for now we hardcode the projection
- # XXX: this is wrong
- return ccrs.PlateCarree()
+ # XXX: this should move to config
+ kwargs = {
+ "LoVInDegrees": 25.0,
+ "LaDInDegrees": 56.7,
+ "Latin1InDegrees": 56.7,
+ "Latin2InDegrees": 56.7,
+ }
+
+ lon_0 = kwargs["LoVInDegrees"] # Latitude of first standard parallel
+ lat_0 = kwargs["LaDInDegrees"] # Latitude of second standard parallel
+ lat_1 = kwargs["Latin1InDegrees"] # Origin latitude
+ lat_2 = kwargs["Latin2InDegrees"] # Origin longitude
+
+ crs = ccrs.LambertConformal(
+ central_longitude=lon_0,
+ central_latitude=lat_0,
+ standard_parallels=(lat_1, lat_2),
+ )
+
+ return crs
@property
def grid_shape_state(self):
@@ -346,10 +383,11 @@ def get_xy(self, category: str, stacked: bool) -> ndarray:
da_xy = xr.concat([da_x, da_y], dim="grid_coord")
if stacked:
- da_xy = da_xy.stack(grid_index=("y", "x")).transpose(
+ da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose(
"grid_coord", "grid_index"
)
else:
- da_xy = da_xy.transpose("grid_coord", "y", "x")
+ dims = ["grid_coord", "y", "x"]
+ da_xy = da_xy.transpose(*dims)
return da_xy.values
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 6b2e72f4..9f4d90e4 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -347,8 +347,7 @@ def _get_single_timeseries_dataarray(
"Member can only be specified for the 'state' category"
)
- # XXX: we here assume that the grid shape is the same for all categories
- grid_shape = self.grid_shape_state
+ concat_axis = 0
file_params = {}
add_feature_dim = False
@@ -387,7 +386,8 @@ def _get_single_timeseries_dataarray(
fp_samples = self.root_path / "static"
elif features == ["x", "y"]:
filename_format = "nwp_xy.npy"
- file_dims = ["y", "x", "feature"]
+ # NB: for x, y the feature dimension is the first one
+ file_dims = ["feature", "y", "x"]
features_vary_with_analysis_time = False
# XXX: x, y are the same for all splits, and so saved in static/
fp_samples = self.root_path / "static"
@@ -403,6 +403,12 @@ def _get_single_timeseries_dataarray(
coords = {}
arr_shape = []
+
+ xs, ys = self.get_xy(category="state", stacked=False)
+ assert np.all(xs[0, :] == xs[-1, :])
+ assert np.all(ys[:, 0] == ys[:, -1])
+ x = xs[0, :]
+ y = ys[:, 0]
for d in dims:
if d == "elapsed_forecast_duration":
coord_values = (
@@ -413,9 +419,9 @@ def _get_single_timeseries_dataarray(
elif d == "analysis_time":
coord_values = self._get_analysis_times(split=split)
elif d == "y":
- coord_values = np.arange(grid_shape.y)
+ coord_values = y
elif d == "x":
- coord_values = np.arange(grid_shape.x)
+ coord_values = x
elif d == "feature":
coord_values = features
else:
@@ -450,7 +456,7 @@ def _get_single_timeseries_dataarray(
]
if features_vary_with_analysis_time:
- arr_all = dask.array.stack(arrays, axis=0)
+ arr_all = dask.array.stack(arrays, axis=concat_axis)
else:
arr_all = arrays[0]
@@ -568,17 +574,17 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
Returns
-------
np.ndarray
- The x, y coordinates of the dataset, returned differently based on
- the value of `stacked`:
+ The x, y coordinates of the dataset (with x first then y second),
+ returned differently based on the value of `stacked`:
- `stacked==True`: shape `(2, n_grid_points)` where
n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
"""
- # the array on disk has shape [2, N_x, N_y], but we want to return it
- # as [2, N_y, N_x] so we swap the axes
- arr = np.load(self.root_path / "static" / "nwp_xy.npy").swapaxes(1, 2)
+ # the array on disk has shape [2, N_y, N_x], with the first dimension
+ # being [x, y]
+ arr = np.load(self.root_path / "static" / "nwp_xy.npy")
assert arr.shape[0] == 2, "Expected 2D array"
grid_shape = self.grid_shape_state
@@ -611,7 +617,7 @@ def grid_shape_state(self) -> CartesianGridShape:
The shape of the cartesian grid for the state variables.
"""
- nx, ny = self.config.grid_shape_state
+ ny, nx = self.config.grid_shape_state
return CartesianGridShape(x=nx, y=ny)
@property
@@ -626,10 +632,10 @@ def boundary_mask(self) -> xr.DataArray:
"""
xs, ys = self.get_xy(category="state", stacked=False)
- assert np.all(xs[:, 0] == xs[:, -1])
- assert np.all(ys[0, :] == ys[-1, :])
- x = xs[:, 0]
- y = ys[0, :]
+ assert np.all(xs[0, :] == xs[-1, :])
+ assert np.all(ys[:, 0] == ys[:, -1])
+ x = xs[0, :]
+ y = ys[:, 0]
values = np.load(self.root_path / "static" / "border_mask.npy")
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
@@ -677,11 +683,11 @@ def load_pickled_tensor(fn):
std_values = np.array([flux_std, 1.0, 1.0, 1.0, 1.0, 1.0])
elif category == "static":
- ds_static = self.get_dataarray(category="static", split="train")
- ds_static_mean = ds_static.mean(dim=["grid_index"])
- ds_static_std = ds_static.std(dim=["grid_index"])
- mean_values = ds_static_mean["static_feature"].values
- std_values = ds_static_std["static_feature"].values
+ da_static = self.get_dataarray(category="static", split="train")
+ da_static_mean = da_static.mean(dim=["grid_index"]).compute()
+ da_static_std = da_static.std(dim=["grid_index"]).compute()
+ mean_values = da_static_mean.values
+ std_values = da_static_std.values
else:
raise NotImplementedError(f"Category {category} not supported")
diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py
new file mode 100644
index 00000000..53bc6d5e
--- /dev/null
+++ b/neural_lam/datastore/plot_example.py
@@ -0,0 +1,150 @@
+# Third-party
+import matplotlib.pyplot as plt
+
+
+def plot_example_from_datastore(
+ category, datastore, col_dim, split="train", standardize=True, selection={}
+):
+ """
+ Create a plot of the data from the datastore.
+
+ Parameters
+ ----------
+ category : str
+ Category of data to plot, one of "state", "forcing", or "static".
+ datastore : Datastore
+ Datastore to retrieve data from.
+ col_dim : str
+ Dimension to use for plot facetting into columns. This can be a
+ template string that can be formatted with the category name.
+ split : str, optional
+ Split of data to plot, by default "train".
+ standardize : bool, optional
+ Whether to standardize the data before plotting, by default True.
+ selection : dict, optional
+ Selections to apply to the dataarray, for example
+ `time="1990-09-03T0:00" would select this single timestep, by default
+ {}.
+
+ Returns
+ -------
+ Figure
+ Matplotlib figure object.
+ """
+ da = datastore.get_dataarray(category=category, split=split)
+ if standardize:
+ da_stats = datastore.get_standardization_dataarray(category=category)
+ da = (da - da_stats[f"{category}_mean"]) / da_stats[f"{category}_std"]
+ da = datastore.unstack_grid_coords(da)
+
+ if len(selection) > 0:
+ da = da.sel(**selection)
+
+ col = col_dim.format(category=category)
+
+ # check that the column dimension exists and that the resulting shape is 2D
+ if col not in da.dims:
+ raise ValueError(f"Column dimension {col} not found in dataarray.")
+ if not len(da.isel({col: 0}).squeeze().shape) == 2:
+ raise ValueError(
+ f"Column dimension {col} and selection {selection} does not "
+ "result in a 2D dataarray. Please adjust the column dimension "
+ "and/or selection."
+ )
+
+ crs = datastore.coords_projection
+ g = da.plot(
+ x="x",
+ y="y",
+ col=col,
+ col_wrap=min(4, int(da[col].count())),
+ subplot_kws={"projection": crs},
+ transform=crs,
+ size=4,
+ )
+ for ax in g.axes.flat:
+ ax.coastlines()
+ ax.gridlines(draw_labels=["left", "bottom"])
+
+ return g.fig
+
+
+if __name__ == "__main__":
+ # Standard library
+ import argparse
+
+ # Local
+ from . import init_datastore
+
+ def _parse_dict(arg_str):
+ key, value = arg_str.split("=")
+ for op in [int, float]:
+ try:
+ value = op(value)
+ break
+ except ValueError:
+ pass
+ return key, value
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument("datastore_kind", help="Kind of datastore to use.")
+ parser.add_argument(
+ "config_path", help="Path to the datastore configuration file."
+ )
+ parser.add_argument(
+ "--category",
+ default="state",
+ help="Category of data to plot",
+ choices=["state", "forcing", "static"],
+ )
+ parser.add_argument(
+ "--split", default="train", help="Split of data to plot"
+ )
+ parser.add_argument(
+ "--col-dim",
+ default="{category}_feature",
+ help="Dimension to use for plot facetting into columns",
+ )
+ parser.add_argument(
+ "--disable-standardize",
+ dest="standardize",
+ action="store_false",
+ help="Disable standardization of data",
+ )
+ # add the ability to create dictionary of kwargs
+ parser.add_argument(
+ "--selection",
+ nargs="+",
+ default=[],
+ type=_parse_dict,
+ help=(
+ "Selections to apply to the dataarray, for example "
+ '`time="1990-09-03T0:00" would select this single timestep',
+ ),
+ )
+ args = parser.parse_args()
+
+ selection = dict(args.selection)
+
+ # check that column dimension is not in the selection
+ if args.col_dim.format(category=args.category) in selection:
+ raise ValueError(
+ f"Column dimension {args.col_dim.format(category=args.category)} "
+ f"cannot be in the selection ({selection}). Please adjust the "
+ "column dimension and/or selection."
+ )
+
+ datastore = init_datastore(
+ datastore_kind=args.datastore_kind, config_path=args.config_path
+ )
+ plot_example_from_datastore(
+ args.category,
+ datastore,
+ split=args.split,
+ col_dim=args.col_dim,
+ standardize=args.standardize,
+ selection=selection,
+ )
+ plt.show()
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 72fa5d54..ed330e29 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,7 +1,10 @@
# Standard library
+import datetime
import warnings
+from typing import Union
# Third-party
+import numpy as np
import pytorch_lightning as pl
import torch
import xarray as xr
@@ -141,32 +144,26 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
)
return da
- def __getitem__(self, idx):
+ def _build_item_dataarrays(self, idx):
"""
- Return a single training sample, which consists of the initial states,
- target states, forcing and batch times.
-
- The implementation currently uses xarray.DataArray objects for the
- standardization (scaling to mean 0.0 and standard deviation of 1.0) so
- that we can make us of xarray's broadcasting capabilities. This makes
- it possible to standardization with both global means, but also for
- example where a grid-point mean has been computed. This code will have
- to be replace if standardization is to be done on the GPU to handle
- different shapes of the standardization.
+ Create the dataarrays for the initial states, target states and forcing
+ data for the sample at index `idx`.
Parameters
----------
idx : int
- The index of the sample to return, this will refer to the time of
- the initial state.
+ The index of the sample to create the dataarrays for.
Returns
-------
- init_states : TrainingSample
- A training sample object containing the initial states, target
- states, forcing and batch times. The batch times are the times of
- the target steps.
-
+ da_init_states : xr.DataArray
+ The dataarray for the initial states.
+ da_target_states : xr.DataArray
+ The dataarray for the target states.
+ da_forcing_windowed : xr.DataArray
+ The dataarray for the forcing data, windowed for the sample.
+ da_target_times : xr.DataArray
+ The dataarray for the target times.
"""
# handling ensemble data
if self.datastore.is_ensemble:
@@ -230,7 +227,7 @@ def __getitem__(self, idx):
da_init_states = da_state.isel(time=slice(None, 2))
da_target_states = da_state.isel(time=slice(2, None))
- batch_times = da_target_states.time.values.astype(float)
+ da_target_times = da_target_states.time
if self.standardize:
da_init_states = (
@@ -251,29 +248,81 @@ def __getitem__(self, idx):
da_forcing_windowed = da_forcing_windowed.stack(
forcing_feature_windowed=("forcing_feature", "window_sample")
)
+ else:
+ # create an empty forcing tensor with the right shape
+ da_forcing_windowed = xr.DataArray(
+ data=np.empty(
+ (self.ar_steps, da_state.grid_index.size, 0),
+ ),
+ dims=("time", "grid_index", "forcing_feature"),
+ coords={
+ "time": da_target_times,
+ "grid_index": da_state.grid_index,
+ "forcing_feature": [],
+ },
+ )
+
+ return (
+ da_init_states,
+ da_target_states,
+ da_forcing_windowed,
+ da_target_times,
+ )
+
+ def __getitem__(self, idx):
+ """
+ Return a single training sample, which consists of the initial states,
+ target states, forcing and batch times.
+
+ The implementation currently uses xarray.DataArray objects for the
+ standardization (scaling to mean 0.0 and standard deviation of 1.0) so
+ that we can make us of xarray's broadcasting capabilities. This makes
+ it possible to standardization with both global means, but also for
+ example where a grid-point mean has been computed. This code will have
+ to be replace if standardization is to be done on the GPU to handle
+ different shapes of the standardization.
+
+ Parameters
+ ----------
+ idx : int
+ The index of the sample to return, this will refer to the time of
+ the initial state.
+
+ Returns
+ -------
+ init_states : TrainingSample
+ A training sample object containing the initial states, target
+ states, forcing and batch times. The batch times are the times of
+ the target steps.
+
+ """
+ (
+ da_init_states,
+ da_target_states,
+ da_forcing_windowed,
+ da_target_times,
+ ) = self._build_item_dataarrays(idx=idx)
+
+ tensor_dtype = torch.float32
- init_states = torch.tensor(da_init_states.values, dtype=torch.float32)
+ init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype)
target_states = torch.tensor(
- da_target_states.values, dtype=torch.float32
+ da_target_states.values, dtype=tensor_dtype
)
- if self.da_forcing is None:
- # create an empty forcing tensor
- forcing = torch.empty(
- (self.ar_steps, da_state.grid_index.size, 0),
- dtype=torch.float32,
- )
- else:
- forcing = torch.tensor(
- da_forcing_windowed.values, dtype=torch.float32
- )
+ target_times = torch.tensor(
+ da_target_times.astype("datetime64[ns]").astype("int64").values,
+ dtype=torch.int64,
+ )
+
+ forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype)
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps, N_grid, d_features)
# forcing: (ar_steps, N_grid, d_windowed_forcing)
- # batch_times: (ar_steps,)
+ # target_times: (ar_steps,)
- return init_states, target_states, forcing, batch_times
+ return init_states, target_states, forcing, target_times
def __iter__(self):
"""
@@ -286,6 +335,98 @@ def __iter__(self):
for i in range(len(self)):
yield self[i]
+ def create_dataarray_from_tensor(
+ self,
+ tensor: torch.Tensor,
+ time: Union[datetime.datetime, list[datetime.datetime]],
+ category: str,
+ ):
+ """
+ Construct a xarray.DataArray from a `pytorch.Tensor` with coordinates
+ for `grid_index`, `time` and `{category}_feature` matching the shape
+ and number of times provided and add the x/y coordinates from the
+ datastore.
+
+ The number if times provided is expected to match the shape of the
+ tensor. For a 2D tensor, the dimensions are assumed to be (grid_index,
+ {category}_feature) and only a single time should be provided. For a 3D
+ tensor, the dimensions are assumed to be (time, grid_index,
+ {category}_feature) and a list of times should be provided.
+
+ Parameters
+ ----------
+ tensor : torch.Tensor
+ The tensor to construct the DataArray from, this assumed to have
+ the same dimension ordering as returned by the __getitem__ method
+ (i.e. time, grid_index, {category}_feature).
+ time : datetime.datetime or list[datetime.datetime]
+ The time or times of the tensor.
+ category : str
+ The category of the tensor, either "state", "forcing" or "static".
+
+ Returns
+ -------
+ da : xr.DataArray
+ The constructed DataArray.
+ """
+
+ def _is_listlike(obj):
+ # match list, tuple, numpy array
+ return hasattr(obj, "__iter__") and not isinstance(obj, str)
+
+ add_time_as_dim = False
+ if len(tensor.shape) == 2:
+ dims = ["grid_index", f"{category}_feature"]
+ if _is_listlike(time):
+ raise ValueError(
+ "Expected a single time for a 2D tensor with assumed "
+ "dimensions (grid_index, {category}_feature), but got "
+ f"{len(time)} times"
+ )
+ elif len(tensor.shape) == 3:
+ add_time_as_dim = True
+ dims = ["time", "grid_index", f"{category}_feature"]
+ if not _is_listlike(time):
+ raise ValueError(
+ "Expected a list of times for a 3D tensor with assumed "
+ "dimensions (time, grid_index, {category}_feature), but "
+ "got a single time"
+ )
+ else:
+ raise ValueError(
+ "Expected tensor to have 2 or 3 dimensions, but got "
+ f"{len(tensor.shape)}"
+ )
+
+ da_datastore_state = getattr(self, f"da_{category}")
+ da_grid_index = da_datastore_state.grid_index
+ da_state_feature = da_datastore_state.state_feature
+
+ coords = {
+ f"{category}_feature": da_state_feature,
+ "grid_index": da_grid_index,
+ }
+ if add_time_as_dim:
+ coords["time"] = time
+
+ da = xr.DataArray(
+ tensor.numpy(),
+ dims=dims,
+ coords=coords,
+ )
+
+ for grid_coord in ["x", "y"]:
+ if (
+ grid_coord in da_datastore_state.coords
+ and grid_coord not in da.coords
+ ):
+ da.coords[grid_coord] = da_datastore_state[grid_coord]
+
+ if not add_time_as_dim:
+ da.coords["time"] = time
+
+ return da
+
class WeatherDataModule(pl.LightningDataModule):
"""DataModule for weather data."""
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index ad03a880..efe2b1c4 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -2,6 +2,7 @@
from pathlib import Path
# Third-party
+import numpy as np
import pytest
import torch
from conftest import init_datastore_example
@@ -15,7 +16,7 @@
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
-def test_dataset_item(datastore_name):
+def test_dataset_item_shapes(datastore_name):
"""Check that the `datastore.get_dataarray` method is implemented.
Validate the shapes of the tensors match between the different
@@ -42,7 +43,7 @@ def test_dataset_item(datastore_name):
# unpack the item, this is the current return signature for
# WeatherDataset.__getitem__
- init_states, target_states, forcing, batch_times = item
+ init_states, target_states, forcing, target_times = item
# initial states
assert init_states.ndim == 3
@@ -66,8 +67,8 @@ def test_dataset_item(datastore_name):
)
# batch times
- assert batch_times.ndim == 1
- assert batch_times.shape[0] == N_pred_steps
+ assert target_times.ndim == 1
+ assert target_times.shape[0] == N_pred_steps
# try to get the last item of the dataset to ensure slicing and stacking
# operations are working as expected and are consistent with the dataset
@@ -75,6 +76,75 @@ def test_dataset_item(datastore_name):
dataset[len(dataset) - 1]
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_dataset_item_create_dataarray_from_tensor(datastore_name):
+ datastore = init_datastore_example(datastore_name)
+
+ N_pred_steps = 4
+ forcing_window_size = 3
+ dataset = WeatherDataset(
+ datastore=datastore,
+ split="train",
+ ar_steps=N_pred_steps,
+ forcing_window_size=forcing_window_size,
+ )
+
+ idx = 0
+
+ # unpack the item, this is the current return signature for
+ # WeatherDataset.__getitem__
+ _, target_states, _, target_times_arr = dataset[idx]
+ _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays(
+ idx=idx
+ )
+
+ target_times = np.array(target_times_arr, dtype="datetime64[ns]")
+ np.testing.assert_equal(target_times, da_target_times_true.values)
+
+ da_target = dataset.create_dataarray_from_tensor(
+ tensor=target_states, category="state", time=target_times
+ )
+
+ # conversion to torch.float32 may lead to loss of precision
+ np.testing.assert_allclose(
+ da_target.values, da_target_true.values, rtol=1e-6
+ )
+ assert da_target.dims == da_target_true.dims
+ for dim in da_target.dims:
+ np.testing.assert_equal(
+ da_target[dim].values, da_target_true[dim].values
+ )
+
+ # test unstacking the grid coordinates
+ da_target_unstacked = datastore.unstack_grid_coords(da_target)
+ assert all(
+ coord_name in da_target_unstacked.coords for coord_name in ["x", "y"]
+ )
+
+ # check construction of a single time
+ da_target_single = dataset.create_dataarray_from_tensor(
+ tensor=target_states[0], category="state", time=target_times[0]
+ )
+
+ # check that the content is the same
+ # conversion to torch.float32 may lead to loss of precision
+ np.testing.assert_allclose(
+ da_target_single.values, da_target_true[0].values, rtol=1e-6
+ )
+ assert da_target_single.dims == da_target_true[0].dims
+ for dim in da_target_single.dims:
+ np.testing.assert_equal(
+ da_target_single[dim].values, da_target_true[0][dim].values
+ )
+
+ # test unstacking the grid coordinates
+ da_target_single_unstacked = datastore.unstack_grid_coords(da_target_single)
+ assert all(
+ coord_name in da_target_single_unstacked.coords
+ for coord_name in ["x", "y"]
+ )
+
+
@pytest.mark.parametrize("split", ["train", "val", "test"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_single_batch(datastore_name, split):
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index dea99e96..1388e4e0 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -27,6 +27,7 @@
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
+- [x] `stack_grid_coords` (method): Stack the grid coordinates of the dataset
"""
@@ -124,7 +125,7 @@ def test_get_normalization_dataarray(datastore_name):
datastore = init_datastore_example(datastore_name)
for category in ["state", "forcing", "static"]:
- ds_stats = datastore.get_normalization_dataarray(category=category)
+ ds_stats = datastore.get_standardization_dataarray(category=category)
# check that the returned object is an xarray DataArray
# and that it has the correct variables
@@ -295,3 +296,23 @@ def get_grid_shape_state(datastore_name):
assert len(grid_shape) == 2
assert all(isinstance(e, int) for e in grid_shape)
assert all(e > 0 for e in grid_shape)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_stacking_grid_coords(datastore_name):
+ """Check that the `datastore.stack_grid_coords` method is implemented."""
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseCartesianDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ da_static = datastore.get_dataarray("static", split=None)
+
+ da_static_unstacked = datastore.unstack_grid_coords(da_static).load()
+ da_static_test = datastore.stack_grid_coords(da_static_unstacked)
+
+ # XXX: for the moment unstacking doesn't guarantee the order of the
+ # dimensions maybe we should enforce this?
+ da_static_test = da_static_test.transpose(*da_static.dims)
+
+ xr.testing.assert_equal(da_static, da_static_test)
From ac10d7df3c25d7fd15437363eecc5d839f484b5b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 12 Sep 2024 11:27:08 +0200
Subject: [PATCH 209/273] add lru_cache to get_xy_extent
---
neural_lam/datastore/base.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 8e6d6e8d..22731e0e 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -2,6 +2,7 @@
import abc
import collections
import dataclasses
+import functools
from pathlib import Path
from typing import List, Union
@@ -300,6 +301,7 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
"""
pass
+ @functools.lru_cache
def get_xy_extent(self, category: str) -> List[float]:
"""
Return the extent of the x, y coordinates for a given category of data.
From bf8172a4d75414a236ae6b1a61287588bda80e86 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 12 Sep 2024 11:32:46 +0200
Subject: [PATCH 210/273] MLLAMDatastore -> MDPDatastore
---
neural_lam/datastore/__init__.py | 4 ++--
neural_lam/datastore/{mllam.py => mdp.py} | 12 +++++++++---
neural_lam/train_model.py | 4 ++--
3 files changed, 13 insertions(+), 7 deletions(-)
rename neural_lam/datastore/{mllam.py => mdp.py} (96%)
diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py
index 479d31a9..901841db 100644
--- a/neural_lam/datastore/__init__.py
+++ b/neural_lam/datastore/__init__.py
@@ -1,9 +1,9 @@
# Local
-from .mllam import MLLAMDatastore # noqa
+from .mdp import MDPDatastore # noqa
from .npyfiles import NpyFilesDatastore # noqa
DATASTORES = dict(
- mllam=MLLAMDatastore,
+ mdp=MDPDatastore,
npyfiles=NpyFilesDatastore,
)
diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mdp.py
similarity index 96%
rename from neural_lam/datastore/mllam.py
rename to neural_lam/datastore/mdp.py
index b0867d02..f40a74a5 100644
--- a/neural_lam/datastore/mllam.py
+++ b/neural_lam/datastore/mdp.py
@@ -15,12 +15,18 @@
from .base import BaseCartesianDatastore, CartesianGridShape
-class MLLAMDatastore(BaseCartesianDatastore):
- """Datastore class for the MLLAM dataset."""
+class MDPDatastore(BaseCartesianDatastore):
+ """
+ Datastore class for datasets made with the mllam_data_prep library
+ (https://github.com/mllam/mllam-data-prep). This class wraps the
+ `mllam_data_prep` library to do the necessary transforms to create the
+ different categories (state/forcing/static) of data, with the actual
+ transform to do being specified in the configuration file.
+ """
def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
"""
- Construct a new MLLAMDatastore from the configuration file at
+ Construct a new MDPDatastore from the configuration file at
`config_path`. A boundary mask is created with `n_boundary_points`
boundary points. If `reuse_existing` is True, the dataset is loaded
from a zarr file if it exists (unless the config has been modified
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 890f80fe..0913319e 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -11,7 +11,7 @@
# Local
from . import utils
-from .datastore import init_datastore
+from .datastore import DATASTORES, init_datastore
from .models import GraphLAM, HiLAM, HiLAMParallel
from .weather_dataset import WeatherDataModule
@@ -30,7 +30,7 @@ def main(input_args=None):
parser.add_argument(
"datastore_kind",
type=str,
- choices=["npyfiles", "mllam"],
+ choices=DATASTORES.keys(),
help="Kind of datastore to use",
)
parser.add_argument(
From 90ca40041ed072340c68f5aba00fcf8353c74b99 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 12 Sep 2024 11:35:29 +0200
Subject: [PATCH 211/273] missed renames for MDPDatastore
---
README.md | 4 ++--
tests/conftest.py | 2 +-
tests/datastore_examples/{mllam => mdp}/.gitignore | 0
tests/datastore_examples/{mllam => mdp}/danra.example.yaml | 0
4 files changed, 3 insertions(+), 3 deletions(-)
rename tests/datastore_examples/{mllam => mdp}/.gitignore (100%)
rename tests/datastore_examples/{mllam => mdp}/danra.example.yaml (100%)
diff --git a/README.md b/README.md
index 6fa4bd98..e4a1989b 100644
--- a/README.md
+++ b/README.md
@@ -84,7 +84,7 @@ There are currently three different datastores implemented in the codebase:
files during train/val/test sampling, with the transformations to facilitate
this implemented within `neural_lam.datastore.MultizarrDatastore`.
-3. `neural_lam.datastore.MLLAMDatastore` which can combine multiple zarr
+3. `neural_lam.datastore.MDPDatastore` which can combine multiple zarr
datasets either either as a preprocessing step or during sampling, but
offloads the implementation of the transformations the
[mllam-data-prep](https://github.com/mllam/mllam-data-prep) package.
@@ -156,7 +156,7 @@ The amount of pre-processing required will depend on what kind of datastore you
#### NpyFiles Datastore
-#### MLLAM Datastore
+#### MDP (mllam-data-prep) Datastore
An overview of how the different pre-processing steps, training and files depend on each other is given in this figure:
diff --git a/tests/conftest.py b/tests/conftest.py
index f5679c66..a5440275 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -59,7 +59,7 @@ def download_meps_example_reduced_dataset():
DATASTORES_EXAMPLES = dict(
- mllam=(DATASTORE_EXAMPLES_ROOT_PATH / "mllam" / "danra.example.yaml"),
+ mdp=(DATASTORE_EXAMPLES_ROOT_PATH / "mdp" / "danra.example.yaml"),
npyfiles=download_meps_example_reduced_dataset(),
)
diff --git a/tests/datastore_examples/mllam/.gitignore b/tests/datastore_examples/mdp/.gitignore
similarity index 100%
rename from tests/datastore_examples/mllam/.gitignore
rename to tests/datastore_examples/mdp/.gitignore
diff --git a/tests/datastore_examples/mllam/danra.example.yaml b/tests/datastore_examples/mdp/danra.example.yaml
similarity index 100%
rename from tests/datastore_examples/mllam/danra.example.yaml
rename to tests/datastore_examples/mdp/danra.example.yaml
From 154139d32359ca73a453245a869819bba76cbb7a Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 12 Sep 2024 11:44:42 +0200
Subject: [PATCH 212/273] update graph plot for datastores
---
plot_graph.py => neural_lam/plot_graph.py | 26 +++++++++++++++++------
1 file changed, 19 insertions(+), 7 deletions(-)
rename plot_graph.py => neural_lam/plot_graph.py (89%)
diff --git a/plot_graph.py b/neural_lam/plot_graph.py
similarity index 89%
rename from plot_graph.py
rename to neural_lam/plot_graph.py
index b0f20b51..ad2d732c 100644
--- a/plot_graph.py
+++ b/neural_lam/plot_graph.py
@@ -1,4 +1,5 @@
# Standard library
+import os
from argparse import ArgumentParser
# Third-party
@@ -8,7 +9,9 @@
# First-party
from neural_lam import utils
-from neural_lam.datastore.multizarr import config
+
+# Local
+from .datastore import DATASTORES, init_datastore
MESH_HEIGHT = 0.1
MESH_LEVEL_DIST = 0.2
@@ -19,10 +22,15 @@ def main():
"""Plot graph structure in 3D using plotly."""
parser = ArgumentParser(description="Plot graph")
parser.add_argument(
- "--data_config",
+ "datastore_kind",
type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
+ choices=DATASTORES.keys(),
+ help="Kind of datastore to use",
+ )
+ parser.add_argument(
+ "datastore_config_path",
+ type=str,
+ help="Path for the datastore config",
)
parser.add_argument(
"--graph",
@@ -42,14 +50,18 @@ def main():
)
args = parser.parse_args()
- data_config = config.Config.from_file(args.data_config)
- xy = data_config.get_xy("state") # (2, N_y, N_x)
+ datastore = init_datastore(
+ datastore_kind=args.datastore_kind,
+ config_path=args.datastore_config_path,
+ )
+ xy = datastore.get_xy("state", stacked=False) # (2, N_y, N_x)
xy = xy.reshape(2, -1).T # (N_grid, 2)
pos_max = np.max(np.abs(xy))
grid_pos = xy / pos_max # Divide by maximum coordinate
# Load graph data
- hierarchical, graph_ldict = utils.load_graph(args.graph)
+ graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph)
+ hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path)
(g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
graph_ldict["g2m_edge_index"],
graph_ldict["m2g_edge_index"],
From 50ee0b0f555384b8d203840c01168687753a46cd Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 12 Sep 2024 11:46:09 +0200
Subject: [PATCH 213/273] use relative import
---
neural_lam/plot_graph.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py
index ad2d732c..ab848ccf 100644
--- a/neural_lam/plot_graph.py
+++ b/neural_lam/plot_graph.py
@@ -7,10 +7,8 @@
import plotly.graph_objects as go
import torch_geometric as pyg
-# First-party
-from neural_lam import utils
-
# Local
+from . import utils
from .datastore import DATASTORES, init_datastore
MESH_HEIGHT = 0.1
From 7dfd570c1377e0f6e1a276e7684e09f39881c660 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 12 Sep 2024 12:46:24 +0200
Subject: [PATCH 214/273] add long_names and refactor npyfiles create weights
---
neural_lam/datastore/base.py | 17 +++++++++
neural_lam/datastore/mdp.py | 20 ++++++++++
.../npyfiles}/create_parameter_weights.py | 38 +++++++++----------
neural_lam/datastore/npyfiles/store.py | 7 ++++
tests/test_datastores.py | 5 +++
5 files changed, 68 insertions(+), 19 deletions(-)
rename neural_lam/{ => datastore/npyfiles}/create_parameter_weights.py (93%)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 22731e0e..47cac66b 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -116,6 +116,23 @@ def get_vars_names(self, category: str) -> List[str]:
"""
pass
+ @abc.abstractmethod
+ def get_vars_long_names(self, category: str) -> List[str]:
+ """Get the long names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The long names of the variables.
+
+ """
+ pass
+
@abc.abstractmethod
def get_num_data_vars(self, category: str) -> int:
"""Get the number of data variables in the given category.
diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py
index f40a74a5..18a8df26 100644
--- a/neural_lam/datastore/mdp.py
+++ b/neural_lam/datastore/mdp.py
@@ -164,6 +164,26 @@ def get_vars_names(self, category: str) -> List[str]:
return []
return self._ds[f"{category}_feature"].values.tolist()
+ def get_vars_long_names(self, category: str) -> List[str]:
+ """
+ Return the long names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The long names of the variables in the given category.
+
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return []
+ return self._ds[f"{category}_feature_long_name"].values.tolist()
+
def get_num_data_vars(self, category: str) -> int:
"""Return the number of variables in the given category.
diff --git a/neural_lam/create_parameter_weights.py b/neural_lam/datastore/npyfiles/create_parameter_weights.py
similarity index 93%
rename from neural_lam/create_parameter_weights.py
rename to neural_lam/datastore/npyfiles/create_parameter_weights.py
index 4867e609..bdf54011 100644
--- a/neural_lam/create_parameter_weights.py
+++ b/neural_lam/datastore/npyfiles/create_parameter_weights.py
@@ -11,7 +11,8 @@
from tqdm import tqdm
# Local
-from . import WeatherDataset, config
+from ... import WeatherDataset
+from ...datastore import init_datastore
class PaddedWeatherDataset(torch.utils.data.Dataset):
@@ -131,10 +132,7 @@ def main():
"""
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
+ "datastore_config", type=str, help="Path to data config file"
)
parser.add_argument(
"--batch_size",
@@ -164,7 +162,9 @@ def main():
rank = get_rank()
world_size = get_world_size()
- config_loader = config.Config.from_file(args.data_config)
+ datastore = init_datastore(
+ datastore_kind="npyfiles", config_path=args.datastore_config
+ )
if distributed:
@@ -175,9 +175,7 @@ def main():
torch.cuda.set_device(device) if torch.cuda.is_available() else None
if rank == 0:
- static_dir_path = os.path.join(
- "data", config_loader.dataset.name, "static"
- )
+ static_dir_path = os.path.join(datastore.root_path, "static")
# Create parameter weights based on height
# based on fig A.1 in graph cast paper
w_dict = {
@@ -191,7 +189,7 @@ def main():
w_list = np.array(
[
w_dict[par.split("_")[-2]]
- for par in config_loader.dataset.var_longnames
+ for par in datastore.get_vars_long_names(category="state")
]
)
print("Saving parameter weights...")
@@ -200,12 +198,13 @@ def main():
w_list.astype("float32"),
)
- # Load dataset without any subsampling
+ # XXX: is this correct?
+ ar_steps = 61
ds = WeatherDataset(
- config_loader.dataset.name,
+ datastore=datastore,
split="train",
- subsample_step=1,
- pred_length=63,
+ ar_steps=ar_steps,
+ # pred_length=63,
standardize=False,
)
if distributed:
@@ -231,7 +230,7 @@ def main():
print("Computing mean and std.-dev. for parameters...")
means, squares, flux_means, flux_squares = [], [], [], []
- for init_batch, target_batch, forcing_batch in tqdm(loader):
+ for init_batch, target_batch, forcing_batch, _ in tqdm(loader):
if distributed:
init_batch, target_batch, forcing_batch = (
init_batch.to(device),
@@ -299,10 +298,9 @@ def main():
if rank == 0:
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
- config_loader.dataset.name,
+ datastore=datastore,
split="train",
- subsample_step=1,
- pred_length=63,
+ ar_steps=ar_steps,
standardize=True,
) # Re-load with standardization
if distributed:
@@ -327,7 +325,9 @@ def main():
diff_means, diff_squares = [], []
- for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0):
+ for init_batch, target_batch, _, _ in tqdm(
+ loader_standard, disable=rank != 0
+ ):
if distributed:
init_batch, target_batch = init_batch.to(device), target_batch.to(
device
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py
index 9f4d90e4..03160599 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfiles/store.py
@@ -558,6 +558,13 @@ def get_vars_names(self, category: str) -> torch.List[str]:
else:
raise NotImplementedError(f"Category {category} not supported")
+ def get_vars_long_names(self, category: str) -> List[str]:
+ if category == "state":
+ return self.config.dataset.var_longnames
+ else:
+ # TODO: should we add these?
+ return self.get_vars_names(category=category)
+
def get_num_data_vars(self, category: str) -> int:
return len(self.get_vars_names(category=category))
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 1388e4e0..b64852f5 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -10,6 +10,8 @@
category.
- [x] `get_vars_names` (method): Get the names of the variables in the given
category.
+- [x] `get_vars_long_names` (method): Get the long names of the variables in
+ the given category.
- [x] `get_num_data_vars` (method): Get the number of data variables in the
given category.
- [x] `get_normalization_dataarray` (method): Return the normalization
@@ -99,6 +101,7 @@ def test_get_vars(datastore_name):
- `datastore.get_vars_units`
- `datastore.get_vars_names`
+ - `datastore.get_vars_long_names`
- `datastore.get_num_data_vars`
are consistent (as in the number of variables are the same) and that the
@@ -110,11 +113,13 @@ def test_get_vars(datastore_name):
for category in ["state", "forcing", "static"]:
units = datastore.get_vars_units(category)
names = datastore.get_vars_names(category)
+ long_names = datastore.get_vars_long_names(category)
num_vars = datastore.get_num_data_vars(category)
assert len(units) == len(names) == num_vars
assert isinstance(units, list)
assert isinstance(names, list)
+ assert isinstance(long_names, list)
assert isinstance(num_vars, int)
From 2b45b5a072b5aec764ba0708c51c606b5da5fe97 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 23 Sep 2024 11:12:58 +0200
Subject: [PATCH 215/273] Update neural_lam/weather_dataset.py
Co-authored-by: Joel Oskarsson
---
neural_lam/weather_dataset.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index ed330e29..fa545b51 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -116,7 +116,8 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
The index of the time step to start the sample from.
n_steps : int
The number of time steps to include in the sample.
-
+ n_timestep_offset : int
+ A number of timesteps to use as offset from the start time of the slice
"""
# selecting the time slice
if self.datastore.is_forecast:
From aee0b1c8be4babb5fdf1eb34c84c15c8e0a66f22 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 23 Sep 2024 11:13:39 +0200
Subject: [PATCH 216/273] Update neural_lam/weather_dataset.py
Co-authored-by: Joel Oskarsson
---
neural_lam/weather_dataset.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index fa545b51..2b24835a 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -108,7 +108,7 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
Parameters
----------
da : xr.DataArray
- The dataarray to sample from. This is expected to have a `time`
+ The dataarray to slice. This is expected to have a `time`
dimension if the datastore is providing analysis only data, and a
`analysis_time` and `elapsed_forecast_duration` dimensions if the
datastore is providing forecast data.
From 8453c2b482533e90e8f75884fc4ada9a4fdb4a6e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 15:09:44 +0200
Subject: [PATCH 217/273] Update neural_lam/models/ar_model.py
Co-authored-by: Joel Oskarsson
---
neural_lam/models/ar_model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index ec1649a4..c1ceff81 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -291,7 +291,7 @@ def validation_step(self, batch, batch_idx):
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_to_log
- if step < len(time_step_loss)
+ if step <= len(time_step_loss)
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
From 7f32557347570002500ec7e6ff65264c013458b5 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 15:23:10 +0200
Subject: [PATCH 218/273] Update neural_lam/weather_dataset.py
Co-authored-by: Joel Oskarsson
---
neural_lam/weather_dataset.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 2b24835a..f0251224 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -492,7 +492,7 @@ def train_dataloader(self):
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
- shuffle=False,
+ shuffle=True,
multiprocessing_context=self.multiprocessing_context,
persistent_workers=True,
)
From 67998b836345336dae4514786c027d3ac1d10347 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 17:37:16 +0200
Subject: [PATCH 219/273] read projection from datastore config extra section
---
README.md | 42 --------------
neural_lam/datastore/mdp.py | 55 ++++++++++++-------
neural_lam/datastore/plot_example.py | 6 +-
pyproject.toml | 2 +-
.../datastore_examples/mdp/danra.example.yaml | 16 ++++--
5 files changed, 51 insertions(+), 70 deletions(-)
diff --git a/README.md b/README.md
index e4a1989b..f23637c9 100644
--- a/README.md
+++ b/README.md
@@ -258,48 +258,6 @@ Except for training and pre-processing scripts all the source code can be found
Model classes, including abstract base classes, are located in `neural_lam/models`.
Notebooks for visualization and analysis are located in `docs`.
-
-## Format of data directory
-It is possible to store multiple datasets in the `data` directory.
-Each dataset contains a set of files with static features and a set of samples.
-The samples are split into different sub-directories for training, validation and testing.
-The directory structure is shown with examples below.
-Script names within parenthesis denote the script used to generate the file.
-```
-data
-├── dataset1
-│ ├── samples - Directory with data samples
-│ │ ├── train - Training data
-│ │ │ ├── nwp_2022040100_mbr000.npy - A time series sample
-│ │ │ ├── nwp_2022040100_mbr001.npy
-│ │ │ ├── ...
-│ │ │ ├── nwp_2022043012_mbr001.npy
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy - Solar flux forcing
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy
-│ │ │ ├── ...
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022043012.npy
-│ │ │ ├── wtr_2022040100.npy - Open water features for one sample
-│ │ │ ├── wtr_2022040112.npy
-│ │ │ ├── ...
-│ │ │ └── wtr_202204012.npy
-│ │ ├── val - Validation data
-│ │ └── test - Test data
-│ └── static - Directory with graph information and static features
-│ ├── nwp_xy.npy - Coordinates of grid nodes (part of dataset)
-│ ├── surface_geopotential.npy - Geopotential at surface of grid nodes (part of dataset)
-│ ├── border_mask.npy - Mask with True for grid nodes that are part of border (part of dataset)
-│ ├── grid_features.pt - Static features of grid nodes (neural_lam.create_grid_features)
-│ ├── parameter_mean.pt - Means of state parameters (neural_lam.create_parameter_weights)
-│ ├── parameter_std.pt - Std.-dev. of state parameters (neural_lam.create_parameter_weights)
-│ ├── diff_mean.pt - Means of one-step differences (neural_lam.create_parameter_weights)
-│ ├── diff_std.pt - Std.-dev. of one-step differences (neural_lam.create_parameter_weights)
-│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (neural_lam.create_parameter_weights)
-│ └── parameter_weights.npy - Loss weights for different state parameters (neural_lam.create_parameter_weights)
-├── dataset2
-├── ...
-└── datasetN
-```
-
## Format of graph directory
The `graphs` directory contains generated graph structures that can be used by different graph-based models.
The structure is shown with examples below:
diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py
index 18a8df26..7384396d 100644
--- a/neural_lam/datastore/mdp.py
+++ b/neural_lam/datastore/mdp.py
@@ -333,7 +333,17 @@ def boundary_mask(self) -> xr.DataArray:
@property
def coords_projection(self) -> ccrs.Projection:
- """Return the projection of the coordinates.
+ """
+ Return the projection of the coordinates.
+
+ NOTE: currently this expects the projection information to be in the
+ `extra` section of the configuration file, with a `projection` key
+ containing a `class_name` and `kwargs` for constructing the
+ `cartopy.crs.Projection` object. This is a temporary solution until
+ the projection information can be parsed in the produced dataset
+ itself. `mllam-data-prep` ignores the contents of the `extra` section
+ of the config file which is why we need to check that the necessary
+ parts are there.
Returns
-------
@@ -341,26 +351,33 @@ def coords_projection(self) -> ccrs.Projection:
The projection of the coordinates.
"""
- # XXX: this should move to config
- kwargs = {
- "LoVInDegrees": 25.0,
- "LaDInDegrees": 56.7,
- "Latin1InDegrees": 56.7,
- "Latin2InDegrees": 56.7,
- }
-
- lon_0 = kwargs["LoVInDegrees"] # Latitude of first standard parallel
- lat_0 = kwargs["LaDInDegrees"] # Latitude of second standard parallel
- lat_1 = kwargs["Latin1InDegrees"] # Origin latitude
- lat_2 = kwargs["Latin2InDegrees"] # Origin longitude
+ if "projection" not in self._config.extra:
+ raise ValueError(
+ "projection information not found in the configuration file "
+ f"({self._config_path}). Please add the projection information"
+ "to the `extra` section of the config, by adding a "
+ "`projection` key with the class name and kwargs of the "
+ "projection."
+ )
- crs = ccrs.LambertConformal(
- central_longitude=lon_0,
- central_latitude=lat_0,
- standard_parallels=(lat_1, lat_2),
- )
+ projection_info = self._config.extra["projection"]
+ if "class_name" not in projection_info:
+ raise ValueError(
+ "class_name not found in the projection information. Please "
+ "add the class name of the projection to the `projection` key "
+ "in the `extra` section of the config."
+ )
+ if "kwargs" not in projection_info:
+ raise ValueError(
+ "kwargs not found in the projection information. Please add "
+ "the keyword arguments of the projection to the `projection` "
+ "key in the `extra` section of the config."
+ )
- return crs
+ class_name = projection_info["class_name"]
+ ProjectionClass = getattr(ccrs, class_name)
+ kwargs = projection_info["kwargs"]
+ return ProjectionClass(**kwargs)
@property
def grid_shape_state(self):
diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py
index 53bc6d5e..b68d33af 100644
--- a/neural_lam/datastore/plot_example.py
+++ b/neural_lam/datastore/plot_example.py
@@ -119,10 +119,8 @@ def _parse_dict(arg_str):
nargs="+",
default=[],
type=_parse_dict,
- help=(
- "Selections to apply to the dataarray, for example "
- '`time="1990-09-03T0:00" would select this single timestep',
- ),
+ help="Selections to apply to the dataarray, for example "
+ "`time='1990-09-03T0:00' would select this single timestep",
)
args = parser.parse_args()
diff --git a/pyproject.toml b/pyproject.toml
index fc3fbf9e..15d59be2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,7 +26,7 @@ dependencies = [
"torch-geometric==2.3.1",
"parse>=1.20.2",
"dataclass-wizard>=0.22.3",
- "mllam-data-prep[dask-distributed]>=0.3.0",
+ "mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep/@feat/extra-section-in-config",
]
requires-python = ">=3.9"
diff --git a/tests/datastore_examples/mdp/danra.example.yaml b/tests/datastore_examples/mdp/danra.example.yaml
index 73aa0dfa..0801f832 100644
--- a/tests/datastore_examples/mdp/danra.example.yaml
+++ b/tests/datastore_examples/mdp/danra.example.yaml
@@ -1,4 +1,4 @@
-schema_version: v0.2.0
+schema_version: v0.2.0+dev
dataset_version: v0.1.0
output:
@@ -49,7 +49,7 @@ inputs:
state_feature:
method: stack_variables_by_var_name
dims: [altitude]
- name_format: f"{var_name}{altitude}m"
+ name_format: "{var_name}{altitude}m"
grid_index:
method: stack
dims: [x, y]
@@ -70,7 +70,7 @@ inputs:
dims: [x, y]
forcing_feature:
method: stack_variables_by_var_name
- name_format: f"{var_name}"
+ name_format: "{var_name}"
target_output_variable: forcing
danra_lsm:
@@ -84,5 +84,13 @@ inputs:
dims: [x, y]
static_feature:
method: stack_variables_by_var_name
- name_format: f"{var_name}"
+ name_format: "{var_name}"
target_output_variable: static
+
+extra:
+ projection:
+ class_name: LambertConformal
+ kwargs:
+ central_longitude: 25.0
+ central_latitude: 56.7
+ standard_parallels: [56.7, 56.7]
From ac7e46a1b8202ed4c4a90b7fe46fee632be30218 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 17:53:29 +0200
Subject: [PATCH 220/273] NpyFilesDatastore -> NpyFilesDatastoreMEPS
---
neural_lam/datastore/__init__.py | 4 ++--
neural_lam/datastore/npyfiles/__init__.py | 2 --
neural_lam/datastore/npyfilesmeps/__init__.py | 2 ++
neural_lam/datastore/{npyfiles => npyfilesmeps}/config.py | 0
.../{npyfiles => npyfilesmeps}/create_parameter_weights.py | 2 +-
neural_lam/datastore/{npyfiles => npyfilesmeps}/store.py | 2 +-
neural_lam/weather_dataset.py | 5 +++--
tests/conftest.py | 2 +-
8 files changed, 10 insertions(+), 9 deletions(-)
delete mode 100644 neural_lam/datastore/npyfiles/__init__.py
create mode 100644 neural_lam/datastore/npyfilesmeps/__init__.py
rename neural_lam/datastore/{npyfiles => npyfilesmeps}/config.py (100%)
rename neural_lam/datastore/{npyfiles => npyfilesmeps}/create_parameter_weights.py (99%)
rename neural_lam/datastore/{npyfiles => npyfilesmeps}/store.py (99%)
diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py
index 901841db..8bda69c0 100644
--- a/neural_lam/datastore/__init__.py
+++ b/neural_lam/datastore/__init__.py
@@ -1,10 +1,10 @@
# Local
from .mdp import MDPDatastore # noqa
-from .npyfiles import NpyFilesDatastore # noqa
+from .npyfilesmeps import NpyFilesDatastoreMEPS # noqa
DATASTORES = dict(
mdp=MDPDatastore,
- npyfiles=NpyFilesDatastore,
+ npyfilesmeps=NpyFilesDatastoreMEPS,
)
diff --git a/neural_lam/datastore/npyfiles/__init__.py b/neural_lam/datastore/npyfiles/__init__.py
deleted file mode 100644
index 3bf6fadb..00000000
--- a/neural_lam/datastore/npyfiles/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Local
-from .store import NpyFilesDatastore # noqa
diff --git a/neural_lam/datastore/npyfilesmeps/__init__.py b/neural_lam/datastore/npyfilesmeps/__init__.py
new file mode 100644
index 00000000..397a5075
--- /dev/null
+++ b/neural_lam/datastore/npyfilesmeps/__init__.py
@@ -0,0 +1,2 @@
+# Local
+from .store import NpyFilesDatastoreMEPS # noqa
diff --git a/neural_lam/datastore/npyfiles/config.py b/neural_lam/datastore/npyfilesmeps/config.py
similarity index 100%
rename from neural_lam/datastore/npyfiles/config.py
rename to neural_lam/datastore/npyfilesmeps/config.py
diff --git a/neural_lam/datastore/npyfiles/create_parameter_weights.py b/neural_lam/datastore/npyfilesmeps/create_parameter_weights.py
similarity index 99%
rename from neural_lam/datastore/npyfiles/create_parameter_weights.py
rename to neural_lam/datastore/npyfilesmeps/create_parameter_weights.py
index bdf54011..81baffe5 100644
--- a/neural_lam/datastore/npyfiles/create_parameter_weights.py
+++ b/neural_lam/datastore/npyfilesmeps/create_parameter_weights.py
@@ -163,7 +163,7 @@ def main():
rank = get_rank()
world_size = get_world_size()
datastore = init_datastore(
- datastore_kind="npyfiles", config_path=args.datastore_config
+ datastore_kind="npyfilesmeps", config_path=args.datastore_config
)
if distributed:
diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfilesmeps/store.py
similarity index 99%
rename from neural_lam/datastore/npyfiles/store.py
rename to neural_lam/datastore/npyfilesmeps/store.py
index 03160599..10cb374d 100644
--- a/neural_lam/datastore/npyfiles/store.py
+++ b/neural_lam/datastore/npyfilesmeps/store.py
@@ -37,7 +37,7 @@ def _load_np(fp, add_feature_dim):
return arr
-class NpyFilesDatastore(BaseCartesianDatastore):
+class NpyFilesDatastoreMEPS(BaseCartesianDatastore):
__doc__ = f"""
Represents a dataset stored as numpy files on disk. The dataset is assumed
to be stored in a directory structure where each sample is stored in a
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index f0251224..e7122e9d 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -117,7 +117,8 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0):
n_steps : int
The number of time steps to include in the sample.
n_timestep_offset : int
- A number of timesteps to use as offset from the start time of the slice
+ A number of timesteps to use as offset from the start time of the
+ slice
"""
# selecting the time slice
if self.datastore.is_forecast:
@@ -455,7 +456,7 @@ def __init__(
self.test_dataset = None
if num_workers > 0:
# default to spawn for now, as the default on linux "fork" hangs
- # when using dask (which the npyfiles datastore uses)
+ # when using dask (which the npyfilesmeps datastore uses)
self.multiprocessing_context = "spawn"
else:
self.multiprocessing_context = None
diff --git a/tests/conftest.py b/tests/conftest.py
index a5440275..12854a0e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -60,7 +60,7 @@ def download_meps_example_reduced_dataset():
DATASTORES_EXAMPLES = dict(
mdp=(DATASTORE_EXAMPLES_ROOT_PATH / "mdp" / "danra.example.yaml"),
- npyfiles=download_meps_example_reduced_dataset(),
+ npyfilesmeps=download_meps_example_reduced_dataset(),
)
From b7bf506ae572ec4247d6ef5696c39803409d162e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 17:56:02 +0200
Subject: [PATCH 221/273] revert tp training with 1 AR step by default
---
neural_lam/train_model.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 0913319e..8bad0f0c 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -125,9 +125,9 @@ def main(input_args=None):
parser.add_argument(
"--ar_steps_train",
type=int,
- default=3,
+ default=1,
help="Number of steps to unroll prediction for in loss function "
- "(default: 3)",
+ "(default: 1)",
)
parser.add_argument(
"--control_only",
From 5df2ecf3e217f6513fa15d417ccdc538846b30e5 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 17:58:21 +0200
Subject: [PATCH 222/273] add missing kwarg to BaseHiGraphModel.__init__
---
neural_lam/models/base_hi_graph_model.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index a2ebcc1b..a81df88f 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -12,8 +12,8 @@ class BaseHiGraphModel(BaseGraphModel):
Base class for hierarchical graph models.
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, datastore):
+ super().__init__(args, datastore=datastore)
# Track number of nodes, edges on each level
# Flatten lists for efficient embedding
From d4d438ff372ce6568c62717c4bd857de83f888a5 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 17:59:43 +0200
Subject: [PATCH 223/273] add missing kwarg to HiLAM.__init__
---
neural_lam/models/hi_lam.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index ec05d59c..68ed0e6e 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -13,8 +13,8 @@ class HiLAM(BaseHiGraphModel):
The Hi-LAM model from Oskarsson et al. (2023)
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, datastore):
+ super().__init__(args, datastore=datastore)
# Make down GNNs, both for down edges and same level
self.mesh_down_gnns = nn.ModuleList(
From 1889771efcc04ce304988e54045950fe4cc46e29 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 18:01:00 +0200
Subject: [PATCH 224/273] add missing kwarg to HiLAMParallel
---
neural_lam/models/hi_lam_parallel.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index 80181ec0..5c1872ab 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -16,8 +16,8 @@ class HiLAMParallel(BaseHiGraphModel):
of Hi-LAM.
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, datastore):
+ super().__init__(args, datastore=datastore)
# Processor GNNs
# Create the complete edge_index combining all edges for processing
From 2c3bbde08db82f232fa8f80b9b52f4b53410600a Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 18:09:50 +0200
Subject: [PATCH 225/273] check that for enough forecast steps given ar_steps
---
neural_lam/weather_dataset.py | 21 +++++++++++++++++----
1 file changed, 17 insertions(+), 4 deletions(-)
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index e7122e9d..da1c7a30 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -76,7 +76,11 @@ def __init__(
def __len__(self):
if self.datastore.is_forecast:
# for now we simply create a single sample for each analysis time
- # and then the next ar_steps forecast times
+ # and then take the first (2 + ar_steps) forecast times. In
+ # addition we only use the first ensemble member (if ensemble data
+ # has been provided).
+ # This means that for each analysis time we get a single sample
+
if self.datastore.is_ensemble:
warnings.warn(
"only using first ensemble member, so dataset size is "
@@ -84,9 +88,18 @@ def __len__(self):
f"({self.da_state.ensemble_member.size})",
UserWarning,
)
- # XXX: we should maybe check that the 2+ar_steps actually fits in
- # the elapsed_forecast_duration dimension, should that be checked
- # here?
+
+ # check that there are enough forecast steps available to create
+ # samples given the number of autoregressive steps requested
+ n_forecast_steps = self.da_state.elapsed_forecast_duration.size
+ if n_forecast_steps < 2 + self.ar_steps:
+ raise ValueError(
+ "The number of forecast steps available "
+ f"({n_forecast_steps}) is less than the required "
+ f"2+ar_steps (2+{self.ar_steps}={2 + self.ar_steps}) for "
+ "creating a sample with initial and target states."
+ )
+
return self.da_state.analysis_time.size
else:
# sample_len = 2 + ar_steps
From f0a151b6a4d78738f2c50d28a83533521fa29931 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Fri, 27 Sep 2024 18:11:51 +0200
Subject: [PATCH 226/273] remove numpy<2.0.0 version cap
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 15d59be2..da6664cf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ authors = [
# PEP 621 project metadata
# See https://www.python.org/dev/peps/pep-0621/
dependencies = [
- "numpy<2.0.0,>=1.24.2",
+ "numpy>=1.24.2",
"wandb>=0.13.10",
"scipy>=1.10.0",
"pytorch-lightning>=2.0.3",
From f3566b043e6ce0b61776716fac30743d90ddb915 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 08:39:46 +0000
Subject: [PATCH 227/273] tweak print statement working in mdp
---
neural_lam/datastore/mdp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py
index 7384396d..e6f7888d 100644
--- a/neural_lam/datastore/mdp.py
+++ b/neural_lam/datastore/mdp.py
@@ -71,7 +71,7 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points
- print("Training with the following features:")
+ print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
if len(self.get_vars_names(category)) > 0:
print(f"{category}: {' '.join(self.get_vars_names(category))}")
From dba94b33eba80f64088a8660c275898d1b80a529 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 08:50:27 +0000
Subject: [PATCH 228/273] fix missed removed argument from cli
---
neural_lam/train_model.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 8bad0f0c..d2da1dae 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -253,9 +253,7 @@ def main(input_args=None):
# Load model parameters Use new args for model
ModelClass = MODELS[args.model]
- model = ModelClass(
- args, datastore=datastore, forcing_window_size=args.forcing_window_size
- )
+ model = ModelClass(args, datastore=datastore)
if args.eval:
prefix = f"eval-{args.eval}-"
From bca1482e2a69057fe4f809e3d0e732569d8f5154 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 14:19:12 +0000
Subject: [PATCH 229/273] remove wandb config log comment, we log now
---
neural_lam/train_model.py | 2 --
1 file changed, 2 deletions(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index d2da1dae..d24ab611 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -292,8 +292,6 @@ def main(input_args=None):
utils.init_wandb_metrics(
logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
- # TODO: should we save the datastore config here?
- # wandb.save()
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
From fc973c4c7bbef013bc2c6f19a16544a2a193a0aa Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 14:19:26 +0000
Subject: [PATCH 230/273] ensure loading from checkpoint during train possible
---
neural_lam/train_model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index d24ab611..a1918994 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -295,7 +295,7 @@ def main(input_args=None):
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
- trainer.fit(model=model, datamodule=data_module)
+ trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
if __name__ == "__main__":
From 9fcf06e5811d5b3c934a1a93f56a626cdb957555 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 14:26:41 +0000
Subject: [PATCH 231/273] get step_length from datastore in plot_error_map
---
neural_lam/models/ar_model.py | 1 -
neural_lam/vis.py | 5 ++---
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index c1ceff81..a171c1ba 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -511,7 +511,6 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
metric_fig = vis.plot_error_map(
errors=metric_tensor,
datastore=self._datastore,
- step_length=self.step_length,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 542b6ab7..bd991399 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -9,9 +9,7 @@
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_error_map(
- errors, datastore: BaseCartesianDatastore, title=None, step_length=1
-):
+def plot_error_map(errors, datastore: BaseCartesianDatastore, title=None):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
@@ -19,6 +17,7 @@ def plot_error_map(
"""
errors_np = errors.T.cpu().numpy() # (d_f, pred_steps)
d_f, pred_steps = errors_np.shape
+ step_length = datastore.step_length
# Normalize all errors to [0,1] for color map
max_errors = errors_np.max(axis=1) # d_f
From 2bbe666fab3df2be1471b7eaebffa245cef48716 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 14:28:35 +0000
Subject: [PATCH 232/273] remove step_legnth attr in ARModel
---
neural_lam/models/ar_model.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index a171c1ba..d8d7e99d 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -118,8 +118,6 @@ def __init__(
"interior_mask", 1.0 - self.boundary_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border
- # Number of hours per pred. step
- self.step_length = datastore.step_length
self.val_metrics = {
"mse": [],
}
@@ -457,7 +455,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
obs_mask=self.interior_mask[:, 0],
datastore=self.datastore,
title=f"{var_name} ({var_unit}), "
- f"t={t_i} ({self.step_length * t_i} h)",
+ f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
From b41ed2f98b551f26fe0d7f8c4e1fb00181130b7d Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 14:30:15 +0000
Subject: [PATCH 233/273] remove unused obs_mask arg for vis.plot_prediction
---
neural_lam/models/ar_model.py | 1 -
neural_lam/vis.py | 1 -
2 files changed, 2 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index d8d7e99d..b08797e5 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -452,7 +452,6 @@ def plot_examples(self, batch, n_examples, prediction=None):
vis.plot_prediction(
pred=pred_t[:, var_i],
target=target_t[:, var_i],
- obs_mask=self.interior_mask[:, 0],
datastore=self.datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self._datastore.step_length * t_i} h)",
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index bd991399..4b5fd6a1 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -67,7 +67,6 @@ def plot_error_map(errors, datastore: BaseCartesianDatastore, title=None):
def plot_prediction(
pred,
target,
- obs_mask,
datastore: BaseCartesianDatastore,
title=None,
vrange=None,
From 7e46194b91bd2b7c00850bbb925b59e1fa98d091 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 1 Oct 2024 14:48:48 +0000
Subject: [PATCH 234/273] ensure no reference to multizarr "data_config"
---
neural_lam/models/ar_model.py | 21 ++++++++-------------
neural_lam/vis.py | 15 ++++++++-------
2 files changed, 16 insertions(+), 20 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index b08797e5..f879f618 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -525,17 +525,13 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
)
# Check if metrics are watched, log exact values for specific vars
+ var_names = self._datastore.get_vars_names(category="state")
if full_log_name in self.args.metrics_watch:
for var_i, timesteps in self.args.var_leads_metrics_watch.items():
- var = self.data_config.vars_names("state")[var_i]
- log_dict.update(
- {
- f"{full_log_name}_{var}_step_{step}": metric_tensor[
- step - 1, var_i
- ] # 1-indexed in data_config
- for step in timesteps
- }
- )
+ var_name = var_names[var_i]
+ for step in timesteps:
+ key = f"{full_log_name}_{var_name}_step_{step}"
+ log_dict[key] = metric_tensor[step - 1, var_i]
return log_dict
@@ -594,9 +590,8 @@ def on_test_epoch_end(self):
loss_map_figs = [
vis.plot_spatial_error(
- loss_map,
- self.interior_mask[:, 0],
- self.data_config,
+ error=loss_map,
+ datastore=self._datastore,
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
for t_i, loss_map in zip(
@@ -611,7 +606,7 @@ def on_test_epoch_end(self):
# also make without title and save as pdf
pdf_loss_map_figs = [
vis.plot_spatial_error(
- loss_map, self.interior_mask[:, 0], self.data_config
+ error=loss_map, datastore=self._datastore
)
for loss_map in mean_spatial_loss
]
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 4b5fd6a1..9653c3fc 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -131,7 +131,9 @@ def plot_prediction(
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
+def plot_spatial_error(
+ error, datastore: BaseCartesianDatastore, title=None, vrange=None
+):
"""
Plot errors over spatial map
Error and obs_mask has shape (N_grid,)
@@ -143,24 +145,23 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
else:
vmin, vmax = vrange
- extent = data_config.get_xy_extent("state")
+ extent = datastore.get_xy_extent("state")
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(
- list(data_config.grid_shape_state.values.values())
- )
+ da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
+ mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
fig, ax = plt.subplots(
figsize=(5, 4.8),
- subplot_kw={"projection": data_config.coords_projection},
+ subplot_kw={"projection": datastore.coords_projection},
)
ax.coastlines() # Add coastline outlines
error_grid = (
- error.reshape(list(data_config.grid_shape_state.values.values()))
+ error.reshape(list(datastore.grid_shape_state.values.values()))
.cpu()
.numpy()
)
From b57bc7ac0c10d9467f04f2500a078543c9b310a1 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 2 Oct 2024 09:54:39 +0200
Subject: [PATCH 235/273] introduce neural-lam config
---
neural_lam/config.py | 123 +++++++++++++++++++++++
neural_lam/create_graph.py | 16 +--
neural_lam/train_model.py | 25 ++---
pyproject.toml | 1 +
tests/datastore_examples/mdp/config.yaml | 7 ++
5 files changed, 147 insertions(+), 25 deletions(-)
create mode 100644 neural_lam/config.py
create mode 100644 tests/datastore_examples/mdp/config.yaml
diff --git a/neural_lam/config.py b/neural_lam/config.py
new file mode 100644
index 00000000..7524e3bd
--- /dev/null
+++ b/neural_lam/config.py
@@ -0,0 +1,123 @@
+# Standard library
+import dataclasses
+from pathlib import Path
+from typing import Dict, Union
+
+# Third-party
+import dataclass_wizard
+
+# Local
+from .datastore import (
+ DATASTORES,
+ MDPDatastore,
+ NpyFilesDatastoreMEPS,
+ init_datastore,
+)
+
+
+class DatastoreKindStr(str):
+ VALID_KINDS = DATASTORES.keys()
+
+ def __new__(cls, value):
+ if value not in cls.VALID_KINDS:
+ raise ValueError(f"Invalid datastore kind: {value}")
+ return super().__new__(cls, value)
+
+
+@dataclasses.dataclass
+class DatastoreSelection:
+ """
+ Configuration for selecting a datastore to use with neural-lam.
+
+ Attributes
+ ----------
+ kind : DatastoreKindStr
+ The kind of datastore to use, currently `mdp` or `npyfilesmeps` are
+ implemented.
+ config_path : str
+ The path to the configuration file for the selected datastore, this is
+ assumed to be relative to the configuration file for neural-lam.
+ """
+
+ kind: DatastoreKindStr
+ config_path: str
+
+
+@dataclasses.dataclass
+class TrainingConfig:
+ """
+ Configuration related to training neural-lam
+
+ Attributes
+ ----------
+ state_feature_weights : Dict[str, float]
+ The weights for each state feature in the datastore to use in the loss
+ function during training.
+ """
+
+ state_feature_weights: Dict[str, float]
+
+
+@dataclasses.dataclass
+class NeuralLAMConfig(dataclass_wizard.YAMLWizard):
+ """
+ Dataclass for Neural-LAM configuration. This class is used to load and
+ store the configuration for using Neural-LAM.
+
+ Attributes
+ ----------
+ datastore : DatastoreSelection
+ The configuration for the datastore to use.
+ training : TrainingConfig
+ The configuration for training the model.
+ """
+
+ datastore: DatastoreSelection
+ training: TrainingConfig
+
+
+def load_config_and_datastore(
+ config_path: str,
+) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]:
+ """
+ Load the neural-lam configuration and the datastore specified in the
+ configuration.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the Neural-LAM configuration file.
+
+ Returns
+ -------
+ tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]
+ The Neural-LAM configuration and the loaded datastore.
+ """
+ config = NeuralLAMConfig.from_yaml_file(config_path)
+ # datastore config is assumed to be relative to the config file
+ datastore_config_path = (
+ Path(config_path).parent / config.datastore.config_path
+ )
+ datastore = init_datastore(
+ datastore_kind=config.datastore.kind, config_path=datastore_config_path
+ )
+
+ # TODO: This check should maybe be moved somewhere else, but I'm not sure
+ # where right now... check that the config state feature weights include a
+ # weight for each state feature
+ state_feature_names = datastore.get_vars_names(category="state")
+ named_feature_weights = config.training.state_feature_weights.keys()
+
+ if set(named_feature_weights) != set(state_feature_names):
+ additional_features = set(named_feature_weights) - set(
+ state_feature_names
+ )
+ missing_features = set(state_feature_names) - set(named_feature_weights)
+ raise ValueError(
+ f"State feature weights must be provided for each state feature in "
+ f"the datastore ({state_feature_names}). {missing_features} are "
+ "missing and weights are defined for the features "
+ f"{additional_features} which are not in the datastore."
+ )
+
+ return config, datastore
diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py
index 0b267f67..4f656d05 100644
--- a/neural_lam/create_graph.py
+++ b/neural_lam/create_graph.py
@@ -13,7 +13,7 @@
from torch_geometric.utils.convert import from_networkx
# Local
-from .datastore import DATASTORES
+from .config import load_config_and_datastore
from .datastore.base import BaseCartesianDatastore
@@ -551,15 +551,9 @@ def create_graph_from_datastore(
def cli(input_args=None):
parser = ArgumentParser(description="Graph generation arguments")
parser.add_argument(
- "datastore",
+ "--config",
type=str,
- choices=DATASTORES.keys(),
- help="kind of data store to use",
- )
- parser.add_argument(
- "datastore_config_path",
- type=str,
- help="path to the data store config",
+ default="tests/datastore_examples/mdp/config.yaml",
)
parser.add_argument(
"--name",
@@ -586,8 +580,8 @@ def cli(input_args=None):
)
args = parser.parse_args(input_args)
- DatastoreClass = DATASTORES[args.datastore]
- datastore = DatastoreClass(config_path=args.datastore_config_path)
+ # Load neural-lam configuration and datastore to use
+ _, datastore = load_config_and_datastore(config_path=args.config)
create_graph_from_datastore(
datastore=datastore,
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index a1918994..e2700bc0 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -11,7 +11,7 @@
# Local
from . import utils
-from .datastore import DATASTORES, init_datastore
+from .config import load_config_and_datastore
from .models import GraphLAM, HiLAM, HiLAMParallel
from .weather_dataset import WeatherDataModule
@@ -28,15 +28,9 @@ def main(input_args=None):
description="Train or evaluate NeurWP models for LAM"
)
parser.add_argument(
- "datastore_kind",
+ "--config",
type=str,
- choices=DATASTORES.keys(),
- help="Kind of datastore to use",
- )
- parser.add_argument(
- "datastore_config_path",
- type=str,
- help="Path for the datastore config",
+ default="tests/datastore_examples/mdp/config.yaml",
)
parser.add_argument(
"--model",
@@ -226,11 +220,14 @@ def main(input_args=None):
# Set seed
seed.seed_everything(args.seed)
- # Create datastore
- datastore = init_datastore(
- datastore_kind=args.datastore_kind,
- config_path=args.datastore_config_path,
- )
+ # Load neural-lam configuration and datastore to use
+ config, datastore = load_config_and_datastore(config_path=args.config)
+ # TODO: config.training.state_feature_weights need passing in somewhere,
+ # probably to ARModel, so that it can be used in the loss function
+ assert (
+ config.training.state_feature_weights
+ ), "No state feature weights found in config"
+
# Create datamodule
data_module = WeatherDataModule(
datastore=datastore,
diff --git a/pyproject.toml b/pyproject.toml
index da6664cf..349e459d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -68,6 +68,7 @@ known_first_party = [
# Add first-party modules that may be misclassified by isort
"neural_lam",
]
+line_length = 80
[tool.flake8]
max-line-length = 80
diff --git a/tests/datastore_examples/mdp/config.yaml b/tests/datastore_examples/mdp/config.yaml
new file mode 100644
index 00000000..44a87ca4
--- /dev/null
+++ b/tests/datastore_examples/mdp/config.yaml
@@ -0,0 +1,7 @@
+datastore:
+ kind: mdp
+ config_path: danra.example.yaml
+training:
+ state_feature_weights:
+ u100m: 1.0
+ v100m: 1.0
From 2b30715d4794cb0cdefe44790f175d8fe3ed9dc1 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 2 Oct 2024 09:57:01 +0200
Subject: [PATCH 236/273] include meps neural-lam config example
---
tests/datastore_examples/.gitignore | 3 ++-
tests/datastore_examples/npy/config_meps.yaml | 13 +++++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
create mode 100644 tests/datastore_examples/npy/config_meps.yaml
diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore
index 2d0a57fd..82c481f7 100644
--- a/tests/datastore_examples/.gitignore
+++ b/tests/datastore_examples/.gitignore
@@ -1 +1,2 @@
-npy/
+npy/*.zip
+npy/meps_example_reduced/
diff --git a/tests/datastore_examples/npy/config_meps.yaml b/tests/datastore_examples/npy/config_meps.yaml
new file mode 100644
index 00000000..ec8336b0
--- /dev/null
+++ b/tests/datastore_examples/npy/config_meps.yaml
@@ -0,0 +1,13 @@
+datastore:
+ kind: npyfilesmeps
+ config_path: meps_example_reduced/data_config.yaml
+training:
+ state_feature_weights:
+ nlwrs_0: 1.0
+ nswrs_0: 1.0
+ pres_0g: 1.0
+ pres_0s: 1.0
+ r_2: 1.0
+ r_65: 1.0
+ t_2: 1.0
+ t_65: 1.0
From 8e7b2e6284609fb4d79eece6b816d7b553b1d94e Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 2 Oct 2024 10:20:00 +0200
Subject: [PATCH 237/273] fix extra space typo in BaseDatastore
---
neural_lam/datastore/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 47cac66b..7a919b57 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -159,7 +159,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
the category. For `category=="state"`, the dataarray should also
contain a `state_diff_mean` and `state_diff_std` variable for the one-
step differences of the state variables. The returned dataarray should
- at least have dimensions of `({categ ory}_feature)`, but can also
+ at least have dimensions of `({category}_feature)`, but can also
include for example `grid_index` (if the standardization is done per
grid point for example).
From e0300fb456ff5e6818bf54603da031c49a96d48b Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 2 Oct 2024 14:39:55 +0200
Subject: [PATCH 238/273] add check and print of train/test/val split in
MDPDatastore
---
neural_lam/datastore/mdp.py | 19 ++++++++++++++++++-
1 file changed, 18 insertions(+), 1 deletion(-)
diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py
index e6f7888d..f92a175d 100644
--- a/neural_lam/datastore/mdp.py
+++ b/neural_lam/datastore/mdp.py
@@ -74,7 +74,24 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
if len(self.get_vars_names(category)) > 0:
- print(f"{category}: {' '.join(self.get_vars_names(category))}")
+ var_names = self.get_vars_names(category)
+ print(f" {category:<8s}: {' '.join(var_names)}")
+
+ # check that all three train/val/test splits are available
+ required_splits = ["train", "val", "test"]
+ available_splits = list(self._ds.splits.split_name.values)
+ if not all(split in available_splits for split in required_splits):
+ raise ValueError(
+ f"Missing required splits: {required_splits} in available "
+ f"splits: {available_splits}"
+ )
+
+ print("With the following splits (over time):")
+ for split in required_splits:
+ da_split = self._ds.splits.sel(split_name=split)
+ da_split_start = da_split.sel(split_part="start").load().item()
+ da_split_end = da_split.sel(split_part="end").load().item()
+ print(f" {split:<8s}: {da_split_start} to {da_split_end}")
# find out the dimension order for the stacking to grid-index
dim_order = None
From a921e353078e406cfc6c93179526311d3defa7ab Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Wed, 2 Oct 2024 16:39:14 +0200
Subject: [PATCH 239/273] add experimental mlflow server support
tracking metrics is disabled currently because neural-lam previously used a Logger.define_metrics method which isn't available with the mlflow logger in pytorch-lightning as far as I'm aware
---
neural_lam/config.py | 2 ++
neural_lam/train_model.py | 42 +++++++++++++++++++-----
neural_lam/utils.py | 15 ++++++---
pyproject.toml | 1 +
tests/datastore_examples/mdp/config.yaml | 2 ++
5 files changed, 49 insertions(+), 13 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 7524e3bd..5880da4f 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -56,6 +56,8 @@ class TrainingConfig:
"""
state_feature_weights: Dict[str, float]
+ logger: str = "wandb"
+ logger_url: str = None
@dataclasses.dataclass
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index e2700bc0..6f97f5e8 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -8,6 +8,7 @@
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities import seed
+from loguru import logger
# Local
from . import utils
@@ -22,6 +23,31 @@
}
+def _setup_training_logger(config, datastore, args, run_name):
+ if config.training.logger == "wandb":
+ logger = pl.loggers.WandbLogger(
+ project=args.wandb_project,
+ name=run_name,
+ config=dict(training=vars(args), datastore=datastore._config),
+ )
+ elif config.training.logger == "mlflow":
+ url = config.training.logger_url
+ if url is None:
+ raise ValueError(
+ "MLFlow logger requires a URL to the MLFlow server"
+ )
+ logger = pl.loggers.MLFlowLogger(
+ experiment_name=args.wandb_project,
+ tracking_uri=url,
+ )
+ logger.log_hyperparams(
+ dict(training=vars(args), datastore=datastore._config)
+ )
+
+ return logger
+
+
+@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
parser = ArgumentParser(
@@ -260,6 +286,11 @@ def main(input_args=None):
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}"
)
+
+ training_logger = _setup_training_logger(
+ config=config, datastore=datastore, args=args, run_name=run_name
+ )
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=f"saved_models/{run_name}",
filename="min_val_loss",
@@ -267,17 +298,12 @@ def main(input_args=None):
mode="min",
save_last=True,
)
- logger = pl.loggers.WandbLogger(
- project=args.wandb_project,
- name=run_name,
- config=dict(training=vars(args), datastore=datastore._config),
- )
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
accelerator=device_name,
- logger=logger,
+ logger=training_logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
check_val_every_n_epoch=args.val_interval,
@@ -286,8 +312,8 @@ def main(input_args=None):
# Only init once, on rank 0 only
if trainer.global_rank == 0:
- utils.init_wandb_metrics(
- logger, val_steps=args.val_steps_to_log
+ utils.init_training_logger_metrics(
+ training_logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 0b4c39a4..9d5ecf24 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -1,9 +1,11 @@
# Standard library
import os
import shutil
+import warnings
# Third-party
import torch
+from pytorch_lightning.loggers import WandbLogger
from torch import nn
from tueplots import bundles, figsizes
@@ -229,11 +231,14 @@ def fractional_plot_bundle(fraction):
return bundle
-def init_wandb_metrics(wandb_logger, val_steps):
+def init_training_logger_metrics(training_logger, val_steps):
"""
Set up wandb metrics to track
"""
- experiment = wandb_logger.experiment
- experiment.define_metric("val_mean_loss", summary="min")
- for step in val_steps:
- experiment.define_metric(f"val_loss_unroll{step}", summary="min")
+ experiment = training_logger.experiment
+ if isinstance(training_logger, WandbLogger):
+ experiment.define_metric("val_mean_loss", summary="min")
+ for step in val_steps:
+ experiment.define_metric(f"val_loss_unroll{step}", summary="min")
+ else:
+ warnings.warn("Only WandbLogger is supported for tracking metrics")
diff --git a/pyproject.toml b/pyproject.toml
index 349e459d..b723e322 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,6 +27,7 @@ dependencies = [
"parse>=1.20.2",
"dataclass-wizard>=0.22.3",
"mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep/@feat/extra-section-in-config",
+ "mlflow>=2.16.2",
]
requires-python = ">=3.9"
diff --git a/tests/datastore_examples/mdp/config.yaml b/tests/datastore_examples/mdp/config.yaml
index 44a87ca4..bcaf589a 100644
--- a/tests/datastore_examples/mdp/config.yaml
+++ b/tests/datastore_examples/mdp/config.yaml
@@ -2,6 +2,8 @@ datastore:
kind: mdp
config_path: danra.example.yaml
training:
+ logger: mlflow
+ logger_url: https://mlflow.dmidev.org
state_feature_weights:
u100m: 1.0
v100m: 1.0
From 0f302596c0e1d3183564e4369fc90b188db31b82 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Thu, 3 Oct 2024 05:21:50 +0000
Subject: [PATCH 240/273] more fixes for mlflow logging support
---
neural_lam/models/ar_model.py | 29 ++++++++++++++++-------------
1 file changed, 16 insertions(+), 13 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index f879f618..54d01e20 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -6,7 +6,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
-import wandb
# Local
from .. import metrics, vis
@@ -467,9 +466,9 @@ def plot_examples(self, batch, n_examples, prediction=None):
]
example_i = self.plotted_examples
- wandb.log(
+ self.logger.log_image(
{
- f"{var_name}_example_{example_i}": wandb.Image(fig)
+ f"{var_name}_example_{example_i}": fig
for var_name, fig in zip(
self._datastore.get_vars_names("state"), var_figs
)
@@ -483,13 +482,15 @@ def plot_examples(self, batch, n_examples, prediction=None):
torch.save(
pred_slice.cpu(),
os.path.join(
- wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"
+ self.logger.save_dir,
+ f"example_pred_{self.plotted_examples}.pt",
),
)
torch.save(
target_slice.cpu(),
os.path.join(
- wandb.run.dir, f"example_target_{self.plotted_examples}.pt"
+ self.logger.save_dir,
+ f"example_target_{self.plotted_examples}.pt",
),
)
@@ -510,16 +511,16 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
datastore=self._datastore,
)
full_log_name = f"{prefix}_{metric_name}"
- log_dict[full_log_name] = wandb.Image(metric_fig)
+ log_dict[full_log_name] = metric_fig
if prefix == "test":
# Save pdf
metric_fig.savefig(
- os.path.join(wandb.run.dir, f"{full_log_name}.pdf")
+ os.path.join(self.logger.save_dir, f"{full_log_name}.pdf")
)
# Save errors also as csv
np.savetxt(
- os.path.join(wandb.run.dir, f"{full_log_name}.csv"),
+ os.path.join(self.logger.save_dir, f"{full_log_name}.csv"),
metric_tensor.cpu().numpy(),
delimiter=",",
)
@@ -568,7 +569,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
)
if self.trainer.is_global_zero and not self.trainer.sanity_checking:
- wandb.log(log_dict) # Log all
+ self.logger.log_image(log_dict) # Log all
plt.close("all") # Close all figs
def on_test_epoch_end(self):
@@ -599,9 +600,9 @@ def on_test_epoch_end(self):
)
]
- # log all to same wandb key, sequentially
+ # log all to same key, sequentially
for fig in loss_map_figs:
- wandb.log({"test_loss": wandb.Image(fig)})
+ self.logger.log_image({"test_loss": fig})
# also make without title and save as pdf
pdf_loss_map_figs = [
@@ -610,14 +611,16 @@ def on_test_epoch_end(self):
)
for loss_map in mean_spatial_loss
]
- pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
+ pdf_loss_maps_dir = os.path.join(
+ self.logger.save_dir, "spatial_loss_maps"
+ )
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
torch.save(
mean_spatial_loss.cpu(),
- os.path.join(wandb.run.dir, "mean_spatial_loss.pt"),
+ os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"),
)
self.spatial_loss_maps.clear()
From 3fbe2d095c58c085f635b1c0de9201c391c4082e Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Thu, 3 Oct 2024 12:07:26 +0000
Subject: [PATCH 241/273] Make wandb work again with pytorch_lightning.logger
---
neural_lam/models/ar_model.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 54d01e20..45cbd247 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -568,8 +568,16 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
)
)
+ # Ensure that log_dict has structure for logging as dict(str, plt.Figure)
+ assert all(
+ isinstance(key, str) and isinstance(value, plt.Figure)
+ for key, value in log_dict.items()
+ )
+
if self.trainer.is_global_zero and not self.trainer.sanity_checking:
- self.logger.log_image(log_dict) # Log all
+ for key, figure in log_dict.items():
+ self.logger.log_image(key=key, images=[figure])
+
plt.close("all") # Close all figs
def on_test_epoch_end(self):
From e0284a81750918fb2e88de40186d14623012089c Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Fri, 4 Oct 2024 12:07:04 +0000
Subject: [PATCH 242/273] upload of artifact to mlflow works, but instantiates
a new experiment
---
neural_lam/train_model.py | 55 ++++++++++++++++++++++++++++++++++++---
pyproject.toml | 1 +
2 files changed, 52 insertions(+), 4 deletions(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 6f97f5e8..bcf920da 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -22,6 +22,24 @@
"hi_lam_parallel": HiLAMParallel,
}
+class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
+
+ def log_image(self, key, images):
+ import mlflow
+ import io
+ from PIL import Image
+ # Need to save the image to a temporary file, then log that file
+ # mlflow.log_image, should do this automatically, but it doesn't work
+ temporary_image = f"{key}.png"
+ images[0].savefig(temporary_image)
+
+ img = Image.open(temporary_image)
+ print(images)
+ print(images[0])
+ mlflow.log_image(img, f"{key}.png")
+
+ #mlflow.log_figure(images[0], key)
+
def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
@@ -36,17 +54,42 @@ def _setup_training_logger(config, datastore, args, run_name):
raise ValueError(
"MLFlow logger requires a URL to the MLFlow server"
)
- logger = pl.loggers.MLFlowLogger(
+ # logger = pl.loggers.MLFlowLogger(
+ # experiment_name=args.wandb_project,
+ # tracking_uri=url,
+ # )
+ logger = CustomMLFlowLogger(
experiment_name=args.wandb_project,
tracking_uri=url,
)
+ print(logger)
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)
+ print("Logged hyperparams")
+ print(run_name)
+
+ print(logger.__str__)
+ # logger.log_image = log_image
return logger
+# def log_image(key, images):
+# import mlflow
+
+# # Log the image
+# # https://learn.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics?view=azureml-api-2&tabs=interactive#log-images
+# # For mlflow a matplotlib figure should use log_figure instead of log_image
+# # Need to save the image to a temporary file, then log that file
+# # mlflow.log_image, should do this automatically, but it doesn't work
+# temporary_image = f"/tmp/key.png"
+# images[0].savefig(temporary_image)
+
+# mlflow.log_figure(temporary_image, key)
+# mlflow.log_figure(images[0], key)
+
+
@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
@@ -301,7 +344,10 @@ def main(input_args=None):
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
- strategy="ddp",
+ #strategy="ddp",
+ #devices=2,
+ devices=[1, 3],
+ strategy="auto",
accelerator=device_name,
logger=training_logger,
log_every_n_steps=1,
@@ -309,7 +355,7 @@ def main(input_args=None):
check_val_every_n_epoch=args.val_interval,
precision=args.precision,
)
-
+ import ipdb
# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_training_logger_metrics(
@@ -318,7 +364,8 @@ def main(input_args=None):
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
- trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
+ with ipdb.launch_ipdb_on_exception():
+ trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
if __name__ == "__main__":
diff --git a/pyproject.toml b/pyproject.toml
index b723e322..be26adb6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,6 +28,7 @@ dependencies = [
"dataclass-wizard>=0.22.3",
"mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep/@feat/extra-section-in-config",
"mlflow>=2.16.2",
+ "boto3>=1.35.32",
]
requires-python = ">=3.9"
From 7eed79b4ca57856ff6e4eb3935b38a4902770850 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Mon, 7 Oct 2024 06:40:41 +0000
Subject: [PATCH 243/273] make mlflow use same experiment run id as
pl.logger.MLFlowLogger
---
neural_lam/train_model.py | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index bcf920da..fd010f59 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -28,6 +28,12 @@ def log_image(self, key, images):
import mlflow
import io
from PIL import Image
+
+ # Retrieve the active run ID from the logger
+ run_id = self.run_id
+ # Ensure mlflow uses the same run
+ mlflow.start_run(run_id=run_id)
+
# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but it doesn't work
temporary_image = f"{key}.png"
@@ -39,6 +45,7 @@ def log_image(self, key, images):
mlflow.log_image(img, f"{key}.png")
#mlflow.log_figure(images[0], key)
+ mlflow.end_run()
def _setup_training_logger(config, datastore, args, run_name):
@@ -346,7 +353,7 @@ def main(input_args=None):
deterministic=True,
#strategy="ddp",
#devices=2,
- devices=[1, 3],
+ devices=[0, 1],
strategy="auto",
accelerator=device_name,
logger=training_logger,
From 27408f235ea4c70842fb9a707e1d5802a39d993b Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Mon, 7 Oct 2024 10:13:25 +0000
Subject: [PATCH 244/273] logger artifact working for both wandb and mlflow
---
neural_lam/models/ar_model.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 45cbd247..53126792 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -575,7 +575,13 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
)
if self.trainer.is_global_zero and not self.trainer.sanity_checking:
+
+ current_epoch = self.trainer.current_epoch
+
for key, figure in log_dict.items():
+ if not isinstance(self.logger, pl.loggers.WandbLogger):
+ key = f"{key}-{current_epoch}"
+
self.logger.log_image(key=key, images=[figure])
plt.close("all") # Close all figs
From e61a9e79b9b34fb980fc015e02926510d06a6265 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Mon, 7 Oct 2024 10:44:51 +0000
Subject: [PATCH 245/273] support mlflow system metrics logging
---
neural_lam/train_model.py | 38 +++++++-------------------------------
1 file changed, 7 insertions(+), 31 deletions(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index fd010f59..ccd74139 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -10,6 +10,8 @@
from lightning_fabric.utilities import seed
from loguru import logger
+import mlflow
+
# Local
from . import utils
from .config import load_config_and_datastore
@@ -24,29 +26,22 @@
class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
+ def __init__(self, experiment_name, tracking_uri):
+ super().__init__(experiment_name=experiment_name, tracking_uri=tracking_uri)
+ mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
+
+
def log_image(self, key, images):
- import mlflow
- import io
from PIL import Image
- # Retrieve the active run ID from the logger
- run_id = self.run_id
- # Ensure mlflow uses the same run
- mlflow.start_run(run_id=run_id)
-
# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but it doesn't work
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)
img = Image.open(temporary_image)
- print(images)
- print(images[0])
mlflow.log_image(img, f"{key}.png")
- #mlflow.log_figure(images[0], key)
- mlflow.end_run()
-
def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
@@ -69,33 +64,14 @@ def _setup_training_logger(config, datastore, args, run_name):
experiment_name=args.wandb_project,
tracking_uri=url,
)
- print(logger)
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)
print("Logged hyperparams")
- print(run_name)
-
- print(logger.__str__)
- # logger.log_image = log_image
return logger
-# def log_image(key, images):
-# import mlflow
-
-# # Log the image
-# # https://learn.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics?view=azureml-api-2&tabs=interactive#log-images
-# # For mlflow a matplotlib figure should use log_figure instead of log_image
-# # Need to save the image to a temporary file, then log that file
-# # mlflow.log_image, should do this automatically, but it doesn't work
-# temporary_image = f"/tmp/key.png"
-# images[0].savefig(temporary_image)
-
-# mlflow.log_figure(temporary_image, key)
-# mlflow.log_figure(images[0], key)
-
@logger.catch
def main(input_args=None):
From b53bab50ad8e859d5bdeca7911f504e68eeb8554 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Mon, 7 Oct 2024 11:15:14 +0000
Subject: [PATCH 246/273] support model logging for mlflow
---
neural_lam/train_model.py | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index ccd74139..f21ac96e 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -11,6 +11,9 @@
from loguru import logger
import mlflow
+# for logging the model:
+import mlflow.pytorch
+from mlflow.models import infer_signature
# Local
from . import utils
@@ -29,7 +32,8 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
def __init__(self, experiment_name, tracking_uri):
super().__init__(experiment_name=experiment_name, tracking_uri=tracking_uri)
mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
-
+ mlflow.log_param("run_id", self.run_id)
+ #mlflow.pytorch.autolog() # Can be used to log the model, but without signature
def log_image(self, key, images):
from PIL import Image
@@ -42,6 +46,11 @@ def log_image(self, key, images):
img = Image.open(temporary_image)
mlflow.log_image(img, f"{key}.png")
+ def log_model(self, model):
+ # Create model signature
+ #signature = infer_signature(X.numpy(), model(X).detach().numpy())
+ mlflow.pytorch.log_model(model, "model")
+
def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
@@ -350,6 +359,9 @@ def main(input_args=None):
with ipdb.launch_ipdb_on_exception():
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
+ # Log the model
+ training_logger.log_model(model)
+
if __name__ == "__main__":
main()
From de27e9a9676dbf3115ed7e2691493c73aa265fc6 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Mon, 7 Oct 2024 16:59:23 +0000
Subject: [PATCH 247/273] log model
---
neural_lam/train_model.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index f21ac96e..792a00f6 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -48,7 +48,7 @@ def log_image(self, key, images):
def log_model(self, model):
# Create model signature
- #signature = infer_signature(X.numpy(), model(X).detach().numpy())
+ #signature = infer_signature(train_dataset.numpy(), model(train_dataset).detach().numpy())
mlflow.pytorch.log_model(model, "model")
@@ -361,7 +361,7 @@ def main(input_args=None):
# Log the model
training_logger.log_model(model)
-
+ # data_module.train_dataloader().dataset.data
if __name__ == "__main__":
main()
From 89d8cde0701e44552ad662d5d5bb78fe943d321d Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Wed, 13 Nov 2024 11:13:34 +0000
Subject: [PATCH 248/273] test system metrics
---
neural_lam/models/ar_model.py | 1 +
neural_lam/train_model.py | 8 ++++++--
2 files changed, 7 insertions(+), 2 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 53126792..416f72ef 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -579,6 +579,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
current_epoch = self.trainer.current_epoch
for key, figure in log_dict.items():
+ # For other loggers than wandb, add epoch to key
if not isinstance(self.logger, pl.loggers.WandbLogger):
key = f"{key}-{current_epoch}"
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 792a00f6..e232b90c 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -326,6 +326,7 @@ def main(input_args=None):
config=config, datastore=datastore, args=args, run_name=run_name
)
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=f"saved_models/{run_name}",
filename="min_val_loss",
@@ -338,7 +339,7 @@ def main(input_args=None):
deterministic=True,
#strategy="ddp",
#devices=2,
- devices=[0, 1],
+ devices=[0, 1, 2],
strategy="auto",
accelerator=device_name,
logger=training_logger,
@@ -359,9 +360,12 @@ def main(input_args=None):
with ipdb.launch_ipdb_on_exception():
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
+ # Get a sample of training data to log
+ #sample_data = data_module.train_dataset
+ #print("Logging sample data")
+ #print(sample_data.train_dataset)
# Log the model
training_logger.log_model(model)
- # data_module.train_dataloader().dataset.data
if __name__ == "__main__":
main()
From 54c7ca71db819b6356bb5b06af188b9589789f35 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Fri, 15 Nov 2024 08:30:19 +0000
Subject: [PATCH 249/273] make mlflow work also for eval mode
---
neural_lam/models/ar_model.py | 37 +++++++++++++++++++++++------------
neural_lam/train_model.py | 23 +++++++++++++++++-----
neural_lam/vis.py | 10 ++++++----
3 files changed, 49 insertions(+), 21 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 416f72ef..218939a1 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -451,7 +451,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
vis.plot_prediction(
pred=pred_t[:, var_i],
target=target_t[:, var_i],
- datastore=self.datastore,
+ datastore=self._datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
@@ -466,14 +466,23 @@ def plot_examples(self, batch, n_examples, prediction=None):
]
example_i = self.plotted_examples
- self.logger.log_image(
- {
- f"{var_name}_example_{example_i}": fig
- for var_name, fig in zip(
- self._datastore.get_vars_names("state"), var_figs
- )
- }
- )
+
+
+
+ for var_name, fig in zip(
+ self._datastore.get_vars_names("state"), var_figs
+ ):
+ key=f"{var_name}_example_{example_i}"
+ self.logger.log_image(key=key, images=[fig])
+
+ # self.logger.log_image(
+ # {
+ # f"{var_name}_example_{example_i}": fig
+ # for var_name, fig in zip(
+ # self._datastore.get_vars_names("state"), var_figs
+ # )
+ # }
+ # )
plt.close(
"all"
) # Close all figs for this time step, saves memory
@@ -608,7 +617,8 @@ def on_test_epoch_end(self):
vis.plot_spatial_error(
error=loss_map,
datastore=self._datastore,
- title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
+ title=f"Test loss, t={t_i} "
+ f"({self._datastore.step_length * t_i} h)",
)
for t_i, loss_map in zip(
self.args.val_steps_to_log, mean_spatial_loss
@@ -616,8 +626,11 @@ def on_test_epoch_end(self):
]
# log all to same key, sequentially
- for fig in loss_map_figs:
- self.logger.log_image({"test_loss": fig})
+ for i, fig in enumerate(loss_map_figs):
+ key=f"test_loss"
+ if not isinstance(self.logger, pl.loggers.WandbLogger):
+ key=f"{key}_{i}"
+ self.logger.log_image(key=key, images=[fig])
# also make without title and save as pdf
pdf_loss_map_figs = [
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index e232b90c..ec599c68 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -34,6 +34,12 @@ def __init__(self, experiment_name, tracking_uri):
mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.log_param("run_id", self.run_id)
#mlflow.pytorch.autolog() # Can be used to log the model, but without signature
+ #mlflow.save_dir = "mlruns"
+ #self.save_dir = "mlruns"
+
+ @property
+ def save_dir(self):
+ return "mlruns"
def log_image(self, key, images):
from PIL import Image
@@ -52,6 +58,7 @@ def log_model(self, model):
mlflow.pytorch.log_model(model, "model")
+
def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
logger = pl.loggers.WandbLogger(
@@ -76,7 +83,6 @@ def _setup_training_logger(config, datastore, args, run_name):
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)
- print("Logged hyperparams")
return logger
@@ -337,10 +343,13 @@ def main(input_args=None):
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
- #strategy="ddp",
+ strategy="ddp",
#devices=2,
- devices=[0, 1, 2],
- strategy="auto",
+ #devices=[1,2],
+ #devices=[0, 1, 2],
+ #strategy="auto",
+ devices=1,
+ num_nodes=1,
accelerator=device_name,
logger=training_logger,
log_every_n_steps=1,
@@ -355,7 +364,11 @@ def main(input_args=None):
training_logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
if args.eval:
- trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
+ trainer.test(
+ model=model,
+ datamodule=data_module,
+ ckpt_path=args.load,
+ )
else:
with ipdb.launch_ipdb_on_exception():
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 9653c3fc..353f50e0 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -90,7 +90,8 @@ def plot_prediction(
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
- mask_reshaped.clamp(0.7, 1).cpu().numpy()
+ #mask_reshaped.clamp(0.7, 1).cpu().numpy()
+ mask_reshaped.clip(0.7, 1)
) # Faded border region
fig, axes = plt.subplots(
@@ -104,7 +105,7 @@ def plot_prediction(
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
data_grid = (
- data.reshape(list(datastore.grid_shape_state.values.values()))
+ data.reshape(datastore.grid_shape_state.x, datastore.grid_shape_state.y)
.cpu()
.numpy()
)
@@ -151,7 +152,8 @@ def plot_spatial_error(
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
- mask_reshaped.clamp(0.7, 1).cpu().numpy()
+ #mask_reshaped.clamp(0.7, 1).cpu().numpy()
+ mask_reshaped.clip(0.7, 1)
) # Faded border region
fig, ax = plt.subplots(
@@ -161,7 +163,7 @@ def plot_spatial_error(
ax.coastlines() # Add coastline outlines
error_grid = (
- error.reshape(list(datastore.grid_shape_state.values.values()))
+ error.reshape(list([datastore.grid_shape_state.x, datastore.grid_shape_state.y]))
.cpu()
.numpy()
)
From a47de0c9a66ffbf89d607e6181469887d32b9489 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Thu, 21 Nov 2024 09:23:56 +0000
Subject: [PATCH 250/273] dummy prints to identify workflow
---
neural_lam/models/ar_model.py | 16 +++++++++++++---
neural_lam/train_model.py | 6 ++++++
2 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 218939a1..9f39aecd 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -330,6 +330,9 @@ def test_step(self, batch, batch_idx):
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
+ if self.args.save_predictions:
+ print("Saving predictions")
+
time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
@@ -399,6 +402,10 @@ def test_step(self, batch, batch_idx):
batch, n_additional_examples, prediction=prediction
)
+ # Save predictions if requested
+ if self.args.save_predictions:
+ print("Saving predictions")
+
def plot_examples(self, batch, n_examples, prediction=None):
"""
Plot the first n_examples forecasts from batch
@@ -467,12 +474,15 @@ def plot_examples(self, batch, n_examples, prediction=None):
example_i = self.plotted_examples
-
-
for var_name, fig in zip(
self._datastore.get_vars_names("state"), var_figs
):
- key=f"{var_name}_example_{example_i}"
+
+ if not isinstance(self.logger, pl.loggers.WandbLogger):
+ key=f"{var_name}_example_{t_i}"
+ else:
+ key=f"{var_name}_example_{example_i}"
+
self.logger.log_image(key=key, images=[fig])
# self.logger.log_image(
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index ec599c68..2601f292 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -234,6 +234,12 @@ def main(input_args=None):
help="Number of example predictions to plot during evaluation "
"(default: 1)",
)
+ parser.add_argument(
+ "--save_predictions",
+ action="store_true",
+ help="If predictions should be saved to disk as a zarr dataset "
+ "(default: false)",
+ )
# Logger Settings
parser.add_argument(
From 10a4494310270245be828db4a7595bbdb89523fa Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Thu, 21 Nov 2024 09:46:10 +0000
Subject: [PATCH 251/273] update mlflow on eval mode
---
neural_lam/models/ar_model.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 9f39aecd..df6d268f 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -330,9 +330,6 @@ def test_step(self, batch, batch_idx):
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
- if self.args.save_predictions:
- print("Saving predictions")
-
time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
@@ -402,9 +399,6 @@ def test_step(self, batch, batch_idx):
batch, n_additional_examples, prediction=prediction
)
- # Save predictions if requested
- if self.args.save_predictions:
- print("Saving predictions")
def plot_examples(self, batch, n_examples, prediction=None):
"""
From 78e874d60b978e9b513fffa207c53201b8ce514e Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Mon, 25 Nov 2024 13:39:43 +0000
Subject: [PATCH 252/273] inspect plot routines
---
neural_lam/config.py | 3 +++
neural_lam/datastore/mdp.py | 5 ++---
neural_lam/datastore/plot_example.py | 3 ++-
neural_lam/models/ar_model.py | 8 --------
neural_lam/train_model.py | 6 +++---
neural_lam/vis.py | 6 ++----
6 files changed, 12 insertions(+), 19 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 53673774..2a962997 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -103,6 +103,9 @@ class TrainingConfig:
ManualStateFeatureWeighting, UniformFeatureWeighting
] = dataclasses.field(default_factory=UniformFeatureWeighting)
+ logger: str = "wandb"
+ logger_url: str = None
+
@dataclasses.dataclass
class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py
index 10593a82..7cf8159a 100644
--- a/neural_lam/datastore/mdp.py
+++ b/neural_lam/datastore/mdp.py
@@ -394,9 +394,8 @@ def coords_projection(self) -> ccrs.Projection:
class_name = projection_info["class_name"]
ProjectionClass = getattr(ccrs, class_name)
- kwargs = projection_info["kwargs"]
-
- globe_kwargs = kwargs.pop("globe", {})
+ kwargs = projection_info["kwargs"].copy()
+ globe_kwargs = kwargs.pop("globe", {}).copy()
if len(globe_kwargs) > 0:
kwargs["globe"] = ccrs.Globe(**globe_kwargs)
diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py
index 2d477271..a74f7409 100644
--- a/neural_lam/datastore/plot_example.py
+++ b/neural_lam/datastore/plot_example.py
@@ -186,4 +186,5 @@ def _parse_dict(arg_str):
selection=selection,
index_selection=index_selection,
)
- plt.show()
+ #plt.show()
+ plt.savefig('plot_example.png')
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index cbfeba52..d015b04a 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -489,14 +489,6 @@ def plot_examples(self, batch, n_examples, prediction=None):
self.logger.log_image(key=key, images=[fig])
- # self.logger.log_image(
- # {
- # f"{var_name}_example_{example_i}": fig
- # for var_name, fig in zip(
- # self._datastore.get_vars_names("state"), var_figs
- # )
- # }
- # )
plt.close(
"all"
) # Close all figs for this time step, saves memory
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 1da75611..f29f522f 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -349,12 +349,12 @@ def main(input_args=None):
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
- #devices=2,
+ #devices=3,
#devices=[1,2],
#devices=[0, 1, 2],
#strategy="auto",
- devices=1,
- num_nodes=1,
+ devices=1, # For eval mode
+ num_nodes=1, # For eval mode
accelerator=device_name,
logger=training_logger,
log_every_n_steps=1,
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 71e6ddf0..ef8c2722 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -90,7 +90,6 @@ def plot_prediction(
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
- #mask_reshaped.clamp(0.7, 1).cpu().numpy()
mask_reshaped.clip(0.7, 1)
) # Faded border region
@@ -105,7 +104,7 @@ def plot_prediction(
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
data_grid = (
- data.reshape(list(datastore.grid_shape_state.values.values()))
+ data.reshape(datastore.grid_shape_state.x, datastore.grid_shape_state.y)
.cpu()
.numpy()
)
@@ -152,7 +151,6 @@ def plot_spatial_error(
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
- #mask_reshaped.clamp(0.7, 1).cpu().numpy()
mask_reshaped.clip(0.7, 1)
) # Faded border region
@@ -163,7 +161,7 @@ def plot_spatial_error(
ax.coastlines() # Add coastline outlines
error_grid = (
- error.reshape(list(datastore.grid_shape_state.values.values()))
+ error.reshape(list([datastore.grid_shape_state.x, datastore.grid_shape_state.y]))
.cpu()
.numpy()
)
From 5904cbe9da67d3e98eaab0cebd501a2ad0ded7f3 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Mon, 25 Nov 2024 16:42:21 +0100
Subject: [PATCH 253/273] identified issue, cleanup next
---
neural_lam/datastore/base.py | 9 ++++-
neural_lam/datastore/mdp.py | 5 ++-
neural_lam/models/ar_model.py | 46 ++++++++++++++++++++--
neural_lam/train_model.py | 2 +-
neural_lam/vis.py | 73 +++++++++++++++++++++++++----------
5 files changed, 107 insertions(+), 28 deletions(-)
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
index 0317c2e5..b0055e39 100644
--- a/neural_lam/datastore/base.py
+++ b/neural_lam/datastore/base.py
@@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.
"""
- xy = self.get_xy(category, stacked=False)
- extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
+ xy = self.get_xy(category, stacked=True)
+ extent = [
+ xy[:, 0].min(),
+ xy[:, 0].max(),
+ xy[:, 1].min(),
+ xy[:, 1].max(),
+ ]
return [float(v) for v in extent]
@property
diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py
index 10593a82..0d1aac7b 100644
--- a/neural_lam/datastore/mdp.py
+++ b/neural_lam/datastore/mdp.py
@@ -1,4 +1,5 @@
# Standard library
+import copy
import warnings
from functools import cached_property
from pathlib import Path
@@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection:
class_name = projection_info["class_name"]
ProjectionClass = getattr(ccrs, class_name)
- kwargs = projection_info["kwargs"]
+ # need to copy otherwise we modify the dict stored in the dataclass
+ # in-place
+ kwargs = copy.deepcopy(projection_info["kwargs"])
globe_kwargs = kwargs.pop("globe", {})
if len(globe_kwargs) > 0:
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index bc4c6719..b55143f0 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -7,12 +7,14 @@
import pytorch_lightning as pl
import torch
import wandb
+from loguru import logger
# Local
from .. import metrics, vis
from ..config import NeuralLAMConfig
from ..datastore import BaseDatastore
from ..loss_weighting import get_state_feature_weighting
+from ..weather_dataset import WeatherDataset
class ARModel(pl.LightningModule):
@@ -147,6 +149,14 @@ def __init__(
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []
+ def _create_dataarray_from_tensor(self, tensor, time, split, category):
+ weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
+ time = np.array(time, dtype="datetime64[ns]")
+ da = weather_dataset.create_dataarray_from_tensor(
+ tensor=tensor, time=time, category=category
+ )
+ return da
+
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
@@ -406,10 +416,13 @@ def test_step(self, batch, batch_idx):
)
self.plot_examples(
- batch, n_additional_examples, prediction=prediction
+ batch,
+ n_additional_examples,
+ prediction=prediction,
+ split="test",
)
- def plot_examples(self, batch, n_examples, prediction=None):
+ def plot_examples(self, batch, n_examples, split, prediction=None):
"""
Plot the first n_examples forecasts from batch
@@ -422,18 +435,34 @@ def plot_examples(self, batch, n_examples, prediction=None):
prediction, target, _, _ = self.common_step(batch)
target = batch[1]
+ time = batch[3]
# Rescale to original data scale
prediction_rescaled = prediction * self.state_std + self.state_mean
target_rescaled = target * self.state_std + self.state_mean
# Iterate over the examples
- for pred_slice, target_slice in zip(
- prediction_rescaled[:n_examples], target_rescaled[:n_examples]
+ for pred_slice, target_slice, time_slice in zip(
+ prediction_rescaled[:n_examples],
+ target_rescaled[:n_examples],
+ time[:n_examples],
):
# Each slice is (pred_steps, num_grid_nodes, d_f)
self.plotted_examples += 1 # Increment already here
+ da_prediction = self._create_dataarray_from_tensor(
+ tensor=pred_slice,
+ time=time_slice,
+ split=split,
+ category="state",
+ ).unstack("grid_index")
+ da_target = self._create_dataarray_from_tensor(
+ tensor=target_slice,
+ time=time_slice,
+ split=split,
+ category="state",
+ ).unstack("grid_index")
+
var_vmin = (
torch.minimum(
pred_slice.flatten(0, 1).min(dim=0)[0],
@@ -465,6 +494,10 @@ def plot_examples(self, batch, n_examples, prediction=None):
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
+ da_prediction=da_prediction.isel(
+ state_feature=var_i
+ ).squeeze(),
+ da_target=da_target.isel(state_feature=var_i).squeeze(),
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
@@ -476,6 +509,11 @@ def plot_examples(self, batch, n_examples, prediction=None):
]
example_i = self.plotted_examples
+ for i, fig in enumerate(var_figs):
+ fn = f"example_{i}_{example_i}_t{t_i}.png"
+ fig.savefig(fn)
+ logger.info(f"Saved example plot to {fn}")
+
wandb.log(
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 74146c89..9d1d5039 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -23,7 +23,7 @@
}
-@logger.catch
+@logger.catch(reraise=True)
def main(input_args=None):
"""Main function for training and evaluating models."""
parser = ArgumentParser(
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index b9d18b39..357a8977 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -68,6 +68,8 @@ def plot_prediction(
pred,
target,
datastore: BaseRegularGridDatastore,
+ da_prediction=None,
+ da_target=None,
title=None,
vrange=None,
):
@@ -88,10 +90,8 @@ def plot_prediction(
# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
- mask_reshaped = da_mask.values
- pixel_alpha = (
- mask_reshaped.clamp(0.7, 1).cpu().numpy()
- ) # Faded border region
+ mask_values = np.invert(da_mask.values.astype(bool)).astype(float)
+ pixel_alpha = mask_values.clip(0.7, 1) # Faded border region
fig, axes = plt.subplots(
1,
@@ -100,29 +100,62 @@ def plot_prediction(
subplot_kw={"projection": datastore.coords_projection},
)
+ use_xarray = True
+
# Plot pred and target
- for ax, data in zip(axes, (target, pred)):
+
+ if not use_xarray:
+ for ax, data in zip(axes, (target, pred)):
+ ax.coastlines() # Add coastline outlines
+ data_grid = (
+ data.reshape(
+ [datastore.grid_shape_state.x, datastore.grid_shape_state.y]
+ )
+ .T.cpu()
+ .numpy()
+ )
+ im = ax.imshow(
+ data_grid,
+ origin="lower",
+ extent=extent,
+ alpha=pixel_alpha,
+ vmin=vmin,
+ vmax=vmax,
+ cmap="plasma",
+ )
+
+ cbar = fig.colorbar(im, aspect=30)
+ cbar.ax.tick_params(labelsize=10)
+
+ x = da_target.x.values
+ y = da_target.y.values
+ extent = [x.min(), x.max(), y.min(), y.max()]
+ for ax, da in zip(axes, (da_target, da_prediction)):
ax.coastlines() # Add coastline outlines
- data_grid = (
- data.reshape(list(datastore.grid_shape_state.values.values()))
- .cpu()
- .numpy()
- )
- im = ax.imshow(
- data_grid,
+ im = da.plot.imshow(
+ ax=ax,
origin="lower",
+ x="x",
extent=extent,
- alpha=pixel_alpha,
+ alpha=pixel_alpha.T,
vmin=vmin,
vmax=vmax,
cmap="plasma",
+ transform=datastore.coords_projection,
)
+ # da.plot.pcolormesh(
+ # ax=ax,
+ # x="x",
+ # vmin=vmin,
+ # vmax=vmax,
+ # transform=datastore.coords_projection,
+ # cmap="plasma",
+ # )
+
# Ticks and labels
axes[0].set_title("Ground Truth", size=15)
axes[1].set_title("Prediction", size=15)
- cbar = fig.colorbar(im, aspect=30)
- cbar.ax.tick_params(labelsize=10)
if title:
fig.suptitle(title, size=20)
@@ -150,9 +183,7 @@ def plot_spatial_error(
# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
- pixel_alpha = (
- mask_reshaped.clamp(0.7, 1).cpu().numpy()
- ) # Faded border region
+ pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region
fig, ax = plt.subplots(
figsize=(5, 4.8),
@@ -161,8 +192,10 @@ def plot_spatial_error(
ax.coastlines() # Add coastline outlines
error_grid = (
- error.reshape(list(datastore.grid_shape_state.values.values()))
- .cpu()
+ error.reshape(
+ [datastore.grid_shape_state.x, datastore.grid_shape_state.y]
+ )
+ .T.cpu()
.numpy()
)
From efe03027842a22139d6554d68ffee7b6ebe0ad73 Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 26 Nov 2024 13:46:05 +0100
Subject: [PATCH 254/273] use xarray plot only
---
neural_lam/models/ar_model.py | 47 +++++++++++++++++++++++++++--------
neural_lam/vis.py | 43 +++-----------------------------
2 files changed, 39 insertions(+), 51 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index b55143f0..0af25367 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -1,5 +1,6 @@
# Standard library
import os
+from typing import List, Union
# Third-party
import matplotlib.pyplot as plt
@@ -7,7 +8,7 @@
import pytorch_lightning as pl
import torch
import wandb
-from loguru import logger
+import xarray as xr
# Local
from .. import metrics, vis
@@ -149,7 +150,35 @@ def __init__(
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []
- def _create_dataarray_from_tensor(self, tensor, time, split, category):
+ def _create_dataarray_from_tensor(
+ self,
+ tensor: torch.Tensor,
+ time: Union[int, List[int]],
+ split: str,
+ category: str,
+ ) -> xr.DataArray:
+ """
+ Create an `xr.DataArray` from a tensor, with the correct dimensions and
+ coordinates to match the datastore used by the model. This function in
+ in effect is the inverse of what is returned by
+ `WeatherDataset.__getitem__`.
+
+ Parameters
+ ----------
+ tensor : torch.Tensor
+ The tensor to convert to a `xr.DataArray` with dimensions [time,
+ grid_index, feature]
+ time : Union[int,List[int]]
+ The time index or indices for the data, given as integers or a list
+ of integers representing epoch time in nanoseconds.
+ split : str
+ The split of the data, either 'train', 'val', or 'test'
+ category : str
+ The category of the data, either 'state' or 'forcing'
+ """
+ # TODO: creating an instance of WeatherDataset here on every call is
+ # not how this should be done but whether WeatherDataset should be
+ # provided to ARModel or where to put plotting still needs discussion
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
time = np.array(time, dtype="datetime64[ns]")
da = weather_dataset.create_dataarray_from_tensor(
@@ -482,14 +511,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
var_vranges = list(zip(var_vmin, var_vmax))
# Iterate over prediction horizon time steps
- for t_i, (pred_t, target_t) in enumerate(
- zip(pred_slice, target_slice), start=1
- ):
+ for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1):
# Create one figure per variable at this time step
var_figs = [
vis.plot_prediction(
- pred=pred_t[:, var_i],
- target=target_t[:, var_i],
datastore=self._datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self._datastore.step_length * t_i} h)",
@@ -509,10 +534,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
]
example_i = self.plotted_examples
- for i, fig in enumerate(var_figs):
- fn = f"example_{i}_{example_i}_t{t_i}.png"
- fig.savefig(fn)
- logger.info(f"Saved example plot to {fn}")
+ # for i, fig in enumerate(var_figs):
+ # fn = f"example_{i}_{example_i}_t{t_i}.png"
+ # fig.savefig(fn)
+ # logger.info(f"Saved example plot to {fn}")
wandb.log(
{
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 357a8977..47c68e4f 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -65,8 +65,6 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
- pred,
- target,
datastore: BaseRegularGridDatastore,
da_prediction=None,
da_target=None,
@@ -81,8 +79,8 @@ def plot_prediction(
"""
# Get common scale for values
if vrange is None:
- vmin = min(vals.min().cpu().item() for vals in (pred, target))
- vmax = max(vals.max().cpu().item() for vals in (pred, target))
+ vmin = min(da_prediction.min(), da_target.min())
+ vmax = max(da_prediction.max(), da_target.max())
else:
vmin, vmax = vrange
@@ -100,39 +98,13 @@ def plot_prediction(
subplot_kw={"projection": datastore.coords_projection},
)
- use_xarray = True
-
# Plot pred and target
-
- if not use_xarray:
- for ax, data in zip(axes, (target, pred)):
- ax.coastlines() # Add coastline outlines
- data_grid = (
- data.reshape(
- [datastore.grid_shape_state.x, datastore.grid_shape_state.y]
- )
- .T.cpu()
- .numpy()
- )
- im = ax.imshow(
- data_grid,
- origin="lower",
- extent=extent,
- alpha=pixel_alpha,
- vmin=vmin,
- vmax=vmax,
- cmap="plasma",
- )
-
- cbar = fig.colorbar(im, aspect=30)
- cbar.ax.tick_params(labelsize=10)
-
x = da_target.x.values
y = da_target.y.values
extent = [x.min(), x.max(), y.min(), y.max()]
for ax, da in zip(axes, (da_target, da_prediction)):
ax.coastlines() # Add coastline outlines
- im = da.plot.imshow(
+ da.plot.imshow(
ax=ax,
origin="lower",
x="x",
@@ -144,15 +116,6 @@ def plot_prediction(
transform=datastore.coords_projection,
)
- # da.plot.pcolormesh(
- # ax=ax,
- # x="x",
- # vmin=vmin,
- # vmax=vmax,
- # transform=datastore.coords_projection,
- # cmap="plasma",
- # )
-
# Ticks and labels
axes[0].set_title("Ground Truth", size=15)
axes[1].set_title("Prediction", size=15)
From a489c2ed974397ea230d2e61b842d8d9384867dc Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 26 Nov 2024 14:07:06 +0100
Subject: [PATCH 255/273] don't reraise
---
neural_lam/train_model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index 9d1d5039..74146c89 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -23,7 +23,7 @@
}
-@logger.catch(reraise=True)
+@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
parser = ArgumentParser(
From 242d08bcb5374cdd90aecfd49f501ed233f1ce0c Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 26 Nov 2024 14:50:03 +0100
Subject: [PATCH 256/273] remove debug plot
---
neural_lam/models/ar_model.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 0af25367..c875688b 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -534,10 +534,6 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
]
example_i = self.plotted_examples
- # for i, fig in enumerate(var_figs):
- # fn = f"example_{i}_{example_i}_t{t_i}.png"
- # fig.savefig(fn)
- # logger.info(f"Saved example plot to {fn}")
wandb.log(
{
From c1f706c29542d770ed49e910f8b9bd5caff1fdec Mon Sep 17 00:00:00 2001
From: Leif Denby
Date: Tue, 26 Nov 2024 16:04:24 +0100
Subject: [PATCH 257/273] remove extent calc used in diagnosing issue
---
neural_lam/vis.py | 3 ---
1 file changed, 3 deletions(-)
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 47c68e4f..c814aacf 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -99,9 +99,6 @@ def plot_prediction(
)
# Plot pred and target
- x = da_target.x.values
- y = da_target.y.values
- extent = [x.min(), x.max(), y.min(), y.max()]
for ax, da in zip(axes, (da_target, da_prediction)):
ax.coastlines() # Add coastline outlines
da.plot.imshow(
From 88ec9dc30ca597427a51485296d66d6307d566f5 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Thu, 28 Nov 2024 07:58:09 +0000
Subject: [PATCH 258/273] Test order of dimension in eval plots
---
neural_lam/vis.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index ef8c2722..b24d77bd 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -104,7 +104,7 @@ def plot_prediction(
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
data_grid = (
- data.reshape(datastore.grid_shape_state.x, datastore.grid_shape_state.y)
+ data.reshape(datastore.grid_shape_state.y, datastore.grid_shape_state.x)
.cpu()
.numpy()
)
From 90f8918217fb7f84737c8ae3001e744b488cd5dd Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Thu, 28 Nov 2024 09:35:15 +0000
Subject: [PATCH 259/273] fix tensors on cpu and plot time index
---
neural_lam/models/ar_model.py | 14 +++++++++++---
neural_lam/weather_dataset.py | 2 +-
2 files changed, 12 insertions(+), 4 deletions(-)
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index f98af756..6d05b2ab 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -181,7 +181,7 @@ def _create_dataarray_from_tensor(
# not how this should be done but whether WeatherDataset should be
# provided to ARModel or where to put plotting still needs discussion
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
- time = np.array(time, dtype="datetime64[ns]")
+ time = np.array(time.cpu(), dtype="datetime64[ns]")
da = weather_dataset.create_dataarray_from_tensor(
tensor=tensor, time=time, category=category
)
@@ -514,6 +514,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
# Iterate over prediction horizon time steps
for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1):
# Create one figure per variable at this time step
+ print(t_i)
+ #print(da_prediction)
+ t_val = t_i * self._datastore.step_length
+ print(t_val)
var_figs = [
vis.plot_prediction(
datastore=self._datastore,
@@ -521,9 +525,13 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
da_prediction=da_prediction.isel(
- state_feature=var_i
+ state_feature=var_i,
+ time=t_i - 1
+ ).squeeze(),
+ da_target=da_target.isel(
+ state_feature=var_i,
+ time=t_i - 1
).squeeze(),
- da_target=da_target.isel(state_feature=var_i).squeeze(),
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 532e3c90..bfb02c8c 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -581,7 +581,7 @@ def _is_listlike(obj):
coords["time"] = time
da = xr.DataArray(
- tensor.numpy(),
+ tensor.cpu().numpy(),
dims=dims,
coords=coords,
)
From 53f0ea4dc379d093a582d1fd564e0a43b8ac223a Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Fri, 29 Nov 2024 10:41:13 +0000
Subject: [PATCH 260/273] restore tests/test_datasets.py
---
tests/test_datasets.py | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 224d0a25..cefda402 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -36,10 +36,13 @@ def test_dataset_item_shapes(datastore_name):
N_pred_steps = 4
num_past_forcing_steps = 1
num_future_forcing_steps = 1
- forcing_window_size=forcing_window_size,
-
- num_past_forcing_steps=num_past_forcing_steps,
- num_future_forcing_steps=num_future_forcing_steps,
+ dataset = WeatherDataset(
+ datastore=datastore,
+ split="train",
+ ar_steps=N_pred_steps,
+ num_past_forcing_steps=num_past_forcing_steps,
+ num_future_forcing_steps=num_future_forcing_steps,
+ )
item = dataset[0]
@@ -255,4 +258,4 @@ def test_dataset_length(dataset_config):
# Check that we can actually get last and first sample
dataset[0]
- dataset[expected_len - 1]
+ dataset[expected_len - 1]
\ No newline at end of file
From cfc249ffa98dbd7d8389659df6df503bb8f4ae44 Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Fri, 29 Nov 2024 10:48:53 +0000
Subject: [PATCH 261/273] cleaning up with focus on linting
---
neural_lam/config.py | 17 -----
.../compute_standardization_stats.py | 1 +
.../npyfilesmeps/create_parameter_weights.py | 1 +
neural_lam/datastore/plot_example.py | 4 +-
neural_lam/models/ar_model.py | 38 ++++--------
neural_lam/train_model.py | 62 +++++++++----------
neural_lam/utils.py | 10 ++-
tests/test_datasets.py | 2 +-
8 files changed, 53 insertions(+), 82 deletions(-)
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 2a962997..349ffd99 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -43,23 +43,6 @@ class DatastoreSelection:
config_path: str
-@dataclasses.dataclass
-class TrainingConfig:
- """
- Configuration related to training neural-lam
-
- Attributes
- ----------
- state_feature_weights : Dict[str, float]
- The weights for each state feature in the datastore to use in the loss
- function during training.
- """
-
- state_feature_weights: Dict[str, float]
- logger: str = "wandb"
- logger_url: str = None
-
-
@dataclasses.dataclass
class ManualStateFeatureWeighting:
"""
diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py
index 76a647d6..f2c80e8a 100644
--- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py
+++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py
@@ -410,5 +410,6 @@ def cli():
distributed=distributed,
)
+
if __name__ == "__main__":
cli()
diff --git a/neural_lam/datastore/npyfilesmeps/create_parameter_weights.py b/neural_lam/datastore/npyfilesmeps/create_parameter_weights.py
index 76a647d6..f2c80e8a 100644
--- a/neural_lam/datastore/npyfilesmeps/create_parameter_weights.py
+++ b/neural_lam/datastore/npyfilesmeps/create_parameter_weights.py
@@ -410,5 +410,6 @@ def cli():
distributed=distributed,
)
+
if __name__ == "__main__":
cli()
diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py
index a74f7409..5e1a57e0 100644
--- a/neural_lam/datastore/plot_example.py
+++ b/neural_lam/datastore/plot_example.py
@@ -186,5 +186,5 @@ def _parse_dict(arg_str):
selection=selection,
index_selection=index_selection,
)
- #plt.show()
- plt.savefig('plot_example.png')
+ # plt.show()
+ plt.savefig("plot_example.png")
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 6d05b2ab..ec950799 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -7,7 +7,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
-import wandb
import xarray as xr
# Local
@@ -514,10 +513,6 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
# Iterate over prediction horizon time steps
for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1):
# Create one figure per variable at this time step
- print(t_i)
- #print(da_prediction)
- t_val = t_i * self._datastore.step_length
- print(t_val)
var_figs = [
vis.plot_prediction(
datastore=self._datastore,
@@ -525,12 +520,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
da_prediction=da_prediction.isel(
- state_feature=var_i,
- time=t_i - 1
+ state_feature=var_i, time=t_i - 1
).squeeze(),
da_target=da_target.isel(
- state_feature=var_i,
- time=t_i - 1
+ state_feature=var_i, time=t_i - 1
).squeeze(),
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
@@ -548,21 +541,13 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
self._datastore.get_vars_names("state"), var_figs
):
- if not isinstance(self.logger, pl.loggers.WandbLogger):
- key=f"{var_name}_example_{t_i}"
+ if isinstance(self.logger, pl.loggers.WandbLogger):
+ key = f"{var_name}_example_{example_i}"
else:
- key=f"{var_name}_example_{example_i}"
-
- self.logger.log_image(key=key, images=[fig])
-
- # wandb.log(
- # {
- # f"{var_name}_example_{example_i}": wandb.Image(fig)
- # for var_name, fig in zip(
- # self._datastore.get_vars_names("state"), var_figs
- # )
- # }
- # )
+ key = f"{var_name}_example"
+
+ self.logger.log_image(key=key, images=[fig], step=t_i)
+
plt.close(
"all"
) # Close all figs for this time step, saves memory
@@ -657,7 +642,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
)
)
- # Ensure that log_dict has structure for logging as dict(str, plt.Figure)
+ # Ensure that log_dict has structure for
+ # logging as dict(str, plt.Figure)
assert all(
isinstance(key, str) and isinstance(value, plt.Figure)
for key, value in log_dict.items()
@@ -707,9 +693,9 @@ def on_test_epoch_end(self):
# log all to same key, sequentially
for i, fig in enumerate(loss_map_figs):
- key=f"test_loss"
+ key = "test_loss"
if not isinstance(self.logger, pl.loggers.WandbLogger):
- key=f"{key}_{i}"
+ key = f"{key}_{i}"
self.logger.log_image(key=key, images=[fig])
# also make without title and save as pdf
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index f29f522f..d8e097b3 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -5,16 +5,15 @@
from argparse import ArgumentParser
# Third-party
+import mlflow
+
+# for logging the model:
+import mlflow.pytorch
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities import seed
from loguru import logger
-import mlflow
-# for logging the model:
-import mlflow.pytorch
-from mlflow.models import infer_signature
-
# Local
from . import utils
from .config import load_config_and_datastore
@@ -27,23 +26,26 @@
"hi_lam_parallel": HiLAMParallel,
}
-class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
+class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
def __init__(self, experiment_name, tracking_uri):
- super().__init__(experiment_name=experiment_name, tracking_uri=tracking_uri)
+ super().__init__(
+ experiment_name=experiment_name, tracking_uri=tracking_uri
+ )
mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.log_param("run_id", self.run_id)
- #mlflow.pytorch.autolog() # Can be used to log the model, but without signature
- #mlflow.save_dir = "mlruns"
- #self.save_dir = "mlruns"
@property
def save_dir(self):
return "mlruns"
- def log_image(self, key, images):
+ def log_image(self, key, images, step=None):
+ # Third-party
from PIL import Image
+ if step is not None:
+ key = f"{key}_{step}"
+
# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but it doesn't work
temporary_image = f"{key}.png"
@@ -54,11 +56,11 @@ def log_image(self, key, images):
def log_model(self, model):
# Create model signature
- #signature = infer_signature(train_dataset.numpy(), model(train_dataset).detach().numpy())
+ # signature = infer_signature(train_dataset.numpy(),
+ # model(train_dataset).detach().numpy())
mlflow.pytorch.log_model(model, "model")
-
def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
logger = pl.loggers.WandbLogger(
@@ -72,10 +74,6 @@ def _setup_training_logger(config, datastore, args, run_name):
raise ValueError(
"MLFlow logger requires a URL to the MLFlow server"
)
- # logger = pl.loggers.MLFlowLogger(
- # experiment_name=args.wandb_project,
- # tracking_uri=url,
- # )
logger = CustomMLFlowLogger(
experiment_name=args.wandb_project,
tracking_uri=url,
@@ -87,7 +85,6 @@ def _setup_training_logger(config, datastore, args, run_name):
return logger
-
@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
@@ -337,7 +334,6 @@ def main(input_args=None):
config=config, datastore=datastore, args=args, run_name=run_name
)
-
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=f"saved_models/{run_name}",
filename="min_val_loss",
@@ -349,12 +345,12 @@ def main(input_args=None):
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
- #devices=3,
- #devices=[1,2],
- #devices=[0, 1, 2],
- #strategy="auto",
- devices=1, # For eval mode
- num_nodes=1, # For eval mode
+ # devices=3,
+ # devices=[1,2],
+ # devices=[0, 1, 2],
+ # strategy="auto",
+ devices=1, # For eval mode
+ num_nodes=1, # For eval mode
accelerator=device_name,
logger=training_logger,
log_every_n_steps=1,
@@ -362,28 +358,28 @@ def main(input_args=None):
check_val_every_n_epoch=args.val_interval,
precision=args.precision,
)
- import ipdb
+
# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_training_logger_metrics(
training_logger, val_steps=args.val_steps_to_log
- ) # Do after wandb.init
+ ) # Do after initializing logger
if args.eval:
trainer.test(
model=model,
datamodule=data_module,
ckpt_path=args.load,
- )
+ )
else:
- with ipdb.launch_ipdb_on_exception():
- trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
+ trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
# Get a sample of training data to log
- #sample_data = data_module.train_dataset
- #print("Logging sample data")
- #print(sample_data.train_dataset)
+ # sample_data = data_module.train_dataset
+ # print("Logging sample data")
+ # print(sample_data.train_dataset)
# Log the model
training_logger.log_model(model)
+
if __name__ == "__main__":
main()
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index 7b61c906..7bd7d88e 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -5,7 +5,7 @@
# Third-party
import torch
-from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.loggers import MLFlowLogger, WandbLogger
from torch import nn
from tueplots import bundles, figsizes
@@ -237,12 +237,16 @@ def fractional_plot_bundle(fraction):
def init_training_logger_metrics(training_logger, val_steps):
"""
- Set up wandb metrics to track
+ Set up logger metrics to track
"""
experiment = training_logger.experiment
if isinstance(training_logger, WandbLogger):
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")
+ elif isinstance(training_logger, MLFlowLogger):
+ pass
else:
- warnings.warn("Only WandbLogger is supported for tracking metrics")
+ warnings.warn(
+ "Only WandbLogger & MLFlowLogger is supported for tracking metrics"
+ )
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index cefda402..419aece0 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -258,4 +258,4 @@ def test_dataset_length(dataset_config):
# Check that we can actually get last and first sample
dataset[0]
- dataset[expected_len - 1]
\ No newline at end of file
+ dataset[expected_len - 1]
From b218c8b2bdf2f1d96c5a118f59e3d2102496769e Mon Sep 17 00:00:00 2001
From: Kasper Hintz
Date: Fri, 29 Nov 2024 11:54:54 +0000
Subject: [PATCH 262/273] update tests
---
.../mdp/danra_100m_winds/config.yaml | 2 ++
.../mllam/graph/multiscale/g2m_edge_index.pt | Bin 7803519 -> 0 bytes
.../mllam/graph/multiscale/g2m_features.pt | Bin 5852917 -> 0 bytes
.../mllam/graph/multiscale/m2g_edge_index.pt | Bin 29743359 -> 0 bytes
.../mllam/graph/multiscale/m2g_features.pt | Bin 22307765 -> 0 bytes
.../mllam/graph/multiscale/m2m_edge_index.pt | Bin 7512895 -> 0 bytes
.../mllam/graph/multiscale/m2m_features.pt | Bin 5634933 -> 0 bytes
.../mllam/graph/multiscale/mesh_features.pt | Bin 473594 -> 0 bytes
.../mllam/graph/multiscale/mesh_graph_1.png | Bin 686565 -> 0 bytes
tests/datastore_examples/npy/config_meps.yaml | 13 -------------
10 files changed, 2 insertions(+), 13 deletions(-)
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/g2m_edge_index.pt
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/g2m_features.pt
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/m2g_edge_index.pt
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/m2g_features.pt
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/m2m_edge_index.pt
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/m2m_features.pt
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/mesh_features.pt
delete mode 100644 tests/datastore_examples/mllam/graph/multiscale/mesh_graph_1.png
delete mode 100644 tests/datastore_examples/npy/config_meps.yaml
diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml
index 0bb5c5ec..9ddcebdd 100644
--- a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml
+++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml
@@ -2,6 +2,8 @@ datastore:
kind: mdp
config_path: danra.datastore.yaml
training:
+ logger: mlflow
+ logger_url: https://mlflow.dmidev.org
state_feature_weighting:
__config_class__: ManualStateFeatureWeighting
weights:
diff --git a/tests/datastore_examples/mllam/graph/multiscale/g2m_edge_index.pt b/tests/datastore_examples/mllam/graph/multiscale/g2m_edge_index.pt
deleted file mode 100644
index ddba0447c8c8d5780977b8afcb16870d931e5dec..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 7803519
zcmZ_Ubr4l$-}q}DX%M7SLRydpK?DgwIt2*{X$k2RM34|fkPt*#5a|#>S~>*@2|+?a
zkdP3h<2>*0z1AP+{++XC_FVJXv-i4Z?LD);_srh4s>X~JBSyS*Q{rcb{)GXXx6QD
zi=OQ|wrbX+b*Ju~yEW@oC_&d4xpFm#```0-|L-fCw{6|^KmYrd1l2lsYWw#8K9`{D
za}ECU-NM*RC3)>T^}jbcq{dSl2lJ$EyDs)-~>k
zmMuGX?AW?xk9M6qbx%;UTdUUHTDL0Su4Rt|UE{r6r)&HkT@%#lnlN9js&V5jtn=ZT
zS^qul|2%5_|6XyL7-_;QZqc_#>(2jk=C{hFh%VWc|1TZe_2|*D_5Z9fwAA^uasGS$
z|9SlXuaPoFv2cz5`Bd{4c&l9D|IY>fU#tD!vEbr=m;3Lrkc+v5-*Y)va3xo9HP><-
zH}EHJ0>9!UPUaL&<8;pCH=M<9IfrvOj|;evi@AhL`8}6&1y^zv
zS92}baXo+HMsDU7ZsT_D!&)myjxSt1kh=+NENBJ9%^LL))X`bO(p5uA`$&37p
zmwA=fc%3)+4{!5d-sOEh;3GcfGsgJKeUvd7n{gPA@%cOx@dYMfGA3tAreYeVWjemd
z49vvL%)+e9&K%6imzkTdFfa4701L4&i||zz<7+Izk}Sp2EW@&VljZpqE3z`Huo|ng
z25Yf4-)3FD!}@Hl9lC9W=ZP}h3_yIey3%jyAd$2eA@I&_FM;yR`9K^vK
z!lC?>!}%FUaui2%499XDzvKjd#YvpZDV)aXoWXB6i{Ek%=W-qwa3L3Q377JFF6Roa
z@_n{v2Y$d#?82_>&K~T^-t5B<*^eJ_00(jq
zKj9D#<)<9Z&p48!IGSH@EXVOnPT*IZ#L1k(shrLk{D!mmE$479=W#w4auJtsDZl4(
zuHZ_p<{GZ!dT!(xy1rpG9KeI0Tc3hCSqc~z$8q{WK7N!OvzMC%`{BQbWG0|
znSmLZiJAElvoI^OF*|cGCv!13^Y9hsWj^L-0TyH-7G@C^<*O{l;(U!I_&Q6n6yIQJ
zmSI_z+v1dX9G55BfiVVY{I5`oBq
zE!dK+*qUwFmhZD2+p_~Z@&k5aXLey%c4K$;U{Cg9Z}wqde#n08&yP5OA9ElF@e>Z_
z5Dw)qe#+q-!Ou97pK}yP^9zpQSdQa(e#r@($genwUvo02a4M&9I%jYuXK^;a4~#Kl~~rCi4Extu?61y}M%uHtI0;aaZadT!uP+{jJb%q`r?ZQRZs+{sHt+CX-sL^s=L0_EBR=L6KIJpUxN6?>ImTox#%3JGWjw}b0w!c4
zCguxF!lX>b=HV;M%Y4kw
z0xZZvEX*P-%2!#8#rYab@O74CDZatdEW@%a$2VD?75Ek_vJxw^3ahdjtFs1cvKDKz
z4&P>7*5f;@&jxJBMtqlz*@R8mjPJ2ITd*Ztu{GPUE#GH5wr2-+$!nHaU(ZzGq-Rnw{bgna3^qJgFM7vd6-9dl*jlRkMjh7
z=SiO8X`bO9Jj-)D&kOvM7kP<)@iMRQDzEWxUgu5T;y=92JN%b-d5`z`fDieIkNJd8
z`HV5Hne%*(F&T@o8HaHhkMWs+3Hdw|F)?3Y5+-FbCT9w!WGbd+8m47Brss>yz>Lhq
z%zTMin3dU>ojI73x%e`3GY?;3Ugl$d7GOaZVqq3xQNGGzEY8xvL4@IeKuf2HsZT%%qDEgW_*v$*@7+E
zimlm(@3S4-vjaQw19oC(c41d`V|VsoPxfMO_F-Rs$bRh4k2ruIb07!t6AtDO4&^X@
z%HbTr&p48wa}-DO3y$Gfj^lWK$qAguuQ-Wcb26uJDyMNeXK*IJ;VjPPx17W8IG6J{
zp9{Eg;HC)SeT+a>si5t0zo4JKsxsBVogFCs4ySaxy
zb1(Pt7w+c)9^_#j;ZYvrZ#>Qu{GBIxil=#oXL*k2d4YfOA}{eTUgi~E5?8CnNkp0-7A8`N&au5e|2#0bQKjm~s
zuHtI0;aaZadT!uP+{jJb%q`r;?cBkg+{NA8!=JgA`}hm@^B@oLFpuykkMTDi=L!DK
zlRV8cJj-)D&kOvM7kP<)@iMRS8n5#PZ}JxZ;cedGzr4%)e85M1%qM)xXN+;(+~aeM
z$=Hm;c#O{kOvvY%h>7_ElQ1chF*#E*6;m?}(=r{?^F?N0MrLAWW?@!lV|M0XPUhmv
z%*{M}g?X8u1z3=USeQjvl&`WFi}N*>U`du@X_jGGmgAc&&kB5t6yoF`6VZCBERA!PUaL&au{Zm$FF#~Ie#8O%m;*V8pKvgTa43iI
zQx4~69LZ4}%`Z5HV>yoF`6VatD^B8MPT^Ee<8;p8On$>z{FZY#m-9HE3%HPrxR^`0
zl;3kXS8yeN5P#)i9^p|Q
z<8M69-+7X!c$#PU2hZ{x&+`KR|&W!JE9re|VdBc$fEhpAYzukNB8R
z_>|8W?WG&Wa9oA(%zQg)#z=mwZciEUt*p$uKoGsXrt=O7v*p~0J9ow@5
zJF*iyvkSYj8@sayd$JdMvk&{SAN%to4&cWe$U*#sgE@plIgGsBV!lOLK<2=FN
zd6K7inrHY2&+;74^8zpO693|5Ug1?<6o4wn30*7nJ+O5voagA
zGY4}r7jrWYUtwP6V}2H3K^9_R7GY5qV{yL55`3K{S&DD4G|R9o%dtEw@GVwkC01q?
zR%JC-XARb5ZPwx2tjl_QhxOTj4cUnAvN4;o8Q)`bwqQ%PVr#ZxTfWbB?7)uvfSuTx
zUD%b~*quGtlfBrReb|@%*q{IY)6czu*{-R?oIFqwDo8NK{zvEoa<9sgQLN4NBF6AZs!i}{Du2@fCqVqzw$7T@)&>Pah~AsJjqi$
z%`^OiXL+6%c#)U*7ccV)uksrI=5^lSP2S>d-r>Kz%X_@f2Ykp!e9R|&${7Ect9*_z
z8H=$QhjAH?@tJ@LnTUz`0+TQ)lQB6{FeOtlHPbL1)AL1UU`A$QX1>HM%*t%c&YaA}
zmzkS+_zLqfAM>*S3$hT4uqa<;F&5`*EWy`VlBM_tOS3G?@lBRz1-`|Kti;N!!m6ys
z8m!4$tj#)nn{`=_@31}_upt}qT{dP@HsgD2&K7LRR&32SY|HoAjvd&MAFvZUvkSYj
z8@sayd$Je%urEJkKlbNG9Keq`kc0RM2XiQg@ly`x2!6(q{G6jWnqP1X$8kKr5P#)i9_2Cq#^XG}-+7X!c$#PU
z2hZ|6FYr%ZKU7KI9`lr<%omu1NtukvnSv>qim91~X_=1c`64qgBQr5GUt$(!Wj1DK4(4Po
zzRcXr!&jJ>`Iw&tSdfKSm_=BWud*16^EH;>>nzDqe1oM~hGkifZ?Zfq@GVwkC01q?
zR%JC-XARb5E!JiozRkL<$9Gtt4cL&4*qBY&l+E}ao3jO5vK3pi4coFE+p_~Z@&k5a
zXLey%c4K$;U@!J&ANJ*k?8pB6hy(aB2XYWU;b0EoP!8j#9L^E^j3fCuM{zX2;24hO
zIF9F+oWP0vij(*?Cvys?avG;|250gc&f;u-%Q^gxb2*Rmxqu6~h>N*|OSz2Sb2)$D
z3a;djT*cK~!?j$;_1wUpxRIN