diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7fca0b6..a6da05c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -70,12 +70,12 @@ def __init__( static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], + static_features_torch[self.interior_mask[:, 0].to(torch.bool)], persistent=False, ) self.register_buffer( "boundary_static_features", - static_features_torch[self.interior_mask[:, 0].to(torch.bool)], + static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], persistent=False, )