Skip to content

Commit

Permalink
added use custom rew flag
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Apr 18, 2024
1 parent 89859d5 commit 4ed5af1
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 10 deletions.
8 changes: 3 additions & 5 deletions agent_zoo/neurips23_start_kit/reward_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ def __init__(
eval_mode=False,
early_stop_agent_num=0,
stat_prefix=None,
use_custom_reward=True,
# Custom reward wrapper args
heal_bonus_weight=0,
explore_bonus_weight=0,
clip_unique_event=3,
):
super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix)
super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward)
self.stat_prefix = stat_prefix

self.heal_bonus_weight = heal_bonus_weight
self.explore_bonus_weight = explore_bonus_weight
self.clip_unique_event = clip_unique_event
Expand Down Expand Up @@ -77,8 +79,4 @@ def reward_terminated_truncated_info(self, agent_id, reward, terminated, truncat

reward += healing_bonus + explore_bonus

# NOTE: Disable death reward. This is temporarly for task-conditioning exp.
# if terminated is True:
# reward = 0

return reward, terminated, truncated, info
4 changes: 3 additions & 1 deletion agent_zoo/takeru/reward_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ def __init__(
eval_mode=False,
early_stop_agent_num=0,
stat_prefix=None,
use_custom_reward=True,
# Custom reward wrapper args
explore_bonus_weight=0,
clip_unique_event=3,
disable_give=True,
):
super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix)
super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward)
self.stat_prefix = stat_prefix

self.explore_bonus_weight = explore_bonus_weight
self.clip_unique_event = clip_unique_event
self.disable_give = disable_give
Expand Down
3 changes: 2 additions & 1 deletion agent_zoo/yaofeng/reward_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
eval_mode=False,
early_stop_agent_num=0,
stat_prefix=None,
use_custom_reward=True,
# Custom reward wrapper args
hp_bonus_weight=0,
exp_bonus_weight=0,
Expand All @@ -27,7 +28,7 @@ def __init__(
disable_give=True,
donot_attack_dangerous_npc=True,
):
super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix)
super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward)
self.stat_prefix = stat_prefix
self.default_spawn_immunity = env.config.COMBAT_SPAWN_IMMUNITY

Expand Down
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ recurrent:
reward_wrapper:
eval_mode: False
early_stop_agent_num: 8
use_custom_reward: True

neurips23_start_kit:
reward_wrapper:
Expand Down
15 changes: 12 additions & 3 deletions reinforcement_learning/stat_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ def __init__(
eval_mode=False,
early_stop_agent_num=0,
stat_prefix=None,
use_custom_reward=True,
):
super().__init__(env)
self.env_done = False
self.early_stop_agent_num = early_stop_agent_num
self.eval_mode = eval_mode
self._reset_episode_stats()
self._stat_prefix = stat_prefix
self.use_custom_reward = use_custom_reward

def seed(self, seed):
self.env.seed(seed)
Expand Down Expand Up @@ -72,9 +74,16 @@ def step(self, action):
trunc, info = self._process_stats_and_early_stop(
agent_id, rewards[agent_id], terms[agent_id], truncs[agent_id], infos[agent_id]
)
rew, term, trunc, info = self.reward_terminated_truncated_info(
agent_id, rewards[agent_id], terms[agent_id], trunc, info
)

if self.use_custom_reward is True:
rew, term, trunc, info = self.reward_terminated_truncated_info(
agent_id, rewards[agent_id], terms[agent_id], trunc, info
)
else:
# NOTE: Also disable death penalty, which is not from the task
rew = 0 if terms[agent_id] is True else rewards[agent_id]
term = terms[agent_id]

rewards[agent_id] = rew
terms[agent_id] = term
truncs[agent_id] = trunc
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def update_args(args, mode=None):
# Make default or syllabus-based env_creator
syllabus = None
if args.syllabus is True:
# NOTE: Setting use_custom_reward to False will ignore the agent's custom reward
# and only use the env-provided reward from the curriculum tasks
args.reward_wrapper.use_custom_reward = False
syllabus, env_creator = syllabus_wrapper.make_syllabus_env_creator(args, agent_module)
else:
args.env.curriculum_file_path = args.curriculum
Expand Down

0 comments on commit 4ed5af1

Please sign in to comment.