Skip to content

Commit

Permalink
yolox
Browse files Browse the repository at this point in the history
  • Loading branch information
IrisRainbowNeko committed Sep 20, 2021
1 parent 22d70ca commit 5338bd6
Show file tree
Hide file tree
Showing 82 changed files with 8,092 additions and 84 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
## Introduction
原神自动钓鱼AI由[YOLOX](https://github.com/Megvii-BaseDetection/YOLOX), DQN两部分模型组成。使用迁移学习,半监督学习进行训练。
模型也包含一些使用opencv等传统数字图像处理方法实现的不可学习部分。

其中YOLOX用于鱼的定位和类型的识别以及鱼竿落点的定位。DQN用于自适应控制钓鱼过程的点击,让力度落在最佳区域内。

## 准备
安装yolox

```shell
python setup.py develop
```


## YOLOX训练工作流程
YOLOX部分因为打标签太累所以用半监督学习。标注少量样本后训练模型生成其余样本伪标签再人工修正,不断迭代提高精度。
样本量较少所以使用迁移学习,在COCO预训练的模型上进行fine-tuning.

训练代码:
```shell
python yolox_tools/train.py -f yolox/exp/yolox_tiny_fish.py -d 1 -b 8 --fp16 -o -c weights/yolox_tiny.pth
```

## DQN训练工作流程
控制力度使用强化学习模型DQN进行训练。两次进度的差值作为reward为模型提供学习方向。模型与环境间交互式学习。

直接在原神内训练耗时较长,太累了。首先制作一个仿真环境,大概模拟钓鱼力度控制操作。在仿真环境内预训练一个模型。
随后将这一模型迁移至原神内,实现域间迁移。

仿真环境预训练代码:
```shell
python train_sim.py
```
原神游戏内训练:
```shell
python train.py
```

## 运行钓鱼AI
```shell
python fishing.py image -f yolox/exp/yolox_tiny_fish.py -c weights/best_tiny3.pth --conf 0.25 --nms 0.45 --tsize 640 --device gpu
```
14 changes: 12 additions & 2 deletions capture.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import time

import pyautogui
import keyboard
import winsound

i=0
'''i=0
while True:
keyboard.wait('t')
img = pyautogui.screenshot()
img.save(f'img_tmp/{i}.png')
i+=1
i+=1'''

print('ok')
keyboard.wait('t')
for i in range(56,56+20):
img = pyautogui.screenshot()
img.save(f'fish_dataset/{i}.png')
time.sleep(0.5)
winsound.Beep(500, 500)
Empty file added fisher/__init__.py
Empty file.
File renamed without changes.
126 changes: 114 additions & 12 deletions environment.py → fisher/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,117 @@
import cv2
from pymouse import *
import pyautogui
import win32api, win32con
import time
from copy import deepcopy
from collections import Counter
import traceback

class FishFind:
def __init__(self, predictor, show_det=True):
self.predictor = predictor
self.food_imgs = [
cv2.imread('./imgs/food_gn.png'),
cv2.imread('./imgs/food_cm.png'),
cv2.imread('./imgs/food_bug.png'),
cv2.imread('./imgs/food_fy.png'),
]
self.ff_dict={'hua jiang':0, 'ji yu':1, 'die yu':2, 'jia long':3, 'pao yu':3}
self.dist_dict={'hua jiang':130, 'ji yu':80, 'die yu':80, 'jia long':80, 'pao yu':80}
self.food_rgn=[580,400,740,220]
self.last_fish_type='hua jiang'
self.show_det=show_det

def get_fish_types(self, n=12, rate=0.6):
counter = Counter()
fx = lambda x: int(np.sign(np.cos(np.pi * (x / (n // 2)) + 1e-4)))
for i in range(n):
obj_list = self.predictor.image_det(cap())
if obj_list is None:
win32api.mouse_event(win32con.MOUSEEVENTF_MOVE, 70 * fx(i), 0, 0, 0)
time.sleep(0.2)
continue
cls_list = set([x[0] for x in obj_list])
counter.update(cls_list)
win32api.mouse_event(win32con.MOUSEEVENTF_MOVE, 70 * fx(i), 0, 0, 0)
time.sleep(0.2)
# pyautogui.moveRel(50, 0, duration=0.5)
# for u in range(1,51):
# mosue.move(sx + u, sy)
# time.sleep(0.003)
fish_list = [k for k, v in dict(counter).items() if v / n >= rate]
return fish_list

def throw_rod(self, fish_type):
win32api.mouse_event(win32con.MOUSEEVENTF_LEFTDOWN, 0, 0)
time.sleep(1)

def move_func(dist):
if dist>100:
return 50 * np.sign(dist)
else:
return (abs(dist)/2.5+10) * np.sign(dist)

for i in range(50):
try:
obj_list, outputs, img_info = self.predictor.image_det(cap(), with_info=True)
if self.show_det:
cv2.imwrite(f'img_tmp/det{i}.png', self.predictor.visual(outputs[0],img_info))

rod_info = sorted(list(filter(lambda x: x[0] == 'rod', obj_list)), key=lambda x: x[1], reverse=True)
if len(rod_info)<=0:
win32api.mouse_event(win32con.MOUSEEVENTF_MOVE, np.random.randint(-50,50), np.random.randint(-50,50), 0, 0)
rod_info=rod_info[0]
rod_cx = (rod_info[2][0] + rod_info[2][2]) / 2
rod_cy = (rod_info[2][1] + rod_info[2][3]) / 2

fish_info = min(list(filter(lambda x: x[0] == fish_type, obj_list)),
key=lambda x: distance((x[2][0]+x[2][2])/2, (x[2][1]+x[2][3])/2, rod_cx, rod_cy))

if (fish_info[2][0] + fish_info[2][2]) > (rod_info[2][0] + rod_info[2][2]):
#dist = -self.dist_dict[fish_type] * np.sign(fish_info[2][2] - (rod_info[2][0] + rod_info[2][2]) / 2)
x_dist = fish_info[2][0] - self.dist_dict[fish_type] - rod_cx
else:
x_dist = fish_info[2][2] + self.dist_dict[fish_type] - rod_cx

print(x_dist, (fish_info[2][3] + fish_info[2][1]) / 2 - rod_info[2][3])
if abs(x_dist)<30 and abs((fish_info[2][3] + fish_info[2][1]) / 2 - rod_info[2][3])<30:
break

dx = move_func(x_dist)
#win32api.mouse_event(win32con.MOUSEEVENTF_MOVE, fish_info[2][2] - rod_info[2][2] + 50, (fish_info[2][3] + fish_info[2][1]) / 2 - rod_info[2][3], 0, 0)
win32api.mouse_event(win32con.MOUSEEVENTF_MOVE, dx, move_func((fish_info[2][3] + fish_info[2][1]) / 2 - rod_info[2][3]), 0, 0)
except Exception as e:
traceback.print_exc()
#time.sleep(0.3)
win32api.mouse_event(win32con.MOUSEEVENTF_LEFTUP, 0, 0)

def select_food(self, fish_type):
pyautogui.press('f')
time.sleep(1)
pyautogui.click(1650, 790, button=pyautogui.SECONDARY)
time.sleep(0.5)
bbox_food = match_img(cap(self.food_rgn), self.food_imgs[self.ff_dict[fish_type]], type=cv2.TM_CCOEFF_NORMED)
pyautogui.click(bbox_food[4]+self.food_rgn[0], bbox_food[5]+self.food_rgn[1])
time.sleep(0.1)
pyautogui.click(1183, 756)

def do_fish(self, fish_init=True):
if fish_init:
self.fish_list = self.get_fish_types()
if self.fish_list[0]!=self.last_fish_type:
self.select_food(self.fish_list[0])
self.last_fish_type = self.fish_list[0]
self.throw_rod(self.fish_list[0])

class Fishing:
def __init__(self, delay=0.1, max_step=100, show_det=True):
def __init__(self, delay=0.1, max_step=100, show_det=True, predictor=None):
self.mosue = PyMouse()
self.t_l = cv2.imread('imgs/target_left.png')
self.t_r = cv2.imread('imgs/target_right.png')
self.t_n = cv2.imread('imgs/target_now.png')
self.im_bar = cv2.imread('imgs/bar2.png')
self.bite = cv2.imread('imgs/bite.png', cv2.IMREAD_GRAYSCALE)
self.t_l = cv2.imread('./imgs/target_left.png')
self.t_r = cv2.imread('./imgs/target_right.png')
self.t_n = cv2.imread('./imgs/target_now.png')
self.im_bar = cv2.imread('./imgs/bar2.png')
self.bite = cv2.imread('./imgs/bite.png', cv2.IMREAD_GRAYSCALE)
self.std_color=np.array([192,255,255])
self.r_ring=21
self.delay=delay
Expand Down Expand Up @@ -45,13 +145,13 @@ def do_action(self, action):
def scale(self, x):
return (x-5-10)/484

def find_bar(self):
img = cap(region=[700, 0, 520, 300])
def find_bar(self, img=None):
img = cap(region=[700, 0, 520, 300]) if img is None else img[:300, 700:700+520, :]
bbox_bar = match_img(img, self.im_bar)
if self.show_det:
img=deepcopy(img)
cv2.rectangle(img, bbox_bar[:2], bbox_bar[2:4], (0, 0, 255), 1) # 画出矩形位置
cv2.imwrite(f'./img_tmp/bar.jpg',img)
cv2.imwrite(f'../img_tmp/bar.jpg', img)
return bbox_bar[1]-9, bbox_bar

def is_bite(self):
Expand All @@ -60,7 +160,7 @@ def is_bite(self):
edge_output = cv2.Canny(gray, 50, 150)
return psnr(self.bite, edge_output)>10

def get_state(self):
def get_state(self, all_box=False):
bar_img=self.img[2:34,:,:]
bbox_l = match_img(bar_img, self.t_l)
bbox_r = match_img(bar_img, self.t_r)
Expand All @@ -87,8 +187,10 @@ def get_state(self):
'''cv2.imwrite(f'./bar_dataset/{self.count}.jpg', self.img)
with open(f'./bar_dataset/{self.count}.xml', 'w', encoding='utf-8') as f:
f.write(self.voc_tmp.format(self.count, *bbox_l[:4], *bbox_r[:4], *bbox_n[:4]))'''

return self.scale(bbox_l[4]),self.scale(bbox_r[4]),self.scale(bbox_n[4])
if all_box:
return bbox_l, bbox_r, bbox_n
else:
return self.scale(bbox_l[4]),self.scale(bbox_r[4]),self.scale(bbox_n[4])

def get_score(self):
cx,cy=247+10,72
Expand Down
File renamed without changes.
114 changes: 114 additions & 0 deletions fisher/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import time
from loguru import logger

import os
import torch
import cv2

from yolox.data.data_augment import ValTransform
from yolox.data.datasets import FISH_CLASSES
from yolox.utils import postprocess, vis

class Predictor(object):
def __init__(
self,
model,
exp,
cls_names=FISH_CLASSES,
trt_file=None,
decoder=None,
device="cpu",
fp16=False,
legacy=False,
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.fp16 = fp16
self.preproc = ValTransform(legacy=legacy)
if trt_file is not None:
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))

x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
self.model(x)
self.model = model_trt

def inference(self, img):
img_info = {"id": 0}
if isinstance(img, str):
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info["file_name"] = None

height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img

ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
img_info["ratio"] = ratio

img, _ = self.preproc(img, None, self.test_size)
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
if self.device == "gpu":
img = img.cuda()
if self.fp16:
img = img.half() # to FP16

with torch.no_grad():
t0 = time.time()
outputs = self.model(img)
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre,
self.nmsthre, class_agnostic=True
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info

def image_det(self, img, with_info=False):
outputs, img_info = self.inference(img)
ratio = img_info["ratio"]
obj_list = []
if outputs[0] is None:
return None
for item in outputs[0].cpu():
bboxes = item[:4]
# preprocessing: resize
bboxes /= ratio
scores = item[4] * item[5]
obj_list.append([self.cls_names[int(item[6])], scores, [bboxes[0], bboxes[1], bboxes[2], bboxes[3]]])
if with_info:
return obj_list, outputs, img_info
else:
return obj_list


def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.cpu()

bboxes = output[:, 0:4]

# preprocessing: resize
bboxes /= ratio

cls = output[:, 6]
scores = output[:, 4] * output[:, 5]

vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res
Loading

0 comments on commit 5338bd6

Please sign in to comment.