-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathmmseg_mask_generator.py
50 lines (41 loc) · 1.7 KB
/
mmseg_mask_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Tuple
import cv2
import modules.shared as shared
import numpy as np
from face_editor.use_cases.mask_generator import MaskGenerator
from huggingface_hub import hf_hub_download
from mmseg.apis import inference_model, init_model
class MMSegMaskGenerator(MaskGenerator):
def __init__(self):
checkpoint_file = hf_hub_download(
repo_id="ototadana/occlusion-aware-face-segmentation",
filename="deeplabv3plus_r101_512x512_face-occlusion-93ec6695.pth",
)
config_file = hf_hub_download(
repo_id="ototadana/occlusion-aware-face-segmentation",
filename="deeplabv3plus_r101_512x512_face-occlusion.py",
)
self.model = init_model(config_file, checkpoint_file, device=shared.device)
def name(self) -> str:
return "MMSeg"
def generate_mask(
self,
face_image: np.ndarray,
face_area_on_image: Tuple[int, int, int, int],
mask_size: int,
use_minimal_area: bool,
**kwargs,
) -> np.ndarray:
face_image = face_image.copy()
face_image = face_image[:, :, ::-1]
if use_minimal_area:
face_image = MaskGenerator.mask_non_face_areas(face_image, face_area_on_image)
result = inference_model(self.model, face_image)
pred_sem_seg = result.pred_sem_seg
pred_sem_seg_data = pred_sem_seg.data.squeeze(0)
pred_sem_seg_np = pred_sem_seg_data.cpu().numpy()
pred_sem_seg_np = (pred_sem_seg_np * 255).astype(np.uint8)
mask = cv2.cvtColor(pred_sem_seg_np, cv2.COLOR_BGR2RGB)
if mask_size > 0:
mask = cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=mask_size)
return mask