-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
038c2dc
commit 6fdf44e
Showing
25 changed files
with
4,450 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,74 @@ | ||
# OMAD: Object Model with Articulated Deformations | ||
|
||
This repository is the official implementation of the paper "OMAD: Object Model with Articulated | ||
Deformations for Pose Estimation and | ||
Retrieval" (BMVC2021). | ||
This repository is the official implementation of the paper | ||
[OMAD: Object Model with Articulated Deformations for Pose Estimation and Retrieval](https://sites.google.com/view/omad-bmvc//). This paper has been accepted to BMVC 2021. | ||
|
||
The code and the dataset will be released in a few days! | ||
![Overview](assets/OMAD.png) | ||
![Visulization](assets/qualitative%20results.png) | ||
|
||
## Datasets | ||
[ArtImage](https://drive.google.com/file/d/1Gp3muPrSY7BPrePhbO1M4U0DQVd_4OqV/view?usp=sharing) dataset contains the synthetic images generated from Unity along with the following annotations: | ||
|
||
- RGB image | ||
- depth map | ||
- part mask | ||
- part pose | ||
|
||
This dataset also contains **URDF** articulated object models of five categories from PartNet-Mobility, | ||
which is re-annotated by us to align the rest state in the same category. | ||
|
||
## Usage | ||
### Installation | ||
|
||
Environments: | ||
|
||
- Python >= 3.7 | ||
- CUDA >= 10.0 | ||
|
||
```bash | ||
git clone https://github.com/xiaoxiaoxh/OMAD.git | ||
cd OMAD | ||
``` | ||
|
||
Install the dependencies listed in ``requirements.txt`` | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Then, compile CUDA module - index_max: | ||
|
||
```bash | ||
cd models/index_max_ext | ||
python setup.py install | ||
``` | ||
|
||
Finally, download [ArtImage](https://drive.google.com/file/d/1Gp3muPrSY7BPrePhbO1M4U0DQVd_4OqV/view?usp=sharing) Dataset and put it in `OMAD/data` folder. | ||
|
||
Now you are ready to go! | ||
|
||
### Training of OMAD-PriorNet | ||
|
||
```bash | ||
python train_omad_priornet.py --num_kp 24 --work_dir work_dir/omad_priornet_laptop --category 1 --num_parts 2 --use_relative_coverage --symtype shape | ||
``` | ||
|
||
### Testing of OMAD-PriorNet | ||
|
||
```bash | ||
python test_omad_priornet.py --num_kp 24 --checkpoint model_current_laptop.pth --work_dir work_dir/omad_priornet_laptop --bs 16 --workers 0 --use_gpu --symtype shape --out --mode train | ||
|
||
python test_omad_priornet.py --num_kp 24 --checkpoint model_current_laptop.pth --work_dir work_dir/omad_priornet_laptop --bs 16 --workers 0 --use_gpu --symtype shape --out --mode val | ||
``` | ||
|
||
### Training of OMADNet | ||
|
||
```bash | ||
python train_omad_net.py --num_kp 24 --work_dir work_dir/omad-net_laptop --params_dir work_dir/omad_priornet_laptop --num_basis 10 --symtype shape | ||
``` | ||
|
||
### Testing of OMADNet | ||
|
||
```bash | ||
python test_omad_net.py --num_kp 24 --checkpoint model_current_laptop.pth --work_dir work_dir/omad-net_laptop --params_dir work_dir/omad_priornet_laptop --category 1 --num_basis 10 --num_parts 2 --symtype shape --kp_thr 0.1 --reg_weight 0 --out raw_results.pkl --num_process 8 --use_gpu --data_postfix final_test --shuffle | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class PartChamferLoss_Brute(nn.Module): | ||
def __init__(self, num_parts): | ||
super(PartChamferLoss_Brute, self).__init__() | ||
self.dimension = 3 | ||
self.num_parts = num_parts | ||
|
||
def forward(self, part_pc_src_input, part_pc_dst_input): | ||
''' | ||
:param part_pc_src_input: BxKX3x(M/K) Tensor in GPU | ||
:param part_pc_dst_input: BxKX3x(N/k) Tensor in GPU | ||
:return: | ||
''' | ||
chamfer_pure_list = [] | ||
for part_idx in range(self.num_parts): | ||
pc_src_input = part_pc_src_input[:, part_idx, :, :] | ||
pc_dst_input = part_pc_dst_input[:, part_idx, :, :] | ||
B, M = pc_src_input.size()[0], pc_src_input.size()[2] | ||
N = pc_dst_input.size()[2] | ||
|
||
pc_src_input_expanded = pc_src_input.unsqueeze(3).expand(B, 3, M, N) | ||
pc_dst_input_expanded = pc_dst_input.unsqueeze(2).expand(B, 3, M, N) | ||
|
||
# the gradient of norm is set to 0 at zero-input. There is no need to use custom norm anymore. | ||
diff = torch.norm(pc_src_input_expanded - pc_dst_input_expanded, dim=1, keepdim=False) # BxMxN | ||
|
||
# pc_src vs selected pc_dst, M | ||
src_dst_min_dist, _ = torch.min(diff, dim=2, keepdim=False) # BxM | ||
forward_loss = src_dst_min_dist.mean() | ||
|
||
# pc_dst vs selected pc_src, N | ||
dst_src_min_dist, _ = torch.min(diff, dim=1, keepdim=False) # BxN | ||
backward_loss = dst_src_min_dist.mean() | ||
|
||
chamfer_pure = forward_loss + backward_loss | ||
chamfer_pure_list.append(chamfer_pure) | ||
|
||
chamfer_all = torch.mean(torch.stack(chamfer_pure_list)) | ||
return chamfer_all | ||
|
||
|
||
def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'): | ||
assert reduction in ('mean', 'none') | ||
assert beta > 0 | ||
assert pred.size() == target.size() | ||
if target.numel() == 0: | ||
return pred * 0. | ||
diff = torch.abs(pred - target) | ||
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, | ||
diff - 0.5 * beta) | ||
if reduction == 'mean': | ||
return loss.mean() | ||
elif reduction == 'none': | ||
return loss | ||
|
||
|
||
class CoverageLoss(nn.Module): | ||
"""Directly calculate loss based on max and min, not scale""" | ||
def __init__(self, beta=0.1, use_relative_coverage=False): | ||
super(CoverageLoss, self).__init__() | ||
self.beta = beta | ||
self.use_relative_coverage = use_relative_coverage | ||
|
||
def forward(self, kp, pc): | ||
if self.use_relative_coverage: | ||
# volume | ||
val_max_pc, _ = torch.max(pc, 2) | ||
val_min_pc, _ = torch.min(pc, 2) | ||
|
||
val_max_kp, _ = torch.max(kp, 2) | ||
val_min_kp, _ = torch.min(kp, 2) | ||
|
||
scale_pc = val_max_pc - val_min_pc | ||
scale_kp = val_max_kp - val_min_kp | ||
|
||
cov_loss = (smooth_l1_loss(val_max_kp / scale_pc, val_max_pc / scale_pc, beta=self.beta) + | ||
smooth_l1_loss(val_min_kp / scale_pc, val_min_pc / scale_pc, beta=self.beta) + | ||
smooth_l1_loss(torch.log(scale_kp), torch.log(scale_pc), beta=self.beta) | ||
) / 3 | ||
else: | ||
# volume | ||
val_max_pc, _ = torch.max(pc, 2) | ||
val_min_pc, _ = torch.min(pc, 2) | ||
|
||
val_max_kp, _ = torch.max(kp, 2) | ||
val_min_kp, _ = torch.min(kp, 2) | ||
|
||
cov_loss = (smooth_l1_loss(val_max_kp, val_max_pc, beta=self.beta) + | ||
smooth_l1_loss(val_min_kp, val_min_pc, beta=self.beta))/2 | ||
return cov_loss | ||
|
||
|
||
class SurfaceLoss(nn.Module): | ||
def __init__(self): | ||
super(SurfaceLoss, self).__init__() | ||
self.single_side_chamfer = SingleSideChamferLoss_Brute() | ||
|
||
def forward(self, keypoint, pc): | ||
loss = self.single_side_chamfer(keypoint, pc) | ||
|
||
return torch.mean(loss) | ||
|
||
|
||
class SingleSideChamferLoss_Brute(nn.Module): | ||
def __init__(self): | ||
super(SingleSideChamferLoss_Brute, self).__init__() | ||
|
||
def forward(self, pc_src_input, pc_dst_input): | ||
''' | ||
:param pc_src_input: Bx3xM Variable in GPU | ||
:param pc_dst_input: Bx3xN Variable in GPU | ||
:return: | ||
''' | ||
|
||
B, M = pc_src_input.size()[0], pc_src_input.size()[2] | ||
N = pc_dst_input.size()[2] | ||
|
||
pc_src_input_expanded = pc_src_input.unsqueeze(3).expand(B, 3, M, N) | ||
pc_dst_input_expanded = pc_dst_input.unsqueeze(2).expand(B, 3, M, N) | ||
|
||
diff = torch.norm(pc_src_input_expanded - pc_dst_input_expanded, dim=1, keepdim=False) # BxMxN | ||
|
||
# pc_src vs selected pc_dst, M | ||
src_dst_min_dist, _ = torch.min(diff, dim=2, keepdim=False) # BxM | ||
|
||
return src_dst_min_dist | ||
|
||
|
||
class JointLoss(nn.Module): | ||
def __init__(self, loc_weight=1.0, axis_weight=0.5): | ||
super(JointLoss, self).__init__() | ||
self.loc_weight = loc_weight | ||
self.axis_weight = axis_weight | ||
|
||
def forward(self, pred_joint_loc, pred_joint_axis, gt_joint_loc, gt_joint_axis): | ||
loss_axis = 1 - F.cosine_similarity(pred_joint_axis, gt_joint_axis, dim=-1).mean(dim=-1).mean(dim=-1) | ||
|
||
norm_gt_joint_axis = gt_joint_axis / torch.norm(gt_joint_axis, dim=-1, keepdim=True) | ||
p = gt_joint_loc | ||
q = gt_joint_loc + norm_gt_joint_axis | ||
r = pred_joint_loc | ||
x = p - q | ||
loss_loc = torch.norm(((r-q)*x).sum(-1, keepdim=True)/(x*x).sum(-1, keepdim=True)*(p-q)+(q-r), dim=-1).mean(-1).mean(-1) | ||
|
||
loss_joint = loss_loc * self.loc_weight + loss_axis * self.axis_weight | ||
return loss_joint | ||
|
||
|
||
class Loss_OMAD_PriorNet(torch.nn.Module): | ||
def __init__(self, | ||
device, | ||
num_kp, num_parts, num_cate, | ||
loss_chamfer_weight=1.0, | ||
loss_coverage_weight=1.0, | ||
loss_surface_weight=5.0, | ||
loss_joint_weight=1.0, | ||
loss_reg_weight=0.01, | ||
loss_sep_weight=2.0, | ||
sep_factor=8, | ||
beta=0.1, | ||
joint_type='revolute', | ||
use_relative_coverage=False, | ||
): | ||
super(Loss_OMAD_PriorNet, self).__init__() | ||
self.num_kp = num_kp | ||
self.num_cate = num_cate | ||
self.num_parts = num_parts | ||
self.num_joints = num_parts - 1 | ||
assert self.num_kp % self.num_parts == 0 | ||
self.num_kp_per_part = self.num_kp // self.num_parts | ||
self.joint_type = joint_type | ||
self.device = device | ||
assert joint_type in ('revolute', 'prismatic') | ||
|
||
self.loss_chamfer_weight = loss_chamfer_weight | ||
self.loss_coverage_weight = loss_coverage_weight | ||
self.loss_surface_weight = loss_surface_weight | ||
self.loss_joint_weight = loss_joint_weight | ||
self.loss_reg_weight = loss_reg_weight | ||
self.loss_sep_weight = loss_sep_weight | ||
self.sep_factor = sep_factor | ||
self.use_relative_coverage = use_relative_coverage | ||
|
||
self.part_chamfer_criteria = PartChamferLoss_Brute(num_parts=num_parts) | ||
self.surface_criteria = SurfaceLoss() | ||
self.coverage_criteria = CoverageLoss(beta=beta, use_relative_coverage=use_relative_coverage) | ||
self.joint_criteria = JointLoss(loc_weight=0. if self.joint_type == 'prismatic' else 1.0) | ||
|
||
self.zeros = torch.tensor( | ||
[0.0 for _ in range(self.num_kp_per_part - 1) for _ in range(self.num_kp_per_part)]).to( | ||
device) | ||
self.select1 = torch.tensor( | ||
[i for _ in range(self.num_kp_per_part - 1) for i in range(self.num_kp_per_part)]).to( | ||
device) | ||
self.select2 = torch.tensor([(i % self.num_kp_per_part) for j in range(1, self.num_kp_per_part) | ||
for i in range(j, j + self.num_kp_per_part)]).to(device) | ||
|
||
def forward(self, coefs, part_pred_kps, part_nodes, cloud, cloud_cls, | ||
pred_joint_loc, pred_joint_axis, gt_joint_loc, gt_joint_axis, gt_part_scale): | ||
loss_dict = dict() | ||
bs = part_pred_kps.size(0) | ||
|
||
part_pred_kps_trans = part_pred_kps.transpose(2, 3) # (B, K, 3, M/K) | ||
part_nodes_trans = part_nodes.transpose(2, 3) # (B, K, 3, M/K) | ||
nodes_trans = part_nodes.reshape(bs, -1, 3).transpose(1, 2) # (B, 3, M) | ||
cloud_trans = cloud.transpose(1, 2) # (B, 3, N) | ||
chf_loss = self.part_chamfer_criteria(part_pred_kps_trans, part_nodes_trans) | ||
|
||
cov_loss_list = [] | ||
sep_loss_list = [] | ||
for part_idx in range(self.num_parts): | ||
for bs_idx in range(bs): | ||
singe_part_nodes_trans = part_nodes_trans[bs_idx, part_idx, :, :].unsqueeze(0) # (1, 3, m') | ||
cloud_inds = (cloud_cls[bs_idx] == part_idx) # (N) | ||
single_part_cloud_trans = cloud_trans[bs_idx, :, cloud_inds].unsqueeze(0) # (1, 3, n') | ||
part_cov_loss = self.coverage_criteria(singe_part_nodes_trans, single_part_cloud_trans) | ||
cov_loss_list.append(part_cov_loss) | ||
|
||
"""Separation Loss""" | ||
max_rad = torch.norm(gt_part_scale[:, part_idx, :], dim=-1) | ||
max_dist = max_rad / self.sep_factor | ||
max_thr = max_dist.unsqueeze(1) # (bs, 1) | ||
|
||
pred_kp_select1 = torch.index_select(part_nodes[:, part_idx, :, :], 1, self.select1).contiguous() | ||
pred_kp_select2 = torch.index_select(part_nodes[:, part_idx, :, :], 1, self.select2).contiguous() | ||
dist_sep = torch.norm((pred_kp_select1 - pred_kp_select2), dim=2) | ||
loss_sep = torch.max(self.zeros.reshape(1, -1).expand(dist_sep.shape), max_thr - dist_sep) | ||
loss_sep = torch.mean(loss_sep) | ||
|
||
sep_loss_list.append(loss_sep) | ||
|
||
cov_loss = torch.mean(torch.stack(cov_loss_list)) | ||
sep_loss = torch.mean(torch.stack(sep_loss_list)) | ||
surf_loss = self.surface_criteria(nodes_trans, cloud_trans) | ||
joint_loss = self.joint_criteria(pred_joint_loc, pred_joint_axis, gt_joint_loc, gt_joint_axis) | ||
reg_loss = (coefs * coefs).mean() | ||
|
||
"""SUM UP""" | ||
loss = chf_loss * self.loss_chamfer_weight + cov_loss * self.loss_coverage_weight + \ | ||
surf_loss * self.loss_surface_weight + joint_loss * self.loss_joint_weight + \ | ||
reg_loss * self.loss_reg_weight + sep_loss * self.loss_sep_weight | ||
score = loss.item() | ||
loss_dict['loss_chamfer'] = chf_loss.item() | ||
loss_dict['loss_coverage'] = cov_loss.item() | ||
loss_dict['loss_surface'] = surf_loss.item() | ||
loss_dict['loss_joint'] = joint_loss.item() | ||
loss_dict['loss_reg'] = reg_loss.item() | ||
loss_dict['loss_sep'] = sep_loss.item() | ||
|
||
return loss, score, loss_dict |
Oops, something went wrong.