From b5abd3e5c29b0f0cdbcb73a2c406ab4180869fe8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=E1=BA=BF=20H=E1=BA=A3i=20Nguy=E1=BB=85n?= Date: Thu, 8 Dec 2022 21:16:21 +0700 Subject: [PATCH 1/3] fix bug for HSV augumentation in dataset.py --- dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dataset.py b/dataset.py index e2b2b8b1..9ea1c382 100644 --- a/dataset.py +++ b/dataset.py @@ -136,11 +136,11 @@ def image_data_augmentation(mat, w, h, pleft, ptop, swidth, sheight, flip, dhue, if dsat != 1 or dexp != 1 or dhue != 0: if img.shape[2] >= 3: hsv_src = cv2.cvtColor(sized.astype(np.float32), cv2.COLOR_RGB2HSV) # RGB to HSV - hsv = cv2.split(hsv_src) - hsv[1] *= dsat - hsv[2] *= dexp - hsv[0] += 179 * dhue - hsv_src = cv2.merge(hsv) + h_channel, s_channel, v_channel = cv2.split(hsv_src) + s_channel *= dsat + v_channel *= dexp + h_channel += 179 * dhue + hsv_src = cv2.merge((h_channel, s_channel, v_channel)) sized = np.clip(cv2.cvtColor(hsv_src, cv2.COLOR_HSV2RGB), 0, 255) # HSV to RGB (the same as previous) else: sized *= dexp From 5f7f9ad784abffd0d3cd8d07f4b06794a56b3aa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=E1=BA=BF=20H=E1=BA=A3i=20Nguy=E1=BB=85n?= Date: Thu, 8 Dec 2022 23:04:15 +0700 Subject: [PATCH 2/3] add prediction proccess --- models.py | 54 ++++++++++++++++++++++++--------------------------- tool/utils.py | 25 ++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/models.py b/models.py index 3ef6787d..f898e10f 100644 --- a/models.py +++ b/models.py @@ -1,6 +1,7 @@ import torch from torch import nn import torch.nn.functional as F +from tqdm import tqdm class Mish(torch.nn.Module): @@ -421,44 +422,39 @@ def forward(self, input): if __name__ == "__main__": import sys + import os from PIL import Image - namesfile = None - if len(sys.argv) == 4: - n_classes = int(sys.argv[1]) - weightfile = sys.argv[2] - imgfile = sys.argv[3] - elif len(sys.argv) == 5: - n_classes = int(sys.argv[1]) - weightfile = sys.argv[2] - imgfile = sys.argv[3] - namesfile = sys.argv[4] - else: - print('Usage: ') - print(' python models.py num_classes weightfile imgfile namefile') + n_classes = int(sys.argv[1]) + weightfile = sys.argv[2] + img_dir = sys.argv[3] + pred_save_dir = sys.argv[4] + namesfile = sys.argv[5] model = Yolov4(n_classes=n_classes) pretrained_dict = torch.load(weightfile, map_location=torch.device('cuda')) model.load_state_dict(pretrained_dict) - if namesfile == None: - if n_classes == 20: - namesfile = 'data/voc.names' - elif n_classes == 80: - namesfile = 'data/coco.names' - else: - print("please give namefile") + model.cuda() + + pred_save_dir = os.path.dirname(pred_save_dir) + + for imgfile in tqdm(os.listdir(img_dir)): + file_extension = imgfile.split('.')[-1] + if file_extension not in ['jpg', 'png', 'jpeg']: + continue - use_cuda = 1 - if use_cuda: - model.cuda() + file_name = imgfile - img = Image.open(imgfile).convert('RGB') - sized = img.resize((608, 608)) - from tool.utils import * + imgfile = os.path.join(img_dir, imgfile) + img = Image.open(imgfile).convert('RGB') + sized = img.resize((608, 608)) + from tool.utils import * - boxes = do_detect(model, sized, 0.5, n_classes,0.4, use_cuda) + boxes = do_detect(model, sized, 0.5, n_classes,0.4, 1) - class_names = load_class_names(namesfile) - plot_boxes(img, boxes, 'predictions.jpg', class_names) + file_name_split = file_name.split('.') + file_name_split[-1] = 'txt' + file_name = '.'.join(file_name_split) + save_prediction(img, boxes, file_name, pred_save_dir) diff --git a/tool/utils.py b/tool/utils.py index d62354f6..ac460124 100644 --- a/tool/utils.py +++ b/tool/utils.py @@ -372,6 +372,31 @@ def get_color(c, x, max_val): return img +def save_prediction(img, boxes, file_name, pred_save_dir): + width = img.width + height = img.height + + content = "" + for i in range(len(boxes)): + box = boxes[i] + x1 = (box[0] - box[2] / 2.0) * width + y1 = (box[1] - box[3] / 2.0) * height + x2 = (box[0] + box[2] / 2.0) * width + y2 = (box[1] + box[3] / 2.0) * height + + cls_conf = box[5] + cls_id = box[6] + + pred = (cls_id, cls_conf, x1, y1, x2, y2) + pred = [str(i) for i in pred] + content += " ".join(pred) + "\n" + + dir = os.path.join(pred_save_dir, file_name) + f = open(dir, "w") + f.write(content) + f.close() + + def read_truths(lab_path): if not os.path.exists(lab_path): return np.array([]) From faf626d35c240c632527d2aaf8da858ced5b667c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=E1=BA=BF=20H=E1=BA=A3i=20Nguy=E1=BB=85n?= Date: Thu, 8 Dec 2022 23:30:52 +0700 Subject: [PATCH 3/3] fix bug for models.py --- .DS_Store | Bin 0 -> 6148 bytes models.py | 4 +++- tool/utils.py | 5 +++-- 3 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..3c06728a6b4e07f70d6a0c95adb629fb64efeb56 GIT binary patch literal 6148 zcmeHK%}T>S5Z-O0Nhm@N3Oz1(Em&JG#Y>3#0!H+pQWH`%7_+5G&7l->))(?gd>&_Z zH)63~MeI!3{pNQ!`$6}IF~);=a>$s+7;~T@a#Yp`y4QwUCK-|A7;#lh<4ne2znR!y z2mE%6l}y4f`}+OyG)s%J-+$+=wzaiw+pgWQZ^H*!g=JLEiy)d^qje!=7B_knU#634 z?CqY&yo}O(GE)U6sk7V0oQfhvDaNJhJ$`j42Pp-Pb@~LAB_gf zrR(hN9~_^LAG4=izGy-@@U3LWU<0q9e6Hs?oaLF!@4?sPH+h7_05L!e5CfabfH@PK z&gN1;D<=ksfgc#a{XsxO^bOV;)z$$WUY{{;A) z<=_`4&o@|W)a8t;nPD6=bMtuNYIg7omCm@Yk$Pf)7+7VXtxXTl|1