Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate interior state and boundary forcing to only predict state #84

Closed
wants to merge 92 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
5df1bff
add datastore_boundary to neural_lam
sadamov Nov 18, 2024
46590ef
complete integration of boundary in weatherDataset
sadamov Nov 18, 2024
b990f49
Add test to check timestep length and spacing
sadamov Nov 18, 2024
3fd1d6b
setting default mdp boundary to 0 gridcells
sadamov Nov 18, 2024
1f2499c
implement time-based slicing
sadamov Nov 18, 2024
1af1481
remove all interior_mask and boundary_mask
sadamov Nov 19, 2024
d545cb7
added gcsfs dependency for era5 weatherbench download
sadamov Nov 19, 2024
5c1a7d7
added new era5 datastore config for boundary
sadamov Nov 19, 2024
30e4f05
removed left-over boundary-mask references
sadamov Nov 19, 2024
6a8c593
make check for existing category in datastore more flexible (for boun…
sadamov Nov 19, 2024
17c920d
implement xarray based (mostly) time slicing and windowing
sadamov Nov 20, 2024
7919995
cleanup analysis based time-slicing
sadamov Nov 21, 2024
9bafcee
implement datastore_boundary in existing tests
sadamov Nov 19, 2024
ce06bbc
allow for grid shape retrieval from forcing data
sadamov Nov 21, 2024
884b5c6
rearrange time slicing, boundary first
sadamov Nov 21, 2024
5904cbe
identified issue, cleanup next
leifdenby Nov 25, 2024
efe0302
use xarray plot only
leifdenby Nov 26, 2024
a489c2e
don't reraise
leifdenby Nov 26, 2024
242d08b
remove debug plot
leifdenby Nov 26, 2024
c1f706c
remove extent calc used in diagnosing issue
leifdenby Nov 26, 2024
cf8e3e4
add type annotation
leifdenby Nov 29, 2024
85160ce
ensure tensor copy to cpu mem before data-array creation
leifdenby Nov 29, 2024
52c4528
apply time-indexing to support ar_steps_val > 1
leifdenby Nov 29, 2024
b96d8eb
renaming test datastores
sadamov Nov 30, 2024
72da25f
adding num_past/future_boundary_step args
sadamov Nov 30, 2024
244f1cc
using combined config file
sadamov Nov 30, 2024
a9cc36e
proper handling of state/forcing/boundary in dataset
sadamov Nov 30, 2024
dcc0b46
datastore_boundars=None introduced
sadamov Nov 30, 2024
a3b3bde
bug fix for file retrieval per member
sadamov Nov 30, 2024
3ffc413
rename datastore for tests
sadamov Nov 30, 2024
85aad66
aligned time with danra for easier boundary testing
sadamov Nov 30, 2024
64f057f
Fixed test for temporal embedding
sadamov Nov 30, 2024
6205dbd
pin dataclass-wizard <0.31.0 to avoid bug in dataclass-wizard
leifdenby Dec 2, 2024
551cd26
allow boundary as input to ar_model.common_step
sadamov Dec 2, 2024
fc95350
linting
sadamov Dec 2, 2024
01fa807
improved docstrings and added some assertions
sadamov Dec 2, 2024
5a749f3
update mdp dependency
sadamov Dec 2, 2024
45ba607
remove boundary datastore from tests that don't need it
sadamov Dec 2, 2024
f36f360
fix scope of _get_slice_time
sadamov Dec 2, 2024
105108e
fix scope of _get_time_step
sadamov Dec 2, 2024
d760145
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov Dec 2, 2024
ae0cf76
added information about optional boundary datastore
sadamov Dec 2, 2024
9af27e0
add datastore_boundary to neural_lam
sadamov Nov 18, 2024
c25fb30
complete integration of boundary in weatherDataset
sadamov Nov 18, 2024
505ceeb
Add test to check timestep length and spacing
sadamov Nov 18, 2024
e733066
setting default mdp boundary to 0 gridcells
sadamov Nov 18, 2024
d8349a4
implement time-based slicing
sadamov Nov 18, 2024
fd791bf
remove all interior_mask and boundary_mask
sadamov Nov 19, 2024
ae82cdb
added gcsfs dependency for era5 weatherbench download
sadamov Nov 19, 2024
34a6cc7
added new era5 datastore config for boundary
sadamov Nov 19, 2024
2dc67a0
removed left-over boundary-mask references
sadamov Nov 19, 2024
9f8628e
make check for existing category in datastore more flexible (for boun…
sadamov Nov 19, 2024
388c79d
implement xarray based (mostly) time slicing and windowing
sadamov Nov 20, 2024
2529969
cleanup analysis based time-slicing
sadamov Nov 21, 2024
179a035
implement datastore_boundary in existing tests
sadamov Nov 19, 2024
2daeb16
allow for grid shape retrieval from forcing data
sadamov Nov 21, 2024
cbcdcae
rearrange time slicing, boundary first
sadamov Nov 21, 2024
e6ace27
renaming test datastores
sadamov Nov 30, 2024
42818f0
adding num_past/future_boundary_step args
sadamov Nov 30, 2024
0103b6e
using combined config file
sadamov Nov 30, 2024
0896344
proper handling of state/forcing/boundary in dataset
sadamov Nov 30, 2024
355423c
datastore_boundars=None introduced
sadamov Nov 30, 2024
121d460
bug fix for file retrieval per member
sadamov Nov 30, 2024
7e82eef
rename datastore for tests
sadamov Nov 30, 2024
320d7c4
aligned time with danra for easier boundary testing
sadamov Nov 30, 2024
f18dcc2
Fixed test for temporal embedding
sadamov Nov 30, 2024
e6327d8
allow boundary as input to ar_model.common_step
sadamov Dec 2, 2024
1374a19
linting
sadamov Dec 2, 2024
779f3e9
improved docstrings and added some assertions
sadamov Dec 2, 2024
f126ec2
remove boundary datastore from tests that don't need it
sadamov Dec 2, 2024
4b656da
fix scope of _get_time_step
sadamov Dec 2, 2024
75db4b8
added information about optional boundary datastore
sadamov Dec 2, 2024
58b4af6
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov Dec 2, 2024
4c17545
moved gcsfs to dev group
sadamov Dec 3, 2024
a700350
linting
sadamov Dec 3, 2024
315aa0f
Propagate separation of state and boundary change through training loop
joeloskarsson Oct 28, 2024
1967221
Start building graphs with wmg
joeloskarsson Nov 4, 2024
cb74e3f
Change forward pass to concat according to enforced node ordering
joeloskarsson Nov 11, 2024
9715ed8
wip to make tests pass
joeloskarsson Nov 11, 2024
336fba9
Fix edge index manipulation to make training work again
joeloskarsson Nov 12, 2024
ce3ea6d
Work on fixing plotting functionality
joeloskarsson Nov 12, 2024
a520505
Linting
joeloskarsson Nov 13, 2024
793e6c0
Add optional separate grid embedder for boundary
joeloskarsson Nov 13, 2024
3515460
Make new graph creation script main and only one
joeloskarsson Nov 13, 2024
05d91f1
Fix some typos and forgot code
joeloskarsson Nov 13, 2024
3eba43c
Correct handling of node indices for m2g when using decode_mask
joeloskarsson Nov 27, 2024
f1b7359
Linting and bugfixes
joeloskarsson Nov 28, 2024
fa6c9e3
Make graph creation and plotting work with datastores
joeloskarsson Dec 2, 2024
4d85384
Fix graph loading and boundary mask
joeloskarsson Dec 2, 2024
9edfec3
Fix boundary masking bug for static features
joeloskarsson Dec 2, 2024
6e1c53c
Add flag making boundary forcing optional in models
joeloskarsson Dec 3, 2024
4bcaa4b
Linting
joeloskarsson Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove all interior_mask and boundary_mask
sadamov committed Nov 21, 2024
commit 1af1481e6884f89ccf39befa37e0d61ed16bbcc3
17 changes: 0 additions & 17 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
@@ -228,23 +228,6 @@ def get_dataarray(
"""
pass

@cached_property
@abc.abstractmethod
def boundary_mask(self) -> xr.DataArray:
"""
Return the boundary mask for the dataset, with spatial dimensions
stacked. Where the value is 1, the grid point is a boundary point, and
where the value is 0, the grid point is not a boundary point.

Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions
`('grid_index',)`.

"""
pass

@abc.abstractmethod
def get_xy(self, category: str) -> np.ndarray:
"""
34 changes: 0 additions & 34 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
@@ -318,40 +318,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)
return ds_stats

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""
Produce a 0/1 mask for the boundary points of the dataset, these will
sit at the edges of the domain (in x/y extent) and will be used to mask
out the boundary points from the loss function and to overwrite the
boundary points from the prediction. For now this is created when the
mask is requested, but in the future this could be saved to the zarr
file.

Returns
-------
xr.DataArray
A 0/1 mask for the boundary points of the dataset, where 1 is a
boundary point and 0 is not.

"""
if self._n_boundary_points > 0:
ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
da_state_variable = (
ds_unstacked["state"].isel(time=0).isel(state_feature=0)
)
da_domain_allzero = xr.zeros_like(da_state_variable)
ds_unstacked["boundary_mask"] = da_domain_allzero.isel(
x=slice(self._n_boundary_points, -self._n_boundary_points),
y=slice(self._n_boundary_points, -self._n_boundary_points),
)
ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(
1
).astype(int)
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)
else:
return None

@property
def coords_projection(self) -> ccrs.Projection:
"""
28 changes: 0 additions & 28 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
@@ -668,34 +668,6 @@ def grid_shape_state(self) -> CartesianGridShape:
ny, nx = self.config.grid_shape_state
return CartesianGridShape(x=nx, y=ny)

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""The boundary mask for the dataset. This is a binary mask that is 1
where the grid cell is on the boundary of the domain, and 0 otherwise.

Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions `[grid_index]`.

"""
xy = self.get_xy(category="state", stacked=False)
xs = xy[:, :, 0]
ys = xy[:, :, 1]
# Check if x-coordinates are constant along columns
assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant"
# Check if y-coordinates are constant along rows
assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant"
# Extract unique x and y coordinates
x = xs[:, 0] # Unique x-coordinates (changes along the first axis)
y = ys[0, :] # Unique y-coordinates (changes along the second axis)
values = np.load(self.root_path / "static" / "border_mask.npy")
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
)
da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int)
return da_mask_stacked_xy

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""Return the standardization dataarray for the given category. This
should contain a `{category}_mean` and `{category}_std` variable for
53 changes: 10 additions & 43 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,6 @@ def __init__(
da_state_stats = datastore.get_standardization_dataarray(
category="state"
)
da_boundary_mask = datastore.boundary_mask
num_past_forcing_steps = args.num_past_forcing_steps
num_future_forcing_steps = args.num_future_forcing_steps

@@ -115,18 +114,6 @@ def __init__(
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)

boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
).unsqueeze(
1
) # add feature dim

self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
"interior_mask", 1.0 - self.boundary_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border

self.val_metrics = {
"mse": [],
}
@@ -153,13 +140,6 @@ def configure_optimizers(self):
)
return opt

@property
def interior_mask_bool(self):
"""
Get the interior mask as a boolean (N,) mask.
"""
return self.interior_mask[:, 0].to(torch.bool)

@staticmethod
def expand_to_batch(x, batch_size):
"""
@@ -191,27 +171,20 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

for i in range(pred_steps):
forcing = forcing_features[:, i]
border_state = true_states[:, i]

pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing
)
# state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes,
# d_f) or None

# Overwrite border with true state
new_state = (
self.boundary_mask * border_state
+ self.interior_mask * pred_state
)

prediction_list.append(new_state)
prediction_list.append(pred_state)
if self.output_std:
pred_std_list.append(pred_std)

# Update conditioning states
prev_prev_state = prev_state
prev_state = new_state
prev_state = pred_state

prediction = torch.stack(
prediction_list, dim=1
@@ -249,12 +222,14 @@ def training_step(self, batch):
"""
prediction, target, pred_std, _ = self.common_step(batch)

# Compute loss
# Compute loss - mean over unrolled times and batch
batch_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target,
pred_std,
)
) # mean over unrolled times and batch
)

log_dict = {"train_loss": batch_loss}
self.log_dict(
@@ -287,9 +262,7 @@ def validation_step(self, batch, batch_idx):
prediction, target, pred_std, _ = self.common_step(batch)

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
),
self.loss(prediction, target, pred_std),
dim=0,
) # (time_steps-1)
mean_loss = torch.mean(time_step_loss)
@@ -314,7 +287,6 @@ def validation_step(self, batch, batch_idx):
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.val_metrics["mse"].append(entry_mses)
@@ -341,9 +313,7 @@ def test_step(self, batch, batch_idx):
# pred_steps, num_grid_nodes, d_f) or (d_f,)

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
),
self.loss(prediction, target, pred_std),
dim=0,
) # (time_steps-1,)
mean_loss = torch.mean(time_step_loss)
@@ -372,16 +342,13 @@ def test_step(self, batch, batch_idx):
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.test_metrics[metric_name].append(batch_metric_vals)

if self.output_std:
# Store output std. per variable, spatially averaged
mean_pred_std = torch.mean(
pred_std[..., self.interior_mask_bool, :], dim=-2
) # (B, pred_steps, d_f)
mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f)
self.test_metrics["output_std"].append(mean_pred_std)

# Save per-sample spatial loss for specific times
16 changes: 0 additions & 16 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
@@ -86,13 +86,6 @@ def plot_prediction(

extent = datastore.get_xy_extent("state")

# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region

fig, axes = plt.subplots(
1,
2,
@@ -112,7 +105,6 @@ def plot_prediction(
data_grid,
origin="lower",
extent=extent,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
cmap="plasma",
@@ -147,13 +139,6 @@ def plot_spatial_error(

extent = datastore.get_xy_extent("state")

# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region

fig, ax = plt.subplots(
figsize=(5, 4.8),
subplot_kw={"projection": datastore.coords_projection},
@@ -170,7 +155,6 @@ def plot_spatial_error(
error_grid,
origin="lower",
extent=extent,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
cmap="OrRd",
22 changes: 0 additions & 22 deletions tests/dummy_datastore.py
Original file line number Diff line number Diff line change
@@ -148,12 +148,6 @@ def __init__(
times = [self.T0 + dt * i for i in range(n_timesteps)]
self.ds.coords["time"] = times

# Add boundary mask
self.ds["boundary_mask"] = xr.DataArray(
np.random.choice([0, 1], size=(n_points_1d, n_points_1d)),
dims=["x", "y"],
)

# Stack the spatial dimensions into grid_index
self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS)

@@ -342,22 +336,6 @@ def get_dataarray(
dim_order = self.expected_dim_order(category=category)
return self.ds[category].transpose(*dim_order)

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""
Return the boundary mask for the dataset, with spatial dimensions
stacked. Where the value is 1, the grid point is a boundary point, and
where the value is 0, the grid point is not a boundary point.

Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions
`('grid_index',)`.

"""
return self.ds["boundary_mask"]

def get_xy(self, category: str, stacked: bool) -> ndarray:
"""Return the x, y coordinates of the dataset.

21 changes: 0 additions & 21 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
@@ -18,8 +18,6 @@
dataarray for the given category.
- `get_dataarray` (method): Return the processed data (as a single
`xr.DataArray`) for the given category and test/train/val-split.
- `boundary_mask` (property): Return the boundary mask for the dataset,
with spatial dimensions stacked.
- `config` (property): Return the configuration of the datastore.

In addition BaseRegularGridDatastore must have the following methods and
@@ -213,25 +211,6 @@ def test_get_dataarray(datastore_name):
assert n_features["train"] == n_features["val"] == n_features["test"]


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_boundary_mask(datastore_name):
"""Check that the `datastore.boundary_mask` property is implemented and
that the returned object is an xarray DataArray with the correct shape."""
datastore = init_datastore_example(datastore_name)
da_mask = datastore.boundary_mask

assert isinstance(da_mask, xr.DataArray)
assert set(da_mask.dims) == {"grid_index"}
assert da_mask.dtype == "int"
assert set(da_mask.values) == {0, 1}
assert da_mask.sum() > 0
assert da_mask.sum() < da_mask.size

if isinstance(datastore, BaseRegularGridDatastore):
grid_shape = datastore.grid_shape_state
assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_xy_extent(datastore_name):
"""Check that the `datastore.get_xy_extent` method is implemented and that