Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic files world reset #265

Merged
merged 18 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions algorithms/sb3/ppo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
self.exp_config = exp_config
self.mlp_class = mlp_class
self.mlp_config = mlp_config
self.resample_counter = 0
super().__init__(*args, **kwargs)

def collect_rollouts(
Expand All @@ -67,6 +68,33 @@ def collect_rollouts(
assert (
self._last_obs is not None
), "No previous observation was provided"

# Check resampling criterion and resample batch of scenarios if needed
if self.env.exp_config.resample_scenarios:
if self.env.exp_config.resample_criterion == "global_step":
if self.resample_counter >= self.env.exp_config.resample_freq:
print(
f"Resampling {self.env.num_worlds} scenarios at global_step {self.num_timesteps:,}..."
)
# Re-initialize the scenes and controlled agents mask
self.env.resample_scenario_batch()
self.resample_counter = 0
# Get new initial observation
self._last_obs = self.env.reset()
# Update storage shapes
self.n_envs = env.num_valid_controlled_agents_across_worlds
rollout_buffer.n_envs = self.n_envs
self._last_episode_starts = (
self.env._env.get_dones().clone()[
~self.env.dead_agent_mask
]
)

else:
raise NotImplementedError(
f"Resampling criterion {self.env.exp_config.resample_criterion} not implemented"
)

# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)

Expand Down Expand Up @@ -165,16 +193,11 @@ def collect_rollouts(

new_obs, rewards, dones, infos = env.step(clipped_actions)

# # (dc) DEBUG
# mask = ~torch.isnan(rewards)
# if (
# self._last_obs[mask].max() > 1
# or self._last_obs[mask].min() < -1
# ):
# logging.error("New observation is out of bounds")

# EDIT_2: Increment the global step by the number of valid samples in rollout step
self.num_timesteps += int((~rewards.isnan()).float().sum().item())
self.resample_counter += int(
(~rewards.isnan()).float().sum().item()
)
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
Expand All @@ -197,7 +220,6 @@ def collect_rollouts(
self._last_episode_starts = dones

# # # # # END LOOP # # # # #

total_steps = self.n_envs * n_rollout_steps
elapsed_time = time.perf_counter() - time_rollout
fps = total_steps / elapsed_time
Expand Down
20 changes: 20 additions & 0 deletions baselines/ippo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,26 @@ To run the multi-agent IPPO baseline using stable-baselines 3 (SB3):
python baselines/ippo/run_sb3_ppo.py
```

### Resampling the data

The configuration for resampling traffic scenarios includes:

- **`resample_scenarios`**: A boolean that enables or disables traffic scenario resampling when set to `True`.

- **`resample_criterion`**: Set to `"global_step"`, indicating resampling occurs based on the global step count.

- **`resample_freq`**: Specifies resampling frequency at `50,000` steps, recommended to align with `num_worlds * n_steps`.

- **`resample_mode`**: Set to `"random"` for random selection of new scenarios.

```
# RESAMPLE TRAFFIC SCENARIOS
resample_scenarios: bool = True
resample_criterion: str = "global_step" # Options: "global_step"
resample_freq: int = 100_000 # Resample every k steps (recommended to be a multiple of num_worlds * n_steps)
resample_mode: str = "random" # Options: "random"
```

## Implemented networks

### Classic Observations
Expand Down
10 changes: 8 additions & 2 deletions baselines/ippo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ExperimentConfig:
"""Configurations for experiments."""

# DATASET
data_dir: str = "data/examples"
data_dir: str = "data/processed/examples"

# NUM PARALLEL ENVIRONMENTS & DEVICE
num_worlds: int = 50 # Number of parallel environmentss
Expand All @@ -27,6 +27,12 @@ class ExperimentConfig:
goal_achieved_weight: float = 1.0
off_road_weight: float = 0.0

# RESAMPLE TRAFFIC SCENARIOS
resample_scenarios: bool = True
resample_criterion: str = "global_step" # Options: "global_step"
resample_freq: int = 100_000 # Resample every k steps (recommended to be a multiple of num_worlds * n_steps)
daphnecor marked this conversation as resolved.
Show resolved Hide resolved
resample_mode: str = "random" # Options: "random"

# RENDERING
render: bool = True
render_mode: str = "rgb_array"
Expand Down Expand Up @@ -63,7 +69,7 @@ class ExperimentConfig:
gae_lambda: float = 0.95
clip_range: float = 0.2
vf_coef: float = 0.5
n_steps: int = 91
n_steps: int = 91 # Number of steps per rollout
num_minibatches: int = 5 # Used to determine the minibatch size
verbose: int = 0
total_timesteps: int = 2e7
Expand Down
1 change: 1 addition & 0 deletions baselines/ippo/run_sb3_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def train(exp_config: ExperimentConfig, scene_config: SceneConfig):
env = SB3MultiAgentEnv(
config=env_config,
scene_config=scene_config,
exp_config=exp_config,
# Control up to all agents in the scene
max_cont_agents=env_config.max_num_agents_in_scene,
device=exp_config.device,
Expand Down
15 changes: 15 additions & 0 deletions pygpudrive/env/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,21 @@ The `SceneConfig` dataclass is used to configure how scenes are selected from a
- `discipline`: The method for selecting scenes, defaulting to `SelectionDiscipline.PAD_N`. (See options in Table below)
- `k_unique_scenes`: Specifies the number of unique scenes to select, if applicable.

### Resampling traffic scenarios

The `reinit_scenarios` function in `pygpudrive/env/base_env.py` reinitializes the simulator with new traffic scenarios as follows:

1. **Scene re-initialization:**
This function updates the simulation maps by calling `self.sim.set_maps(dataset)`, replacing the current scenes with those provided `dataset`, which should be a list with paths to traffic scenarios.

2. **Controlled agent mask re-nitialization:**
It reinitializes the controlled agents' mask using `self.get_controlled_agents_mask()`, determining which agents are user-controlled (this will change based on the specific traffic scenes used).

3. **Agent count update:**
The function updates `self.max_agent_count` to reflect the number of controlled agents and recomputes `self.num_valid_controlled_agents_across_worlds`, indicating the total active controlled agents across all scenarios.

See the `resample_scenario_batch()` method in `pygpudrive/env/wrappers/sb3_wrapper.py` for an example of how you can use this function with IPPO.

## Render

Render settings can be changed using the `RenderConfig`.
Expand Down
25 changes: 22 additions & 3 deletions pygpudrive/env/base_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List, Optional
import gymnasium as gym
from pygpudrive.env.config import RenderConfig, RenderMode
from pygpudrive.env.viz import PyGameVisualizer
Expand Down Expand Up @@ -57,7 +58,10 @@ def _set_reward_params(self):
"""
reward_params = gpudrive.RewardParams()

if self.config.reward_type == "sparse_on_goal_achieved" or self.config.reward_type == "weighted_combination":
if (
self.config.reward_type == "sparse_on_goal_achieved"
or self.config.reward_type == "weighted_combination"
):
reward_params.rewardType = gpudrive.RewardType.OnGoalAchieved
else:
raise ValueError(f"Invalid reward type: {self.config.reward_type}")
Expand Down Expand Up @@ -119,8 +123,10 @@ def _setup_environment_parameters(self):

if self.config.lidar_obs:
if not self.config.lidar_obs and self.config.disable_classic_obs:
raise ValueError("Lidar observations must be enabled if classic observations are disabled.")

raise ValueError(
"Lidar observations must be enabled if classic observations are disabled."
)

else:
params.enableLidar = self.config.lidar_obs
params.disableClassicalObs = self.config.disable_classic_obs
Expand Down Expand Up @@ -259,6 +265,19 @@ def render(self, world_render_idx=0, color_objects_by_actor=None):
}:
return self.visualizer.getRender()

def reinit_scenarios(self, dataset: List[str]):
"""Resample the scenes."""

# Resample the scenes
self.sim.set_maps(dataset)

# Re-initialize the controlled agents mask
self.cont_agent_mask = self.get_controlled_agents_mask()
self.max_agent_count = self.cont_agent_mask.shape[1]
self.num_valid_controlled_agents_across_worlds = (
self.cont_agent_mask.sum().item()
)

def close(self):
"""Destroy the simulator and visualizer."""
del self.sim
Expand Down
2 changes: 1 addition & 1 deletion pygpudrive/env/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class EnvConfig:

# Reward settings
reward_type: str = (
"sparse_on_goal_achieved" # Alternatively, "weighted_combination"
"sparse_on_goal_achieved" # Alternatively, "weighted_combination"
)

dist_to_goal_threshold: float = (
Expand Down
1 change: 1 addition & 0 deletions pygpudrive/env/env_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
):
# Initialization of environment configurations
self.config = config
self.scene_config = scene_config
self.num_worlds = scene_config.num_scenes
self.max_cont_agents = max_cont_agents
self.device = device
Expand Down
50 changes: 50 additions & 0 deletions pygpudrive/env/wrappers/sb3_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
from typing import Optional, Sequence
import torch
import os
import gymnasium as gym
import random
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import (
VecEnv,
Expand All @@ -25,6 +27,7 @@ class SB3MultiAgentEnv(VecEnv):
def __init__(
self,
config,
exp_config,
scene_config,
max_cont_agents,
device,
Expand All @@ -37,6 +40,13 @@ def __init__(
device=device,
)
self.config = config
self.exp_config = exp_config
self.all_scene_paths = [
os.path.join(self.exp_config.data_dir, scene)
for scene in sorted(os.listdir(self.exp_config.data_dir))
if scene.startswith("tfrecord")
]
self.unique_scene_paths = list(set(self.all_scene_paths))
self.num_worlds = self._env.num_worlds
self.max_agent_count = self._env.max_agent_count
self.num_envs = self._env.cont_agent_mask.sum().item()
Expand Down Expand Up @@ -198,6 +208,46 @@ def seed(self, seed=None):
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds

def resample_scenario_batch(self):
"""Swap out the dataset."""
if self.exp_config.resample_mode == "random":
total_unique = len(self.unique_scene_paths)

# Check if N is greater than the number of unique scenes
if self.num_worlds <= total_unique:
dataset = random.sample(
self.unique_scene_paths, self.num_worlds
)

# If N is greater, repeat the unique scenes until we get N scenes
dataset = []
while len(dataset) < self.num_worlds:
dataset.extend(
random.sample(self.unique_scene_paths, total_unique)
)
if len(dataset) > self.num_worlds:
dataset = dataset[
: self.num_worlds
] # Trim the result to N scenes
else:
raise NotImplementedError(
f"Resample mode {self.exp_config.resample_mode} is currently not supported."
)

# Re-initialize the simulator with the new dataset
print(
f"Re-initializing sim with {len(set(dataset))} {self.exp_config.resample_mode} unique scenes.\n"
)
self._env.reinit_scenarios(dataset)

# Update controlled agent mask
self.controlled_agent_mask = self._env.cont_agent_mask.clone()
self.max_agent_count = self.controlled_agent_mask.shape[1]
self.num_valid_controlled_agents_across_worlds = (
self.controlled_agent_mask.sum().item()
)
self.num_envs = self.controlled_agent_mask.sum().item()
daphnecor marked this conversation as resolved.
Show resolved Hide resolved

def _update_info_dict(self, info, indices) -> None:
"""Update the info logger."""

Expand Down
3 changes: 2 additions & 1 deletion src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ namespace gpudrive
.def("rgb_tensor", &Manager::rgbTensor)
.def("depth_tensor", &Manager::depthTensor)
.def("response_type_tensor", &Manager::responseTypeTensor)
.def("expert_trajectory_tensor", &Manager::expertTrajectoryTensor);
.def("expert_trajectory_tensor", &Manager::expertTrajectoryTensor)
.def("set_maps", &Manager::setMaps);
}

}
6 changes: 3 additions & 3 deletions src/headless.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ int main(int argc, char *argv[])
}

uint64_t num_steps = std::stoul(argv[2]);
std::vector<std::string> scenes = {"../data/examples/tfrecord-00001-of-01000_307.json",
"../data/examples/tfrecord-00003-of-01000_109.json",
"../data/examples/tfrecord-00012-of-01000_389.json"};
std::vector<std::string> scenes = {"../data/processed/examples/tfrecord-00001-of-01000_307.json",
"../data/processed/examples/tfrecord-00003-of-01000_109.json",
"../data/processed/examples/tfrecord-00012-of-01000_389.json"};
uint64_t num_worlds = scenes.size();

bool rand_actions = false;
Expand Down
Loading
Loading