-
Notifications
You must be signed in to change notification settings - Fork 718
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
7eu7d7
committed
Sep 10, 2021
0 parents
commit 5944fd0
Showing
13 changed files
with
530 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
from torch import nn | ||
from copy import deepcopy | ||
import numpy as np | ||
|
||
class DQN: | ||
def __init__(self, base_net, batch_size, n_states, n_actions, memory_capacity=2000, epsilon=0.9, gamma=0.9, rep_frep=100, lr=0.01): | ||
self.eval_net = base_net | ||
self.target_net = deepcopy(base_net) | ||
|
||
self.batch_size=batch_size | ||
self.epsilon=epsilon | ||
self.gamma=gamma | ||
self.n_states=n_states | ||
self.n_actions=n_actions | ||
self.memory_capacity=memory_capacity | ||
self.rep_frep=rep_frep | ||
|
||
self.learn_step_counter = 0 # count the steps of learning process | ||
self.memory_counter = 0 # counter used for experience replay buffer | ||
|
||
# of columns depends on 4 elements, s, a, r, s_, the total is N_STATES*2 + 2---# | ||
self.memory = np.zeros((memory_capacity, n_states * 2 + 2)) | ||
|
||
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=lr) | ||
self.loss_func = nn.MSELoss() | ||
|
||
def choose_action(self, x): | ||
# This function is used to make decision based upon epsilon greedy | ||
x = torch.FloatTensor(x).unsqueeze(0) # add 1 dimension to input state x | ||
# input only one sample | ||
if np.random.uniform() < self.epsilon: # greedy | ||
# use epsilon-greedy approach to take action | ||
actions_value = self.eval_net.forward(x) | ||
# torch.max() returns a tensor composed of max value along the axis=dim and corresponding index | ||
# what we need is the index in this function, representing the action of cart. | ||
action = torch.argmax(actions_value, dim=1).numpy() # return the argmax index | ||
else: # random | ||
action = np.random.randint(0, self.n_actions) | ||
return action | ||
|
||
def store_transition(self, s, a, r, s_): | ||
# This function acts as experience replay buffer | ||
transition = np.hstack((s, [a, r], s_)) # horizontally stack these vectors | ||
# if the capacity is full, then use index to replace the old memory with new one | ||
index = self.memory_counter % self.memory_capacity | ||
self.memory[index, :] = transition | ||
self.memory_counter += 1 | ||
|
||
def train_step(self): | ||
# Define how the whole DQN works including sampling batch of experiences, | ||
# when and how to update parameters of target network, and how to implement | ||
# backward propagation. | ||
|
||
# update the target network every fixed steps | ||
if self.learn_step_counter % self.rep_frep == 0: | ||
# Assign the parameters of eval_net to target_net | ||
self.target_net.load_state_dict(self.eval_net.state_dict()) | ||
self.learn_step_counter += 1 | ||
|
||
# Determine the index of Sampled batch from buffer | ||
sample_index = np.random.choice(self.memory_capacity, self.batch_size) # randomly select some data from buffer | ||
# extract experiences of batch size from buffer. | ||
b_memory = self.memory[sample_index, :] | ||
# extract vectors or matrices s,a,r,s_ from batch memory and convert these to torch Variables | ||
# that are convenient to back propagation | ||
b_s = torch.FloatTensor(b_memory[:, :self.n_states]) | ||
# convert long int type to tensor | ||
b_a = torch.LongTensor(b_memory[:, self.n_states:self.n_states + 1].astype(int)) | ||
b_r = torch.FloatTensor(b_memory[:, self.n_states + 1:self.n_states + 2]) | ||
b_s_ = torch.FloatTensor(b_memory[:, -self.n_states:]) | ||
|
||
# calculate the Q value of state-action pair | ||
q_eval = self.eval_net(b_s).gather(1, b_a) # (batch_size, 1) | ||
# print(q_eval) | ||
# calculate the q value of next state | ||
q_next = self.target_net(b_s_).detach() # detach from computational graph, don't back propagate | ||
# select the maximum q value | ||
# print(q_next) | ||
# q_next.max(1) returns the max value along the axis=1 and its corresponding index | ||
q_target = b_r + self.gamma * q_next.max(dim=1)[0].view(self.batch_size, 1) # (batch_size, 1) | ||
loss = self.loss_func(q_eval, q_target) | ||
|
||
self.optimizer.zero_grad() # reset the gradient to zero | ||
loss.backward() | ||
self.optimizer.step() # execute back propagation for one step |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import pyautogui | ||
import cv2 | ||
import numpy as np | ||
import time | ||
|
||
def match_img(img, target): | ||
h, w = target.shape[:2] | ||
res = cv2.matchTemplate(img, target, cv2.TM_CCOEFF) | ||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) | ||
return (*max_loc, max_loc[0] + w, max_loc[1] + h, max_loc[0] + w//2, max_loc[1] + h//2) | ||
|
||
img = cv2.imread('imgs/a.png')[94:94 + 103, 712:712 + 496, :] | ||
t_l = cv2.imread('imgs/target_left.png') | ||
t_r = cv2.imread('imgs/target_right.png') | ||
t_n = cv2.imread('imgs/target_now.png') | ||
|
||
start=time.time() | ||
img2 = pyautogui.screenshot(region=[712, 94, 496, 103]) | ||
bbox_l=match_img(img, t_l) | ||
cv2.rectangle(img, bbox_l[0:2], bbox_l[2:4], (255,0,0), 2) # 画出矩形位置 | ||
bbox_r=match_img(img, t_r) | ||
cv2.rectangle(img, bbox_r[0:2], bbox_r[2:4], (0,255,0), 2) # 画出矩形位置 | ||
bbox_n=match_img(img, t_n) | ||
cv2.rectangle(img, bbox_n[0:2], bbox_n[2:4], (0,0,255), 2) # 画出矩形位置 | ||
end=time.time() | ||
print(end-start) | ||
cv2.imshow('a',img) | ||
cv2.waitKey() | ||
|
||
''' | ||
#959,166 r=21 ring center | ||
img = pyautogui.screenshot(region=[712, 94, 496, 103]) # x,y,w,h | ||
img.save('bar.png')''' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import numpy as np | ||
import sys | ||
import cv2 | ||
from pymouse import * | ||
import pyautogui | ||
import time | ||
from copy import deepcopy | ||
|
||
class Fishing: | ||
def __init__(self, delay=0.1, max_step=100): | ||
self.mosue = PyMouse() | ||
self.t_l = cv2.imread('imgs/target_left.png') | ||
self.t_r = cv2.imread('imgs/target_right.png') | ||
self.t_n = cv2.imread('imgs/target_now.png') | ||
self.std_color=np.array([192,255,255]) | ||
self.r_ring=21 | ||
self.delay=delay | ||
self.max_step=max_step | ||
self.count=0 | ||
|
||
def reset(self): | ||
self.step_count=0 | ||
self.last_score=0 | ||
|
||
img = pyautogui.screenshot(region=[712-10, 94, 496+20, 103]) | ||
self.img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) | ||
return self.get_state() | ||
|
||
def match_img(self, img, target): | ||
h, w = target.shape[:2] | ||
res = cv2.matchTemplate(img, target, cv2.TM_CCOEFF) | ||
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) | ||
return (*max_loc, max_loc[0] + w, max_loc[1] + h, max_loc[0] + w // 2, max_loc[1] + h // 2) | ||
|
||
def drag(self): | ||
self.mosue.click(1630,995) | ||
|
||
def do_action(self, action): | ||
if action==1: | ||
self.drag() | ||
|
||
def scale(self, x): | ||
return (x-5-10)/484 | ||
|
||
def get_state(self): | ||
bbox_l = self.match_img(self.img, self.t_l) | ||
bbox_r = self.match_img(self.img, self.t_r) | ||
bbox_n = self.match_img(self.img, self.t_n) | ||
|
||
img=deepcopy(self.img) | ||
cv2.rectangle(img, bbox_l[0:2], bbox_l[2:4], (255, 0, 0), 2) # 画出矩形位置 | ||
cv2.rectangle(img, bbox_r[0:2], bbox_r[2:4], (0, 255, 0), 2) # 画出矩形位置 | ||
cv2.rectangle(img, bbox_n[0:2], bbox_n[2:4], (0, 0, 255), 2) # 画出矩形位置 | ||
cv2.imwrite(f'./img_tmp/{self.count}.jpg',img) | ||
self.count+=1 | ||
|
||
return self.scale(bbox_l[4]),self.scale(bbox_r[4]),self.scale(bbox_n[4]) | ||
|
||
def check_done(self): | ||
cx,cy=247+10,72 | ||
for x in range(2,360): | ||
px=int(cx+self.r_ring*np.sin(np.deg2rad(x))) | ||
py=int(cy+self.r_ring*np.cos(np.deg2rad(x))) | ||
if np.mean(np.abs(self.img[py,px,:]-self.std_color))>3: | ||
return x | ||
return 360 | ||
|
||
def step(self, action): | ||
self.do_action(action) | ||
|
||
time.sleep(self.delay-0.05) | ||
img = pyautogui.screenshot(region=[712-10, 94, 496+20, 103]) | ||
self.img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) | ||
self.step_count+=1 | ||
|
||
score=self.check_done() | ||
#print(score) | ||
reward=score-self.last_score | ||
self.last_score=score | ||
|
||
return self.get_state(), reward, (self.step_count>self.max_step or score>350) | ||
|
||
def render(self): | ||
pass | ||
|
||
class Fishing_sim: | ||
def __init__(self, bar_range=(0.14, 0.4), move_range=(30,60*2), resize_freq_range=(15,60*5), | ||
move_speed_range=(-0.3,0.3), tick_count=60, step_tick=15, stop_tick=60*15, | ||
drag_force=0.4, down_speed=0.015, stable_speed=-0.32, drawer=None): | ||
self.bar_range=bar_range | ||
self.move_range=move_range | ||
self.resize_freq_range=resize_freq_range | ||
self.move_speed_range=(move_speed_range[0]/tick_count, move_speed_range[1]/tick_count) | ||
self.tick_count=tick_count | ||
|
||
self.step_tick=step_tick | ||
self.stop_tick=stop_tick | ||
self.drag_force=drag_force/tick_count | ||
self.down_speed=down_speed/tick_count | ||
self.stable_speed=stable_speed/tick_count | ||
|
||
self.drawer=drawer | ||
|
||
self.reset() | ||
|
||
def reset(self): | ||
self.len = np.random.uniform(*self.bar_range) | ||
self.low = np.random.uniform(0,1-self.len) | ||
self.pointer = np.random.uniform(0,1) | ||
self.v=0 | ||
|
||
self.resize_tick = 0 | ||
self.move_tick = 0 | ||
self.move_speed = 0 | ||
|
||
self.score = 100 | ||
self.ticks = 0 | ||
|
||
return (self.low,self.low+self.len,self.pointer) | ||
|
||
def drag(self): | ||
self.v=self.drag_force | ||
|
||
def move_bar(self): | ||
if self.move_tick<=0: | ||
self.move_tick=np.random.uniform(*self.move_range) | ||
self.move_speed=np.random.uniform(*self.move_speed_range) | ||
self.low=np.clip(self.low+self.move_speed, a_min=0, a_max=1-self.len) | ||
self.move_tick-=1 | ||
|
||
def resize_bar(self): | ||
if self.resize_tick<=0: | ||
self.resize_tick=np.random.uniform(*self.resize_freq_range) | ||
self.len=min(np.random.uniform(*self.bar_range),1-self.low) | ||
self.resize_tick-=1 | ||
|
||
def tick(self): | ||
self.ticks+=1 | ||
if self.pointer>self.low and self.pointer<self.low+self.len: | ||
self.score+=1 | ||
else: | ||
self.score-=1 | ||
|
||
if self.ticks>self.stop_tick or self.score<=-10000: | ||
return True | ||
|
||
self.pointer+=self.v | ||
self.pointer=np.clip(self.pointer, a_min=0, a_max=1) | ||
self.v=max(self.v-self.down_speed, self.stable_speed) | ||
|
||
self.move_bar() | ||
self.resize_bar() | ||
return False | ||
|
||
def do_action(self, action): | ||
if action==1: | ||
self.drag() | ||
|
||
def get_state(self): | ||
return self.low,self.low+self.len,self.pointer | ||
|
||
def step(self, action): | ||
self.do_action(action) | ||
|
||
done=False | ||
score_before=self.score | ||
for x in range(self.step_tick): | ||
if self.tick(): | ||
done=True | ||
return (self.low,self.low+self.len,self.pointer), (self.score-score_before)/self.step_tick, done | ||
|
||
def render(self): | ||
if self.drawer: | ||
self.drawer.draw(self.low, self.low+self.len,self.pointer,self.ticks) | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# This is a sample Python script. | ||
|
||
# Press Shift+F10 to execute it or replace it with your code. | ||
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. | ||
|
||
|
||
def print_hi(name): | ||
# Use a breakpoint in the code line below to debug your script. | ||
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. | ||
|
||
|
||
# Press the green button in the gutter to run the script. | ||
if __name__ == '__main__': | ||
print_hi('PyCharm') | ||
|
||
# See PyCharm help at https://www.jetbrains.com/help/pycharm/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
from torch import nn | ||
|
||
class FishNet(nn.Sequential): | ||
def __init__(self, in_ch, out_ch): | ||
layers=[ | ||
nn.Linear(in_ch, 10), | ||
nn.ReLU(), | ||
nn.Linear(10, out_ch) | ||
] | ||
super(FishNet, self).__init__(*layers) | ||
self.apply(weight_init) | ||
|
||
def weight_init(m): | ||
if isinstance(m, nn.Linear): | ||
nn.init.normal_(m.weight, 0, 0.1) | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from environment import Fishing_sim | ||
import time | ||
from tkinter import * | ||
import threading | ||
import numpy as np | ||
|
||
env=Fishing_sim() | ||
|
||
# mouse callback function | ||
def drag(event): | ||
env.drag() | ||
|
||
w,h=500,100 | ||
root = Tk() | ||
root.geometry('500x100') | ||
cv = Canvas(root, bg='black', width=w, height=h) | ||
root.bind("<Button-1>", drag) | ||
|
||
low=cv.create_rectangle(int(env.low*w)-3, 0, int(env.low*w)+3, h, fill='blue') | ||
high=cv.create_rectangle(int((env.low+env.len)*w)-3, 0, int((env.low+env.len)*w)+3, h, fill='green') | ||
pointer=cv.create_rectangle(int(env.pointer*w)-3, 0, int(env.pointer*w)+3, h, fill='red') | ||
cv.pack() | ||
|
||
def update(): | ||
env.tick() | ||
cv.coords(low, int(env.low*w)-3, 0, int(env.low*w)+3, h) | ||
cv.coords(high, int((env.low+env.len)*w)-3, 0, int((env.low+env.len)*w)+3, h) | ||
cv.coords(pointer, int(env.pointer*w)-3, 0, int(env.pointer*w)+3, h) | ||
root.after(int(np.round(1000/env.tick_count)), update) | ||
root.after(int(np.round(1000/env.tick_count)), update) | ||
root.mainloop() |
Oops, something went wrong.