diff --git a/icu_benchmarks/data/preprocessor.py b/icu_benchmarks/data/preprocessor.py index 27f847cb..a00303b5 100644 --- a/icu_benchmarks/data/preprocessor.py +++ b/icu_benchmarks/data/preprocessor.py @@ -82,41 +82,55 @@ def apply(self, data, vars) -> dict[dict[pl.DataFrame]]: Returns: Preprocessed data. """ - logging.info("Preprocessing dynamic features.") - - data = self._process_dynamic(data, vars) - if self.use_static_features and len(vars[Segment.static]) > 0: + # Check if dynamic features are present + if self.use_static_features and all(Segment.static in value for value in data.values()) and len(vars[Segment.static]) > 0: logging.info("Preprocessing static features.") data = self._process_static(data, vars) + else: + self.use_static_features = False + + if all(Segment.dynamic in value for value in data.values()): + logging.info("Preprocessing dynamic features.") + logging.info(data.keys()) + data = self._process_dynamic(data, vars) + if self.use_static_features: + # Join static and dynamic data. + data[Split.train][Segment.dynamic] = data[Split.train][Segment.dynamic].join( + data[Split.train][Segment.static], on=vars["GROUP"] + ) + data[Split.val][Segment.dynamic] = data[Split.val][Segment.dynamic].join( + data[Split.val][Segment.static], on=vars["GROUP"] + ) + data[Split.test][Segment.dynamic] = data[Split.test][Segment.dynamic].join( + data[Split.test][Segment.static], on=vars["GROUP"] + ) - # Set index to grouping variable - data[Split.train][Segment.static] = data[Split.train][Segment.static]#.set_index(vars["GROUP"]) - data[Split.val][Segment.static] = data[Split.val][Segment.static]#.set_index(vars["GROUP"]) - data[Split.test][Segment.static] = data[Split.test][Segment.static]#.set_index(vars["GROUP"]) + # Remove static features from splits + data[Split.train][Segment.features] = data[Split.train].pop(Segment.static) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.static) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.static) - # Join static and dynamic data. - data[Split.train][Segment.dynamic] = data[Split.train][Segment.dynamic].join( - data[Split.train][Segment.static], on=vars["GROUP"] - ) - data[Split.val][Segment.dynamic] = data[Split.val][Segment.dynamic].join( - data[Split.val][Segment.static], on=vars["GROUP"] - ) - data[Split.test][Segment.dynamic] = data[Split.test][Segment.dynamic].join( - data[Split.test][Segment.static], on=vars["GROUP"] - ) - - # Remove static features from splits + # Create feature splits + data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic) + elif self.use_static_features: data[Split.train][Segment.features] = data[Split.train].pop(Segment.static) data[Split.val][Segment.features] = data[Split.val].pop(Segment.static) data[Split.test][Segment.features] = data[Split.test].pop(Segment.static) - - # Create feature splits - data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic) - data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic) - data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic) - + else: + raise Exception(f"No recognized data segments data to preprocess. Available: {data.keys()}") logging.debug("Data head") logging.debug(data[Split.train][Segment.features].head()) + logging.debug(data[Split.train][Segment.outcome]) + for split in [Split.train, Split.val, Split.test]: + if vars["SEQUENCE"] in data[split][Segment.outcome] and len(data[split][Segment.features]) != len(data[split][Segment.outcome]): + raise Exception(f"Data and outcome length mismatch in {split} split: " + f"features: {len(data[split][Segment.features])}, outcome: {len(data[split][Segment.outcome])}") + data[Split.train][Segment.features] = data[Split.train][Segment.features].unique() + data[Split.val][Segment.features] = data[Split.val][Segment.features].unique() + data[Split.test][Segment.features] = data[Split.test][Segment.features].unique() + logging.info(f"Generate features: {self.generate_features}") return data @@ -321,6 +335,7 @@ def apply(self, data, vars) -> dict[dict[pd.DataFrame]]: logging.debug("Data head") logging.debug(data[Split.train][Segment.features].head()) + logging.debug(data[Split.train][Segment.outcome].head()) logging.info(f"Generate features: {self.generate_features}") return data