Skip to content

Commit

Permalink
Dynamic files world reset (Emerge-Lab#265)
Browse files Browse the repository at this point in the history
* Introduce new types for Map Resets

* Copy Maps to sim

* Hook up to reset

* Dont throw exceptions

* Hook up resetmap

* Hook up resetmap

* Fix the BVH tree limits

* Add resampling of scenarios method

* Add optional list param

* Fix cuda compile

* Missing semicolon

* Integrate resampling of scenarios with gym env and IPPO

* Revert max objects const

* Update default and remove redundancy

* Small improvements

* Fix

---------

Co-authored-by: Daphne Cornelisse <[email protected]>
  • Loading branch information
2 people authored and KKGB committed Oct 23, 2024
1 parent dc62091 commit bcd4091
Show file tree
Hide file tree
Showing 18 changed files with 307 additions and 40 deletions.
40 changes: 31 additions & 9 deletions algorithms/sb3/ppo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
self.exp_config = exp_config
self.mlp_class = mlp_class
self.mlp_config = mlp_config
self.resample_counter = 0
super().__init__(*args, **kwargs)

def collect_rollouts(
Expand All @@ -67,6 +68,33 @@ def collect_rollouts(
assert (
self._last_obs is not None
), "No previous observation was provided"

# Check resampling criterion and resample batch of scenarios if needed
if self.env.exp_config.resample_scenarios:
if self.env.exp_config.resample_criterion == "global_step":
if self.resample_counter >= self.env.exp_config.resample_freq:
print(
f"Resampling {self.env.num_worlds} scenarios at global_step {self.num_timesteps:,}..."
)
# Re-initialize the scenes and controlled agents mask
self.env.resample_scenario_batch()
self.resample_counter = 0
# Get new initial observation
self._last_obs = self.env.reset()
# Update storage shapes
self.n_envs = env.num_valid_controlled_agents_across_worlds
rollout_buffer.n_envs = self.n_envs
self._last_episode_starts = (
self.env._env.get_dones().clone()[
~self.env.dead_agent_mask
]
)

else:
raise NotImplementedError(
f"Resampling criterion {self.env.exp_config.resample_criterion} not implemented"
)

# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)

Expand Down Expand Up @@ -165,16 +193,11 @@ def collect_rollouts(

new_obs, rewards, dones, infos = env.step(clipped_actions)

# # (dc) DEBUG
# mask = ~torch.isnan(rewards)
# if (
# self._last_obs[mask].max() > 1
# or self._last_obs[mask].min() < -1
# ):
# logging.error("New observation is out of bounds")

# EDIT_2: Increment the global step by the number of valid samples in rollout step
self.num_timesteps += int((~rewards.isnan()).float().sum().item())
self.resample_counter += int(
(~rewards.isnan()).float().sum().item()
)
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
Expand All @@ -197,7 +220,6 @@ def collect_rollouts(
self._last_episode_starts = dones

# # # # # END LOOP # # # # #

total_steps = self.n_envs * n_rollout_steps
elapsed_time = time.perf_counter() - time_rollout
fps = total_steps / elapsed_time
Expand Down
20 changes: 20 additions & 0 deletions baselines/ippo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,26 @@ To run the multi-agent IPPO baseline using stable-baselines 3 (SB3):
python baselines/ippo/run_sb3_ppo.py
```

### Resampling the data

The configuration for resampling traffic scenarios includes:

- **`resample_scenarios`**: A boolean that enables or disables traffic scenario resampling when set to `True`.

- **`resample_criterion`**: Set to `"global_step"`, indicating resampling occurs based on the global step count.

- **`resample_freq`**: Specifies resampling frequency at `50,000` steps, recommended to align with `num_worlds * n_steps`.

- **`resample_mode`**: Set to `"random"` for random selection of new scenarios.

```
# RESAMPLE TRAFFIC SCENARIOS
resample_scenarios: bool = True
resample_criterion: str = "global_step" # Options: "global_step"
resample_freq: int = 100_000 # Resample every k steps (recommended to be a multiple of num_worlds * n_steps)
resample_mode: str = "random" # Options: "random"
```

## Implemented networks

### Classic Observations
Expand Down
8 changes: 7 additions & 1 deletion baselines/ippo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class ExperimentConfig:
goal_achieved_weight: float = 1.0
off_road_weight: float = 0.0

# RESAMPLE TRAFFIC SCENARIOS
resample_scenarios: bool = True
resample_criterion: str = "global_step" # Options: "global_step"
resample_freq: int = 1e6 # Resample every k steps (recommended to be a multiple of num_worlds * n_steps)
resample_mode: str = "random" # Options: "random"

# RENDERING
render: bool = True
render_mode: str = "rgb_array"
Expand Down Expand Up @@ -60,7 +66,7 @@ class ExperimentConfig:
gae_lambda: float = 0.95
clip_range: float = 0.2
vf_coef: float = 0.5
n_steps: int = 91
n_steps: int = 91 # Number of steps per rollout
num_minibatches: int = 5 # Used to determine the minibatch size
verbose: int = 0
total_timesteps: int = 2e7
Expand Down
1 change: 1 addition & 0 deletions baselines/ippo/run_sb3_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def train(env_config: EnvConfig, exp_config: ExperimentConfig, scene_config: Sce
env = SB3MultiAgentEnv(
config=env_config,
scene_config=scene_config,
exp_config=exp_config,
# Control up to all agents in the scene
max_cont_agents=env_config.max_num_agents_in_scene,
device=exp_config.device,
Expand Down
24 changes: 22 additions & 2 deletions pygpudrive/env/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ env = GPUDriveTorchEnv(
Step the environment using:

```Python
obs, reward, done, info = env.step_dynamics(action)
env.step_dynamics(actions)

# Extract info
obs = env.get_obs()
reward = env.get_rewards()
done = env.get_dones()
```

Further configuration details are available in `config.py`.
Expand Down Expand Up @@ -116,7 +121,7 @@ A reward of +1 is assigned when an agent is within the `dist_to_goal_threshold`

Upon initialization, every vehicle starts at the beginning of the expert trajectory.

## Dataset
## Dataset

The `SceneConfig` dataclass is used to configure how scenes are selected from a dataset. It has four attributes:

Expand All @@ -125,6 +130,21 @@ The `SceneConfig` dataclass is used to configure how scenes are selected from a
- `discipline`: The method for selecting scenes, defaulting to `SelectionDiscipline.PAD_N`. (See options in Table below)
- `k_unique_scenes`: Specifies the number of unique scenes to select, if applicable.

### Resampling traffic scenarios

The `reinit_scenarios()` method in `pygpudrive/env/base_env.py` reinitializes the simulator with new traffic scenarios as follows:

1. **Scene re-initialization:**
This function updates the simulation maps by calling `self.sim.set_maps(dataset)`, replacing the current scenes with those provided `dataset`, which should be a list with paths to traffic scenarios.

2. **Controlled agent mask re-nitialization:**
It reinitializes the controlled agents' mask using `self.get_controlled_agents_mask()`, determining which agents are user-controlled (this will change based on the specific traffic scenes used).

3. **Agent count update:**
The function updates `self.max_agent_count` to reflect the number of controlled agents and recomputes `self.num_valid_controlled_agents_across_worlds`, indicating the total active controlled agents across all scenarios.

See the `resample_scenario_batch()` method in `pygpudrive/env/wrappers/sb3_wrapper.py` for an example of how you can use this method with IPPO.

## Render

Render settings can be changed using the `RenderConfig`.
Expand Down
31 changes: 28 additions & 3 deletions pygpudrive/env/base_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List, Optional
import gymnasium as gym
from pygpudrive.env.config import RenderConfig, RenderMode
from pygpudrive.env.viz import PyGameVisualizer
Expand Down Expand Up @@ -57,7 +58,10 @@ def _set_reward_params(self):
"""
reward_params = gpudrive.RewardParams()

if self.config.reward_type == "sparse_on_goal_achieved" or self.config.reward_type == "weighted_combination":
if (
self.config.reward_type == "sparse_on_goal_achieved"
or self.config.reward_type == "weighted_combination"
):
reward_params.rewardType = gpudrive.RewardType.OnGoalAchieved
else:
raise ValueError(f"Invalid reward type: {self.config.reward_type}")
Expand Down Expand Up @@ -119,8 +123,10 @@ def _setup_environment_parameters(self):

if self.config.lidar_obs:
if not self.config.lidar_obs and self.config.disable_classic_obs:
raise ValueError("Lidar observations must be enabled if classic observations are disabled.")

raise ValueError(
"Lidar observations must be enabled if classic observations are disabled."
)

else:
params.enableLidar = self.config.lidar_obs
params.disableClassicalObs = self.config.disable_classic_obs
Expand Down Expand Up @@ -243,6 +249,25 @@ def render(self, world_render_idx=0, color_objects_by_actor=None):
}:
return self.visualizer.getRender()

def reinit_scenarios(self, dataset: List[str]):
"""Resample the scenes.
Args:
dataset (List[str]): List of scene names to resample.
Returns:
None
"""

# Resample the scenes
self.sim.set_maps(dataset)

# Re-initialize the controlled agents mask
self.cont_agent_mask = self.get_controlled_agents_mask()
self.max_agent_count = self.cont_agent_mask.shape[1]
self.num_valid_controlled_agents_across_worlds = (
self.cont_agent_mask.sum().item()
)

def close(self):
"""Destroy the simulator and visualizer."""
del self.sim
Expand Down
2 changes: 1 addition & 1 deletion pygpudrive/env/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class EnvConfig:

# Reward settings
reward_type: str = (
"sparse_on_goal_achieved" # Alternatively, "weighted_combination"
"sparse_on_goal_achieved" # Alternatively, "weighted_combination"
)

dist_to_goal_threshold: float = (
Expand Down
1 change: 1 addition & 0 deletions pygpudrive/env/env_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
):
# Initialization of environment configurations
self.config = config
self.scene_config = scene_config
self.num_worlds = scene_config.num_scenes
self.max_cont_agents = max_cont_agents
self.device = device
Expand Down
48 changes: 48 additions & 0 deletions pygpudrive/env/wrappers/sb3_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
from typing import Optional, Sequence
import torch
import os
import gymnasium as gym
import random
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import (
VecEnv,
Expand All @@ -28,6 +30,7 @@ class SB3MultiAgentEnv(VecEnv):
def __init__(
self,
config,
exp_config,
scene_config,
max_cont_agents,
device,
Expand All @@ -42,6 +45,13 @@ def __init__(
self._env = make(dynamics_id=DynamicsModel.DELTA_LOCAL, action_id=ActionSpace.DISCRETE, kwargs=kwargs)

self.config = config
self.exp_config = exp_config
self.all_scene_paths = [
os.path.join(self.exp_config.data_dir, scene)
for scene in sorted(os.listdir(self.exp_config.data_dir))
if scene.startswith("tfrecord")
]
self.unique_scene_paths = list(set(self.all_scene_paths))
self.num_worlds = self._env.num_worlds
self.max_agent_count = self._env.max_agent_count
self.num_envs = self._env.cont_agent_mask.sum().item()
Expand Down Expand Up @@ -203,6 +213,44 @@ def seed(self, seed=None):
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds

def resample_scenario_batch(self):
"""Swap out the dataset."""
if self.exp_config.resample_mode == "random":
total_unique = len(self.unique_scene_paths)

# Check if N is greater than the number of unique scenes
if self.num_worlds <= total_unique:
dataset = random.sample(
self.unique_scene_paths, self.num_worlds
)

# If N is greater, repeat the unique scenes until we get N scenes
dataset = []
while len(dataset) < self.num_worlds:
dataset.extend(
random.sample(self.unique_scene_paths, total_unique)
)
if len(dataset) > self.num_worlds:
dataset = dataset[
: self.num_worlds
] # Trim the result to N scenes
else:
raise NotImplementedError(
f"Resample mode {self.exp_config.resample_mode} is currently not supported."
)

# Re-initialize the simulator with the new dataset
print(
f"Re-initializing sim with {len(set(dataset))} {self.exp_config.resample_mode} unique scenes.\n"
)
self._env.reinit_scenarios(dataset)

# Update controlled agent mask
self.controlled_agent_mask = self._env.cont_agent_mask.clone()
self.max_agent_count = self._env.max_agent_count
self.num_valid_controlled_agents_across_worlds = self._env.num_valid_controlled_agents_across_worlds
self.num_envs = self.controlled_agent_mask.sum().item()

def _update_info_dict(self, info, indices) -> None:
"""Update the info logger."""

Expand Down
3 changes: 2 additions & 1 deletion src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ namespace gpudrive
.def("rgb_tensor", &Manager::rgbTensor)
.def("depth_tensor", &Manager::depthTensor)
.def("response_type_tensor", &Manager::responseTypeTensor)
.def("expert_trajectory_tensor", &Manager::expertTrajectoryTensor);
.def("expert_trajectory_tensor", &Manager::expertTrajectoryTensor)
.def("set_maps", &Manager::setMaps);
}

}
6 changes: 3 additions & 3 deletions src/headless.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ int main(int argc, char *argv[])
}

uint64_t num_steps = std::stoul(argv[2]);
std::vector<std::string> scenes = {"../data/examples/tfrecord-00001-of-01000_307.json",
"../data/examples/tfrecord-00003-of-01000_109.json",
"../data/examples/tfrecord-00012-of-01000_389.json"};
std::vector<std::string> scenes = {"../data/processed/examples/tfrecord-00001-of-01000_307.json",
"../data/processed/examples/tfrecord-00003-of-01000_109.json",
"../data/processed/examples/tfrecord-00012-of-01000_389.json"};
uint64_t num_worlds = scenes.size();

bool rand_actions = false;
Expand Down
Loading

0 comments on commit bcd4091

Please sign in to comment.