From d355ef5030c80faa50e4f440cddee052597dd53b Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 12 Nov 2024 21:00:01 +0100 Subject: [PATCH] bugfix for earlier unstacking dim order fix in datastores --- neural_lam/datastore/base.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 25e7d01..418b3b3 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -466,19 +466,27 @@ def stack_grid_coords( da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) # find the feature dimension, which has named with the format # `{category}_feature` + + # ensure that grid_index is the first dimension, and the feature + # dimension is the second + dim_order = ["grid_index"] + potential_feature_dims = [ - d for d in da_or_ds_stacked.dims if d.endswith("_feature") + d for d in da_or_ds_stacked.dims if d.endswith("feature") ] - if not len(potential_feature_dims) == 1: + n_feature_dims = len(potential_feature_dims) + if n_feature_dims == 0: + pass + elif n_feature_dims == 1: + feature_dim = potential_feature_dims[0] + dim_order.append(feature_dim) + else: raise ValueError( "Expected exactly one feature dimension in the stacked data, " f"got {potential_feature_dims}" ) - feature_dim = potential_feature_dims[0] - # ensure that grid_index is the first dimension, and the feature - # dimension is the second - return da_or_ds_stacked.transpose("grid_index", feature_dim, ...) + return da_or_ds_stacked.transpose(*dim_order, ...) @property @functools.lru_cache