From 4daf1d39b7b1825966bda62b8d74ef89672866c5 Mon Sep 17 00:00:00 2001 From: Mikel Date: Thu, 1 Aug 2024 10:54:47 +0200 Subject: [PATCH] ppo_train.py: Option to periodically save agent's model --- ppo_train.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/ppo_train.py b/ppo_train.py index 1b2e62053..8c6214687 100644 --- a/ppo_train.py +++ b/ppo_train.py @@ -42,6 +42,11 @@ class Args: mt_wd: str = "./" """Directory where the Minetest working directories will be created (defaults to the current one)""" frameskip: int = 4 + """Number of frames to skip between observations""" + save_agent: bool = False + """Save the agent's model (disabled by default)""" + save_num: int = 5 + """Number of times to save the agent's model""" # Algorithm specific arguments env_id: str = "Craftium/ChopTree-v0" @@ -102,14 +107,6 @@ def thunk(): else: env = gym.make(env_id, **craftium_kwargs) env = gym.wrappers.RecordEpisodeStatistics(env) - # env = NoopResetEnv(env, noop_max=30) - # env = MaxAndSkipEnv(env, skip=4) - # env = EpisodicLifeEnv(env) - # if "FIRE" in env.unwrapped.get_action_meanings(): - # env = FireResetEnv(env) - # env = ClipRewardEnv(env) - # env = gym.wrappers.ResizeObservation(env, (84, 84)) - # env = gym.wrappers.GrayScaleObservation(env) env = gym.wrappers.FrameStack(env, 4) return env @@ -175,6 +172,10 @@ def get_action_and_value(self, x, action=None): "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) + if args.save_agent: + agent_path = f"agents/{run_name}" + os.makedirs(agent_path, exist_ok=True) + # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) @@ -335,5 +336,11 @@ def get_action_and_value(self, x, action=None): print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if args.save_agent and \ + (iteration % (args.num_iterations//args.save_num) == 0 \ + or iteration == args.num_iterations): + print("Saving agent...") + torch.save(agent, f"{agent_path}/agent_step_{global_step}.pt") + envs.close() writer.close()