From 022b50f64284c937f4304961e62843a46691e127 Mon Sep 17 00:00:00 2001 From: 7eu7d7 Date: Fri, 10 Sep 2021 15:45:32 +0800 Subject: [PATCH] change activate --- environment.py | 6 ++-- models.py | 6 ++-- render.py | 3 +- test.py | 24 +++++++-------- test_ys.py => test_sim.py | 24 ++++++++------- train_sim.py | 62 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 31 deletions(-) rename test_ys.py => test_sim.py (51%) create mode 100644 train_sim.py diff --git a/environment.py b/environment.py index 3c5f409..65e2f9a 100644 --- a/environment.py +++ b/environment.py @@ -84,7 +84,7 @@ def render(self): pass class Fishing_sim: - def __init__(self, bar_range=(0.14, 0.4), move_range=(30,60*2), resize_freq_range=(15,60*5), + def __init__(self, bar_range=(0.18, 0.4), move_range=(30,60*2), resize_freq_range=(15,60*5), move_speed_range=(-0.3,0.3), tick_count=60, step_tick=15, stop_tick=60*15, drag_force=0.4, down_speed=0.015, stable_speed=-0.32, drawer=None): self.bar_range=bar_range @@ -141,7 +141,7 @@ def tick(self): else: self.score-=1 - if self.ticks>self.stop_tick or self.score<=-10000: + if self.ticks>self.stop_tick or self.score<=-100000: return True self.pointer+=self.v @@ -167,7 +167,7 @@ def step(self, action): for x in range(self.step_tick): if self.tick(): done=True - return (self.low,self.low+self.len,self.pointer), (self.score-score_before)/self.step_tick, done + return self.get_state(), (self.score-score_before)/self.step_tick, done def render(self): if self.drawer: diff --git a/models.py b/models.py index 2798713..b4143c3 100644 --- a/models.py +++ b/models.py @@ -4,9 +4,9 @@ class FishNet(nn.Sequential): def __init__(self, in_ch, out_ch): layers=[ - nn.Linear(in_ch, 10), - nn.ReLU(), - nn.Linear(10, out_ch) + nn.Linear(in_ch, 16), + nn.LeakyReLU(), + nn.Linear(16, out_ch) ] super(FishNet, self).__init__(*layers) self.apply(weight_init) diff --git a/render.py b/render.py index 675bfc8..895b539 100644 --- a/render.py +++ b/render.py @@ -20,7 +20,8 @@ def draw(self, low, high, pointer, ticks): plt.imshow(img) plt.title(f'tick:{ticks}') #plt.draw() - self.call_back() + if self.call_back: + self.call_back() plt.pause(0.0001) plt.clf() diff --git a/test.py b/test.py index 60841f9..8c6a845 100644 --- a/test.py +++ b/test.py @@ -4,7 +4,6 @@ import torch import argparse import os -from matplotlib.animation import FFMpegWriter parser = argparse.ArgumentParser(description='Test Genshin finsing with DQN') parser.add_argument('--n_states', default=3, type=int) @@ -14,23 +13,20 @@ args = parser.parse_args() if __name__ == '__main__': - writer = FFMpegWriter(fps=60) - render = PltRender(call_back=lambda: writer.grab_frame()) net = FishNet(in_ch=args.n_states, out_ch=args.n_actions) - env = Fishing_sim(step_tick=args.step_tick, drawer=render, stop_tick=10000) + env = Fishing(delay=0.1) net.load_state_dict(torch.load(args.model_dir)) net.eval() - state = env.reset() - with writer.saving(render.fig, 'out.mp4', 100): - for i in range(2000): - env.render() + state = env.step(0)[0] + for i in range(2000): + env.render() - state = torch.FloatTensor(state).unsqueeze(0) - action = net(state) - action = torch.argmax(action, dim=1).numpy() - state, reward, done = env.step(action) - if done: - break \ No newline at end of file + state = torch.FloatTensor(state).unsqueeze(0) + action = net(state) + action = torch.argmax(action, dim=1).numpy() + state, reward, done = env.step(action) + if done: + break \ No newline at end of file diff --git a/test_ys.py b/test_sim.py similarity index 51% rename from test_ys.py rename to test_sim.py index 8c6a845..60841f9 100644 --- a/test_ys.py +++ b/test_sim.py @@ -4,6 +4,7 @@ import torch import argparse import os +from matplotlib.animation import FFMpegWriter parser = argparse.ArgumentParser(description='Test Genshin finsing with DQN') parser.add_argument('--n_states', default=3, type=int) @@ -13,20 +14,23 @@ args = parser.parse_args() if __name__ == '__main__': + writer = FFMpegWriter(fps=60) + render = PltRender(call_back=lambda: writer.grab_frame()) net = FishNet(in_ch=args.n_states, out_ch=args.n_actions) - env = Fishing(delay=0.1) + env = Fishing_sim(step_tick=args.step_tick, drawer=render, stop_tick=10000) net.load_state_dict(torch.load(args.model_dir)) net.eval() - state = env.step(0)[0] - for i in range(2000): - env.render() + state = env.reset() + with writer.saving(render.fig, 'out.mp4', 100): + for i in range(2000): + env.render() - state = torch.FloatTensor(state).unsqueeze(0) - action = net(state) - action = torch.argmax(action, dim=1).numpy() - state, reward, done = env.step(action) - if done: - break \ No newline at end of file + state = torch.FloatTensor(state).unsqueeze(0) + action = net(state) + action = torch.argmax(action, dim=1).numpy() + state, reward, done = env.step(action) + if done: + break \ No newline at end of file diff --git a/train_sim.py b/train_sim.py new file mode 100644 index 0000000..0eb00ec --- /dev/null +++ b/train_sim.py @@ -0,0 +1,62 @@ +from agent import DQN +from models import FishNet +from environment import * +import torch +import argparse +import os +from render import * + +parser = argparse.ArgumentParser(description='Train Genshin finsing simulation with DQN') +parser.add_argument('--batch_size', default=32, type=int) +parser.add_argument('--n_states', default=3, type=int) +parser.add_argument('--n_actions', default=2, type=int) +parser.add_argument('--step_tick', default=12, type=int) +parser.add_argument('--n_episode', default=400, type=int) +parser.add_argument('--save_dir', default='./output', type=str) +parser.add_argument('--resume', default=None, type=str) +args = parser.parse_args() + +if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + +net = FishNet(in_ch=args.n_states, out_ch=args.n_actions) +if args.resume: + net.load_state_dict(torch.load(args.resume)) + +agent = DQN(net, args.batch_size, args.n_states, args.n_actions, memory_capacity=2000) +env = Fishing_sim(step_tick=args.step_tick, drawer=PltRender()) + +if __name__ == '__main__': + # Start training + print("\nCollecting experience...") + net.train() + for i_episode in range(args.n_episode): + #keyboard.wait('r') + # play 400 episodes of cartpole game + s = env.reset() + ep_r = 0 + while True: + if i_episode>200 and i_episode%20==0: + env.render() + # take action based on the current state + a = agent.choose_action(s) + # obtain the reward and next state and some other information + s_, r, done = env.step(a) + + # store the transitions of states + agent.store_transition(s, a, r, s_) + + ep_r += r + # if the experience repaly buffer is filled, DQN begins to learn or update + # its parameters. + if agent.memory_counter > agent.memory_capacity: + agent.train_step() + if done: + print('Ep: ', i_episode, ' |', 'Ep_r: ', round(ep_r, 2)) + + if done: + # if game is over, then skip the while loop. + break + # use next state to update the current state. + s = s_ + torch.save(net.state_dict(), os.path.join(args.save_dir, f'fish_ys_net_{i_episode}.pth')) \ No newline at end of file