Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
7eu7d7 committed Sep 10, 2021
0 parents commit 5944fd0
Show file tree
Hide file tree
Showing 13 changed files with 530 additions and 0 deletions.
86 changes: 86 additions & 0 deletions agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
from torch import nn
from copy import deepcopy
import numpy as np

class DQN:
def __init__(self, base_net, batch_size, n_states, n_actions, memory_capacity=2000, epsilon=0.9, gamma=0.9, rep_frep=100, lr=0.01):
self.eval_net = base_net
self.target_net = deepcopy(base_net)

self.batch_size=batch_size
self.epsilon=epsilon
self.gamma=gamma
self.n_states=n_states
self.n_actions=n_actions
self.memory_capacity=memory_capacity
self.rep_frep=rep_frep

self.learn_step_counter = 0 # count the steps of learning process
self.memory_counter = 0 # counter used for experience replay buffer

# of columns depends on 4 elements, s, a, r, s_, the total is N_STATES*2 + 2---#
self.memory = np.zeros((memory_capacity, n_states * 2 + 2))

self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=lr)
self.loss_func = nn.MSELoss()

def choose_action(self, x):
# This function is used to make decision based upon epsilon greedy
x = torch.FloatTensor(x).unsqueeze(0) # add 1 dimension to input state x
# input only one sample
if np.random.uniform() < self.epsilon: # greedy
# use epsilon-greedy approach to take action
actions_value = self.eval_net.forward(x)
# torch.max() returns a tensor composed of max value along the axis=dim and corresponding index
# what we need is the index in this function, representing the action of cart.
action = torch.argmax(actions_value, dim=1).numpy() # return the argmax index
else: # random
action = np.random.randint(0, self.n_actions)
return action

def store_transition(self, s, a, r, s_):
# This function acts as experience replay buffer
transition = np.hstack((s, [a, r], s_)) # horizontally stack these vectors
# if the capacity is full, then use index to replace the old memory with new one
index = self.memory_counter % self.memory_capacity
self.memory[index, :] = transition
self.memory_counter += 1

def train_step(self):
# Define how the whole DQN works including sampling batch of experiences,
# when and how to update parameters of target network, and how to implement
# backward propagation.

# update the target network every fixed steps
if self.learn_step_counter % self.rep_frep == 0:
# Assign the parameters of eval_net to target_net
self.target_net.load_state_dict(self.eval_net.state_dict())
self.learn_step_counter += 1

# Determine the index of Sampled batch from buffer
sample_index = np.random.choice(self.memory_capacity, self.batch_size) # randomly select some data from buffer
# extract experiences of batch size from buffer.
b_memory = self.memory[sample_index, :]
# extract vectors or matrices s,a,r,s_ from batch memory and convert these to torch Variables
# that are convenient to back propagation
b_s = torch.FloatTensor(b_memory[:, :self.n_states])
# convert long int type to tensor
b_a = torch.LongTensor(b_memory[:, self.n_states:self.n_states + 1].astype(int))
b_r = torch.FloatTensor(b_memory[:, self.n_states + 1:self.n_states + 2])
b_s_ = torch.FloatTensor(b_memory[:, -self.n_states:])

# calculate the Q value of state-action pair
q_eval = self.eval_net(b_s).gather(1, b_a) # (batch_size, 1)
# print(q_eval)
# calculate the q value of next state
q_next = self.target_net(b_s_).detach() # detach from computational graph, don't back propagate
# select the maximum q value
# print(q_next)
# q_next.max(1) returns the max value along the axis=1 and its corresponding index
q_target = b_r + self.gamma * q_next.max(dim=1)[0].view(self.batch_size, 1) # (batch_size, 1)
loss = self.loss_func(q_eval, q_target)

self.optimizer.zero_grad() # reset the gradient to zero
loss.backward()
self.optimizer.step() # execute back propagation for one step
34 changes: 34 additions & 0 deletions capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pyautogui
import cv2
import numpy as np
import time

def match_img(img, target):
h, w = target.shape[:2]
res = cv2.matchTemplate(img, target, cv2.TM_CCOEFF)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
return (*max_loc, max_loc[0] + w, max_loc[1] + h, max_loc[0] + w//2, max_loc[1] + h//2)

img = cv2.imread('imgs/a.png')[94:94 + 103, 712:712 + 496, :]
t_l = cv2.imread('imgs/target_left.png')
t_r = cv2.imread('imgs/target_right.png')
t_n = cv2.imread('imgs/target_now.png')

start=time.time()
img2 = pyautogui.screenshot(region=[712, 94, 496, 103])
bbox_l=match_img(img, t_l)
cv2.rectangle(img, bbox_l[0:2], bbox_l[2:4], (255,0,0), 2) # 画出矩形位置
bbox_r=match_img(img, t_r)
cv2.rectangle(img, bbox_r[0:2], bbox_r[2:4], (0,255,0), 2) # 画出矩形位置
bbox_n=match_img(img, t_n)
cv2.rectangle(img, bbox_n[0:2], bbox_n[2:4], (0,0,255), 2) # 画出矩形位置
end=time.time()
print(end-start)
cv2.imshow('a',img)
cv2.waitKey()

'''
#959,166 r=21 ring center
img = pyautogui.screenshot(region=[712, 94, 496, 103]) # x,y,w,h
img.save('bar.png')'''

175 changes: 175 additions & 0 deletions environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import numpy as np
import sys
import cv2
from pymouse import *
import pyautogui
import time
from copy import deepcopy

class Fishing:
def __init__(self, delay=0.1, max_step=100):
self.mosue = PyMouse()
self.t_l = cv2.imread('imgs/target_left.png')
self.t_r = cv2.imread('imgs/target_right.png')
self.t_n = cv2.imread('imgs/target_now.png')
self.std_color=np.array([192,255,255])
self.r_ring=21
self.delay=delay
self.max_step=max_step
self.count=0

def reset(self):
self.step_count=0
self.last_score=0

img = pyautogui.screenshot(region=[712-10, 94, 496+20, 103])
self.img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
return self.get_state()

def match_img(self, img, target):
h, w = target.shape[:2]
res = cv2.matchTemplate(img, target, cv2.TM_CCOEFF)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
return (*max_loc, max_loc[0] + w, max_loc[1] + h, max_loc[0] + w // 2, max_loc[1] + h // 2)

def drag(self):
self.mosue.click(1630,995)

def do_action(self, action):
if action==1:
self.drag()

def scale(self, x):
return (x-5-10)/484

def get_state(self):
bbox_l = self.match_img(self.img, self.t_l)
bbox_r = self.match_img(self.img, self.t_r)
bbox_n = self.match_img(self.img, self.t_n)

img=deepcopy(self.img)
cv2.rectangle(img, bbox_l[0:2], bbox_l[2:4], (255, 0, 0), 2) # 画出矩形位置
cv2.rectangle(img, bbox_r[0:2], bbox_r[2:4], (0, 255, 0), 2) # 画出矩形位置
cv2.rectangle(img, bbox_n[0:2], bbox_n[2:4], (0, 0, 255), 2) # 画出矩形位置
cv2.imwrite(f'./img_tmp/{self.count}.jpg',img)
self.count+=1

return self.scale(bbox_l[4]),self.scale(bbox_r[4]),self.scale(bbox_n[4])

def check_done(self):
cx,cy=247+10,72
for x in range(2,360):
px=int(cx+self.r_ring*np.sin(np.deg2rad(x)))
py=int(cy+self.r_ring*np.cos(np.deg2rad(x)))
if np.mean(np.abs(self.img[py,px,:]-self.std_color))>3:
return x
return 360

def step(self, action):
self.do_action(action)

time.sleep(self.delay-0.05)
img = pyautogui.screenshot(region=[712-10, 94, 496+20, 103])
self.img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
self.step_count+=1

score=self.check_done()
#print(score)
reward=score-self.last_score
self.last_score=score

return self.get_state(), reward, (self.step_count>self.max_step or score>350)

def render(self):
pass

class Fishing_sim:
def __init__(self, bar_range=(0.14, 0.4), move_range=(30,60*2), resize_freq_range=(15,60*5),
move_speed_range=(-0.3,0.3), tick_count=60, step_tick=15, stop_tick=60*15,
drag_force=0.4, down_speed=0.015, stable_speed=-0.32, drawer=None):
self.bar_range=bar_range
self.move_range=move_range
self.resize_freq_range=resize_freq_range
self.move_speed_range=(move_speed_range[0]/tick_count, move_speed_range[1]/tick_count)
self.tick_count=tick_count

self.step_tick=step_tick
self.stop_tick=stop_tick
self.drag_force=drag_force/tick_count
self.down_speed=down_speed/tick_count
self.stable_speed=stable_speed/tick_count

self.drawer=drawer

self.reset()

def reset(self):
self.len = np.random.uniform(*self.bar_range)
self.low = np.random.uniform(0,1-self.len)
self.pointer = np.random.uniform(0,1)
self.v=0

self.resize_tick = 0
self.move_tick = 0
self.move_speed = 0

self.score = 100
self.ticks = 0

return (self.low,self.low+self.len,self.pointer)

def drag(self):
self.v=self.drag_force

def move_bar(self):
if self.move_tick<=0:
self.move_tick=np.random.uniform(*self.move_range)
self.move_speed=np.random.uniform(*self.move_speed_range)
self.low=np.clip(self.low+self.move_speed, a_min=0, a_max=1-self.len)
self.move_tick-=1

def resize_bar(self):
if self.resize_tick<=0:
self.resize_tick=np.random.uniform(*self.resize_freq_range)
self.len=min(np.random.uniform(*self.bar_range),1-self.low)
self.resize_tick-=1

def tick(self):
self.ticks+=1
if self.pointer>self.low and self.pointer<self.low+self.len:
self.score+=1
else:
self.score-=1

if self.ticks>self.stop_tick or self.score<=-10000:
return True

self.pointer+=self.v
self.pointer=np.clip(self.pointer, a_min=0, a_max=1)
self.v=max(self.v-self.down_speed, self.stable_speed)

self.move_bar()
self.resize_bar()
return False

def do_action(self, action):
if action==1:
self.drag()

def get_state(self):
return self.low,self.low+self.len,self.pointer

def step(self, action):
self.do_action(action)

done=False
score_before=self.score
for x in range(self.step_tick):
if self.tick():
done=True
return (self.low,self.low+self.len,self.pointer), (self.score-score_before)/self.step_tick, done

def render(self):
if self.drawer:
self.drawer.draw(self.low, self.low+self.len,self.pointer,self.ticks)

Binary file added imgs/target_left.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/target_now.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/target_right.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This is a sample Python script.

# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.


def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.


# Press the green button in the gutter to run the script.
if __name__ == '__main__':
print_hi('PyCharm')

# See PyCharm help at https://www.jetbrains.com/help/pycharm/
18 changes: 18 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch import nn

class FishNet(nn.Sequential):
def __init__(self, in_ch, out_ch):
layers=[
nn.Linear(in_ch, 10),
nn.ReLU(),
nn.Linear(10, out_ch)
]
super(FishNet, self).__init__(*layers)
self.apply(weight_init)

def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
31 changes: 31 additions & 0 deletions play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from environment import Fishing_sim
import time
from tkinter import *
import threading
import numpy as np

env=Fishing_sim()

# mouse callback function
def drag(event):
env.drag()

w,h=500,100
root = Tk()
root.geometry('500x100')
cv = Canvas(root, bg='black', width=w, height=h)
root.bind("<Button-1>", drag)

low=cv.create_rectangle(int(env.low*w)-3, 0, int(env.low*w)+3, h, fill='blue')
high=cv.create_rectangle(int((env.low+env.len)*w)-3, 0, int((env.low+env.len)*w)+3, h, fill='green')
pointer=cv.create_rectangle(int(env.pointer*w)-3, 0, int(env.pointer*w)+3, h, fill='red')
cv.pack()

def update():
env.tick()
cv.coords(low, int(env.low*w)-3, 0, int(env.low*w)+3, h)
cv.coords(high, int((env.low+env.len)*w)-3, 0, int((env.low+env.len)*w)+3, h)
cv.coords(pointer, int(env.pointer*w)-3, 0, int(env.pointer*w)+3, h)
root.after(int(np.round(1000/env.tick_count)), update)
root.after(int(np.round(1000/env.tick_count)), update)
root.mainloop()
Loading

0 comments on commit 5944fd0

Please sign in to comment.