Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MatchingPolicy #116

Merged
merged 3 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions perception_eval/perception_eval/common/label/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .types import LabelType
from .types import SemanticLabel
from .types import TrafficLightLabel
from .utils import is_same_label
from .utils import set_target_lists

__all__ = (
Expand All @@ -13,5 +14,6 @@
"LabelType",
"SemanticLabel",
"TrafficLightLabel",
"is_same_label",
"set_target_lists",
)
15 changes: 15 additions & 0 deletions perception_eval/perception_eval/common/label/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from perception_eval.object import ObjectType

from .converter import LabelConverter
from .types import LabelType

Expand Down Expand Up @@ -32,3 +34,16 @@ def set_target_lists(
if target_labels is None or len(target_labels) == 0:
return [label for label in label_converter.label_type]
return [label_converter.convert_name(name) for name in target_labels]


def is_same_label(object1: ObjectType, object2: ObjectType) -> bool:
"""Return `True`, if both objects has the same label.

Args:
object1 (ObjectType): An object.
object2 (ObjectType): An object.

Returns:
bool: `True` if the label is same.
"""
return object1.semantic_label == object2.semantic_label
17 changes: 17 additions & 0 deletions perception_eval/perception_eval/common/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from enum import Enum
import logging
from typing import Dict
from typing import TYPE_CHECKING
from typing import Union

from perception_eval.common.evaluation_task import EvaluationTask

if TYPE_CHECKING:
from perception_eval.object import ObjectType


class FrameID(Enum):
# 3D
Expand Down Expand Up @@ -91,6 +95,19 @@ def from_task(cls, task: Union[str, EvaluationTask]) -> FrameID:
raise ValueError(f"Unexpected task: {task}")


def is_same_frame_id(object1: ObjectType, object2: ObjectType) -> bool:
"""Returns `True` if the both objects has same frame id.

Args:
object1 (ObjectType): An object.
object2 (ObjectType): An object.

Returns:
bool: `True` if the frame id is same.
"""
return object1.frame_id == object2.frame_id


class Visibility(Enum):
"""Visibility status class.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Tuple
from typing import Union

from perception_eval.matching import MatchingPolicy
from perception_eval.metrics import MetricsScoreConfig

from .evaluation_config_base import EvaluationConfigBase
Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(
load_raw_data=load_raw_data,
)

self.matching_policy = MatchingPolicy.from_dict(config_dict)

self.metrics_config = MetricsScoreConfig(self.metrics_param)

def _extract_params(self, cfg: Dict[str, Any]) -> Tuple[PerceptionFilterParam, PerceptionMetricsParam]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from perception_eval.common.label import LabelType
from perception_eval.config import PerceptionEvaluationConfig
from perception_eval.dataset import FrameGroundTruth
from perception_eval.matching import MatchingPolicy
from perception_eval.object import ObjectType
from perception_eval.result import DynamicObjectWithPerceptionResult
from perception_eval.result import PerceptionFrameConfig
Expand Down Expand Up @@ -66,6 +67,10 @@ def __init__(self, config: PerceptionEvaluationConfig) -> None:
def target_labels(self) -> List[LabelType]:
return self.config.target_labels

@property
def matching_policy(self) -> MatchingPolicy:
return self.config.matching_policy

@property
def metrics_config(self):
return self.config.metrics_config
Expand Down Expand Up @@ -170,8 +175,7 @@ def _filter_objects(
estimated_objects=estimated_objects,
ground_truth_objects=frame_ground_truth.objects,
target_labels=self.target_labels,
# allow_matching_unknown=self.label_param["allow_matching_unknown"], TODO
# matchable_thresholds=self.filtering_params["max_matchable_radii"],
matching_policy=self.matching_policy,
)

if self.filter_param.target_uuids is not None:
Expand Down
2 changes: 2 additions & 0 deletions perception_eval/perception_eval/matching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .matching_policy import MatchingPolicy
from .object_matching import CenterDistanceMatching
from .object_matching import IOU2dMatching
from .object_matching import IOU3dMatching
Expand All @@ -20,6 +21,7 @@
from .object_matching import PlaneDistanceMatching

__all__ = (
"MatchingPolicy",
"CenterDistanceMatching",
"IOU2dMatching",
"IOU3dMatching",
Expand Down
122 changes: 122 additions & 0 deletions perception_eval/perception_eval/matching/matching_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from enum import Enum
from numbers import Number
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union

from perception_eval.common.label import is_same_label
from perception_eval.common.schema import is_same_frame_id
from perception_eval.common.threshold import get_label_threshold

from .object_matching import CenterDistanceMatching
from .object_matching import IOU2dMatching
from .object_matching import IOU3dMatching
from .object_matching import MatchingMode
from .object_matching import PlaneDistanceMatching

if TYPE_CHECKING:
from perception_eval.common.label import LabelType
from perception_eval.object import ObjectType

from .object_matching import MatchingMethod


class MatchingLabelPolicy(Enum):
STRICT = "STRICT"
ALLOW_UNKNOWN = "ALLOW_UNKNOWN"
ALLOW_ANY = "ALLOW_ANY"

@classmethod
def from_str(cls, name: str) -> MatchingLabelPolicy:
name = name.upper()
assert name in cls.__members__, f"{name} is not enum member"
return cls.__members__[name]

def is_matchable(self, estimation: ObjectType, ground_truth: ObjectType) -> bool:
if ground_truth.semantic_label.is_fp() or self == MatchingLabelPolicy.ALLOW_ANY:
return True
elif self == MatchingLabelPolicy.ALLOW_UNKNOWN:
return is_same_label(estimation, ground_truth) or estimation.semantic_label.is_unknown()
else: # STRICT
return is_same_label(estimation, ground_truth)


class MatchingPolicy:
def __init__(
self,
matching_mode: Optional[Union[str, MatchingMode]] = None,
label_policy: Optional[Union[str, MatchingLabelPolicy]] = None,
matchable_thresholds: Optional[List[Number]] = None,
) -> None:
if matching_mode is None:
self.matching_mode = MatchingMode.CENTERDISTANCE
elif isinstance(matching_mode, str):
self.matching_mode = MatchingMode.from_str(matching_mode)
else:
self.matching_mode = matching_mode

self.matching_module, self.maximize = self.get_matching_module(self.matching_mode)

if label_policy is None:
self.label_policy = MatchingLabelPolicy.STRICT
elif isinstance(label_policy, str):
self.label_policy = MatchingLabelPolicy.from_str(label_policy)
else:
self.label_policy = label_policy

self.matchable_thresholds = matchable_thresholds

@classmethod
def from_dict(cls, cfg: Dict[str, Any]) -> MatchingPolicy:
matching_mode = cfg.get("matching_mode")
label_policy = cfg.get("matching_label_policy")
matchable_thresholds = cfg.get("matchable_thresholds")
return cls(matching_mode=matching_mode, label_policy=label_policy, matchable_thresholds=matchable_thresholds)

@staticmethod
def get_matching_module(matching_mode: MatchingMode) -> Tuple[Callable, bool]:
if matching_mode == MatchingMode.CENTERDISTANCE:
matching_method_module: CenterDistanceMatching = CenterDistanceMatching
maximize: bool = False
elif matching_mode == MatchingMode.PLANEDISTANCE:
matching_method_module: PlaneDistanceMatching = PlaneDistanceMatching
maximize: bool = False
elif matching_mode == MatchingMode.IOU2D:
matching_method_module: IOU2dMatching = IOU2dMatching
maximize: bool = True
elif matching_mode == MatchingMode.IOU3D:
matching_method_module: IOU3dMatching = IOU3dMatching
maximize: bool = True
else:
raise ValueError(f"Unsupported matching mode: {matching_mode}")

return matching_method_module, maximize

def is_matchable(self, estimation: ObjectType, ground_truth: ObjectType) -> bool:
return self.label_policy.is_matchable(estimation, ground_truth) and is_same_frame_id(estimation, ground_truth)

def get_matching_score(
self,
estimation: ObjectType,
ground_truth: ObjectType,
target_labels: List[LabelType],
) -> Optional[float]:
threshold: Optional[float] = get_label_threshold(
ground_truth.semantic_label,
target_labels,
self.matchable_thresholds,
)

matching_method: MatchingMethod = self.matching_module(estimation, ground_truth)

if threshold is None or (threshold is not None and matching_method.is_better_than(threshold)):
return matching_method.value
else:
return None
5 changes: 5 additions & 0 deletions perception_eval/perception_eval/matching/object_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class MatchingMode(Enum):
def __str__(self) -> str:
return self.value

@classmethod
def from_str(cls, name: str) -> MatchingMode:
assert name in cls.__members__, f"{name} is not enum member"
return cls.__members__[name]


class MatchingMethod(ABC):
"""A base class for matching method class.
Expand Down
Loading
Loading