-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_metric.py
78 lines (64 loc) · 3.06 KB
/
eval_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import pytorch3d
from pytorch3d.loss import chamfer_distance, point_mesh_face_distance
from pytorch3d.ops import knn_points
from models.util import *
def gradient_ascent_denoise(noisy_pc, model, patch_size=1000, denoise_knn=4, num_steps=30, ablation1=False, ablation3=False):
N = noisy_pc.size()[0] #(N,3)
num_patches = int(3 * N / patch_size)
patch_centers = farthest_point_sampling(noisy_pc, num_patches)
noisy_patches = knn_points(patch_centers.unsqueeze(dim=0), noisy_pc.unsqueeze(dim=0), K=patch_size, return_nn=True)[2][0] #(M,P,3)
with torch.no_grad():
model.eval()
model.feat_unit.eval()
model.score_unit.eval()
feat = model.feat_unit(noisy_patches)
iter_patches = noisy_patches.clone()
# trace = [noisy_patches.clone().cpu()]
if ablation3:
denoise_knn = 1
for i in range(num_steps):
r = knn_points(noisy_patches, iter_patches, K=denoise_knn, return_nn=True) #idx: (M,P,knn)
idx, nn = r[1], r[2]
x = (nn - noisy_patches.unsqueeze(dim=2).repeat(1,1,denoise_knn,1)).reshape(-1,denoise_knn,3)
z = feat.reshape(-1,feat.size()[-1])
score = model.score_unit(x, z).reshape(noisy_patches.size()[0],-1,3) #(M*P,knn,3) -> (M,P*knn,3)
gradients = torch.zeros_like(noisy_patches) #(M,P,3)
idx = idx.reshape(idx.size()[0],-1,1).expand_as(score)
gradients.scatter_add_(dim=1, index=idx, src=score)
if not ablation1:
iter_patches += 0.2 * (0.95 ** i) * gradients
else:
iter_patches += gradients
break
# trace.append(iter_patches.clone().cpu())
return farthest_point_sampling(iter_patches.reshape(-1, 3), N)
def compute_chamfer_distance(denoised_pc, clean_pc):
# Normalize
point_max = clean_pc.max(dim=-2, keepdim=True)[0]
point_min = clean_pc.min(dim=-2, keepdim=True)[0]
center = (point_max + point_min) / 2
clean_pc -= center
# Scale
scale = clean_pc.pow(2).sum(dim=-1, keepdim=True).sqrt().max(dim=-2, keepdim=True)[0] / 1.0
gt = clean_pc / scale
pred = (denoised_pc - center) / scale
return chamfer_distance(pred, gt, batch_reduction='mean', point_reduction='mean')[0].item()
def compute_point_to_mesh(denoised_pc, verts, faces):
# Normalize mesh
verts = verts.unsqueeze(0)
vertex_max = verts.max(dim=-2, keepdim=True)[0]
vertex_min = verts.min(dim=-2, keepdim=True)[0]
center = (vertex_max + vertex_min) / 2
verts -= center
# Scale
scale = verts.pow(2).sum(dim=-1, keepdim=True).sqrt().max(dim=-2, keepdim=True)[0] / 1.0
verts /= scale
verts = torch.squeeze(verts, dim=0)
# Normalize pc
denoised_pc.unsqueeze(0)
denoised_pc = (denoised_pc - center)/scale
denoised_pc = torch.squeeze(denoised_pc, dim=0)
pc = pytorch3d.structures.Pointclouds([denoised_pc])
mesh = pytorch3d.structures.Meshes([verts], [faces])
return point_mesh_face_distance(mesh, pc).item()