Skip to content

Commit

Permalink
Merge pull request #792 from StanfordVL/fix/flatdim
Browse files Browse the repository at this point in the history
Fix gymnasium flatdim
  • Loading branch information
hang-yin authored Jul 11, 2024
2 parents fd82eed + 45d15df commit 59bd883
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
9 changes: 5 additions & 4 deletions omnigibson/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from omnigibson.utils.config_utils import parse_config
from omnigibson.utils.gym_utils import (
GymObservable,
maxdim,
recursively_generate_compatible_dict,
recursively_generate_flat_dict,
)
Expand Down Expand Up @@ -337,12 +338,12 @@ def _load_observation_space(self):
for robot in self.robots:
# Load the observation space for the robot
robot_obs = robot.load_observation_space()
if gym.spaces.utils.flatdim(robot_obs) > 0:
if maxdim(robot_obs) > 0:
obs_space[robot.name] = robot_obs

# Also load the task obs space
task_space = self._task.load_observation_space()
if gym.spaces.utils.flatdim(task_space) > 0:
if maxdim(task_space) > 0:
obs_space["task"] = task_space

# Also load any external sensors
Expand Down Expand Up @@ -470,11 +471,11 @@ def get_obs(self):

# Grab all observations from each robot
for robot in self.robots:
if gym.spaces.utils.flatdim(robot.observation_space) > 0:
if maxdim(robot.observation_space) > 0:
obs[robot.name], info[robot.name] = robot.get_obs()

# Add task observations
if gym.spaces.utils.flatdim(self._task.observation_space) > 0:
if maxdim(self._task.observation_space) > 0:
obs["task"] = self._task.get_obs(env=self)

# Add external sensor observations if they exist
Expand Down
6 changes: 3 additions & 3 deletions omnigibson/sensors/vision_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class VisionSensor(BaseSensor):
Args:
relative_prim_path (str): Scene-local prim path of the Sensor to encapsulate or create.
name (str): Name for the object. Names need to be unique per scene.
modalities (str or list of str): Modality(s) supported by this sensor. Default is "all", which corresponds
to all modalities being used. Otherwise, valid options should be part of cls.all_modalities.
modalities (str or list of str): Modality(s) supported by this sensor. Default is "rgb".
Otherwise, valid options should be part of cls.all_modalities.
For this vision sensor, this includes any of:
{rgb, depth, depth_linear, normal, seg_semantic, seg_instance, flow, bbox_2d_tight,
bbox_2d_loose, bbox_3d, camera}
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
self,
relative_prim_path,
name,
modalities="all",
modalities=["rgb"],
enabled=True,
noise=None,
load_config=None,
Expand Down
20 changes: 20 additions & 0 deletions omnigibson/utils/gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,23 @@ def load_observation_space(self):
log.debug(f"Loaded obs space dictionary for: {self.__class__.__name__}")

return self.observation_space


def maxdim(space):
"""
Helper function to get the maximum dimension of a gym space
Args:
space (gym.spaces.Space): Gym space to get the maximum dimension of
Returns:
int: Maximum dimension of the gym space
"""
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
return sum([maxdim(s) for s in space.spaces.values()])
elif isinstance(space, (gym.spaces.Box, gym.spaces.Discrete, gym.spaces.MultiDiscrete, gym.spaces.MultiBinary)):
return gym.spaces.utils.flatdim(space)
elif isinstance(space, (gym.spaces.Sequence, gym.spaces.Graph)):
return float("inf")
else:
raise ValueError(f"Unsupported gym space type: {type(space)}")

0 comments on commit 59bd883

Please sign in to comment.