diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..3c06728a Binary files /dev/null and b/.DS_Store differ diff --git a/dataset.py b/dataset.py index e2b2b8b1..9ea1c382 100644 --- a/dataset.py +++ b/dataset.py @@ -136,11 +136,11 @@ def image_data_augmentation(mat, w, h, pleft, ptop, swidth, sheight, flip, dhue, if dsat != 1 or dexp != 1 or dhue != 0: if img.shape[2] >= 3: hsv_src = cv2.cvtColor(sized.astype(np.float32), cv2.COLOR_RGB2HSV) # RGB to HSV - hsv = cv2.split(hsv_src) - hsv[1] *= dsat - hsv[2] *= dexp - hsv[0] += 179 * dhue - hsv_src = cv2.merge(hsv) + h_channel, s_channel, v_channel = cv2.split(hsv_src) + s_channel *= dsat + v_channel *= dexp + h_channel += 179 * dhue + hsv_src = cv2.merge((h_channel, s_channel, v_channel)) sized = np.clip(cv2.cvtColor(hsv_src, cv2.COLOR_HSV2RGB), 0, 255) # HSV to RGB (the same as previous) else: sized *= dexp diff --git a/models.py b/models.py index 3ef6787d..9231143c 100644 --- a/models.py +++ b/models.py @@ -1,6 +1,7 @@ import torch from torch import nn import torch.nn.functional as F +from tqdm import tqdm class Mish(torch.nn.Module): @@ -421,44 +422,41 @@ def forward(self, input): if __name__ == "__main__": import sys + import os from PIL import Image - namesfile = None - if len(sys.argv) == 4: - n_classes = int(sys.argv[1]) - weightfile = sys.argv[2] - imgfile = sys.argv[3] - elif len(sys.argv) == 5: - n_classes = int(sys.argv[1]) - weightfile = sys.argv[2] - imgfile = sys.argv[3] - namesfile = sys.argv[4] - else: - print('Usage: ') - print(' python models.py num_classes weightfile imgfile namefile') + n_classes = int(sys.argv[1]) + weightfile = sys.argv[2] + img_dir = sys.argv[3] + pred_save_dir = sys.argv[4] + namesfile = sys.argv[5] model = Yolov4(n_classes=n_classes) pretrained_dict = torch.load(weightfile, map_location=torch.device('cuda')) model.load_state_dict(pretrained_dict) - if namesfile == None: - if n_classes == 20: - namesfile = 'data/voc.names' - elif n_classes == 80: - namesfile = 'data/coco.names' - else: - print("please give namefile") + model.cuda() + + pred_save_dir = os.path.dirname(pred_save_dir) + + for imgfile in tqdm(os.listdir(img_dir)): + file_extension = imgfile.split('.')[-1] + if file_extension not in ['jpg', 'png', 'jpeg']: + continue + + file_name = imgfile - use_cuda = 1 - if use_cuda: - model.cuda() + imgfile = os.path.join(img_dir, imgfile) + img = Image.open(imgfile).convert('RGB') + sized = img.resize((608, 608)) + from tool.utils import * - img = Image.open(imgfile).convert('RGB') - sized = img.resize((608, 608)) - from tool.utils import * + boxes = do_detect(model, sized, 0.5, n_classes,0.4, 1) - boxes = do_detect(model, sized, 0.5, n_classes,0.4, use_cuda) + class_names = load_class_names(namesfile) - class_names = load_class_names(namesfile) - plot_boxes(img, boxes, 'predictions.jpg', class_names) + file_name_split = file_name.split('.') + file_name_split[-1] = 'txt' + file_name = '.'.join(file_name_split) + save_prediction(img, boxes, file_name, pred_save_dir, class_names) diff --git a/tool/utils.py b/tool/utils.py index d62354f6..79357997 100644 --- a/tool/utils.py +++ b/tool/utils.py @@ -372,6 +372,32 @@ def get_color(c, x, max_val): return img +def save_prediction(img, boxes, file_name, pred_save_dir, class_names): + width = img.width + height = img.height + + content = "" + for i in range(len(boxes)): + box = boxes[i] + x1 = (box[0] - box[2] / 2.0) * width + y1 = (box[1] - box[3] / 2.0) * height + x2 = (box[0] + box[2] / 2.0) * width + y2 = (box[1] + box[3] / 2.0) * height + + cls_conf = box[5] + cls_id = box[6] + class_name = class_names[cls_id] + + pred = (class_name, cls_conf, x1, y1, x2, y2) + pred = [str(i) for i in pred] + content += " ".join(pred) + "\n" + + dir = os.path.join(pred_save_dir, file_name) + f = open(dir, "w") + f.write(content) + f.close() + + def read_truths(lab_path): if not os.path.exists(lab_path): return np.array([])