diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index a303a478..eb544d64 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -34,9 +34,19 @@ def __init__( # data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1) # ) # Get the row indicators for the data to be able to match predicted labels - self.row_indicators = data[split][Segment.features][self.vars["GROUP"], self.vars["SEQUENCE"]] - self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours()) - self.features_df = data[split][Segment.features].drop(self.vars["SEQUENCE"]) + if "SEQUENCE" in self.vars and self.vars["SEQUENCE"] in data[split][Segment.features].columns: + # We have a time series dataset + self.row_indicators = data[split][Segment.features][self.vars["GROUP"], self.vars["SEQUENCE"]] + self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours()) + self.features_df = data[split][Segment.features] + self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) + self.features_df = self.features_df.drop(self.vars["SEQUENCE"]) + self.grouping_df = self.grouping_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) + else: + # We have a static dataset + logging.info("Using static dataset") + self.row_indicators = data[split][Segment.features][self.vars["GROUP"]] + self.features_df = data[split][Segment.features] # calculate basic info for the data self.num_stays = self.grouping_df[self.vars["GROUP"]].unique().shape[0] self.maxlen = self.features_df.group_by([self.vars["GROUP"]]).len().max().item(0, 1) @@ -151,12 +161,19 @@ def get_data_and_labels(self) -> Tuple[np.array, np.array, np.array]: """ labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) rep = self.features_df + if len(labels) == self.num_stays: # order of groups could be random, we make sure not to change it # rep = rep.groupby(level=self.vars["GROUP"], sort=False).last() rep = rep.group_by(self.vars["GROUP"]).last() + else: + # Adding segment count for each stay id and timestep. + rep = rep.with_columns( + pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter") + ) rep = rep.to_numpy().astype(float) - + logging.debug(f"rep shape: {rep.shape}") + logging.debug(f"labels shape: {labels.shape}") return rep, labels, self.row_indicators.to_numpy() def to_tensor(self):