From d3ea1031237544850222c9fb2fe10baabaaf93dc Mon Sep 17 00:00:00 2001 From: ATATC Date: Fri, 13 Sep 2024 23:32:02 -0400 Subject: [PATCH] Routine updates. --- dataset/cxr_datamodule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dataset/cxr_datamodule.py b/dataset/cxr_datamodule.py index 9bb7492..4d1be5c 100644 --- a/dataset/cxr_datamodule.py +++ b/dataset/cxr_datamodule.py @@ -29,7 +29,9 @@ def setup(self, stage): # split train/val msss = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=self.cfg["val_split"], random_state=self.cfg["seed"]) - print(self.cfg["classes"], self.df[self.cfg["classes"]]) + print(self.cfg["classes"]) + print(self.df) + print(self.df[self.cfg["classes"]]) train_idx, val_idx = next(msss.split(self.df, self.df[self.cfg["classes"]].values)) train_df = self.df.iloc[train_idx] val_df = self.df.iloc[val_idx]