Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements, env cleanup and scripts #108

Merged
merged 23 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 58 additions & 33 deletions algorithms/ppo/sb3/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
super().__init__(**kwargs)
self.config = config
self.wandb_run = wandb_run
self.num_rollouts = 0

def _on_training_start(self) -> None:
"""
Expand All @@ -42,77 +43,101 @@ def _on_rollout_end(self) -> None:
Triggered before updating the policy.
"""

# Get the total number of controlled agents we are controlling
# The number of controllable agents is different per scenario
num_episodes_in_rollout = self.locals["env"].num_episodes
num_controlled_agents = self.locals["env"]._tot_controlled_valid_agents

# Filter out all nans
rewards = np.nan_to_num(
rollout_rewards = np.nan_to_num(
(self.locals["rollout_buffer"].rewards.cpu().detach().numpy()),
nan=0,
)

# Get the total number of controlled agents we are controlling
# The number of controllable agents is different per scenario
num_controlled_agents = self.locals[
"env"
]._get_sum_controlled_valid_agents

print(f"num_controlled_agents: {num_controlled_agents}")

# Number of episodes in the rollout
num_episodes_in_rollout = (
np.nan_to_num(
(
self.locals["rollout_buffer"]
.episode_starts.cpu()
.detach()
.numpy()
),
nan=0,
).sum()
/ num_controlled_agents
mean_reward_per_agent_per_episode = rollout_rewards.sum() / (
num_episodes_in_rollout * num_controlled_agents
)

mean_reward_per_agent_per_episode = (
rewards.sum() / num_episodes_in_rollout / num_controlled_agents
)

observations = (
rollout_observations = (
self.locals["rollout_buffer"].observations.cpu().detach().numpy()
)

# Average info across agents and episodes
rollout_info = self.locals["env"].infos
for key, value in rollout_info.items():
self.locals["env"].infos[key] = value / (
num_episodes_in_rollout * num_controlled_agents
)
self.logger.record(f"rollout/{key}", self.locals["env"].infos[key])

# Log
self.logger.record("rollout/global_step", self.num_timesteps)
self.logger.record(
"rollout/num_episodes_in_rollout",
num_episodes_in_rollout.item(),
num_episodes_in_rollout,
)
self.logger.record("rollout/sum_reward", rewards.sum())
self.logger.record("rollout/sum_reward", rollout_rewards.sum())
self.logger.record(
"rollout/avg_reward", mean_reward_per_agent_per_episode.item()
)
self.logger.record("rollout/obs_max", observations.max())
self.logger.record("rollout/obs_min", observations.min())
self.logger.record("rollout/obs_max", rollout_observations.max())
self.logger.record("rollout/obs_min", rollout_observations.min())

# Render the environment
if self.config.render:
self._create_and_log_video()
if self.num_rollouts % self.config.render_freq == 0:
self._create_and_log_video()

self.num_rollouts += 1

def _create_and_log_video(self):
def _batchify_and_filter_obs(self, obs, env, render_world_idx=0):
# Unsqueeze
obs = obs.reshape((env.num_worlds, env.max_agent_count, -1))

# Only select obs for the render env
obs = obs[render_world_idx, :, :]

return obs[env.controlled_agent_mask[render_world_idx, :]]

def _pad_actions(self, pred_actions, env, render_world_idx):
"""Currently we're only rendering the 0th world index."""

actions = torch.full(
(env.num_worlds, env.max_agent_count), fill_value=float("nan")
).to("cpu")

world_cont_agent_mask = env.controlled_agent_mask[
render_world_idx, :
].to("cpu")

actions[render_world_idx, :][world_cont_agent_mask] = torch.Tensor(
pred_actions
).to("cpu")
return actions

def _create_and_log_video(self, render_world_idx=0):
"""Make a video and log to wandb.
Note: Currently only works a single world."""
policy = self.model
env = self.locals["env"]

obs = env.reset()
obs = self._batchify_and_filter_obs(obs, env)

frames = []

for _ in range(90):

action, _ = policy.predict(obs.detach().cpu().numpy())
action = self._pad_actions(action, env, render_world_idx)

# Step the environment
obs, _, _, _ = env.step(action)
obs = self._batchify_and_filter_obs(obs, env)

frame = env.render()
frames.append(frame.T)

frames = np.array(frames)

wandb.log({"video": wandb.Video(frames, fps=5, format="gif")})
wandb.log({"video": wandb.Video(frames, fps=10, format="gif")})
6 changes: 4 additions & 2 deletions algorithms/ppo/sb3/mappo.py → algorithms/ppo/sb3/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def explained_variance(
return torch.nan if var_y == 0 else 1 - torch.var(y_true - y_pred) / var_y


class MAPPO(PPO):
class IPPO(PPO):
"""Adapted Proximal Policy Optimization algorithm (PPO) that is compatible with multi-agent environments."""

def __init__(
Expand Down Expand Up @@ -182,9 +182,11 @@ def collect_rollouts(
)

callback.update_locals(locals())

callback.on_rollout_end()

# Reset logger info (num_episodes and infos)
env._reset_rollout_loggers()

return True

def _setup_model(self) -> None:
Expand Down
33 changes: 18 additions & 15 deletions algorithms/ppo/sb3/rollout_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
device: Union[torch.device, str] = "auto",
storage_device: Union[torch.device, str] = "cpu", #TODO(ev) add storage device to config
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
Expand All @@ -39,48 +40,49 @@ def __init__(
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.storage_device = storage_device
self.reset()

def reset(self) -> None:
"""Reset the buffer."""
self.observations = torch.zeros(
(self.buffer_size, self.n_envs, *self.obs_shape),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.actions = torch.zeros(
(self.buffer_size, self.n_envs, self.action_dim),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.rewards = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.returns = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.episode_starts = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.values = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.log_probs = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.advantages = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.generator_ready = False
Expand Down Expand Up @@ -110,12 +112,12 @@ def add(
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.observations[self.pos] = obs
self.actions[self.pos] = action
self.rewards[self.pos] = reward
self.episode_starts[self.pos] = episode_start
self.values[self.pos] = value.flatten()
self.log_probs[self.pos] = log_prob.clone()
self.observations[self.pos] = obs.to(self.storage_device)
self.actions[self.pos] = action.to(self.storage_device)
self.rewards[self.pos] = reward.to(self.storage_device)
self.episode_starts[self.pos] = episode_start.to(self.storage_device)
self.values[self.pos] = value.flatten().to(self.storage_device)
self.log_probs[self.pos] = log_prob.clone().to(self.storage_device)
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
Expand All @@ -125,7 +127,8 @@ def compute_returns_and_advantage(
) -> None:
"""GAE (General Advantage Estimation) to compute advantages and returns."""
# Convert to numpy
last_values = last_values.clone().flatten()
last_values = last_values.clone().flatten().to(self.storage_device)
dones = dones.clone().flatten().to(self.storage_device)

last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
Expand Down
26 changes: 15 additions & 11 deletions baselines/config.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@

from networks.basic_ffn import FeedForwardPolicy
from dataclasses import dataclass
import torch


@dataclass
class ExperimentConfig:
"""
Configurations for experiments.
"""
"""Configurations for experiments."""

# General
device: str = "cuda"

# Rendering options
render: bool = False
render_mode: str = "rgb_array"
render_freq: int = 1
render_freq: int = 10

# TODO: Logging
log_dir: str = "logs"

# Hyperparameters
policy: str = "MlpPolicy"
policy: torch.nn.Module = FeedForwardPolicy
seed: int = 42
n_steps: int = 2048
batch_size: int = 256
verbose: int = 0
n_steps: int = 180
batch_size: int = 180
verbose: int = 0
total_timesteps: int = 50_000_000
28 changes: 17 additions & 11 deletions baselines/run_ppo_sb3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import wandb
import torch
import torch

# Import the EnvConfig dataclass
from pygpudrive.env.config import EnvConfig
Expand All @@ -9,35 +11,38 @@
from algorithms.ppo.sb3.callbacks import MultiAgentCallback

# Import adapted PPO version
from algorithms.ppo.sb3.mappo import MAPPO
from algorithms.ppo.sb3.ippo import IPPO

from baselines.config import ExperimentConfig

torch.cuda.empty_cache()

if __name__ == "__main__":

env_config = EnvConfig(
ego_state=True,
road_map_obs=False,
road_map_obs=True,
partner_obs=True,
normalize_obs=False,
norm_obs=False,
sample_method="rand_n",
)

exp_config = ExperimentConfig(
render=False,
render=True,
)

# Make SB3-compatible environment
env = SB3MultiAgentEnv(
config=env_config,
num_worlds=2,
max_cont_agents=10,
data_dir="waymo_data",
device="cuda",
max_cont_agents=128,
data_dir="formatted_json_v2_no_tl_train",
device=exp_config.device,
)

run = wandb.init(
project="rl_benchmarking",
group="different_scenes",
project="rl_bench",
group="render_test",
sync_tensorboard=True,
)
run_id = run.id
Expand All @@ -48,21 +53,22 @@
wandb_run=run if run_id is not None else None,
)

model = MAPPO(
model = IPPO(
policy=exp_config.policy,
n_steps=exp_config.n_steps,
batch_size=exp_config.batch_size,
env=env,
seed=exp_config.seed,
verbose=exp_config.verbose,
device=exp_config.device,
tensorboard_log=f"runs/{run_id}"
if run_id is not None
else None, # Sync with wandb
)

# Learn
model.learn(
total_timesteps=10_000_000,
total_timesteps=exp_config.total_timesteps,
callback=custom_callback,
)

Expand Down
Loading
Loading