diff --git a/src/utils/datasets/corrupt_label_dataset.py b/src/utils/datasets/corrupt_label_dataset.py index f9802d19..c0a0a350 100644 --- a/src/utils/datasets/corrupt_label_dataset.py +++ b/src/utils/datasets/corrupt_label_dataset.py @@ -26,7 +26,8 @@ def __init__( assert hasattr(dataset, "device") self.device = dataset.device - if IC.exists(path=cache_path, file_id=f"{dataset_id}_corrupt_ids") and IC.exists(path=cache_path, file_id=f"{dataset_id}_corrupt_labels"): + if IC.exists(path=cache_path, file_id=f"{dataset_id}_corrupt_ids") and \ + IC.exists(path=cache_path, file_id=f"{dataset_id}_corrupt_labels"): self.corrupt_indices = IC.load(path=cache_path, file_id=f"{dataset_id}_corrupt_ids") self.corrupt_labels = IC.load(path=cache_path, file_id=f"{dataset_id}_corrupt_labels", device=self.device) else: