Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group metrics by labels #245

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions nucleus/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterable, List
from typing import Dict, Iterable, List

from nucleus.annotation import AnnotationList
from nucleus.prediction import PredictionList
Expand All @@ -10,6 +10,16 @@
class MetricResult(ABC):
"""Base MetricResult class"""

@property
@abstractmethod
def results(self) -> Dict[str, float]:
"""Interface for item results"""

@property
def extra_info(self) -> Dict[str, str]:
"""Overload this to pass extra info about the item to show in the UI"""
return {}


@dataclass
class ScalarResult(MetricResult):
Expand All @@ -27,6 +37,14 @@ class ScalarResult(MetricResult):
value: float
weight: float = 1.0

@property
def results(self) -> Dict[str, float]:
return {"value": self.value}

@property
def extra_info(self) -> Dict[str, str]:
return {"weight:": str(self.weight)}

@staticmethod
def aggregate(results: Iterable["ScalarResult"]) -> "ScalarResult":
"""Aggregates results using a weighted average."""
Expand All @@ -37,6 +55,22 @@ def aggregate(results: Iterable["ScalarResult"]) -> "ScalarResult":
return ScalarResult(value, total_weight)


@dataclass
class GroupedScalarResult(MetricResult):
group_to_scalar: Dict[str, ScalarResult]

@property
def results(self) -> Dict[str, float]:
group_results = {
group: scalar.value
for group, scalar in self.group_to_scalar.items()
}
group_results["all_groups"] = ScalarResult.aggregate(
self.group_to_scalar.values()
).value
return group_results


class Metric(ABC):
"""Abstract class for defining a metric, which takes a list of annotations
and predictions and returns a scalar.
Expand Down Expand Up @@ -93,7 +127,9 @@ def __call__(
"""A metric must override this method and return a metric result, given annotations and predictions."""

@abstractmethod
def aggregate_score(self, results: List[MetricResult]) -> ScalarResult:
def aggregate_score(
self, results: List[MetricResult]
) -> Dict[str, ScalarResult]:
"""A metric must define how to aggregate results from single items to a single ScalarResult.

E.g. to calculate a R2 score with sklearn you could define a custom metric class ::
Expand Down
41 changes: 30 additions & 11 deletions nucleus/metrics/categorization_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import List, Set, Tuple, Union
from typing import Dict, List, Set, Tuple, Union

from sklearn.metrics import f1_score

Expand Down Expand Up @@ -33,16 +33,28 @@ class CategorizationResult(MetricResult):
predictions: List[CategoryPrediction]

@property
def value(self):
def results(self) -> Dict[str, float]:
annotation_labels = to_taxonomy_labels(self.annotations)
prediction_labels = to_taxonomy_labels(self.predictions)

# TODO: Change task.py interface such that we can return label matching
# NOTE: Returning 1 if all taxonomy labels match else 0
value = f1_score(
list(annotation_labels), list(prediction_labels), average="macro"
)
return value
results = {
"f1_macro": f1_score(
list(annotation_labels),
list(prediction_labels),
average="macro",
)
}
return results

@property
def extra_info(self) -> Dict[str, str]:
annotation_labels = to_taxonomy_labels(self.annotations)
prediction_labels = to_taxonomy_labels(self.predictions)
return {
"annotations": ", ".join(annotation_labels),
"predictions": ", ".join(prediction_labels),
}


class CategorizationMetric(Metric):
Expand Down Expand Up @@ -80,7 +92,7 @@ def eval(
pass

@abstractmethod
def aggregate_score(self, results: List[CategorizationResult]) -> ScalarResult: # type: ignore[override]
def aggregate_score(self, results: List[CategorizationResult]) -> Dict[str, ScalarResult]: # type: ignore[override]
pass

def __call__(
Expand Down Expand Up @@ -189,11 +201,18 @@ def eval(
annotations=annotations, predictions=predictions
)

def aggregate_score(self, results: List[CategorizationResult]) -> ScalarResult: # type: ignore[override]
def aggregate_score(self, results: List[CategorizationResult]) -> Dict[str, ScalarResult]: # type: ignore[override]
gt = []
predicted = []
for result in results:
gt.extend(list(to_taxonomy_labels(result.annotations)))
predicted.extend(list(to_taxonomy_labels(result.predictions)))
value = f1_score(gt, predicted, average=self.f1_method)
return ScalarResult(value)
aggregate_scores = {}
aggregate_scores["macro"] = f1_score(gt, predicted, average="macro")
aggregate_scores["weighted"] = f1_score(
gt, predicted, average="weighted"
)
return {
result_label: ScalarResult(val)
for result_label, val in aggregate_scores.items()
}
43 changes: 43 additions & 0 deletions nucleus/metrics/label_grouper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any, List

import numpy as np
import pandas as pd


class LabelsGrouper:
def __init__(self, annotations_or_predictions_list: List[Any]):
self.items = annotations_or_predictions_list
if len(self.items) > 0:
assert hasattr(
self.items[0], "label"
), f"Expected items to have attribute 'label' found none on {repr(self.items[0])}"
self.codes, self.labels = pd.factorize(
[item.label for item in self.items]
)
self.group_idx = 0

def __iter__(self):
self.group_idx = 0
return self

def __next__(self):
if self.group_idx >= len(self.labels):
raise StopIteration
label = self.labels[self.group_idx]
label_items = list(
np.take(self.items, np.where(self.codes == self.group_idx)[0])
)
self.group_idx += 1
return label, label_items

def label_group(self, label: str) -> List[Any]:
if len(self.items) == 0:
return []
idx = np.where(self.labels == label)[0]
if idx >= 0:
label_items = list(
np.take(self.items, np.where(self.codes == idx)[0])
)
return label_items
else:
return []
61 changes: 47 additions & 14 deletions nucleus/metrics/polygon_metrics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import sys
from abc import abstractmethod
from typing import List, Union
from collections import defaultdict
from typing import Dict, List, Union

import numpy as np

from nucleus.annotation import AnnotationList, BoxAnnotation, PolygonAnnotation
from nucleus.prediction import BoxPrediction, PolygonPrediction, PredictionList

from .base import Metric, ScalarResult
from .base import GroupedScalarResult, Metric, ScalarResult
from .filters import confidence_filter, polygon_label_filter
from .label_grouper import LabelsGrouper
from .metric_utils import compute_average_precision
from .polygon_utils import (
BoxOrPolygonAnnotation,
Expand Down Expand Up @@ -80,19 +82,44 @@ def eval(

def __init__(
self,
enforce_label_match: bool = False,
enforce_label_match: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:nit: please adjust comment below

confidence_threshold: float = 0.0,
):
"""Initializes PolygonMetric abstract object.

Args:
enforce_label_match: whether to enforce that annotation and prediction labels must match. Default False
enforce_label_match: whether to enforce that annotation and prediction labels must match. Default True
confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
"""
self.enforce_label_match = enforce_label_match
assert 0 <= confidence_threshold <= 1
self.confidence_threshold = confidence_threshold

def eval_grouped(
self,
annotations: List[Union[BoxAnnotation, PolygonAnnotation]],
predictions: List[Union[BoxPrediction, PolygonPrediction]],
) -> GroupedScalarResult:
grouped_annotations = LabelsGrouper(annotations)
grouped_predictions = LabelsGrouper(predictions)
results = {}
for label, label_annotations in grouped_annotations:
# TODO(gunnar): Enforce label match -> Why is that a parameter? Should we generally allow IOU matches
# between different labels?!?
Comment on lines +107 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we should have an option to allow this. E.g. you need to compute matches across the classes for the confusion matrix.

match_predictions = (
grouped_predictions.label_group(label)
if self.enforce_label_match
else predictions
)
eval_fn = label_match_wrapper(self.eval)
result = eval_fn(
label_annotations,
match_predictions,
enforce_label_match=self.enforce_label_match,
)
results[label] = result
return GroupedScalarResult(group_to_scalar=results)

@abstractmethod
def eval(
self,
Expand All @@ -102,12 +129,20 @@ def eval(
# Main evaluation function that subclasses must override.
pass

def aggregate_score(self, results: List[ScalarResult]) -> ScalarResult: # type: ignore[override]
return ScalarResult.aggregate(results)
def aggregate_score(self, results: List[GroupedScalarResult]) -> Dict[str, ScalarResult]: # type: ignore[override]
label_to_values = defaultdict(list)
for item_result in results:
for label, label_result in item_result.group_to_scalar.items():
label_to_values[label].append(label_result)
scores = {
label: ScalarResult.aggregate(values)
for label, values in label_to_values.items()
}
return scores

def __call__(
self, annotations: AnnotationList, predictions: PredictionList
) -> ScalarResult:
) -> GroupedScalarResult:
if self.confidence_threshold > 0:
predictions = confidence_filter(
predictions, self.confidence_threshold
Expand All @@ -119,11 +154,9 @@ def __call__(
polygon_predictions.extend(predictions.box_predictions)
polygon_predictions.extend(predictions.polygon_predictions)

eval_fn = label_match_wrapper(self.eval)
result = eval_fn(
result = self.eval_grouped(
polygon_annotations,
polygon_predictions,
enforce_label_match=self.enforce_label_match,
)
return result

Expand Down Expand Up @@ -166,7 +199,7 @@ class PolygonIOU(PolygonMetric):
# TODO: Remove defaults once these are surfaced more cleanly to users.
def __init__(
self,
enforce_label_match: bool = False,
enforce_label_match: bool = True,
iou_threshold: float = 0.0,
confidence_threshold: float = 0.0,
):
Expand Down Expand Up @@ -234,7 +267,7 @@ class PolygonPrecision(PolygonMetric):
# TODO: Remove defaults once these are surfaced more cleanly to users.
def __init__(
self,
enforce_label_match: bool = False,
enforce_label_match: bool = True,
iou_threshold: float = 0.5,
confidence_threshold: float = 0.0,
):
Expand Down Expand Up @@ -303,7 +336,7 @@ class PolygonRecall(PolygonMetric):
# TODO: Remove defaults once these are surfaced more cleanly to users.
def __init__(
self,
enforce_label_match: bool = False,
enforce_label_match: bool = True,
iou_threshold: float = 0.5,
confidence_threshold: float = 0.0,
):
Expand Down Expand Up @@ -460,7 +493,7 @@ def __init__(
0 <= iou_threshold <= 1
), "IoU threshold must be between 0 and 1."
self.iou_threshold = iou_threshold
super().__init__(enforce_label_match=False, confidence_threshold=0)
super().__init__(enforce_label_match=True, confidence_threshold=0)

def eval(
self,
Expand Down
2 changes: 1 addition & 1 deletion nucleus/metrics/polygon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def wrapper(
annotations: List[BoxOrPolygonAnnotation],
predictions: List[BoxOrPolygonPrediction],
*args,
enforce_label_match: bool = False,
enforce_label_match: bool = True,
**kwargs,
) -> ScalarResult:
# Simply return the metric if we are not enforcing label matches.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ click = ">=7.1.2,<9.0" # NOTE: COLAB has 7.1.2 and has problems updating
rich = "^10.15.2"
shellingham = "^1.4.0"
scikit-learn = ">=0.24.0"
pandas = ">=1.0"

[tool.poetry.dev-dependencies]
poetry = "^1.1.5"
Expand Down
5 changes: 3 additions & 2 deletions tests/metrics/test_categorization_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def test_perfect_match_f1_score():
)
)

assert results
assert [res.value for res in results]
aggregate_result = metric.aggregate_score(results)
assert aggregate_result.value == 1
for result_label, scalar in aggregate_result.items():
assert scalar.value == 1


def test_no_match_f1_score():
Expand Down
Loading