Skip to content

Commit

Permalink
doc: Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoycj committed Aug 9, 2024
1 parent d41a19f commit a5680e9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pip install -e .
```

## 🚀 Inference

### Using Command-line Scripts
Download the pre-trained models ```genpercept_ckpt_v1.zip``` from [BaiduNetDisk](https://pan.baidu.com/s/1n6FlqrOTZqHX-F6OhcvNyA?pwd=g2cm) (Extract code: g2cm), [HuggingFace](https://huggingface.co/guangkaixu/GenPercept), or [Rec Cloud Disk (To be uploaded)](). Please unzip the package and put the checkpoints under ```./weights/v1/```.

Then, place images in the ```./input/$TASK_TYPE``` dictionary, and run the following script. The output depth will be saved in ```./output/$TASK_TYPE```. The ```$TASK_TYPE``` can be chosen from ```depth```, ```normal```, and ```dis```.
Expand All @@ -59,6 +59,32 @@ bash scripts/inference_dis.sh

Thanks to our one-step perception paradigm, the inference process runs much faster. (Around 0.4s for each image on an A800 GPU card.)

### Using torch.hub
```
import torch
import cv2
import numpy as np
# Load the normal predictor model from torch hub
normal_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Normal", trust_repo=True)
# Load the input image using OpenCV
image = cv2.imread(args.input, cv2.IMREAD_COLOR)
h, w = image.shape[:2]
# Use the model to infer the normal map from the input image
with torch.inference_mode():
normal = normal_predictor.infer_cv2(image)[0] # Output shape: (H, W, 3)
normal = (normal + 1) / 2 # Convert values to the range [0, 1]
# Convert the normal map to a displayable format
normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
# Save the output normal map to a file
cv2.imwrite(args.output, normal)
```

## 📖 Recommanded Works

- Marigold: Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation. [arXiv](https://github.com/prs-eth/marigold), [GitHub](https://github.com/prs-eth/marigold).
Expand Down
8 changes: 4 additions & 4 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def resize_image(input_image, resolution):
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img

class Predictor:
class NormalPredictor:
def __init__(self, model, device="cuda"):
self.model = model
self.device = device
Expand All @@ -68,7 +68,7 @@ def infer_pil(self, image, image_resolution=768):
pred_normal = cv2.cvtColor(pred_normal, cv2.COLOR_RGB2BGR)
return pred_normal

def GenPercept(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"):
def GenPercept_Normal(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"):
unet_ckpt_path = hf_hub_download(repo_id=repo_id, filename='unet_normal_v1/diffusion_pytorch_model.safetensors',
local_dir=local_dir)
vae_ckpt_path = hf_hub_download(repo_id=repo_id, filename='vae/diffusion_pytorch_model.safetensors',
Expand Down Expand Up @@ -105,7 +105,7 @@ def GenPercept(local_dir: Optional[str] = None, device="cuda", repo_id = "guangk
normal_predictor = GenPerceptPipeline(**genpercept_params_ckpt)
normal_predictor = normal_predictor.to(device)

return Predictor(normal_predictor, device)
return NormalPredictor(normal_predictor, device)

def _test_run():
import argparse
Expand All @@ -117,7 +117,7 @@ def _test_run():
parser.add_argument("--pil", action="store_true", help="use PIL instead of OpenCV")
args = parser.parse_args()

predictor = torch.hub.load(".", "GenPercept", local_dir=args.local_dir,
predictor = torch.hub.load(".", "GenPercept_Normal", local_dir=args.local_dir,
source="local", trust_repo=True)

if args.pil:
Expand Down

0 comments on commit a5680e9

Please sign in to comment.