diff --git a/README.md b/README.md index 12e6e9c..2178c1f 100755 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ pip install -e . ``` ## 🚀 Inference - +### Using Command-line Scripts Download the pre-trained models ```genpercept_ckpt_v1.zip``` from [BaiduNetDisk](https://pan.baidu.com/s/1n6FlqrOTZqHX-F6OhcvNyA?pwd=g2cm) (Extract code: g2cm), [HuggingFace](https://huggingface.co/guangkaixu/GenPercept), or [Rec Cloud Disk (To be uploaded)](). Please unzip the package and put the checkpoints under ```./weights/v1/```. Then, place images in the ```./input/$TASK_TYPE``` dictionary, and run the following script. The output depth will be saved in ```./output/$TASK_TYPE```. The ```$TASK_TYPE``` can be chosen from ```depth```, ```normal```, and ```dis```. @@ -59,6 +59,32 @@ bash scripts/inference_dis.sh Thanks to our one-step perception paradigm, the inference process runs much faster. (Around 0.4s for each image on an A800 GPU card.) +### Using torch.hub +``` +import torch +import cv2 +import numpy as np + +# Load the normal predictor model from torch hub +normal_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Normal", trust_repo=True) + +# Load the input image using OpenCV +image = cv2.imread(args.input, cv2.IMREAD_COLOR) +h, w = image.shape[:2] + +# Use the model to infer the normal map from the input image +with torch.inference_mode(): + normal = normal_predictor.infer_cv2(image)[0] # Output shape: (H, W, 3) + normal = (normal + 1) / 2 # Convert values to the range [0, 1] + +# Convert the normal map to a displayable format +normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0) +normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR) + +# Save the output normal map to a file +cv2.imwrite(args.output, normal) +``` + ## 📖 Recommanded Works - Marigold: Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation. [arXiv](https://github.com/prs-eth/marigold), [GitHub](https://github.com/prs-eth/marigold). diff --git a/hubconf.py b/hubconf.py index d9101ac..8826786 100644 --- a/hubconf.py +++ b/hubconf.py @@ -43,7 +43,7 @@ def resize_image(input_image, resolution): img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) return img -class Predictor: +class NormalPredictor: def __init__(self, model, device="cuda"): self.model = model self.device = device @@ -68,7 +68,7 @@ def infer_pil(self, image, image_resolution=768): pred_normal = cv2.cvtColor(pred_normal, cv2.COLOR_RGB2BGR) return pred_normal -def GenPercept(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"): +def GenPercept_Normal(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"): unet_ckpt_path = hf_hub_download(repo_id=repo_id, filename='unet_normal_v1/diffusion_pytorch_model.safetensors', local_dir=local_dir) vae_ckpt_path = hf_hub_download(repo_id=repo_id, filename='vae/diffusion_pytorch_model.safetensors', @@ -105,7 +105,7 @@ def GenPercept(local_dir: Optional[str] = None, device="cuda", repo_id = "guangk normal_predictor = GenPerceptPipeline(**genpercept_params_ckpt) normal_predictor = normal_predictor.to(device) - return Predictor(normal_predictor, device) + return NormalPredictor(normal_predictor, device) def _test_run(): import argparse @@ -117,7 +117,7 @@ def _test_run(): parser.add_argument("--pil", action="store_true", help="use PIL instead of OpenCV") args = parser.parse_args() - predictor = torch.hub.load(".", "GenPercept", local_dir=args.local_dir, + predictor = torch.hub.load(".", "GenPercept_Normal", local_dir=args.local_dir, source="local", trust_repo=True) if args.pil: