Skip to content

Commit

Permalink
Make action space dynamcially & Train IPPO with multi-discrete + delt…
Browse files Browse the repository at this point in the history
…aLocal (#13)

* Ongoing: implement RL with multi-discrete

* [Ongoing] Debug

* Debug for training 1 scene

* Run multi-discrete space

* Change action space dynamically
  • Loading branch information
KevinJeon authored and KKGB committed Oct 22, 2024
1 parent 2a544a4 commit c20a363
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
8 changes: 7 additions & 1 deletion algorithms/sb3/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import wandb
from stable_baselines3.common.callbacks import BaseCallback
from time import perf_counter
import gymnasium as gym
EPISODE_LENGTH = 91


class MultiAgentCallback(BaseCallback):
Expand Down Expand Up @@ -177,8 +179,12 @@ def _create_and_log_video(
"""Make a video and log to wandb."""
policy = self.model
base_env = self.locals["env"]._env
if isinstance(base_env.action_space, gym.spaces.Discrete):
action_size = (base_env.num_worlds, base_env.max_agent_count, )
else:
action_size = (base_env.num_worlds, base_env.max_agent_count, 3)
action_tensor = torch.zeros(
(base_env.num_worlds, base_env.max_agent_count, 3) # todo: fix the dim
action_size
)

obs = base_env.reset()
Expand Down
4 changes: 3 additions & 1 deletion algorithms/sb3/ppo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stable_baselines3.common.utils import get_schedule_fn
from stable_baselines3.common.vec_env import VecEnv
from torch import nn
import gymnasium as gym

# Import masked rollout buffer class
from algorithms.sb3.rollout_buffer import MaskedRolloutBuffer
Expand Down Expand Up @@ -94,8 +95,9 @@ def collect_rollouts(

# EDIT_1: Mask out invalid observations (NaN axes and/or dead agents)
# Create dummy actions, values and log_probs (NaN)
action_size = (self.n_envs, ) if isinstance(env.action_space, gym.spaces.Discrete) else (self.n_envs, 3)
actions = torch.full(
fill_value=float("nan"), size=(self.n_envs, 3) # todo: should change based on action space
fill_value=float("nan"), size=action_size# todo: should change based on action space
).to(self.device)
log_probs = torch.full(
fill_value=float("nan"),
Expand Down
3 changes: 2 additions & 1 deletion algorithms/sb3/rollout_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ def get(

# Flatten data
# EDIT_5: And mask out invalid samples
tensors_ = ["observations"] if len(self.actions.shape) == 3 else ["observations", "actions"]
for tensor in _tensor_names:
if tensor in ["observations", "actions"]:
if tensor in tensors_:
self.__dict__[tensor] = self.swap_and_flatten(
self.__dict__[tensor]
)[self.valid_samples_mask.flatten(), :]
Expand Down
4 changes: 2 additions & 2 deletions baselines/ippo/run_sb3_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def train(exp_config: ExperimentConfig, scene_config: SceneConfig, action_type:

# CONFIG
env_config = EnvConfig(
dynamics_model="delta_local",
dynamics_model="classic",
dx=torch.round(
torch.linspace(-6.0, 6.0, 20), decimals=3
),
Expand Down Expand Up @@ -148,4 +148,4 @@ def train(exp_config: ExperimentConfig, scene_config: SceneConfig, action_type:
k_unique_scenes=exp_config.k_unique_scenes,
)

train(exp_config, scene_config, action_type="multi_discrete")
train(exp_config, scene_config, action_type="discrete")

0 comments on commit c20a363

Please sign in to comment.