Skip to content

Commit

Permalink
Preprocessing improvements and failsafe when using just static features
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Oct 15, 2024
1 parent 34e0573 commit ab3ab4b
Showing 1 changed file with 41 additions and 26 deletions.
67 changes: 41 additions & 26 deletions icu_benchmarks/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,41 +82,55 @@ def apply(self, data, vars) -> dict[dict[pl.DataFrame]]:
Returns:
Preprocessed data.
"""
logging.info("Preprocessing dynamic features.")

data = self._process_dynamic(data, vars)
if self.use_static_features and len(vars[Segment.static]) > 0:
# Check if dynamic features are present
if self.use_static_features and all(Segment.static in value for value in data.values()) and len(vars[Segment.static]) > 0:
logging.info("Preprocessing static features.")
data = self._process_static(data, vars)
else:
self.use_static_features = False

if all(Segment.dynamic in value for value in data.values()):
logging.info("Preprocessing dynamic features.")
logging.info(data.keys())
data = self._process_dynamic(data, vars)
if self.use_static_features:
# Join static and dynamic data.
data[Split.train][Segment.dynamic] = data[Split.train][Segment.dynamic].join(
data[Split.train][Segment.static], on=vars["GROUP"]
)
data[Split.val][Segment.dynamic] = data[Split.val][Segment.dynamic].join(
data[Split.val][Segment.static], on=vars["GROUP"]
)
data[Split.test][Segment.dynamic] = data[Split.test][Segment.dynamic].join(
data[Split.test][Segment.static], on=vars["GROUP"]
)

# Set index to grouping variable
data[Split.train][Segment.static] = data[Split.train][Segment.static]#.set_index(vars["GROUP"])
data[Split.val][Segment.static] = data[Split.val][Segment.static]#.set_index(vars["GROUP"])
data[Split.test][Segment.static] = data[Split.test][Segment.static]#.set_index(vars["GROUP"])
# Remove static features from splits
data[Split.train][Segment.features] = data[Split.train].pop(Segment.static)
data[Split.val][Segment.features] = data[Split.val].pop(Segment.static)
data[Split.test][Segment.features] = data[Split.test].pop(Segment.static)

# Join static and dynamic data.
data[Split.train][Segment.dynamic] = data[Split.train][Segment.dynamic].join(
data[Split.train][Segment.static], on=vars["GROUP"]
)
data[Split.val][Segment.dynamic] = data[Split.val][Segment.dynamic].join(
data[Split.val][Segment.static], on=vars["GROUP"]
)
data[Split.test][Segment.dynamic] = data[Split.test][Segment.dynamic].join(
data[Split.test][Segment.static], on=vars["GROUP"]
)

# Remove static features from splits
# Create feature splits
data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic)
data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic)
data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic)
elif self.use_static_features:
data[Split.train][Segment.features] = data[Split.train].pop(Segment.static)
data[Split.val][Segment.features] = data[Split.val].pop(Segment.static)
data[Split.test][Segment.features] = data[Split.test].pop(Segment.static)

# Create feature splits
data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic)
data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic)
data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic)

else:
raise Exception(f"No recognized data segments data to preprocess. Available: {data.keys()}")
logging.debug("Data head")
logging.debug(data[Split.train][Segment.features].head())
logging.debug(data[Split.train][Segment.outcome])
for split in [Split.train, Split.val, Split.test]:
if vars["SEQUENCE"] in data[split][Segment.outcome] and len(data[split][Segment.features]) != len(data[split][Segment.outcome]):
raise Exception(f"Data and outcome length mismatch in {split} split: "
f"features: {len(data[split][Segment.features])}, outcome: {len(data[split][Segment.outcome])}")
data[Split.train][Segment.features] = data[Split.train][Segment.features].unique()
data[Split.val][Segment.features] = data[Split.val][Segment.features].unique()
data[Split.test][Segment.features] = data[Split.test][Segment.features].unique()

logging.info(f"Generate features: {self.generate_features}")
return data

Expand Down Expand Up @@ -321,6 +335,7 @@ def apply(self, data, vars) -> dict[dict[pd.DataFrame]]:

logging.debug("Data head")
logging.debug(data[Split.train][Segment.features].head())
logging.debug(data[Split.train][Segment.outcome].head())
logging.info(f"Generate features: {self.generate_features}")
return data

Expand Down

0 comments on commit ab3ab4b

Please sign in to comment.