Skip to content

Commit

Permalink
add mesh texture
Browse files Browse the repository at this point in the history
  • Loading branch information
niujinshuchong committed May 23, 2024
1 parent a01227e commit 98d0858
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions extract_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,27 @@
from utils.tetmesh import marching_tetrahedra

@torch.no_grad()
def evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size):
def evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size, return_color=False):
final_alpha = torch.ones((points.shape[0]), dtype=torch.float32, device="cuda")
if return_color:
final_color = torch.ones((points.shape[0], 3), dtype=torch.float32, device="cuda")

with torch.no_grad():
for _, view in enumerate(tqdm(views, desc="Rendering progress")):
ret = integrate(points, view, gaussians, pipeline, background, kernel_size=kernel_size)
alpha_integrated = ret["alpha_integrated"]
if return_color:
color_integrated = ret["color_integrated"]
final_color = torch.where((alpha_integrated < final_alpha).reshape(-1, 1), color_integrated, final_color)
final_alpha = torch.min(final_alpha, alpha_integrated)

alpha = 1 - final_alpha
if return_color:
return alpha, final_color
return alpha

@torch.no_grad()
def marching_tetrahedra_with_binary_search(model_path, name, iteration, views, gaussians, pipeline, background, kernel_size):
def marching_tetrahedra_with_binary_search(model_path, name, iteration, views, gaussians, pipeline, background, kernel_size, filter_mesh : bool, texture_mesh : bool):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "fusion")

makedirs(render_path, exist_ok=True)
Expand Down Expand Up @@ -95,13 +103,19 @@ def alpha_to_sdf(alpha):
if step not in [7]:
continue

mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, process=False)
if texture_mesh:
_, color = evaluage_alpha(points, views, gaussians, pipeline, background, kernel_size, return_color=True)
vertex_colors=(color.cpu().numpy() * 255).astype(np.uint8)
else:
vertex_colors=None
mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, vertex_colors=vertex_colors, process=False)

# filter
mask = (distance <= scale).cpu().numpy()
face_mask = mask[faces].all(axis=1)
mesh.update_vertices(mask)
mesh.update_faces(face_mask)
if filter_mesh:
mask = (distance <= scale).cpu().numpy()
face_mask = mask[faces].all(axis=1)
mesh.update_vertices(mask)
mesh.update_faces(face_mask)

mesh.export(os.path.join(render_path, f"mesh_binary_search_{step}.ply"))

Expand All @@ -112,7 +126,7 @@ def alpha_to_sdf(alpha):
# mesh.export(os.path.join(render_path, f"mesh_binary_search_interp.ply"))


def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelineParams):
def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelineParams, filter_mesh : bool, texture_mesh : bool):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
Expand All @@ -124,7 +138,7 @@ def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelinePara
kernel_size = dataset.kernel_size

cams = scene.getTrainCameras()
marching_tetrahedra_with_binary_search(dataset.model_path, "test", iteration, cams, gaussians, pipeline, background, kernel_size)
marching_tetrahedra_with_binary_search(dataset.model_path, "test", iteration, cams, gaussians, pipeline, background, kernel_size, filter_mesh, texture_mesh)

if __name__ == "__main__":
# Set up command line argument parser
Expand All @@ -133,6 +147,8 @@ def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelinePara
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=30000, type=int)
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--filter_mesh", action="store_true")
parser.add_argument("--texture_mesh", action="store_true")
args = get_combined_args(parser)
print("Rendering " + args.model_path)

Expand All @@ -141,4 +157,4 @@ def extract_mesh(dataset : ModelParams, iteration : int, pipeline : PipelinePara
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))

extract_mesh(model.extract(args), args.iteration, pipeline.extract(args))
extract_mesh(model.extract(args), args.iteration, pipeline.extract(args), args.filter_mesh, args.texture_mesh)

0 comments on commit 98d0858

Please sign in to comment.