From 87b8a9c7651daa9830c6e66c7d012b55730e987c Mon Sep 17 00:00:00 2001 From: Yue Pan Date: Thu, 19 Dec 2024 11:29:28 +0100 Subject: [PATCH] [MINOR] more documents --- README.md | 2 -- utils/loss.py | 6 ++--- utils/mesher.py | 38 ++++++++++++++++++++++++++---- utils/tools.py | 60 ++++++++++++++++++++++++++++++++++++++---------- utils/tracker.py | 30 +++++++++++++++++++++--- 5 files changed, 111 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index a7f1b05..51914d7 100644 --- a/README.md +++ b/README.md @@ -352,10 +352,8 @@ After building the container, configure the storage path in `start_docker.sh` an ``` sudo chmod +x ./start_docker.sh ./start_docker.sh - ``` - ## Visualizer Instructions We provide a PIN-SLAM visualizer based on [lidar-visualizer](https://github.com/PRBonn/lidar-visualizer) to monitor the SLAM process. You can use `-v` flag to turn on it. diff --git a/utils/loss.py b/utils/loss.py index 78faa16..99ff145 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -100,7 +100,7 @@ def smooth_sdf_loss(pred, label, delta=20.0, weight=None, weighted=False): final_loss = ((2.0 / delta) * final_loss * weight).mean() return final_loss - +# deprecated def ray_estimation_loss(x, y, d_meas): # for each ray # x as depth # y as sdf prediction @@ -120,7 +120,7 @@ def ray_estimation_loss(x, y, d_meas): # for each ray return d_error - +# deprecated def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch] # x as depth # y as occ.prob. prediction @@ -140,7 +140,7 @@ def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch] return d_error - +# deprecated def batch_ray_rendering_loss(x, y, d_meas, neus_on=True): # for all rays in a batch # x as depth [ray number * sample number] # y as prediction (the alpha in volume rendering) [ray number * sample number] diff --git a/utils/mesher.py b/utils/mesher.py index fabe907..66da0c8 100644 --- a/utils/mesher.py +++ b/utils/mesher.py @@ -213,7 +213,9 @@ def get_query_from_bbx(self, bbx, voxel_size, pad_voxel=0, skip_top_voxel=0): return coord, voxel_num_xyz, voxel_origin def get_query_from_hor_slice(self, bbx, slice_z, voxel_size): - """get grid query points inside a given bounding box (bbx) at slice height (slice_z)""" + """ + get grid query points inside a given bounding box (bbx) at slice height (slice_z) + """ # bbx and voxel_size are all in the world coordinate system min_bound = bbx.get_min_bound() max_bound = bbx.get_max_bound() @@ -246,7 +248,9 @@ def get_query_from_hor_slice(self, bbx, slice_z, voxel_size): return coord, voxel_num_xyz, voxel_origin def get_query_from_ver_slice(self, bbx, slice_x, voxel_size): - """get grid query points inside a given bounding box (bbx) at slice position (slice_x)""" + """ + get grid query points inside a given bounding box (bbx) at slice position (slice_x) + """ # bbx and voxel_size are all in the world coordinate system min_bound = bbx.get_min_bound() max_bound = bbx.get_max_bound() @@ -279,6 +283,9 @@ def get_query_from_ver_slice(self, bbx, slice_x, voxel_size): return coord, voxel_num_xyz, voxel_origin def generate_sdf_map(self, coord, sdf_pred, mc_mask): + """ + Generate the SDF map for saving + """ device = o3d.core.Device("CPU:0") dtype = o3d.core.float32 sdf_map_pc = o3d.t.geometry.PointCloud(device) @@ -305,7 +312,9 @@ def generate_sdf_map(self, coord, sdf_pred, mc_mask): def generate_sdf_map_for_vis( self, coord, sdf_pred, mc_mask, min_sdf=-1.0, max_sdf=1.0, cmap="bwr" ): # 'jet','bwr','viridis' - + """ + Generate the SDF map for visualization + """ # do the masking or not if mc_mask is not None: coord = coord[mc_mask > 0] @@ -392,6 +401,9 @@ def mc_mesh(self, mc_sdf, mc_mask, voxel_size, mc_origin): return verts, faces def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True): + """ + Predict the semantic label of the vertices + """ if len(verts) == 0: return mesh @@ -413,6 +425,9 @@ def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True): return mesh def estimate_vertices_color(self, mesh, verts): + """ + Predict the color of the vertices + """ if len(verts) == 0: return mesh @@ -430,7 +445,9 @@ def estimate_vertices_color(self, mesh, verts): return mesh def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300): - # print("Cluster connected triangles") + """ + Cluster connected triangles and remove the small clusters + """ triangle_clusters, cluster_n_triangles, _ = mesh.cluster_connected_triangles() triangle_clusters = np.asarray(triangle_clusters) cluster_n_triangles = np.asarray(cluster_n_triangles) @@ -445,6 +462,9 @@ def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300): def generate_bbx_sdf_hor_slice( self, bbx, slice_z, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0 ): + """ + Generate the SDF slice at height (slice_z) + """ # print("Generate the SDF slice at heright %.2f (m)" % (slice_z)) coord, _, _ = self.get_query_from_hor_slice(bbx, slice_z, voxel_size) sdf_pred, _, _, mc_mask = self.query_points( @@ -466,6 +486,9 @@ def generate_bbx_sdf_hor_slice( def generate_bbx_sdf_ver_slice( self, bbx, slice_x, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0 ): + """ + Generate the SDF slice at x position (slice_x) + """ # print("Generate the SDF slice at x position %.2f (m)" % (slice_x)) coord, _, _ = self.get_query_from_ver_slice(bbx, slice_x, voxel_size) sdf_pred, _, _, mc_mask = self.query_points( @@ -499,6 +522,9 @@ def recon_aabb_collections_mesh( mesh_min_nn=10, use_torch_mc=False, ): + """ + Reconstruct the mesh from a collection of bounding boxes + """ if not self.silence: print("# Chunk for meshing: ", len(aabbs)) @@ -545,7 +571,9 @@ def recon_aabb_mesh( mesh_min_nn=10, use_torch_mc=False, ): - + """ + Reconstruct the mesh from a given bounding box + """ # reconstruct and save the (semantic) mesh from the feature octree the decoders within a # given bounding box. bbx and voxel_size all with unit m, in world coordinate system coord, voxel_num_xyz, voxel_origin = self.get_query_from_bbx( diff --git a/utils/tools.py b/utils/tools.py index 3468b60..b970cab 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -227,8 +227,10 @@ def step_lr_decay( return learning_rate -# calculate the analytical gradient by pytorch auto diff def get_gradient(inputs, outputs): + """ + Calculate the analytical gradient by pytorch auto diff + """ d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) points_grad = grad( outputs=outputs, @@ -387,8 +389,10 @@ def create_axis_aligned_bounding_box(center, size): def apply_quaternion_rotation(quat: torch.tensor, points: torch.tensor) -> torch.tensor: - # apply passive rotation: coordinate system rotation w.r.t. the points - # p' = qpq^-1 + """ + Apply passive rotation: coordinate system rotation w.r.t. the points + p' = qpq^-1 + """ quat_w = quat[..., 0].unsqueeze(-1) quat_xyz = -quat[..., 1:] t = 2 * torch.linalg.cross(quat_xyz, points) @@ -416,6 +420,11 @@ def rotmat_to_quat(rot_matrix: torch.tensor): def quat_to_rotmat(quaternions: torch.tensor): + """ + Convert a batch of quaternions to rotation matrices. + quaternions: N,4 + return N,3,3 + """ # Ensure quaternions are normalized quaternions /= torch.norm(quaternions, dim=1, keepdim=True) @@ -469,6 +478,9 @@ def quat_multiply(q1: torch.tensor, q2: torch.tensor): def torch2o3d(points_torch): + """ + Convert a batch of points from torch to o3d + """ pc_o3d = o3d.geometry.PointCloud() points_np = points_torch.cpu().detach().numpy().astype(np.float64) pc_o3d.points = o3d.utility.Vector3dVector(points_np) @@ -476,12 +488,22 @@ def torch2o3d(points_torch): def o3d2torch(o3d, device="cpu", dtype=torch.float32): + """ + Convert a batch of points from o3d to torch + """ return torch.tensor(np.asarray(o3d.points), dtype=dtype, device=device) def transform_torch(points: torch.tensor, transformation: torch.tensor): - # points [N, 3] - # transformation [4, 4] + """ + Transform a batch of points by a transformation matrix + Args: + points: N,3 torch tensor, the coordinates of all N (axbxc) query points in the scaled + kaolin coordinate system [-1,1] + transformation: 4,4 torch tensor, the transformation matrix + Returns: + transformed_points: N,3 torch tensor, the transformed coordinates + """ # Add a homogeneous coordinate to each point in the point cloud points_homo = torch.cat([points, torch.ones(points.shape[0], 1).to(points)], dim=1) @@ -495,9 +517,15 @@ def transform_torch(points: torch.tensor, transformation: torch.tensor): def transform_batch_torch(points: torch.tensor, transformation: torch.tensor): - # points [N, 3] - # transformation [N, 4, 4] - # N,3,3 @ N,3,1 -> N,3,1 + N,3,1 -> N,3,1 -> N,3 + """ + Transform a batch of points by a batch of transformation matrices + Args: + points: N,3 torch tensor, the coordinates of all N (axbxc) query points in the scaled + kaolin coordinate system [-1,1] + transformation: N,4,4 torch tensor, the transformation matrices + Returns: + transformed_points: N,3 torch tensor, the transformed coordinates + """ # Extract rotation and translation components rotation = transformation[:, :3, :3].to(points) @@ -609,7 +637,9 @@ def split_chunks( aabb: o3d.geometry.AxisAlignedBoundingBox(), chunk_m: float = 100.0 ): - + """ + Split a large point cloud into bounding box chunks + """ if not pc.has_points(): return None @@ -680,7 +710,9 @@ def split_chunks( def deskewing( points: torch.tensor, ts: torch.tensor, pose: torch.tensor, ts_mid_pose=0.5 ): - + """ + Deskew a batch of points at timestamp ts by a relative transformation matrix + """ if ts is None: return points # no deskewing @@ -711,7 +743,9 @@ def deskewing( def tranmat_close_to_identity(mats: np.ndarray, rot_thre: float, tran_thre: float): - + """ + Check if a batch of transformation matrices is close to identity + """ rot_diff = np.abs(mats[:3, :3] - np.identity(3)) rot_close_to_identity = np.all(rot_diff < rot_thre) @@ -781,7 +815,9 @@ def feature_pca_torch(data, principal_components = None, return data_pca, principal_components def plot_timing_detail(time_table: np.ndarray, saving_path: str, with_loop=False): - + """ + Plot the timing detail for processing per frame + """ frame_count = time_table.shape[0] time_table_ms = time_table * 1e3 diff --git a/utils/tracker.py b/utils/tracker.py index 32c6198..95650a0 100644 --- a/utils/tracker.py +++ b/utils/tracker.py @@ -54,7 +54,21 @@ def tracking( loop_reg: bool = False, vis_result: bool = False, ): - + """ + Perform tracking + Args: + source_points: N,3 torch tensor, the coordinates of all N query points + init_pose: 4,4 torch tensor, the initial pose + source_colors: N,3 torch tensor, the colors of all N query points + source_normals: N,3 torch tensor, the normals of all N query points + source_sdf: N torch tensor, the SDF values of all N query points + cur_ts: float, the timestamp of the current frame + loop_reg: bool, whether this is a registration for loop closure + vis_result: bool, whether to visualize the result + Returns: + T: 4,4 torch tensor, the final pose + cov_mat: 6,6 torch tensor, the covariance matrix + """ if init_pose is None: T = torch.eye(4, dtype=torch.float64, device=self.device) else: @@ -365,7 +379,9 @@ def registration_step( lm_lambda=0.0, vis_weight_pc=False, ): # if lm_lambda = 0, then it's Gaussian Newton Optimization - + """ + Perform one step of registration + """ T0 = get_time() colors_on = colors is not None and self.config.color_on @@ -757,6 +773,9 @@ def ct_registration_step( # math tools def skew(v): + """ + Compute the skew-symmetric matrix of a 3D vector + """ S = torch.zeros(3, 3, device=v.device, dtype=v.dtype) S[0, 1] = -v[2] S[0, 2] = v[1] @@ -765,7 +784,9 @@ def skew(v): def expmap(axis_angle: torch.Tensor): - + """ + Convert an axis-angle representation to a rotation matrix + """ angle = axis_angle.norm() axis = axis_angle / angle eye = torch.eye(3, device=axis_angle.device, dtype=axis_angle.dtype) @@ -777,6 +798,9 @@ def expmap(axis_angle: torch.Tensor): def rotation_matrix_to_axis_angle(R): + """ + Convert a rotation matrix to an axis-angle representation + """ # epsilon = 1e-8 # A small value to handle numerical precision issues # Ensure the input matrix is a valid rotation matrix assert torch.is_tensor(R) and R.shape == (3, 3), "Invalid rotation matrix"