From 9edfec37af343be4675de402a1b7d11f7731ddd7 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 2 Dec 2024 16:33:41 +0100 Subject: [PATCH] Fix boundary masking bug for static features --- neural_lam/models/ar_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 1a24136..ceadb85 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -69,12 +69,12 @@ def __init__( static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], + static_features_torch[self.interior_mask[:, 0].to(torch.bool)], persistent=False, ) self.register_buffer( "boundary_static_features", - static_features_torch[self.interior_mask[:, 0].to(torch.bool)], + static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], persistent=False, )