Skip to content

Commit

Permalink
Fix verif leadtime hours and misc
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Jan 20, 2025
1 parent 9e76ecb commit 00e348b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
9 changes: 4 additions & 5 deletions bris/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def __init__(
self.forecast_length = forecast_length
self.latitudes = data_reader.latitudes
self.longitudes = data_reader.longitudes
# self.required_variables = required_variables[0] # Assume we only have one decoder

# this makes it backwards compatible with older
# anemoi-models versions. I.e legendary gnome, etc..
Expand All @@ -134,7 +133,7 @@ def __init__(
self.release_cache = release_cache


def set_static_forcings(self, data_reader, selection):
def set_static_forcings(self, data_reader, selection) -> None:

self.static_forcings = {}
data = torch.from_numpy(data_reader[0].squeeze(axis=1).swapaxes(0,1))
Expand Down Expand Up @@ -166,7 +165,7 @@ def set_static_forcings(self, data_reader, selection):

del data_normalized

def set_variable_indices(self, required_variables: list):
def set_variable_indices(self, required_variables: list) -> None:
required_variables = required_variables[0] #Assume one decoder
variable_indices_input = list()
variable_indices_output = list()
Expand All @@ -183,7 +182,7 @@ def set_variable_indices(self, required_variables: list):
def forward(self, x: torch.Tensor)-> torch.Tensor:
return self.model(x, self.model_comm_group)

def advance_input_predict(self, x, y_pred, time):
def advance_input_predict(self, x: torch.Tensor, y_pred: torch.Tensor, time: np.datetime64) -> torch.Tensor:
x = x.roll(-1, dims=1)

#Get prognostic variables:
Expand All @@ -200,7 +199,7 @@ def advance_input_predict(self, x, y_pred, time):
return x

@torch.inference_mode
def predict_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
def predict_step(self, batch: tuple, batch_idx: int) -> dict:

multistep = self.metadata.config.training.multistep_input

Expand Down
16 changes: 8 additions & 8 deletions bris/outputs/verif.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
)

self.ipoints = gridpp.Points(self.pm.lats, self.pm.lons, self.pm.altitudes)
self.ipoints_tuple = np.column_stack((self.pm.lats, self.pm.lons))
self.ipoints_array = np.column_stack((self.pm.lats, self.pm.lons))
self.ialtitudes = self.pm.altitudes

obs_lats = list()
Expand All @@ -78,13 +78,13 @@ def __init__(
self.obs_altitudes = np.array(obs_altitudes, np.float32)
self.obs_ids = np.array(obs_ids, np.int32)
self.opoints = gridpp.Points(self.obs_lats, self.obs_lons, self.obs_altitudes)
self.opoints_tuple = np.column_stack((self.obs_lats, self.obs_lons))
self.opoints_array = np.column_stack((self.obs_lats, self.obs_lons))

self.triangulation = self.ipoints_tuple
if not self._is_gridded_input and self.ipoints_tuple.shape[0] > 3:
self.triangulation = self.ipoints_array
if not self._is_gridded_input and self.ipoints_array.shape[0] > 3:
# This speeds up interpolation from irregular points to observation points
# but Delaunay needs enough points for this to work
self.triangulation = Delaunay(self.ipoints_tuple)
self.triangulation = Delaunay(self.ipoints_array)

# The intermediate will only store the final output locations
intermediate_pm = PredictMetadata(
Expand Down Expand Up @@ -129,7 +129,7 @@ def _add_forecast(self, times: list, ensemble_member: int, pred: np.array):
interpolator = scipy.interpolate.LinearNDInterpolator(
self.triangulation, self.ialtitudes
)
interpolated_altitudes = interpolator(self.opoints_tuple)
interpolated_altitudes = interpolator(self.opoints_array)
altitude_correction = self.opoints.get_elevs() - interpolated_altitudes

num_leadtimes = pred.shape[0]
Expand All @@ -142,7 +142,7 @@ def _add_forecast(self, times: list, ensemble_member: int, pred: np.array):
interpolator = scipy.interpolate.LinearNDInterpolator(
self.triangulation, pred[lt, :, 0]
)
interpolated_pred[lt, :, 0] = interpolator(self.opoints_tuple)
interpolated_pred[lt, :, 0] = interpolator(self.opoints_array)
if altitude_correction is not None:
interpolated_pred[lt, :, 0] += (
self.elev_gradient * altitude_correction
Expand Down Expand Up @@ -183,7 +183,7 @@ def finalize(self):
coords["time"] = (["time"], [], cf.get_attributes("time"))
coords["leadtime"] = (
["leadtime"],
self.intermediate.pm.leadtimes.astype(np.float32),
self.intermediate.pm.leadtimes.astype(np.float32) / 3600,
{"units": "hour"},
)
coords["location"] = (["location"], self.obs_ids)
Expand Down

0 comments on commit 00e348b

Please sign in to comment.