-
Notifications
You must be signed in to change notification settings - Fork 718
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
22d70ca
commit 5338bd6
Showing
82 changed files
with
8,092 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
## Introduction | ||
原神自动钓鱼AI由[YOLOX](https://github.com/Megvii-BaseDetection/YOLOX), DQN两部分模型组成。使用迁移学习,半监督学习进行训练。 | ||
模型也包含一些使用opencv等传统数字图像处理方法实现的不可学习部分。 | ||
|
||
其中YOLOX用于鱼的定位和类型的识别以及鱼竿落点的定位。DQN用于自适应控制钓鱼过程的点击,让力度落在最佳区域内。 | ||
|
||
## 准备 | ||
安装yolox | ||
|
||
```shell | ||
python setup.py develop | ||
``` | ||
|
||
|
||
## YOLOX训练工作流程 | ||
YOLOX部分因为打标签太累所以用半监督学习。标注少量样本后训练模型生成其余样本伪标签再人工修正,不断迭代提高精度。 | ||
样本量较少所以使用迁移学习,在COCO预训练的模型上进行fine-tuning. | ||
|
||
训练代码: | ||
```shell | ||
python yolox_tools/train.py -f yolox/exp/yolox_tiny_fish.py -d 1 -b 8 --fp16 -o -c weights/yolox_tiny.pth | ||
``` | ||
|
||
## DQN训练工作流程 | ||
控制力度使用强化学习模型DQN进行训练。两次进度的差值作为reward为模型提供学习方向。模型与环境间交互式学习。 | ||
|
||
直接在原神内训练耗时较长,太累了。首先制作一个仿真环境,大概模拟钓鱼力度控制操作。在仿真环境内预训练一个模型。 | ||
随后将这一模型迁移至原神内,实现域间迁移。 | ||
|
||
仿真环境预训练代码: | ||
```shell | ||
python train_sim.py | ||
``` | ||
原神游戏内训练: | ||
```shell | ||
python train.py | ||
``` | ||
|
||
## 运行钓鱼AI | ||
```shell | ||
python fishing.py image -f yolox/exp/yolox_tiny_fish.py -c weights/best_tiny3.pth --conf 0.25 --nms 0.45 --tsize 640 --device gpu | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,20 @@ | ||
import time | ||
|
||
import pyautogui | ||
import keyboard | ||
import winsound | ||
|
||
i=0 | ||
'''i=0 | ||
while True: | ||
keyboard.wait('t') | ||
img = pyautogui.screenshot() | ||
img.save(f'img_tmp/{i}.png') | ||
i+=1 | ||
i+=1''' | ||
|
||
print('ok') | ||
keyboard.wait('t') | ||
for i in range(56,56+20): | ||
img = pyautogui.screenshot() | ||
img.save(f'fish_dataset/{i}.png') | ||
time.sleep(0.5) | ||
winsound.Beep(500, 500) |
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import time | ||
from loguru import logger | ||
|
||
import os | ||
import torch | ||
import cv2 | ||
|
||
from yolox.data.data_augment import ValTransform | ||
from yolox.data.datasets import FISH_CLASSES | ||
from yolox.utils import postprocess, vis | ||
|
||
class Predictor(object): | ||
def __init__( | ||
self, | ||
model, | ||
exp, | ||
cls_names=FISH_CLASSES, | ||
trt_file=None, | ||
decoder=None, | ||
device="cpu", | ||
fp16=False, | ||
legacy=False, | ||
): | ||
self.model = model | ||
self.cls_names = cls_names | ||
self.decoder = decoder | ||
self.num_classes = exp.num_classes | ||
self.confthre = exp.test_conf | ||
self.nmsthre = exp.nmsthre | ||
self.test_size = exp.test_size | ||
self.device = device | ||
self.fp16 = fp16 | ||
self.preproc = ValTransform(legacy=legacy) | ||
if trt_file is not None: | ||
from torch2trt import TRTModule | ||
|
||
model_trt = TRTModule() | ||
model_trt.load_state_dict(torch.load(trt_file)) | ||
|
||
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda() | ||
self.model(x) | ||
self.model = model_trt | ||
|
||
def inference(self, img): | ||
img_info = {"id": 0} | ||
if isinstance(img, str): | ||
img_info["file_name"] = os.path.basename(img) | ||
img = cv2.imread(img) | ||
else: | ||
img_info["file_name"] = None | ||
|
||
height, width = img.shape[:2] | ||
img_info["height"] = height | ||
img_info["width"] = width | ||
img_info["raw_img"] = img | ||
|
||
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1]) | ||
img_info["ratio"] = ratio | ||
|
||
img, _ = self.preproc(img, None, self.test_size) | ||
img = torch.from_numpy(img).unsqueeze(0) | ||
img = img.float() | ||
if self.device == "gpu": | ||
img = img.cuda() | ||
if self.fp16: | ||
img = img.half() # to FP16 | ||
|
||
with torch.no_grad(): | ||
t0 = time.time() | ||
outputs = self.model(img) | ||
if self.decoder is not None: | ||
outputs = self.decoder(outputs, dtype=outputs.type()) | ||
outputs = postprocess( | ||
outputs, self.num_classes, self.confthre, | ||
self.nmsthre, class_agnostic=True | ||
) | ||
logger.info("Infer time: {:.4f}s".format(time.time() - t0)) | ||
return outputs, img_info | ||
|
||
def image_det(self, img, with_info=False): | ||
outputs, img_info = self.inference(img) | ||
ratio = img_info["ratio"] | ||
obj_list = [] | ||
if outputs[0] is None: | ||
return None | ||
for item in outputs[0].cpu(): | ||
bboxes = item[:4] | ||
# preprocessing: resize | ||
bboxes /= ratio | ||
scores = item[4] * item[5] | ||
obj_list.append([self.cls_names[int(item[6])], scores, [bboxes[0], bboxes[1], bboxes[2], bboxes[3]]]) | ||
if with_info: | ||
return obj_list, outputs, img_info | ||
else: | ||
return obj_list | ||
|
||
|
||
def visual(self, output, img_info, cls_conf=0.35): | ||
ratio = img_info["ratio"] | ||
img = img_info["raw_img"] | ||
if output is None: | ||
return img | ||
output = output.cpu() | ||
|
||
bboxes = output[:, 0:4] | ||
|
||
# preprocessing: resize | ||
bboxes /= ratio | ||
|
||
cls = output[:, 6] | ||
scores = output[:, 4] * output[:, 5] | ||
|
||
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names) | ||
return vis_res |
Oops, something went wrong.