diff --git a/dataset/cxr_datamodule.py b/dataset/cxr_datamodule.py index b721a0a..9e4c84b 100644 --- a/dataset/cxr_datamodule.py +++ b/dataset/cxr_datamodule.py @@ -19,7 +19,7 @@ def __init__(self, datamodule_cfg, dataloader_init_args): self.dataloader_init_args = dataloader_init_args if self.cfg["use_pseudo_label"]: print("Using pseudo label") - self.vin_df = pd.read_csv(self.cfg["vinbig_pseudo_train_df_path"]) + # self.vin_df = pd.read_csv(self.cfg["vinbig_pseudo_train_df_path"]) self.nih_df = pd.read_csv(self.cfg["nih_pseudo_train_df_path"]) self.chexpert_df = pd.read_csv(self.cfg["chexpert_pseudo_train_df_path"]) @@ -37,13 +37,14 @@ def setup(self, stage): self.val_dataset = CxrStudyIdDataset(self.cfg, val_df, transforms_val) if self.cfg["use_pseudo_label"]: - vin_dataset = VinDataset(self.cfg, self.vin_df, transforms_train) + # vin_dataset = VinDataset(self.cfg, self.vin_df, transforms_train) nih_dataset = NihDataset(self.cfg, self.nih_df, transforms_train) chexpert_dataset = ChexpertDataset(self.cfg, self.chexpert_df, transforms_train) - print(f"vin len: {len(vin_dataset)}") + # print(f"vin len: {len(vin_dataset)}") print(f"nih len: {len(nih_dataset)}") print(f"chexpert len: {len(chexpert_dataset)}") - self.train_dataset = ConcatDataset([self.train_dataset, vin_dataset, nih_dataset, chexpert_dataset]) + # self.train_dataset = ConcatDataset([self.train_dataset, vin_dataset, nih_dataset, chexpert_dataset]) + self.train_dataset = ConcatDataset([self.train_dataset, nih_dataset, chexpert_dataset]) print(f"train len: {len(self.train_dataset)}") print(f"val len: {len(self.val_dataset)}")