diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..b7c4487 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include LICENSE +include README.md +include requirements.txt +recursive-include configs *.py diff --git a/README.md b/README.md index a17b4cf..4a9fd7f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # Anime Face Detector +This is an anime face detector using +[mmdetection](https://github.com/open-mmlab/mmdetection) +and [mmpose](https://github.com/open-mmlab/mmpose). diff --git a/anime_face_detector/__init__.py b/anime_face_detector/__init__.py new file mode 100644 index 0000000..8862231 --- /dev/null +++ b/anime_face_detector/__init__.py @@ -0,0 +1,14 @@ +import pathlib + +from .detector import LandmarkDetector + + +def get_config_path(model_name: str) -> pathlib.Path: + assert model_name in ['faster-rcnn', 'yolov3', 'hrnetv2'] + + package_path = pathlib.Path(__file__).parent.resolve() + if model_name in ['faster-rcnn', 'yolov3']: + config_dir = package_path / 'configs' / 'mmdet' + else: + config_dir = package_path / 'configs' / 'mmpose' + return config_dir / f'{model_name}.py' diff --git a/anime_face_detector/configs/mmdet/faster-rcnn.py b/anime_face_detector/configs/mmdet/faster-rcnn.py new file mode 100644 index 0000000..2ccf8d2 --- /dev/null +++ b/anime_face_detector/configs/mmdet/faster-rcnn.py @@ -0,0 +1,66 @@ +model = dict(type='FasterRCNN', + backbone=dict(type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict(type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict(type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict(type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict(type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0])), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict(type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', + output_size=7, + sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict(type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False)), + test_cfg=dict(rpn=dict(nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict(score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100))) +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict(test=dict(pipeline=test_pipeline)) diff --git a/anime_face_detector/configs/mmdet/yolov3.py b/anime_face_detector/configs/mmdet/yolov3.py new file mode 100644 index 0000000..c644633 --- /dev/null +++ b/anime_face_detector/configs/mmdet/yolov3.py @@ -0,0 +1,47 @@ +model = dict(type='YOLOV3', + backbone=dict(type='Darknet', depth=53, out_indices=(3, 4, 5)), + neck=dict(type='YOLOV3Neck', + num_scales=3, + in_channels=[1024, 512, 256], + out_channels=[512, 256, 128]), + bbox_head=dict(type='YOLOV3Head', + num_classes=1, + in_channels=[512, 256, 128], + out_channels=[1024, 512, 256], + anchor_generator=dict(type='YOLOAnchorGenerator', + base_sizes=[[(116, 90), + (156, 198), + (373, 326)], + [(30, 61), + (62, 45), + (59, 119)], + [(10, 13), + (16, 30), + (33, 23)]], + strides=[32, 16, 8]), + bbox_coder=dict(type='YOLOBBoxCoder'), + featmap_strides=[32, 16, 8]), + test_cfg=dict(nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + conf_thr=0.005, + nms=dict(type='nms', iou_threshold=0.45), + max_per_img=100)) +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='MultiScaleFlipAug', + img_scale=(608, 608), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', + mean=[0, 0, 0], + std=[255.0, 255.0, 255.0], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict(test=dict(pipeline=test_pipeline)) diff --git a/anime_face_detector/configs/mmpose/hrnetv2.py b/anime_face_detector/configs/mmpose/hrnetv2.py new file mode 100644 index 0000000..b113ef4 --- /dev/null +++ b/anime_face_detector/configs/mmpose/hrnetv2.py @@ -0,0 +1,250 @@ +channel_cfg = dict(num_output_channels=28, + dataset_joints=28, + dataset_channel=[ + list(range(28)), + ], + inference_channel=list(range(28))) + +model = dict( + type='TopDown', + backbone=dict(type='HRNet', + in_channels=3, + extra=dict(stage1=dict(num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict(num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict(num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict(num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144), + multiscale_output=True), + upsample=dict(mode='bilinear', + align_corners=False))), + keypoint_head=dict(type='TopdownHeatmapSimpleHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + out_channels=channel_cfg['num_output_channels'], + num_deconv_layers=0, + extra=dict(final_conv_kernel=1, + num_conv_layers=1, + num_conv_kernels=(1, )), + loss_keypoint=dict(type='JointsMSELoss', + use_target_weight=True)), + test_cfg=dict(flip_test=True, + post_process='unbiased', + shift_heatmap=True, + modulate_kernel=11)) + +data_cfg = dict(image_size=[256, 256], + heatmap_size=[64, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel']) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine'), + dict(type='ToTensor'), + dict(type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict(type='Collect', + keys=['img'], + meta_keys=['image_file', 'center', 'scale', 'rotation', + 'flip_pairs']), +] + +dataset_info = dict(dataset_name='anime_face', + paper_info=dict(), + keypoint_info={ + 0: + dict(name='kpt-0', + id=0, + color=[255, 255, 255], + type='', + swap='kpt-4'), + 1: + dict(name='kpt-1', + id=1, + color=[255, 255, 255], + type='', + swap='kpt-3'), + 2: + dict(name='kpt-2', + id=2, + color=[255, 255, 255], + type='', + swap=''), + 3: + dict(name='kpt-3', + id=3, + color=[255, 255, 255], + type='', + swap='kpt-1'), + 4: + dict(name='kpt-4', + id=4, + color=[255, 255, 255], + type='', + swap='kpt-0'), + 5: + dict(name='kpt-5', + id=5, + color=[255, 255, 255], + type='', + swap='kpt-10'), + 6: + dict(name='kpt-6', + id=6, + color=[255, 255, 255], + type='', + swap='kpt-9'), + 7: + dict(name='kpt-7', + id=7, + color=[255, 255, 255], + type='', + swap='kpt-8'), + 8: + dict(name='kpt-8', + id=8, + color=[255, 255, 255], + type='', + swap='kpt-7'), + 9: + dict(name='kpt-9', + id=9, + color=[255, 255, 255], + type='', + swap='kpt-6'), + 10: + dict(name='kpt-10', + id=10, + color=[255, 255, 255], + type='', + swap='kpt-5'), + 11: + dict(name='kpt-11', + id=11, + color=[255, 255, 255], + type='', + swap='kpt-19'), + 12: + dict(name='kpt-12', + id=12, + color=[255, 255, 255], + type='', + swap='kpt-18'), + 13: + dict(name='kpt-13', + id=13, + color=[255, 255, 255], + type='', + swap='kpt-17'), + 14: + dict(name='kpt-14', + id=14, + color=[255, 255, 255], + type='', + swap='kpt-22'), + 15: + dict(name='kpt-15', + id=15, + color=[255, 255, 255], + type='', + swap='kpt-21'), + 16: + dict(name='kpt-16', + id=16, + color=[255, 255, 255], + type='', + swap='kpt-20'), + 17: + dict(name='kpt-17', + id=17, + color=[255, 255, 255], + type='', + swap='kpt-13'), + 18: + dict(name='kpt-18', + id=18, + color=[255, 255, 255], + type='', + swap='kpt-12'), + 19: + dict(name='kpt-19', + id=19, + color=[255, 255, 255], + type='', + swap='kpt-11'), + 20: + dict(name='kpt-20', + id=20, + color=[255, 255, 255], + type='', + swap='kpt-16'), + 21: + dict(name='kpt-21', + id=21, + color=[255, 255, 255], + type='', + swap='kpt-15'), + 22: + dict(name='kpt-22', + id=22, + color=[255, 255, 255], + type='', + swap='kpt-14'), + 23: + dict(name='kpt-23', + id=23, + color=[255, 255, 255], + type='', + swap=''), + 24: + dict(name='kpt-24', + id=24, + color=[255, 255, 255], + type='', + swap='kpt-26'), + 25: + dict(name='kpt-25', + id=25, + color=[255, 255, 255], + type='', + swap=''), + 26: + dict(name='kpt-26', + id=26, + color=[255, 255, 255], + type='', + swap='kpt-24'), + 27: + dict(name='kpt-27', + id=27, + color=[255, 255, 255], + type='', + swap='') + }, + skeleton_info={}, + joint_weights=[1.] * 28, + sigmas=[]) + +data = dict(test=dict(type='', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info=dataset_info), ) diff --git a/anime_face_detector/detector.py b/anime_face_detector/detector.py new file mode 100644 index 0000000..82df6eb --- /dev/null +++ b/anime_face_detector/detector.py @@ -0,0 +1,141 @@ +import pathlib +import warnings +from typing import Optional, Union + +import cv2 +import mmcv +import numpy as np +import torch.nn as nn +from mmdet.apis import inference_detector, init_detector +from mmpose.apis import inference_top_down_pose_model, init_pose_model +from mmpose.datasets import DatasetInfo + + +class LandmarkDetector: + def __init__( + self, + landmark_detector_config_or_path: Union[mmcv.Config, str, + pathlib.Path], + landmark_detector_checkpoint_path: Union[str, pathlib.Path], + face_detector_config_or_path: Optional[Union[mmcv.Config, str, + pathlib.Path]] = None, + face_detector_checkpoint_path: Optional[Union[ + str, pathlib.Path]] = None, + device: str = 'cuda:0', + flip_test: bool = True, + box_scale_factor: float = 1.1): + landmark_config = self._load_config(landmark_detector_config_or_path) + self.dataset_info = DatasetInfo( + landmark_config.dataset_info) # type: ignore + face_detector_config = self._load_config(face_detector_config_or_path) + + self.landmark_detector = self._init_pose_model( + landmark_config, landmark_detector_checkpoint_path, device, + flip_test) + self.face_detector = self._init_face_detector( + face_detector_config, face_detector_checkpoint_path, device) + + self.box_scale_factor = box_scale_factor + + @staticmethod + def _load_config( + config_or_path: Optional[Union[mmcv.Config, str, pathlib.Path]] + ) -> Optional[mmcv.Config]: + if config_or_path is None or isinstance(config_or_path, mmcv.Config): + return config_or_path + return mmcv.Config.fromfile(config_or_path) + + @staticmethod + def _init_pose_model(config: mmcv.Config, + checkpoint_path: Union[str, pathlib.Path], + device: str, flip_test: bool) -> nn.Module: + model = init_pose_model(config, checkpoint_path, device=device) + model.cfg.model.test_cfg.flip_test = flip_test + return model + + @staticmethod + def _init_face_detector(config: Optional[mmcv.Config], + checkpoint_path: Optional[Union[str, + pathlib.Path]], + device: str) -> Optional[nn.Module]: + if config is not None: + model = init_detector(config, checkpoint_path, device=device) + else: + model = None + return model + + def _detect_faces(self, image: np.ndarray) -> list[np.ndarray]: + # predicted boxes using mmdet model have the format of + # [x0, y0, x1, y1, score] + boxes = inference_detector(self.face_detector, image)[0] + # scale boxes by `self.box_scale_factor` + boxes = self._update_pred_box(boxes) + return boxes + + def _update_pred_box(self, pred_boxes: np.ndarray) -> list[np.ndarray]: + boxes = [] + for pred_box in pred_boxes: + box = pred_box[:4] + size = box[2:] - box[:2] + 1 + new_size = size * self.box_scale_factor + center = (box[:2] + box[2:]) / 2 + tl = center - new_size / 2 + br = tl + new_size + pred_box[:4] = np.concatenate([tl, br]) + boxes.append(pred_box) + return boxes + + def _detect_landmarks( + self, image: np.ndarray, + boxes: list[dict[str, np.ndarray]]) -> list[dict[str, np.ndarray]]: + preds, _ = inference_top_down_pose_model( + self.landmark_detector, + image, + boxes, + format='xyxy', + dataset_info=self.dataset_info, + return_heatmap=False) + return preds + + @staticmethod + def _load_image( + image_or_path: Union[np.ndarray, str, pathlib.Path]) -> np.ndarray: + if isinstance(image_or_path, np.ndarray): + image = image_or_path + elif isinstance(image_or_path, str): + image = cv2.imread(image_or_path) + elif isinstance(image_or_path, pathlib.Path): + image = cv2.imread(image_or_path.as_posix()) + else: + raise ValueError + return image + + def __call__( + self, + image_or_path: Union[np.ndarray, str, pathlib.Path], + boxes: Optional[list[np.ndarray]] = None + ) -> list[dict[str, np.ndarray]]: + """Detect face landmarks. + + Args: + image_or_path: An image with BGR channel order or an image path. + boxes: A list of bounding boxes for faces. Each bounding box + should be of the form [x0, y0, x1, y1, [score]]. + + Returns: A list of detection results. Each detection result has + bounding box of the form [x0, y0, x1, y1, [score]], and landmarks + of the form [x, y, score]. + """ + image = self._load_image(image_or_path) + if boxes is None: + if self.face_detector is not None: + boxes = self._detect_faces(image) + else: + warnings.warn( + 'Neither the face detector nor the bounding box is ' + 'specified. So the entire image is treated as the face ' + 'region.') + h, w = image.shape[:2] + boxes = [np.array([0, 0, w - 1, h - 1])] + box_list = [{'bbox': box} for box in boxes] + return self._detect_landmarks(image, box_list) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..24a0d2a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +mmcv-full==1.3.15 +mmdet==2.17.0 +mmpose==0.19.0 +numpy==1.21.2 +opencv-python-headless==4.5.3.56 +torch==1.9.1 +torchvision==0.10.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e670337 --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +import pathlib + +from setuptools import find_packages, setup + + +def _get_long_description(): + path = pathlib.Path(__file__).parent / 'README.md' + with open(path, encoding='utf-8') as f: + long_description = f.read() + return long_description + + +def _get_requirements(path): + with open(path) as f: + data = f.readlines() + return data + + +setup( + name='anime-face-detector', + version='0.0.1', + author='hysts', + url='https://github.com/hysts/anime_face_detector', + python_requires='>=3.7', + install_requires=_get_requirements('requirements.txt'), + packages=find_packages(exclude=('tests', )), + include_package_data=True, + description='Anime Face Detector using mmdet and mmpose', + long_description=_get_long_description(), + long_description_content_type='text/markdown', +)