diff --git a/perception_eval/perception_eval/common/deepcopy.py b/perception_eval/perception_eval/common/deepcopy.py new file mode 100644 index 00000000..5ab4cbe4 --- /dev/null +++ b/perception_eval/perception_eval/common/deepcopy.py @@ -0,0 +1,38 @@ +from copy import deepcopy + + +# https://stackoverflow.com/a/24621200/4732868 +def deepcopy_with_sharing(obj, shared_attribute_names, memo=None): + """ + Deepcopy an object, except for a given list of attributes, which should + be shared between the original object and its copy. + + obj is some object + shared_attribute_names: A list of strings identifying the attributes that + should be shared between the original and its copy. + memo is the dictionary passed into __deepcopy__. Ignore this argument if + not calling from within __deepcopy__. + """ + assert isinstance(shared_attribute_names, (list, tuple)) + shared_attributes = {k: getattr(obj, k) for k in shared_attribute_names} + + if hasattr(obj, "__deepcopy__"): + # Do hack to prevent infinite recursion in call to deepcopy + deepcopy_method = obj.__deepcopy__ + obj.__deepcopy__ = None + + for attr in shared_attribute_names: + del obj.__dict__[attr] + + clone = deepcopy(obj) + + for attr, val in shared_attributes.items(): + setattr(obj, attr, val) + setattr(clone, attr, val) + + if hasattr(obj, "__deepcopy__"): + # Undo hack + obj.__deepcopy__ = deepcopy_method + del clone.__deepcopy__ + + return clone diff --git a/perception_eval/perception_eval/evaluation/matching/objects_filter.py b/perception_eval/perception_eval/evaluation/matching/objects_filter.py index eab315a0..3a82f539 100644 --- a/perception_eval/perception_eval/evaluation/matching/objects_filter.py +++ b/perception_eval/perception_eval/evaluation/matching/objects_filter.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import copy from typing import Dict from typing import List from typing import Optional @@ -123,7 +124,7 @@ def filter_object_results( is_target = False if is_target: - filtered_object_results.append(object_result) + filtered_object_results.append(copy(object_result)) return filtered_object_results @@ -198,7 +199,7 @@ def filter_objects( transforms=transforms, ) if is_target: - filtered_objects.append(object_) + filtered_objects.append(copy(object_)) return filtered_objects diff --git a/perception_eval/perception_eval/evaluation/result/object_result.py b/perception_eval/perception_eval/evaluation/result/object_result.py index d4c6e9df..6adf39d2 100644 --- a/perception_eval/perception_eval/evaluation/result/object_result.py +++ b/perception_eval/perception_eval/evaluation/result/object_result.py @@ -25,6 +25,7 @@ from perception_eval.common import DynamicObject from perception_eval.common import DynamicObject2D from perception_eval.common import ObjectType +from perception_eval.common.deepcopy import deepcopy_with_sharing from perception_eval.common.evaluation_task import EvaluationTask from perception_eval.common.label import LabelType from perception_eval.common.label import TrafficLightLabel @@ -106,6 +107,9 @@ def __init__( self.iou_3d = None self.plane_distance = None + def __deepcopy__(self, memo): + return deepcopy_with_sharing(self, shared_attribute_names = ['estimated_object', 'ground_truth_object'], memo=memo) + def get_status( self, matching_mode: MatchingMode, diff --git a/perception_eval/perception_eval/evaluation/result/perception_frame_result.py b/perception_eval/perception_eval/evaluation/result/perception_frame_result.py index 0d9b5a08..373b64a3 100644 --- a/perception_eval/perception_eval/evaluation/result/perception_frame_result.py +++ b/perception_eval/perception_eval/evaluation/result/perception_frame_result.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations - +from copy import copy, deepcopy from typing import Dict from typing import List from typing import Optional diff --git a/perception_eval/perception_eval/tool/utils.py b/perception_eval/perception_eval/tool/utils.py index 0c1c4e23..3c7ab2df 100644 --- a/perception_eval/perception_eval/tool/utils.py +++ b/perception_eval/perception_eval/tool/utils.py @@ -476,7 +476,6 @@ def filter_frame_by_distance( PerceptionFrameResult: Filtered frame results. """ ret_frame = deepcopy(frame) - if min_distance is not None: min_distance_list = [min_distance] * len(ret_frame.target_labels) else: