Skip to content

Commit

Permalink
change activate
Browse files Browse the repository at this point in the history
  • Loading branch information
7eu7d7 committed Sep 10, 2021
1 parent 5944fd0 commit 022b50f
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 31 deletions.
6 changes: 3 additions & 3 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def render(self):
pass

class Fishing_sim:
def __init__(self, bar_range=(0.14, 0.4), move_range=(30,60*2), resize_freq_range=(15,60*5),
def __init__(self, bar_range=(0.18, 0.4), move_range=(30,60*2), resize_freq_range=(15,60*5),
move_speed_range=(-0.3,0.3), tick_count=60, step_tick=15, stop_tick=60*15,
drag_force=0.4, down_speed=0.015, stable_speed=-0.32, drawer=None):
self.bar_range=bar_range
Expand Down Expand Up @@ -141,7 +141,7 @@ def tick(self):
else:
self.score-=1

if self.ticks>self.stop_tick or self.score<=-10000:
if self.ticks>self.stop_tick or self.score<=-100000:
return True

self.pointer+=self.v
Expand All @@ -167,7 +167,7 @@ def step(self, action):
for x in range(self.step_tick):
if self.tick():
done=True
return (self.low,self.low+self.len,self.pointer), (self.score-score_before)/self.step_tick, done
return self.get_state(), (self.score-score_before)/self.step_tick, done

def render(self):
if self.drawer:
Expand Down
6 changes: 3 additions & 3 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
class FishNet(nn.Sequential):
def __init__(self, in_ch, out_ch):
layers=[
nn.Linear(in_ch, 10),
nn.ReLU(),
nn.Linear(10, out_ch)
nn.Linear(in_ch, 16),
nn.LeakyReLU(),
nn.Linear(16, out_ch)
]
super(FishNet, self).__init__(*layers)
self.apply(weight_init)
Expand Down
3 changes: 2 additions & 1 deletion render.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def draw(self, low, high, pointer, ticks):
plt.imshow(img)
plt.title(f'tick:{ticks}')
#plt.draw()
self.call_back()
if self.call_back:
self.call_back()
plt.pause(0.0001)
plt.clf()

Expand Down
24 changes: 10 additions & 14 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import argparse
import os
from matplotlib.animation import FFMpegWriter

parser = argparse.ArgumentParser(description='Test Genshin finsing with DQN')
parser.add_argument('--n_states', default=3, type=int)
Expand All @@ -14,23 +13,20 @@
args = parser.parse_args()

if __name__ == '__main__':
writer = FFMpegWriter(fps=60)
render = PltRender(call_back=lambda: writer.grab_frame())

net = FishNet(in_ch=args.n_states, out_ch=args.n_actions)
env = Fishing_sim(step_tick=args.step_tick, drawer=render, stop_tick=10000)
env = Fishing(delay=0.1)

net.load_state_dict(torch.load(args.model_dir))

net.eval()
state = env.reset()
with writer.saving(render.fig, 'out.mp4', 100):
for i in range(2000):
env.render()
state = env.step(0)[0]
for i in range(2000):
env.render()

state = torch.FloatTensor(state).unsqueeze(0)
action = net(state)
action = torch.argmax(action, dim=1).numpy()
state, reward, done = env.step(action)
if done:
break
state = torch.FloatTensor(state).unsqueeze(0)
action = net(state)
action = torch.argmax(action, dim=1).numpy()
state, reward, done = env.step(action)
if done:
break
24 changes: 14 additions & 10 deletions test_ys.py → test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import argparse
import os
from matplotlib.animation import FFMpegWriter

parser = argparse.ArgumentParser(description='Test Genshin finsing with DQN')
parser.add_argument('--n_states', default=3, type=int)
Expand All @@ -13,20 +14,23 @@
args = parser.parse_args()

if __name__ == '__main__':
writer = FFMpegWriter(fps=60)
render = PltRender(call_back=lambda: writer.grab_frame())

net = FishNet(in_ch=args.n_states, out_ch=args.n_actions)
env = Fishing(delay=0.1)
env = Fishing_sim(step_tick=args.step_tick, drawer=render, stop_tick=10000)

net.load_state_dict(torch.load(args.model_dir))

net.eval()
state = env.step(0)[0]
for i in range(2000):
env.render()
state = env.reset()
with writer.saving(render.fig, 'out.mp4', 100):
for i in range(2000):
env.render()

state = torch.FloatTensor(state).unsqueeze(0)
action = net(state)
action = torch.argmax(action, dim=1).numpy()
state, reward, done = env.step(action)
if done:
break
state = torch.FloatTensor(state).unsqueeze(0)
action = net(state)
action = torch.argmax(action, dim=1).numpy()
state, reward, done = env.step(action)
if done:
break
62 changes: 62 additions & 0 deletions train_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from agent import DQN
from models import FishNet
from environment import *
import torch
import argparse
import os
from render import *

parser = argparse.ArgumentParser(description='Train Genshin finsing simulation with DQN')
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--n_states', default=3, type=int)
parser.add_argument('--n_actions', default=2, type=int)
parser.add_argument('--step_tick', default=12, type=int)
parser.add_argument('--n_episode', default=400, type=int)
parser.add_argument('--save_dir', default='./output', type=str)
parser.add_argument('--resume', default=None, type=str)
args = parser.parse_args()

if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)

net = FishNet(in_ch=args.n_states, out_ch=args.n_actions)
if args.resume:
net.load_state_dict(torch.load(args.resume))

agent = DQN(net, args.batch_size, args.n_states, args.n_actions, memory_capacity=2000)
env = Fishing_sim(step_tick=args.step_tick, drawer=PltRender())

if __name__ == '__main__':
# Start training
print("\nCollecting experience...")
net.train()
for i_episode in range(args.n_episode):
#keyboard.wait('r')
# play 400 episodes of cartpole game
s = env.reset()
ep_r = 0
while True:
if i_episode>200 and i_episode%20==0:
env.render()
# take action based on the current state
a = agent.choose_action(s)
# obtain the reward and next state and some other information
s_, r, done = env.step(a)

# store the transitions of states
agent.store_transition(s, a, r, s_)

ep_r += r
# if the experience repaly buffer is filled, DQN begins to learn or update
# its parameters.
if agent.memory_counter > agent.memory_capacity:
agent.train_step()
if done:
print('Ep: ', i_episode, ' |', 'Ep_r: ', round(ep_r, 2))

if done:
# if game is over, then skip the while loop.
break
# use next state to update the current state.
s = s_
torch.save(net.state_dict(), os.path.join(args.save_dir, f'fish_ys_net_{i_episode}.pth'))

0 comments on commit 022b50f

Please sign in to comment.