Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add gt frame interpolation function #120

Merged
merged 10 commits into from
Feb 6, 2024
177 changes: 177 additions & 0 deletions perception_eval/perception_eval/common/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
import logging
from typing import Any
from typing import Dict
Expand All @@ -28,9 +29,13 @@
from perception_eval.common.dataset_utils import _sample_to_frame
from perception_eval.common.dataset_utils import _sample_to_frame_2d
from perception_eval.common.evaluation_task import EvaluationTask
from perception_eval.common.geometry import interpolate_homogeneous_matrix
from perception_eval.common.geometry import interpolate_object_list
from perception_eval.common.label import LabelConverter
from perception_eval.common.object import DynamicObject
from perception_eval.common.schema import FrameID
from perception_eval.util.math import get_pose_transform_matrix
from pyquaternion import Quaternion
from tqdm import tqdm


Expand Down Expand Up @@ -288,3 +293,175 @@ def get_now_frame(
return None
else:
return ground_truth_now_frame


def get_interpolated_now_frame(
ground_truth_frames: List[FrameGroundTruth],
unix_time: int,
threshold_min_time: int,
) -> Optional[FrameGroundTruth]:
"""Get interpolated ground truth frame in specified unix time.
It searches before and after frames which satisfy the time difference condition and if found both, interpolate them.

Args:
ground_truth_frames (List[FrameGroundTruth]): FrameGroundTruth instance list.
unix_time (int): Unix time [us].
threshold_min_time (int): Min time for unix time difference [us].

Returns:
Optional[FrameGroundTruth]:
The ground truth frame whose unix time is most close to args unix time
from dataset.
If the difference time between unix time parameter and the most close time
ground truth frame is larger than threshold_min_time, return None.

Examples:
>>> ground_truth_frames = load_all_datasets(...)
>>> get_interpolated_now_frame(ground_truth_frames, 1624157578750212, 7500)
<perception_eval.common.dataset.FrameGroundTruth object at 0x7f66040c36a0>
"""
# extract closest two frames
before_frame = None
after_frame = None
dt_before = 0.0
dt_after = 0.0
for ground_truth_frame in ground_truth_frames:
diff_time = unix_time - ground_truth_frame.unix_time
if diff_time >= 0:
before_frame = ground_truth_frame
dt_before = diff_time
else:
after_frame = ground_truth_frame
dt_after = -diff_time
if before_frame is not None and after_frame is not None:
break

# disable frame if time difference is too large
if dt_before > threshold_min_time:
before_frame = None
if dt_after > threshold_min_time:
after_frame = None

# check frame availability
if before_frame is None and after_frame is None:
logging.info(f"No frame is available for interpolation")
return None
elif before_frame is None:
logging.info(f"Only after frame is available for interpolation")
return after_frame
elif after_frame is None:
logging.info(f"Only before frame is available for interpolation")
return before_frame
else:
# do interpolation
return interpolate_ground_truth_frames(before_frame, after_frame, unix_time)


def interpolate_ground_truth_frames(
before_frame: FrameGroundTruth,
after_frame: FrameGroundTruth,
unix_time: int,
):
"""Interpolate ground truth frame with linear interpolation.

Args:
before_frame (FrameGroundTruth): input frame1
after_frame (FrameGroundTruth): input frame2
unix_time (int): target time
"""
# interpolate ego2map
ego2map = interpolate_homogeneous_matrix(
before_frame.ego2map, after_frame.ego2map, before_frame.unix_time, after_frame.unix_time, unix_time
)

# TODO: Need refactor for simplicity
# if frame is base_link, need to interpolate with global coordinate
# 1. convert object list to global
before_frame_objects = convert_objects_to_global(before_frame.objects, before_frame.ego2map)
after_frame_objects = convert_objects_to_global(after_frame.objects, after_frame.ego2map)

# 2. interpolate objects
object_list = interpolate_object_list(
before_frame_objects, after_frame_objects, before_frame.unix_time, after_frame.unix_time, unix_time
)
# 3. convert object list to base_link
# object_list = convert_objects_to_base_link(object_list, ego2map)

# interpolate raw data
output_frame = deepcopy(before_frame)
output_frame.ego2map = ego2map
output_frame.objects = object_list
output_frame.unix_time = unix_time
return output_frame


def convert_objects_to_global(
object_list: List[ObjectType],
ego2map: np.ndarray,
) -> List[ObjectType]:
"""Convert object list to global coordinate.

Args:
object_list (List[ObjectType]): object list
ego2map (np.ndarray): ego2map matrix

Returns:
List[ObjectType]: object list in global coordinate
"""
output_object_list = []
for object in object_list:
if object.frame_id == "map":
output_object_list.append(deepcopy(object))
continue
elif object.frame_id == "base_link":
src: np.ndarray = get_pose_transform_matrix(
position=object.state.position,
rotation=object.state.orientation.rotation_matrix,
)
dst: np.ndarray = ego2map.dot(src)
updated_position: np.ndarray = tuple(dst[:3, 3].flatten())
updated_rotation: np.ndarray = Quaternion(dst[:3, :3])
output_object = deepcopy(object)
output_object.state.position = updated_position
output_object.state.orientation = updated_rotation
output_object.frame_id = "map"
output_object_list.append(output_object)
else:
raise NotImplementedError(f"Unexpected frame_id: {object.frame_id}")
return output_object_list


def convert_objects_to_base_link(
object_list: List[ObjectType],
ego2map: np.ndarray,
) -> List[ObjectType]:
"""Convert object list to base_link coordinate.

Args:
object_list (List[ObjectType]): object list
ego2map (np.ndarray): ego2map matrix

Returns:
List[ObjectType]: object list in base_link coordinate
"""
output_object_list = []
for object in object_list:
if object.frame_id == "base_link":
output_object_list.append(deepcopy(object))
continue
elif object.frame_id == "map":
src: np.ndarray = get_pose_transform_matrix(
position=object.state.position,
rotation=object.state.orientation.rotation_matrix,
)
dst: np.ndarray = np.linalg.inv(ego2map).dot(src)
updated_position: np.ndarray = tuple(dst[:3, 3].flatten())
updated_rotation: Quaternion = Quaternion(matrix=dst[:3, :3])
output_object = deepcopy(object)
output_object.state.position = updated_position
output_object.state.orientation = updated_rotation
output_object.frame_id = "base_link"
output_object_list.append(output_object)
else:
raise NotImplementedError(f"Unexpected frame_id: {object.frame_id}")
return output_object_list
Loading
Loading