diff --git a/onepose/models/factory.py b/onepose/models/factory.py index 1baeb10..f3b3b39 100644 --- a/onepose/models/factory.py +++ b/onepose/models/factory.py @@ -1,7 +1,10 @@ import numpy as np +import cv2 import torch import torch.nn as nn import onepose.models.vitpose as vitpose +from PIL import Image + from onepose.utils import read_cfg, download_weights from onepose.transforms import ComposeTransforms, BGR2RGB, TopDownAffine, ToTensor, NormalizeTensor, _box2cs from onepose.functional import keypoints_from_heatmaps @@ -73,12 +76,12 @@ } class Model(nn.Module): - def __init__(self, + def __init__(self, model_name: str = 'ViTPose_huge_simple_coco') -> None: super().__init__() file_path = pathlib.Path(os.path.abspath(__file__)).parent - + self.model_cfg = read_cfg(os.path.join(file_path, 'configs', model_config[model_name]['model_cfg'])) self.model = vitpose.ViTPose(self.model_cfg.model) @@ -93,53 +96,94 @@ def __init__(self, weights_folder = os.path.join(file_path, 'weights') os.makedirs(weights_folder, exist_ok=True) ckpt = os.path.join(weights_folder, model_config[model_name]['url'].split('/')[-1]) - download_weights(model_config[model_name]['url'], - ckpt, + download_weights(model_config[model_name]['url'], + ckpt, model_config[model_name]['hash']) self.model.load_state_dict(torch.load(ckpt, map_location='cpu')) self.model.eval() - + dataset_cfg = read_cfg(os.path.join(file_path.parent, 'datasets', model_config[model_name]['dataset_cfg'])) self.keypoint_info = dataset_cfg.dataset_info['keypoint_info'] self.skeleton_info = dataset_cfg.dataset_info['skeleton_info'] @torch.no_grad() @torch.inference_mode() - def forward(self, x: np.ndarray) -> np.ndarray: + def forward(self, x: Union[np.ndarray, Image.Image, List[Union[np.ndarray, Image.Image]]]) -> Union[Dict, List[Dict]]: if self.training: self.eval() device = next(self.parameters()).device - img_height, img_width = x.shape[:2] - center, scale = _box2cs(self.model_cfg.data_cfg['image_size'], [0, 0, img_width, img_height]) + single_image = False + # Input validation and conversion + if isinstance(x, list): + if not x: # empty list check + raise ValueError("Input list cannot be empty") + if not all(isinstance(img, (np.ndarray, Image.Image)) for img in x): + raise TypeError("All elements in the list must be either numpy arrays or PIL Images") + # Convert PIL images to numpy arrays with BGR color space + x = [cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) if isinstance(img, Image.Image) else img for img in x] + elif isinstance(x, (np.ndarray, Image.Image)): + if isinstance(x, Image.Image): + x = cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR) + x = [x] + single_image = True + else: + raise TypeError("Input must be either a numpy array, PIL Image, or a list of them") + + # Convert grayscale images to BGR + for i, img in enumerate(x): + if img.ndim == 2 or (img.ndim == 3 and img.shape[2] == 1): + x[i] = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - results = {'img': x, - 'rotation': 0, - 'center': center, - 'scale': scale, - 'image_size': np.array(self.model_cfg.data_cfg['image_size']), - } + batch_results = [] + for img in x: + img_height, img_width = img.shape[:2] + center, scale = _box2cs(self.model_cfg.data_cfg['image_size'], + [0, 0, img_width, img_height]) - results = self.transforms(results) - results['img'] = results['img'].to(device) + results = { + 'img': img, + 'rotation': 0, + 'center': center, + 'scale': scale, + 'image_size': np.array(self.model_cfg.data_cfg['image_size']), + } + + results = self.transforms(results) + batch_results.append(results['img']) + + # Stack transformed images into a batch + batch_tensor = torch.stack(batch_results).to(device) - out = self.model(results['img'][None, ...]) + # Forward pass + out = self.model(batch_tensor) out = out.cpu().numpy() - out, maxvals = keypoints_from_heatmaps(out, - center=[center], - scale=[scale], - unbiased=False, - post_process='default', - kernel=11, - valid_radius_factor=0.0546875, - use_udp=self.use_udp, - target_type='GaussianHeatmap') - out = out[0] - maxvals = maxvals[0] - out = {'points': out, 'confidence': maxvals} - return out + # Process each image's predictions + centers = [_box2cs(self.model_cfg.data_cfg['image_size'], + [0, 0, img.shape[1], img.shape[0]])[0] for img in x] + scales = [_box2cs(self.model_cfg.data_cfg['image_size'], + [0, 0, img.shape[1], img.shape[0]])[1] for img in x] + + points, maxvals = keypoints_from_heatmaps( + out, + center=centers, + scale=scales, + unbiased=False, + post_process='default', + kernel=11, + valid_radius_factor=0.0546875, + use_udp=self.use_udp, + target_type='GaussianHeatmap' + ) + + outputs = [{'points': p, 'confidence': c} for p, c in zip(points, maxvals)] + + # Return single result for single image input + if single_image: + return outputs[0] + return outputs def create_model(model_name: str = 'ViTPose_huge_simple_coco') -> Model: model = Model(model_name=model_name) diff --git a/setup.py b/setup.py index f24514e..b0fd601 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,6 @@ setup( name='onepose', version='1.0', - install_requires=['opencv-python', 'torch', 'torchvision', 'tqdm', 'numpy'], + install_requires=['opencv-python', 'torch', 'torchvision', 'tqdm', 'numpy', 'Pillow'], packages=find_packages(exclude='notebooks') -) \ No newline at end of file +)