-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_model.py
81 lines (66 loc) · 2.54 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from multiprocessing import Pool
from tqdm import tqdm
import argparse
import time
import os
from environment.tetris import Tetris
from agents.DQNAgent import DQNAgent
from agents.DumbAgent import DumbAgent
def load_agent(model, model_path):
tqdm.write(f"Loading model:{model} from {model_path}")
match model:
case "dqn":
return DQNAgent(5, model_path, epsilon=0)
case "genetic":
return DumbAgent(5, model_path)
parser = argparse.ArgumentParser(
prog="Test model", description="Test and evaluate model"
)
parser.add_argument("--model", choices=["dqn", "genetic"], default="dqn")
parser.add_argument("--path", default="models/dqn_10_20.pt")
parser.add_argument("--render", action=argparse.BooleanOptionalAction)
parser.add_argument("--plot", action=argparse.BooleanOptionalAction)
parser.add_argument("--cols", nargs="?", default=10)
parser.add_argument("--rows", nargs="?", default=20)
parser.add_argument("--max_steps", nargs="?", default=2000)
parser.add_argument("--samples", nargs="?", default=5)
parser.add_argument("--out", nargs="?", default="results")
parser.add_argument("--level_multi", action=argparse.BooleanOptionalAction)
def worker_function(args):
env = Tetris(args.cols, args.rows, args.level_multi)
agent = load_agent(args.model, args.path)
env.reset()
steps = 0
done = False
while not done:
next_states = env.get_possible_states()
best_action = agent.act(next_states)
done, score, _ = env.step(*best_action)
steps += 1
if steps > args.max_steps:
break
return score, list(env.line_clear_types.values())
if __name__ == "__main__":
args = parser.parse_args()
now = str(time.time()).split(".")[0]
path = f"{args.out}/{args.model}_{now}"
os.makedirs(path, exist_ok=True)
scores = open(f"{path}/scores.txt", "w", newline="")
line_history = open(f"{path}/line_history.txt", "w", newline="")
with Pool() as pool:
results = list(
tqdm(
pool.imap_unordered(
worker_function,
[args for _ in range(args.samples)],
chunksize=1,
),
total=args.samples,
)
)
with open(f"{path}/scores.txt", "w", newline="") as scores, open(
f"{path}/line_history.txt", "w", newline=""
) as line_history:
for score, line_history_data in results:
scores.write(f"{score}\n")
line_history.write(f'{" ".join(map(str, list(line_history_data)))}\n')