Skip to content

Commit

Permalink
Merge pull request #962 from StanfordVL/fix/key-array-caching
Browse files Browse the repository at this point in the history
Fix Key Array Caching
  • Loading branch information
cgokmen authored Oct 24, 2024
2 parents 599489a + a44b201 commit 942ca2b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
6 changes: 1 addition & 5 deletions omnigibson/sensors/vision_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,6 @@ def _remap_modality(self, modality, obs, info, raw_obs):
obs[modality], info[modality] = self._remap_instance_segmentation(
obs[modality],
id_to_labels,
obs["seg_semantic"],
info["seg_semantic"],
id=(modality == "seg_instance_id"),
)
elif "bbox" in modality:
Expand Down Expand Up @@ -387,16 +385,14 @@ def _remap_semantic_segmentation(self, img, id_to_labels):

return VisionSensor.SEMANTIC_REMAPPER.remap(replicator_mapping, semantic_class_id_to_name(), img, image_keys)

def _remap_instance_segmentation(self, img, id_to_labels, semantic_img, semantic_labels, id=False):
def _remap_instance_segmentation(self, img, id_to_labels, id=False):
"""
Remap the instance segmentation image to our own instance IDs.
Also, correct the id_to_labels input with our new labels and return it.
Args:
img (th.tensor): Instance segmentation image to remap
id_to_labels (dict): Dictionary of instance IDs to class labels
semantic_img (th.tensor): Semantic segmentation image to use for instance registry
semantic_labels (dict): Dictionary of semantic IDs to class labels
id (bool): Whether to remap for instance ID segmentation
Returns:
th.tensor: Remapped instance segmentation image
Expand Down
14 changes: 14 additions & 0 deletions omnigibson/utils/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ class Remapper:
def __init__(self):
self.key_array = th.empty(0, dtype=th.int32, device="cuda") # Initialize the key_array as empty
self.known_ids = set()
self.unlabelled_ids = set()
self.warning_printed = set()

def clear(self):
"""Resets the key_array to empty."""
self.key_array = th.empty(0, dtype=th.int32, device="cuda")
self.known_ids = set()
self.unlabelled_ids = set()

def remap(self, old_mapping, new_mapping, image, image_keys=None):
"""
Expand Down Expand Up @@ -109,6 +111,15 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None):
# Copy the previous key array into the new key array
self.key_array[: len(prev_key_array)] = prev_key_array

# Retrospectively inspect our cached ids against the old mapping and update the key array
updated_ids = set()
for unlabelled_id in self.unlabelled_ids:
if unlabelled_id in old_mapping and old_mapping[unlabelled_id] != "unlabelled":
# If an object was previously unlabelled but now has a label, we need to update the key array
updated_ids.add(unlabelled_id)
self.unlabelled_ids -= updated_ids
self.known_ids -= updated_ids

new_keys = old_mapping.keys() - self.known_ids
if new_keys:
self.known_ids.update(new_keys)
Expand All @@ -118,6 +129,9 @@ def remap(self, old_mapping, new_mapping, image, image_keys=None):
new_key = next((k for k, v in new_mapping.items() if v == label), None)
assert new_key is not None, f"Could not find a new key for label {label} in new_mapping!"
self.key_array[key] = new_key
if label == "unlabelled":
# Some objects in the image might be unlabelled first but later get a valid label later, so we keep track of them
self.unlabelled_ids.add(key)

# For all the values that exist in the image but not in old_mapping.keys(), we map them to whichever key in
# new_mapping that equals to 'unlabelled'. This is needed because some values in the image don't necessarily
Expand Down

0 comments on commit 942ca2b

Please sign in to comment.