Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Apr 18, 2024
1 parent ae17352 commit 704ca82
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/utils/datasets/activation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, layer_dir, device="cpu"):
def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, ...]]:
assert idx < len(self.files), "Layer index is out of bounds!"
fl = self.files[idx]
av = torch.load(fl)
av = torch.load(fl, map_location=self.device)
return av

def __len__(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions src/utils/datasets/mark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def __init__(
cache_path: str = "./datasets",
p: float = 0.3,
cls_to_mark: int = 2,
only_train: bool = False,
mark_fn: Optional[Union[Callable, str]]
mark_fn: Optional[Union[Callable, str]]=None,
only_train: bool = False
):
super().__init__()
self.dataset = dataset
Expand All @@ -31,7 +31,7 @@ def __init__(
self.mark_image=mark_fn
else:
self.mark_image=self.mark_image_contour_and_square

if IC.exists(path=cache_path, file_id=f"{dataset_id}_mark_ids"):
self.mark_indices = IC.load(path="./datasets", file_id=f"{dataset_id}_mark_ids")
else:
Expand Down

0 comments on commit 704ca82

Please sign in to comment.