diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 7bc390f..9491f88 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -442,11 +442,11 @@ def train( train_ids, test_ids = train_test_split(list_ids, train_size=train_size) # Generators - training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids, device=self.device) + training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=train_ids) training_generator = torch.utils.data.DataLoader( training_set, batch_size=batch_size, shuffle=shuffle ) - val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids, device=self.device) + val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=test_ids) val_generator = torch.utils.data.DataLoader( val_set, batch_size=batch_size, shuffle=shuffle )