-
Notifications
You must be signed in to change notification settings - Fork 3
/
sb3_train.py
69 lines (55 loc) · 2.22 KB
/
sb3_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from argparse import ArgumentParser
from uuid import uuid4
import os
from stable_baselines3 import A2C, PPO
from stable_baselines3.common import logger
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, VecFrameStack
import gymnasium as gym
import craftium
def parse_args():
parser = ArgumentParser()
# fmt: off
parser.add_argument("--run-name", type=str, default=None,
help="Unique name for the run. Defaults to a random uuid.")
parser.add_argument("--runs-dir", type=str, default="./run-logs/",
help="Name of the directory where run's data is stored. Defaults to './run-logs/'")
parser.add_argument("--env-id", type=str, default="Craftium/ChopTree-v0",
help="Name (registered) of the environment.")
parser.add_argument("--total-timesteps", type=int, default=10_000_000,
help="Number of timesteps to train for.")
parser.add_argument("--num-envs", type=int, default=4,
help="Number of environments to use.")
parser.add_argument("--method", type=str, default="a2c", choices=["ppo", "a2c"],
help="RL method to use to optimize the agent.")
# fmt: on
return parser.parse_args()
def make_env(env_id):
def _init():
# set up the environment
craftium_kwargs = dict(
frameskip=3,
rgb_observations=False,
gray_scale_keepdim=True,
)
env = gym.make(env_id, **craftium_kwargs)
env.reset()
return env
return _init
if __name__ == "__main__":
args = parse_args()
if args.run_name is None:
run_name = f"{args.env_id.replace('/', '-')}-{args.method}--{str(uuid4())}"
else:
run_name = args.run_name
# configure SB3 logger
log_path = os.path.join(args.runs_dir, run_name)
print(f"** Storing run's data in {log_path}")
new_logger = logger.configure(log_path, ["stdout", "csv"])
envs = DummyVecEnv([make_env(args.env_id) for _ in range(args.num_envs)])
envs = VecFrameStack(envs, 3)
envs = VecMonitor(envs)
method_class = PPO if args.method == "ppo" else A2C
model = method_class("CnnPolicy", envs, verbose=1)
model.set_logger(new_logger)
model.learn(total_timesteps=args.total_timesteps)
envs.close()