diff --git a/dataset/cxr_datamodule.py b/dataset/cxr_datamodule.py index 9bb7492..4d1be5c 100644 --- a/dataset/cxr_datamodule.py +++ b/dataset/cxr_datamodule.py @@ -29,7 +29,9 @@ def setup(self, stage): # split train/val msss = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=self.cfg["val_split"], random_state=self.cfg["seed"]) - print(self.cfg["classes"], self.df[self.cfg["classes"]]) + print(self.cfg["classes"]) + print(self.df) + print(self.df[self.cfg["classes"]]) train_idx, val_idx = next(msss.split(self.df, self.df[self.cfg["classes"]].values)) train_df = self.df.iloc[train_idx] val_df = self.df.iloc[val_idx]