From 86d2c3f0ae0ca0c053a20ff3ba07d3cc7b837efa Mon Sep 17 00:00:00 2001 From: luvargas2 Date: Thu, 25 Apr 2024 14:47:28 -0400 Subject: [PATCH 1/2] Add all other class confidences to the result --- PytorchWildlife/models/classification/resnet/amazon.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/PytorchWildlife/models/classification/resnet/amazon.py b/PytorchWildlife/models/classification/resnet/amazon.py index c427fd518..c82a0f03c 100644 --- a/PytorchWildlife/models/classification/resnet/amazon.py +++ b/PytorchWildlife/models/classification/resnet/amazon.py @@ -93,6 +93,8 @@ def results_generation(self, logits, img_ids, id_strip=None): probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] + confidences = probs[0].tolist() + result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] results = [] for pred, img_id, conf in zip(preds, img_ids, confs): @@ -100,6 +102,7 @@ def results_generation(self, logits, img_ids, id_strip=None): r["prediction"] = self.CLASS_NAMES[pred.item()] r["class_id"] = pred.item() r["confidence"] = conf.item() + r["all_confidences"] = result results.append(r) return results From c2841a316e2587f9ec0dfb9b0ea1b93c3a3907fb Mon Sep 17 00:00:00 2001 From: luvargas2 Date: Thu, 25 Apr 2024 15:08:44 -0400 Subject: [PATCH 2/2] Add all other class confidences to the result --- PytorchWildlife/models/classification/resnet/serengeti.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/PytorchWildlife/models/classification/resnet/serengeti.py b/PytorchWildlife/models/classification/resnet/serengeti.py index 3c32652e7..fa7671d51 100644 --- a/PytorchWildlife/models/classification/resnet/serengeti.py +++ b/PytorchWildlife/models/classification/resnet/serengeti.py @@ -67,6 +67,8 @@ def results_generation(self, logits, img_ids, id_strip=None): probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] + confidences = probs[0].tolist() + result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] results = [] for pred, img_id, conf in zip(preds, img_ids, confs): @@ -74,6 +76,7 @@ def results_generation(self, logits, img_ids, id_strip=None): r["prediction"] = self.CLASS_NAMES[pred.item()] r["class_id"] = pred.item() r["confidence"] = conf.item() + r["all_confidences"] = result results.append(r) return results