From c20a363ad6017f66dd67770b8aa214f198015205 Mon Sep 17 00:00:00 2001 From: Hyeonchang Jeon Date: Sat, 12 Oct 2024 04:34:29 +0900 Subject: [PATCH] Make action space dynamcially & Train IPPO with multi-discrete + deltaLocal (#13) * Ongoing: implement RL with multi-discrete * [Ongoing] Debug * Debug for training 1 scene * Run multi-discrete space * Change action space dynamically --- algorithms/sb3/callbacks.py | 8 +++++++- algorithms/sb3/ppo/ippo.py | 4 +++- algorithms/sb3/rollout_buffer.py | 3 ++- baselines/ippo/run_sb3_ppo.py | 4 ++-- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/algorithms/sb3/callbacks.py b/algorithms/sb3/callbacks.py index 46f4d91e..246fc33f 100755 --- a/algorithms/sb3/callbacks.py +++ b/algorithms/sb3/callbacks.py @@ -5,6 +5,8 @@ import wandb from stable_baselines3.common.callbacks import BaseCallback from time import perf_counter +import gymnasium as gym +EPISODE_LENGTH = 91 class MultiAgentCallback(BaseCallback): @@ -177,8 +179,12 @@ def _create_and_log_video( """Make a video and log to wandb.""" policy = self.model base_env = self.locals["env"]._env + if isinstance(base_env.action_space, gym.spaces.Discrete): + action_size = (base_env.num_worlds, base_env.max_agent_count, ) + else: + action_size = (base_env.num_worlds, base_env.max_agent_count, 3) action_tensor = torch.zeros( - (base_env.num_worlds, base_env.max_agent_count, 3) # todo: fix the dim + action_size ) obs = base_env.reset() diff --git a/algorithms/sb3/ppo/ippo.py b/algorithms/sb3/ppo/ippo.py index 323a8a84..4f8872e1 100755 --- a/algorithms/sb3/ppo/ippo.py +++ b/algorithms/sb3/ppo/ippo.py @@ -10,6 +10,7 @@ from stable_baselines3.common.utils import get_schedule_fn from stable_baselines3.common.vec_env import VecEnv from torch import nn +import gymnasium as gym # Import masked rollout buffer class from algorithms.sb3.rollout_buffer import MaskedRolloutBuffer @@ -94,8 +95,9 @@ def collect_rollouts( # EDIT_1: Mask out invalid observations (NaN axes and/or dead agents) # Create dummy actions, values and log_probs (NaN) + action_size = (self.n_envs, ) if isinstance(env.action_space, gym.spaces.Discrete) else (self.n_envs, 3) actions = torch.full( - fill_value=float("nan"), size=(self.n_envs, 3) # todo: should change based on action space + fill_value=float("nan"), size=action_size# todo: should change based on action space ).to(self.device) log_probs = torch.full( fill_value=float("nan"), diff --git a/algorithms/sb3/rollout_buffer.py b/algorithms/sb3/rollout_buffer.py index 8f8c903d..778b8d50 100755 --- a/algorithms/sb3/rollout_buffer.py +++ b/algorithms/sb3/rollout_buffer.py @@ -214,8 +214,9 @@ def get( # Flatten data # EDIT_5: And mask out invalid samples + tensors_ = ["observations"] if len(self.actions.shape) == 3 else ["observations", "actions"] for tensor in _tensor_names: - if tensor in ["observations", "actions"]: + if tensor in tensors_: self.__dict__[tensor] = self.swap_and_flatten( self.__dict__[tensor] )[self.valid_samples_mask.flatten(), :] diff --git a/baselines/ippo/run_sb3_ppo.py b/baselines/ippo/run_sb3_ppo.py index 93a31beb..6d385bca 100755 --- a/baselines/ippo/run_sb3_ppo.py +++ b/baselines/ippo/run_sb3_ppo.py @@ -39,7 +39,7 @@ def train(exp_config: ExperimentConfig, scene_config: SceneConfig, action_type: # CONFIG env_config = EnvConfig( - dynamics_model="delta_local", + dynamics_model="classic", dx=torch.round( torch.linspace(-6.0, 6.0, 20), decimals=3 ), @@ -148,4 +148,4 @@ def train(exp_config: ExperimentConfig, scene_config: SceneConfig, action_type: k_unique_scenes=exp_config.k_unique_scenes, ) - train(exp_config, scene_config, action_type="multi_discrete") + train(exp_config, scene_config, action_type="discrete")