Skip to content

Commit

Permalink
Routine updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
ATATC committed Sep 14, 2024
1 parent 6e454c1 commit 1ebda1e
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions dataset/cxr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, datamodule_cfg, dataloader_init_args):
self.dataloader_init_args = dataloader_init_args
if self.cfg["use_pseudo_label"]:
print("Using pseudo label")
self.vin_df = pd.read_csv(self.cfg["vinbig_pseudo_train_df_path"])
# self.vin_df = pd.read_csv(self.cfg["vinbig_pseudo_train_df_path"])
self.nih_df = pd.read_csv(self.cfg["nih_pseudo_train_df_path"])
self.chexpert_df = pd.read_csv(self.cfg["chexpert_pseudo_train_df_path"])

Expand All @@ -37,13 +37,14 @@ def setup(self, stage):
self.val_dataset = CxrStudyIdDataset(self.cfg, val_df, transforms_val)

if self.cfg["use_pseudo_label"]:
vin_dataset = VinDataset(self.cfg, self.vin_df, transforms_train)
# vin_dataset = VinDataset(self.cfg, self.vin_df, transforms_train)
nih_dataset = NihDataset(self.cfg, self.nih_df, transforms_train)
chexpert_dataset = ChexpertDataset(self.cfg, self.chexpert_df, transforms_train)
print(f"vin len: {len(vin_dataset)}")
# print(f"vin len: {len(vin_dataset)}")
print(f"nih len: {len(nih_dataset)}")
print(f"chexpert len: {len(chexpert_dataset)}")
self.train_dataset = ConcatDataset([self.train_dataset, vin_dataset, nih_dataset, chexpert_dataset])
# self.train_dataset = ConcatDataset([self.train_dataset, vin_dataset, nih_dataset, chexpert_dataset])
self.train_dataset = ConcatDataset([self.train_dataset, nih_dataset, chexpert_dataset])

print(f"train len: {len(self.train_dataset)}")
print(f"val len: {len(self.val_dataset)}")
Expand Down

0 comments on commit 1ebda1e

Please sign in to comment.