Skip to content

Commit

Permalink
fix: fix index issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean1572 committed Jul 11, 2024
1 parent c9a4606 commit 41aef87
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pyha_analyzer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
cfg: config.Config,
onehot:bool = False,
) -> None:
self.samples = df[~(df[cfg.file_name_col].isnull())]
self.samples = df[~(df[cfg.file_name_col].isnull())].reset_index()
if onehot:
if self.samples.iloc[0][species].shape[0] != len(species):
logger.error(species)
Expand Down Expand Up @@ -279,7 +279,7 @@ def __getitem__(self, index): #-> Any:
target = self.samples.loc[index, self.classes].values.astype(np.int32)
target = torch.Tensor(target)

return image, target
return image, target, index

def get_num_classes(self) -> int:
""" Returns number of classes
Expand Down
6 changes: 3 additions & 3 deletions pyha_analyzer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def train(model: TimmModel,
log_pred = []
log_labels = []

for i, (mels, labels) in enumerate(data_loader):
for i, (mels, labels, idxs) in enumerate(data_loader):

optimizer.zero_grad()

Expand Down Expand Up @@ -209,7 +209,7 @@ def valid(model: Any,
dl_iter = tqdm(data_loader, position=5, total=num_valid_samples)

with torch.no_grad():
for index, (mels, labels) in enumerate(dl_iter):
for index, (mels, labels, idxs) in enumerate(dl_iter):
if index > num_valid_samples:
# Stop early if not doing full validation
break
Expand Down Expand Up @@ -271,7 +271,7 @@ def inference_valid(model: Any,
dl_iter = tqdm(data_loader, position=5, total=num_valid_samples)

with torch.no_grad():
for _, (mels, labels) in enumerate(dl_iter):
for _, (mels, labels, idxs) in enumerate(dl_iter):
_, outputs = run_batch(model, mels, labels)
log_pred.append(torch.clone(outputs.cpu()).detach())
log_label.append(torch.clone(labels.cpu()).detach())
Expand Down

0 comments on commit 41aef87

Please sign in to comment.