Skip to content

Commit

Permalink
Compatibility with static (non temporal) datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Oct 15, 2024
1 parent 786e256 commit 125e4f0
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions icu_benchmarks/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ def __init__(
# data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1)
# )
# Get the row indicators for the data to be able to match predicted labels
self.row_indicators = data[split][Segment.features][self.vars["GROUP"], self.vars["SEQUENCE"]]
self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours())
self.features_df = data[split][Segment.features].drop(self.vars["SEQUENCE"])
if "SEQUENCE" in self.vars and self.vars["SEQUENCE"] in data[split][Segment.features].columns:
# We have a time series dataset
self.row_indicators = data[split][Segment.features][self.vars["GROUP"], self.vars["SEQUENCE"]]
self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours())
self.features_df = data[split][Segment.features]
self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]])
self.features_df = self.features_df.drop(self.vars["SEQUENCE"])
self.grouping_df = self.grouping_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]])
else:
# We have a static dataset
logging.info("Using static dataset")
self.row_indicators = data[split][Segment.features][self.vars["GROUP"]]
self.features_df = data[split][Segment.features]
# calculate basic info for the data
self.num_stays = self.grouping_df[self.vars["GROUP"]].unique().shape[0]
self.maxlen = self.features_df.group_by([self.vars["GROUP"]]).len().max().item(0, 1)
Expand Down Expand Up @@ -151,12 +161,19 @@ def get_data_and_labels(self) -> Tuple[np.array, np.array, np.array]:
"""
labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float)
rep = self.features_df

if len(labels) == self.num_stays:
# order of groups could be random, we make sure not to change it
# rep = rep.groupby(level=self.vars["GROUP"], sort=False).last()
rep = rep.group_by(self.vars["GROUP"]).last()
else:
# Adding segment count for each stay id and timestep.
rep = rep.with_columns(
pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")
)
rep = rep.to_numpy().astype(float)

logging.debug(f"rep shape: {rep.shape}")
logging.debug(f"labels shape: {labels.shape}")
return rep, labels, self.row_indicators.to_numpy()

def to_tensor(self):
Expand Down

0 comments on commit 125e4f0

Please sign in to comment.