Skip to content

Commit

Permalink
ppo_train.py: Option to periodically save agent's model
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikel committed Aug 1, 2024
1 parent 7f22da4 commit 4daf1d3
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions ppo_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class Args:
mt_wd: str = "./"
"""Directory where the Minetest working directories will be created (defaults to the current one)"""
frameskip: int = 4
"""Number of frames to skip between observations"""
save_agent: bool = False
"""Save the agent's model (disabled by default)"""
save_num: int = 5
"""Number of times to save the agent's model"""

# Algorithm specific arguments
env_id: str = "Craftium/ChopTree-v0"
Expand Down Expand Up @@ -102,14 +107,6 @@ def thunk():
else:
env = gym.make(env_id, **craftium_kwargs)
env = gym.wrappers.RecordEpisodeStatistics(env)
# env = NoopResetEnv(env, noop_max=30)
# env = MaxAndSkipEnv(env, skip=4)
# env = EpisodicLifeEnv(env)
# if "FIRE" in env.unwrapped.get_action_meanings():
# env = FireResetEnv(env)
# env = ClipRewardEnv(env)
# env = gym.wrappers.ResizeObservation(env, (84, 84))
# env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
return env

Expand Down Expand Up @@ -175,6 +172,10 @@ def get_action_and_value(self, x, action=None):
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

if args.save_agent:
agent_path = f"agents/{run_name}"
os.makedirs(agent_path, exist_ok=True)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
Expand Down Expand Up @@ -335,5 +336,11 @@ def get_action_and_value(self, x, action=None):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

if args.save_agent and \
(iteration % (args.num_iterations//args.save_num) == 0 \
or iteration == args.num_iterations):
print("Saving agent...")
torch.save(agent, f"{agent_path}/agent_step_{global_step}.pt")

envs.close()
writer.close()

0 comments on commit 4daf1d3

Please sign in to comment.