Skip to content

Commit

Permalink
feat: add MatchingPolicy
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Dec 29, 2023
1 parent 3d0658f commit 6af7a92
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 83 deletions.
2 changes: 2 additions & 0 deletions perception_eval/perception_eval/common/label/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .types import LabelType
from .types import SemanticLabel
from .types import TrafficLightLabel
from .utils import is_same_label
from .utils import set_target_lists

__all__ = (
Expand All @@ -13,5 +14,6 @@
"LabelType",
"SemanticLabel",
"TrafficLightLabel",
"is_same_label",
"set_target_lists",
)
15 changes: 15 additions & 0 deletions perception_eval/perception_eval/common/label/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from perception_eval.object import ObjectType

from .converter import LabelConverter
from .types import LabelType

Expand Down Expand Up @@ -32,3 +34,16 @@ def set_target_lists(
if target_labels is None or len(target_labels) == 0:
return [label for label in label_converter.label_type]
return [label_converter.convert_name(name) for name in target_labels]


def is_same_label(object1: ObjectType, object2: ObjectType) -> bool:
"""Return `True`, if both objects has the same label.
Args:
object1 (ObjectType): An object.
object2 (ObjectType): An object.
Returns:
bool: `True` if the label is same.
"""
return object1.semantic_label == object2.semantic_label
17 changes: 17 additions & 0 deletions perception_eval/perception_eval/common/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from enum import Enum
import logging
from typing import Dict
from typing import TYPE_CHECKING
from typing import Union

from perception_eval.common.evaluation_task import EvaluationTask

if TYPE_CHECKING:
from perception_eval.object import ObjectType


class FrameID(Enum):
# 3D
Expand Down Expand Up @@ -91,6 +95,19 @@ def from_task(cls, task: Union[str, EvaluationTask]) -> FrameID:
raise ValueError(f"Unexpected task: {task}")


def is_same_frame_id(object1: ObjectType, object2: ObjectType) -> bool:
"""Returns `True` if the both objects has same frame id.
Args:
object1 (ObjectType): An object.
object2 (ObjectType): An object.
Returns:
bool: `True` if the frame id is same.
"""
return object1.frame_id == object2.frame_id


class Visibility(Enum):
"""Visibility status class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Tuple
from typing import Union

from perception_eval.matching import MatchingPolicy
from perception_eval.metrics import MetricsScoreConfig

from .evaluation_config_base import EvaluationConfigBase
Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(
load_raw_data=load_raw_data,
)

self.matching_policy = MatchingPolicy.from_dict(config_dict)

self.metrics_config = MetricsScoreConfig(self.metrics_param)

def _extract_params(self, cfg: Dict[str, Any]) -> Tuple[PerceptionFilterParam, PerceptionMetricsParam]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from perception_eval.common.label import LabelType
from perception_eval.config import PerceptionEvaluationConfig
from perception_eval.dataset import FrameGroundTruth
from perception_eval.matching import MatchingPolicy
from perception_eval.object import ObjectType
from perception_eval.result import DynamicObjectWithPerceptionResult
from perception_eval.result import PerceptionFrameConfig
Expand Down Expand Up @@ -66,6 +67,10 @@ def __init__(self, config: PerceptionEvaluationConfig) -> None:
def target_labels(self) -> List[LabelType]:
return self.config.target_labels

@property
def matching_policy(self) -> MatchingPolicy:
return self.config.matching_policy

@property
def metrics_config(self):
return self.config.metrics_config
Expand Down Expand Up @@ -170,8 +175,7 @@ def _filter_objects(
estimated_objects=estimated_objects,
ground_truth_objects=frame_ground_truth.objects,
target_labels=self.target_labels,
# allow_matching_unknown=self.label_param["allow_matching_unknown"], TODO
# matchable_thresholds=self.filtering_params["max_matchable_radii"],
matching_policy=self.matching_policy,
)

if self.filter_param.target_uuids is not None:
Expand Down
2 changes: 2 additions & 0 deletions perception_eval/perception_eval/matching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .matching_policy import MatchingPolicy
from .object_matching import CenterDistanceMatching
from .object_matching import IOU2dMatching
from .object_matching import IOU3dMatching
Expand All @@ -20,6 +21,7 @@
from .object_matching import PlaneDistanceMatching

__all__ = (
"MatchingPolicy",
"CenterDistanceMatching",
"IOU2dMatching",
"IOU3dMatching",
Expand Down
122 changes: 122 additions & 0 deletions perception_eval/perception_eval/matching/matching_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from enum import Enum
from numbers import Number
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union

from perception_eval.common.label import is_same_label
from perception_eval.common.schema import is_same_frame_id
from perception_eval.common.threshold import get_label_threshold

from .object_matching import CenterDistanceMatching
from .object_matching import IOU2dMatching
from .object_matching import IOU3dMatching
from .object_matching import MatchingMode
from .object_matching import PlaneDistanceMatching

if TYPE_CHECKING:
from perception_eval.common.label import LabelType
from perception_eval.object import ObjectType

from .object_matching import MatchingMethod


class MatchingLabelPolicy(Enum):
STRICT = "STRICT"
ALLOW_UNKNOWN = "ALLOW_UNKNOWN"
ALLOW_ANY = "ALLOW_ANY"

@classmethod
def from_str(cls, name: str) -> MatchingLabelPolicy:
name = name.upper()
assert name in cls.__members__, f"{name} is not enum member"
return cls.__members__[name]

def is_matchable(self, estimation: ObjectType, ground_truth: ObjectType) -> bool:
if ground_truth.semantic_label.is_fp() or self == MatchingLabelPolicy.ALLOW_ANY:
return True
elif self == MatchingLabelPolicy.ALLOW_UNKNOWN:
return is_same_label(estimation, ground_truth) or estimation.semantic_label.is_unknown()
else: # STRICT
return is_same_label(estimation, ground_truth)


class MatchingPolicy:
def __init__(
self,
matching_mode: Optional[Union[str, MatchingMode]] = None,
label_policy: Optional[Union[str, MatchingLabelPolicy]] = None,
matchable_thresholds: Optional[List[Number]] = None,
) -> None:
if matching_mode is None:
self.matching_mode = MatchingMode.CENTERDISTANCE
elif isinstance(matching_mode, str):
self.matching_mode = MatchingMode.from_str(matching_mode)
else:
self.matching_mode = matching_mode

self.matching_module, self.maximize = self.get_matching_module(self.matching_mode)

if label_policy is None:
self.label_policy = MatchingLabelPolicy.STRICT
elif isinstance(label_policy, str):
self.label_policy = MatchingLabelPolicy.from_str(label_policy)
else:
self.label_policy = label_policy

self.matchable_thresholds = matchable_thresholds

@classmethod
def from_dict(cls, cfg: Dict[str, Any]) -> MatchingPolicy:
matching_mode = cfg.get("matching_mode")
label_policy = cfg.get("matching_label_policy")
matchable_thresholds = cfg.get("matchable_thresholds")
return cls(matching_mode=matching_mode, label_policy=label_policy, matchable_thresholds=matchable_thresholds)

@staticmethod
def get_matching_module(matching_mode: MatchingMode) -> Tuple[Callable, bool]:
if matching_mode == MatchingMode.CENTERDISTANCE:
matching_method_module: CenterDistanceMatching = CenterDistanceMatching
maximize: bool = False
elif matching_mode == MatchingMode.PLANEDISTANCE:
matching_method_module: PlaneDistanceMatching = PlaneDistanceMatching
maximize: bool = False
elif matching_mode == MatchingMode.IOU2D:
matching_method_module: IOU2dMatching = IOU2dMatching
maximize: bool = True
elif matching_mode == MatchingMode.IOU3D:
matching_method_module: IOU3dMatching = IOU3dMatching
maximize: bool = True
else:
raise ValueError(f"Unsupported matching mode: {matching_mode}")

return matching_method_module, maximize

def is_matchable(self, estimation: ObjectType, ground_truth: ObjectType) -> bool:
return self.label_policy.is_matchable(estimation, ground_truth) and is_same_frame_id(estimation, ground_truth)

def get_matching_score(
self,
estimation: ObjectType,
ground_truth: ObjectType,
target_labels: List[LabelType],
) -> Optional[float]:
threshold: Optional[float] = get_label_threshold(
ground_truth.semantic_label,
target_labels,
self.matchable_thresholds,
)

matching_method: MatchingMethod = self.matching_module(estimation, ground_truth)

if threshold is None or (threshold is not None and matching_method.is_better_than(threshold)):
return matching_method.value
else:
return None
5 changes: 5 additions & 0 deletions perception_eval/perception_eval/matching/object_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class MatchingMode(Enum):
def __str__(self) -> str:
return self.value

@classmethod
def from_str(cls, name: str) -> MatchingMode:
assert name in cls.__members__, f"{name} is not enum member"
return cls.__members__[name]


class MatchingMethod(ABC):
"""A base class for matching method class.
Expand Down
Loading

0 comments on commit 6af7a92

Please sign in to comment.