Skip to content

Commit

Permalink
Getting ready for running RL with multi-discrete, delta_local action …
Browse files Browse the repository at this point in the history
…space (#11)

* Ongoing: implement RL with multi-discrete

* Debug for training 1 scene

* Run multi-discrete space
  • Loading branch information
KevinJeon authored and KKGB committed Oct 22, 2024
1 parent 74989f8 commit 2a544a4
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 26 deletions.
2 changes: 1 addition & 1 deletion algorithms/sb3/callbacks.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _create_and_log_video(
policy = self.model
base_env = self.locals["env"]._env
action_tensor = torch.zeros(
(base_env.num_worlds, base_env.max_agent_count)
(base_env.num_worlds, base_env.max_agent_count, 3) # todo: fix the dim
)

obs = base_env.reset()
Expand Down
19 changes: 9 additions & 10 deletions algorithms/sb3/ppo/ippo.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def collect_rollouts(

while n_steps < n_rollout_steps:
if (
self.use_sde
and self.sde_sample_freq > 0
and n_steps % self.sde_sample_freq == 0
self.use_sde
and self.sde_sample_freq > 0
and n_steps % self.sde_sample_freq == 0
):
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
Expand All @@ -95,21 +95,21 @@ def collect_rollouts(
# EDIT_1: Mask out invalid observations (NaN axes and/or dead agents)
# Create dummy actions, values and log_probs (NaN)
actions = torch.full(
fill_value=float("nan"), size=(self.n_envs,)
fill_value=float("nan"), size=(self.n_envs, 3) # todo: should change based on action space
).to(self.device)
log_probs = torch.full(
fill_value=float("nan"),
size=(self.n_envs,),
size=(self.n_envs, ),
dtype=torch.float32,
).to(self.device)
values = (
torch.full(
fill_value=float("nan"),
size=(self.n_envs,),
size=(self.n_envs, ),
dtype=torch.float32,
)
.unsqueeze(dim=1)
.to(self.device)
.unsqueeze(dim=1)
.to(self.device)
)

# Get indices of alive agent ids
Expand All @@ -131,10 +131,9 @@ def collect_rollouts(
obs_tensor_alive
)
nn_fps = actions_tmp.shape[0] / (
time.perf_counter() - time_actions
time.perf_counter() - time_actions
)
self.logger.record("rollout/nn_fps", nn_fps)

# Predict actions, vals and log_probs given obs
(
actions[alive_agent_mask.squeeze(dim=1)],
Expand Down
18 changes: 16 additions & 2 deletions algorithms/sb3/rollout_buffer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,21 @@ def compute_returns_and_advantage(
self.advantages
).any(), "Advantages arr contains NaN values: Check GAE computation"

# def swap_and_flatten(self, arr: np.ndarray) -> np.ndarray:
# """
# Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
# to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
# to [n_steps * n_envs, ...] (which maintain the order)
#
# :param arr:
# :return:
# """
# shape = arr.shape
# print(shape)
# if len(shape) < 3:
# shape = (*shape, 1)
# return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])

def get(
self, batch_size: Optional[int] = None
) -> Generator[RolloutBufferSamples, None, None]:
Expand All @@ -200,7 +215,7 @@ def get(
# Flatten data
# EDIT_5: And mask out invalid samples
for tensor in _tensor_names:
if tensor == "observations":
if tensor in ["observations", "actions"]:
self.__dict__[tensor] = self.swap_and_flatten(
self.__dict__[tensor]
)[self.valid_samples_mask.flatten(), :]
Expand All @@ -212,7 +227,6 @@ def get(
assert not torch.isnan(
self.__dict__[tensor]
).any(), f"{tensor} tensor contains NaN values; something went wrong"

self.generator_ready = True

# EDIT_6: Compute total number of samples and create indices
Expand Down
7 changes: 3 additions & 4 deletions baselines/ippo/config.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ class ExperimentConfig:
"""Configurations for experiments."""

# DATASET
data_dir: str = "data/processed/examples"
data_dir: str = "/data/formatted_json_v2_no_tl_train/" #todo: to be changed

# NUM PARALLEL ENVIRONMENTS & DEVICE
num_worlds: int = 50 # Number of parallel environmentss

num_worlds: int = 1 # Number of parallel environments
# How to select scenes from the dataset
selection_discipline = SelectionDiscipline.K_UNIQUE_N # K_UNIQUE_N / PAD_N
k_unique_scenes: int = 3
Expand All @@ -31,7 +30,7 @@ class ExperimentConfig:
render: bool = True
render_mode: str = "rgb_array"
render_freq: int = 50 # Render every k rollouts
render_n_worlds: int = 3 # Number of worlds to render
render_n_worlds: int = 1 # Number of worlds to render

# TRACK THE TIME IT TAKES TO GET TO 95% GOAL RATE
track_time_to_solve: bool = False
Expand Down
24 changes: 20 additions & 4 deletions baselines/ippo/run_sb3_ppo.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,31 @@ def func(progress_remaining: float) -> float:
return func


def train(env_config: EnvConfig, exp_config: ExperimentConfig, scene_config: SceneConfig, action_type: str = "discrete"):
def train(exp_config: ExperimentConfig, scene_config: SceneConfig, action_type: str = "discrete"):
"""Run PPO training with stable-baselines3."""

# CONFIG
env_config = EnvConfig(
dynamics_model="delta_local",
dx=torch.round(
torch.linspace(-6.0, 6.0, 20), decimals=3
),
dy=torch.round(
torch.linspace(-6.0, 6.0, 20), decimals=3
),
dyaw=torch.round(
torch.linspace(-np.pi, np.pi, 20), decimals=3
),
)

# MAKE SB3-COMPATIBLE ENVIRONMENT
env = SB3MultiAgentEnv(
config=env_config,
scene_config=scene_config,
# Control up to all agents in the scene
max_cont_agents=env_config.max_num_agents_in_scene,
device=exp_config.device,
action_type=action_type
)

# SET MINIBATCH SIZE BASED ON ROLLOUT LENGTH
Expand Down Expand Up @@ -72,6 +88,7 @@ def train(env_config: EnvConfig, exp_config: ExperimentConfig, scene_config: Sce
custom_callback = MultiAgentCallback(
config=exp_config,
wandb_run=run if run_id is not None else None,
# wandb_run=None,
)

# INITIALIZE IPPO
Expand Down Expand Up @@ -104,12 +121,11 @@ def train(env_config: EnvConfig, exp_config: ExperimentConfig, scene_config: Sce
callback=custom_callback,
)

run.finish()
# run.finish()
env.close()


if __name__ == "__main__":

exp_config = pyrallis.parse(config_class=ExperimentConfig)

env_config = EnvConfig(
Expand All @@ -132,4 +148,4 @@ def train(env_config: EnvConfig, exp_config: ExperimentConfig, scene_config: Sce
k_unique_scenes=exp_config.k_unique_scenes,
)

train(env_config, exp_config, scene_config, action_type="discrete")
train(exp_config, scene_config, action_type="multi_discrete")
1 change: 0 additions & 1 deletion pygpudrive/env/env_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(

# Initialize simulator with parameters
self.sim = self._initialize_simulator(params, scene_config)

# Controlled agents setup
self.cont_agent_mask = self.get_controlled_agents_mask()
self.max_agent_count = self.cont_agent_mask.shape[1]
Expand Down
17 changes: 13 additions & 4 deletions pygpudrive/env/wrappers/sb3_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
scene_config,
max_cont_agents,
device,
action_type,
render_mode="rgb_array",
):
kwargs={
Expand All @@ -47,7 +48,8 @@ def __init__(
self.num_envs = self._env.cont_agent_mask.sum().item()
self.device = device
self.controlled_agent_mask = self._env.cont_agent_mask.clone()
self.action_space = gym.spaces.Discrete(self._env.action_space.n)
self.action_space = self._env.action_space
print(f'wrapper action space {self.action_space} {action_type}')
self.observation_space = gym.spaces.Box(
-np.inf, np.inf, self._env.observation_space.shape, np.float32
)
Expand All @@ -57,9 +59,7 @@ def __init__(
self.agent_step = torch.zeros(
(self.num_worlds, self.max_agent_count)
).to(self.device)
self.actions_tensor = torch.zeros(
(self.num_worlds, self.max_agent_count)
).to(self.device)
self._set_action_tensor(action_type, config.dynamics_model)
# Storage: Fill buffer with nan values
self.buf_rews = torch.full(
(self.num_worlds, self.max_agent_count), fill_value=float("nan")
Expand All @@ -73,11 +73,20 @@ def __init__(
).to(self.device)

self.num_episodes = 0
self.info_dict = {
"off_road": 0,
"veh_collisions": 0,
"non_veh_collision": 0,
"goal_achieved": 0,
}

def _reset_seeds(self) -> None:
"""Reset all environments' seeds."""
self._seeds = None

def _set_action_tensor(self, action_type, dynamics_model):
pass

def reset(self, world_idx=None, seed=None):
"""Reset environment and return initial observations.
Expand Down

0 comments on commit 2a544a4

Please sign in to comment.