diff --git a/api.py b/api.py index bd08984..204b390 100644 --- a/api.py +++ b/api.py @@ -6,15 +6,13 @@ from predictor import PosPrediction - class PRN: ''' Joint 3D Face Reconstruction and Dense Alignment with Position Map Regression Network Args: is_dlib(bool, optional): If true, dlib is used for detecting faces. - is_opencv(bool, optional): If true, opencv is used for extracting texture. prefix(str, optional): If run at another folder, the absolute path is needed to load the data. ''' - def __init__(self, is_dlib = False, is_opencv = False, prefix = '.'): + def __init__(self, is_dlib = False, prefix = '.'): # resolution of input and output image size. self.resolution_inp = 256 @@ -27,9 +25,6 @@ def __init__(self, is_dlib = False, is_opencv = False, prefix = '.'): self.face_detector = dlib.cnn_face_detection_model_v1( detector_path) - if is_opencv: - import cv2 - #---- load PRN self.pos_predictor = PosPrediction(self.resolution_inp, self.resolution_op) prn_path = os.path.join(prefix, 'Data/net-data/256_256_resfcn256_weight') @@ -42,6 +37,16 @@ def __init__(self, is_dlib = False, is_opencv = False, prefix = '.'): self.uv_kpt_ind = np.loadtxt(prefix + '/Data/uv-data/uv_kpt_ind.txt').astype(np.int32) # 2 x 68 get kpt self.face_ind = np.loadtxt(prefix + '/Data/uv-data/face_ind.txt').astype(np.int32) # get valid vertices in the pos map self.triangles = np.loadtxt(prefix + '/Data/uv-data/triangles.txt').astype(np.int32) # ntri x 3 + self.uv_coords = self.generate_uv_coords() + + def generate_uv_coords(self): + resolution = self.resolution_op + uv_coords = np.meshgrid(range(resolution),range(resolution)) + uv_coords = np.transpose(np.array(uv_coords), [1,2,0]) + uv_coords = np.reshape(uv_coords, [resolution**2, -1]); + uv_coords = uv_coords[self.face_ind, :] + uv_coords = np.hstack((uv_coords[:,:2], np.zeros([uv_coords.shape[0], 1]))) + return uv_coords def dlib_detect(self, image): return self.face_detector(image, 1) @@ -147,18 +152,6 @@ def get_vertices(self, pos): return vertices - def get_texture(self, image, pos): - ''' extract uv texture from image. opencv is needed here. - Args: - image: input image. - pos: the 3D position map. shape = (256, 256, 3). - Returns: - texture: the corresponding colors of vertices. shape = (num of points, 3). n is 45128 here. - ''' - texture = cv2.remap(image, pos[:,:,:2].astype(np.float32), None, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT,borderValue=(0)) - return texture - - def get_colors(self, image, vertices): ''' Args: