Skip to content

Commit

Permalink
Update: apply action (discrete, multi-discrete)
Browse files Browse the repository at this point in the history
  • Loading branch information
KKGB committed Oct 23, 2024
1 parent e03d5c2 commit 573276e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
4 changes: 2 additions & 2 deletions baselines/ippo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class ExperimentConfig:
data_dir: str = "/data/formatted_json_v2_no_tl_train/" #todo: to be changed

# NUM PARALLEL ENVIRONMENTS & DEVICE
num_worlds: int = 1 # Number of parallel environments
num_worlds: int = 50 # Number of parallel environments
# How to select scenes from the dataset
selection_discipline = SelectionDiscipline.K_UNIQUE_N # K_UNIQUE_N / PAD_N
selection_discipline = SelectionDiscipline.PAD_N # K_UNIQUE_N / PAD_N
k_unique_scenes: int = 3
device: str = "cuda" # or "cpu"

Expand Down
8 changes: 6 additions & 2 deletions pygpudrive/env/env_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,9 @@ def _set_discrete_action_space(self) -> None:
def step_dynamics(self, actions, use_indices=True):
if actions is not None:
if use_indices:
actions = actions.squeeze(dim=2).long().to(self.device) if actions.dim() == 3 else actions.long().to(self.device)
actions = (
torch.nan_to_num(actions, nan=0).long().to(self.device)
)
action_value_tensor = self.action_keys_tensor[actions]
else:
action_value_tensor = torch.nan_to_num(actions, nan=0).float().to(self.device)
Expand Down Expand Up @@ -618,7 +620,9 @@ def _set_multi_discrete_action_space(self) -> None:
def step_dynamics(self, actions, use_indices=True):
if actions is not None:
if use_indices:
actions = actions.squeeze(dim=3).long().to(self.device) if actions.dim() == 4 else actions.long().to(self.device)
actions = (
torch.nan_to_num(actions, nan=0).long().to(self.device)
)
action_value_tensor = self.action_keys_tensor[actions[...,0], actions[...,1], actions[...,2]]
else:
action_value_tensor = torch.nan_to_num(actions, nan=0).float().to(self.device)
Expand Down
19 changes: 18 additions & 1 deletion pygpudrive/env/wrappers/sb3_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,24 @@ def _reset_seeds(self) -> None:
self._seeds = None

def _set_action_tensor(self, action_type, dynamics_model):
pass
if action_type == 'discrete':
self.actions_tensor = torch.zeros(
(self.num_worlds, self.max_agent_count)
).to(self.device)
elif action_type == 'continuous' and dynamics_model == 'bicycle':
self.actions_tensor = torch.zeros(
(self.num_worlds, self.max_agent_count, 2)
).to(self.device)
elif action_type == 'continuous' and dynamics_model == 'delta_local':
self.actions_tensor = torch.zeros(
(self.num_worlds, self.max_agent_count, 3)
).to(self.device)
elif action_type == 'multi_discrete' and dynamics_model == 'delta_local':
self.actions_tensor = torch.zeros(
(self.num_worlds, self.max_agent_count, 3)
).to(self.device)
else:
raise NotImplementedError(f"Set action_tensors error: ({action_type}, {dynamics_model}) pairs are not supported.")

def reset(self, world_idx=None, seed=None):
"""Reset environment and return initial observations.
Expand Down

0 comments on commit 573276e

Please sign in to comment.