diff --git a/agent_zoo/neurips23_start_kit/reward_wrapper.py b/agent_zoo/neurips23_start_kit/reward_wrapper.py index 02b866c..a7aa2f3 100644 --- a/agent_zoo/neurips23_start_kit/reward_wrapper.py +++ b/agent_zoo/neurips23_start_kit/reward_wrapper.py @@ -9,13 +9,15 @@ def __init__( eval_mode=False, early_stop_agent_num=0, stat_prefix=None, + use_custom_reward=True, # Custom reward wrapper args heal_bonus_weight=0, explore_bonus_weight=0, clip_unique_event=3, ): - super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix) + super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward) self.stat_prefix = stat_prefix + self.heal_bonus_weight = heal_bonus_weight self.explore_bonus_weight = explore_bonus_weight self.clip_unique_event = clip_unique_event @@ -77,8 +79,4 @@ def reward_terminated_truncated_info(self, agent_id, reward, terminated, truncat reward += healing_bonus + explore_bonus - # NOTE: Disable death reward. This is temporarly for task-conditioning exp. - # if terminated is True: - # reward = 0 - return reward, terminated, truncated, info diff --git a/agent_zoo/takeru/reward_wrapper.py b/agent_zoo/takeru/reward_wrapper.py index cbd4e47..175fd5b 100644 --- a/agent_zoo/takeru/reward_wrapper.py +++ b/agent_zoo/takeru/reward_wrapper.py @@ -9,13 +9,15 @@ def __init__( eval_mode=False, early_stop_agent_num=0, stat_prefix=None, + use_custom_reward=True, # Custom reward wrapper args explore_bonus_weight=0, clip_unique_event=3, disable_give=True, ): - super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix) + super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward) self.stat_prefix = stat_prefix + self.explore_bonus_weight = explore_bonus_weight self.clip_unique_event = clip_unique_event self.disable_give = disable_give diff --git a/agent_zoo/yaofeng/reward_wrapper.py b/agent_zoo/yaofeng/reward_wrapper.py index 7ff2d07..93bb46b 100644 --- a/agent_zoo/yaofeng/reward_wrapper.py +++ b/agent_zoo/yaofeng/reward_wrapper.py @@ -16,6 +16,7 @@ def __init__( eval_mode=False, early_stop_agent_num=0, stat_prefix=None, + use_custom_reward=True, # Custom reward wrapper args hp_bonus_weight=0, exp_bonus_weight=0, @@ -27,7 +28,7 @@ def __init__( disable_give=True, donot_attack_dangerous_npc=True, ): - super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix) + super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward) self.stat_prefix = stat_prefix self.default_spawn_immunity = env.config.COMBAT_SPAWN_IMMUNITY diff --git a/config.yaml b/config.yaml index cb0e529..019a705 100644 --- a/config.yaml +++ b/config.yaml @@ -98,6 +98,7 @@ recurrent: reward_wrapper: eval_mode: False early_stop_agent_num: 8 + use_custom_reward: True neurips23_start_kit: reward_wrapper: diff --git a/reinforcement_learning/stat_wrapper.py b/reinforcement_learning/stat_wrapper.py index c018e44..a184f9e 100644 --- a/reinforcement_learning/stat_wrapper.py +++ b/reinforcement_learning/stat_wrapper.py @@ -13,6 +13,7 @@ def __init__( eval_mode=False, early_stop_agent_num=0, stat_prefix=None, + use_custom_reward=True, ): super().__init__(env) self.env_done = False @@ -20,6 +21,7 @@ def __init__( self.eval_mode = eval_mode self._reset_episode_stats() self._stat_prefix = stat_prefix + self.use_custom_reward = use_custom_reward def seed(self, seed): self.env.seed(seed) @@ -72,9 +74,16 @@ def step(self, action): trunc, info = self._process_stats_and_early_stop( agent_id, rewards[agent_id], terms[agent_id], truncs[agent_id], infos[agent_id] ) - rew, term, trunc, info = self.reward_terminated_truncated_info( - agent_id, rewards[agent_id], terms[agent_id], trunc, info - ) + + if self.use_custom_reward is True: + rew, term, trunc, info = self.reward_terminated_truncated_info( + agent_id, rewards[agent_id], terms[agent_id], trunc, info + ) + else: + # NOTE: Also disable death penalty, which is not from the task + rew = 0 if terms[agent_id] is True else rewards[agent_id] + term = terms[agent_id] + rewards[agent_id] = rew terms[agent_id] = term truncs[agent_id] = trunc diff --git a/train.py b/train.py index 4cbc46b..11ed288 100644 --- a/train.py +++ b/train.py @@ -217,6 +217,9 @@ def update_args(args, mode=None): # Make default or syllabus-based env_creator syllabus = None if args.syllabus is True: + # NOTE: Setting use_custom_reward to False will ignore the agent's custom reward + # and only use the env-provided reward from the curriculum tasks + args.reward_wrapper.use_custom_reward = False syllabus, env_creator = syllabus_wrapper.make_syllabus_env_creator(args, agent_module) else: args.env.curriculum_file_path = args.curriculum