Skip to content

Commit

Permalink
refactor: Update READEME and hubconf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoycj committed Aug 9, 2024
1 parent a5680e9 commit d5a9c12
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 19 deletions.
54 changes: 44 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ bash scripts/inference_dis.sh
Thanks to our one-step perception paradigm, the inference process runs much faster. (Around 0.4s for each image on an A800 GPU card.)

### Using torch.hub
```
GenPercept models can be easily used with torch.hub for quick integration into your Python projects. Here's how to use the models for normal estimation, depth estimation, and segmentation:
#### Normal Estimation
```python
import torch
import cv2
import numpy as np
Expand All @@ -69,20 +71,52 @@ import numpy as np
normal_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Normal", trust_repo=True)

# Load the input image using OpenCV
image = cv2.imread(args.input, cv2.IMREAD_COLOR)
h, w = image.shape[:2]
image = cv2.imread("path/to/your/image.jpg", cv2.IMREAD_COLOR)

# Use the model to infer the normal map from the input image
with torch.inference_mode():
normal = normal_predictor.infer_cv2(image)[0] # Output shape: (H, W, 3)
normal = (normal + 1) / 2 # Convert values to the range [0, 1]
# Convert the normal map to a displayable format
normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
normal = normal_predictor.infer_cv2(image)

# Save the output normal map to a file
cv2.imwrite(args.output, normal)
cv2.imwrite("output_normal_map.png", normal)
```

#### Depth Estimation
```python
import torch
import cv2

# Load the depth predictor model from torch hub
depth_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Depth", trust_repo=True)

# Load the input image using OpenCV
image = cv2.imread("path/to/your/image.jpg", cv2.IMREAD_COLOR)

# Use the model to infer the depth map from the input image
with torch.inference_mode():
depth = depth_predictor.infer_cv2(image)

# Save the output depth map to a file
cv2.imwrite("output_depth_map.png", depth)
```

#### Segmentation
```python
import torch
import cv2

# Load the segmentation predictor model from torch hub
seg_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Segmentation", trust_repo=True)

# Load the input image using OpenCV
image = cv2.imread("path/to/your/image.jpg", cv2.IMREAD_COLOR)

# Use the model to infer the segmentation map from the input image
with torch.inference_mode():
segmentation = seg_predictor.infer_cv2(image)

# Save the output segmentation map to a file
cv2.imwrite("output_segmentation_map.png", segmentation)
```

## 📖 Recommanded Works
Expand Down
63 changes: 54 additions & 9 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def resize_image(input_image, resolution):
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img

class NormalPredictor:
class BasePredictor:
def __init__(self, model, device="cuda"):
self.model = model
self.device = device
Expand All @@ -54,6 +53,10 @@ def infer_cv2(self, image, image_resolution=768):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return self.infer_pil(Image.fromarray(img))

def infer_pil(self, image, image_resolution=768):
raise NotImplementedError("Subclasses must implement this method")

class NormalPredictor(BasePredictor):
def infer_pil(self, image, image_resolution=768):
with torch.no_grad():
pipe_out = self.model(image,
Expand All @@ -68,14 +71,44 @@ def infer_pil(self, image, image_resolution=768):
pred_normal = cv2.cvtColor(pred_normal, cv2.COLOR_RGB2BGR)
return pred_normal

def GenPercept_Normal(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"):
unet_ckpt_path = hf_hub_download(repo_id=repo_id, filename='unet_normal_v1/diffusion_pytorch_model.safetensors',
class DepthPredictor(BasePredictor):
def infer_pil(self, image, image_resolution=768):
with torch.no_grad():
pipe_out = self.model(image,
processing_res=image_resolution,
match_input_res=True,
batch_size=1,
color_map="Spectral",
show_progress_bar=True,
mode='depth',
)
pred_depth = np.asarray(pipe_out.pred_colored)
pred_depth = cv2.cvtColor(pred_depth, cv2.COLOR_RGB2BGR)
return pred_depth

class SegmentationPredictor(BasePredictor):
def infer_pil(self, image, image_resolution=768):
with torch.no_grad():
pipe_out = self.model(image,
processing_res=image_resolution,
match_input_res=True,
batch_size=1,
color_map="Spectral",
show_progress_bar=True,
mode='seg',
)
pred_seg = np.asarray(pipe_out.pred_colored)
pred_seg = cv2.cvtColor(pred_seg, cv2.COLOR_RGB2BGR)
return pred_seg

def load_model(repo_id, unet_subfolder, device="cuda", local_dir=None):
unet_ckpt_path = hf_hub_download(repo_id=repo_id, filename=f'{unet_subfolder}/diffusion_pytorch_model.safetensors',
local_dir=local_dir)
vae_ckpt_path = hf_hub_download(repo_id=repo_id, filename='vae/diffusion_pytorch_model.safetensors',
local_dir=local_dir)

# Load UNet
unet = CustomUNet2DConditionModel.from_config(repo_id, subfolder="unet_normal_v1")
unet = CustomUNet2DConditionModel.from_config(repo_id, subfolder=unet_subfolder)
load_ckpt_unet = safetensors.torch.load_file(unet_ckpt_path)
if not any('conv_out' in key for key in load_ckpt_unet.keys()):
unet.conv_out = None
Expand All @@ -102,10 +135,22 @@ def GenPercept_Normal(local_dir: Optional[str] = None, device="cuda", repo_id =
customized_head=None,
)

normal_predictor = GenPerceptPipeline(**genpercept_params_ckpt)
normal_predictor = normal_predictor.to(device)
predictor = GenPerceptPipeline(**genpercept_params_ckpt)
predictor = predictor.to(device)

return predictor

def GenPercept_Normal(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"):
model = load_model(repo_id, "unet_normal_v1", device, local_dir)
return NormalPredictor(model, device)

def GenPercept_Depth(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"):
model = load_model(repo_id, "unet_depth_v1", device, local_dir)
return DepthPredictor(model, device)

return NormalPredictor(normal_predictor, device)
def GenPercept_Segmentation(local_dir: Optional[str] = None, device="cuda", repo_id = "guangkaixu/GenPercept"):
model = load_model(repo_id, "unet_dis_v1", device, local_dir)
return SegmentationPredictor(model, device)

def _test_run():
import argparse
Expand All @@ -117,7 +162,7 @@ def _test_run():
parser.add_argument("--pil", action="store_true", help="use PIL instead of OpenCV")
args = parser.parse_args()

predictor = torch.hub.load(".", "GenPercept_Normal", local_dir=args.local_dir,
predictor = torch.hub.load(".", "GenPercept_Segmentation", local_dir=args.local_dir,
source="local", trust_repo=True)

if args.pil:
Expand Down

0 comments on commit d5a9c12

Please sign in to comment.