Skip to content

Commit

Permalink
[MINOR] more documents
Browse files Browse the repository at this point in the history
  • Loading branch information
Yue Pan committed Dec 19, 2024
1 parent 77be121 commit 87b8a9c
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 25 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,8 @@ After building the container, configure the storage path in `start_docker.sh` an
```
sudo chmod +x ./start_docker.sh
./start_docker.sh
```


## Visualizer Instructions

We provide a PIN-SLAM visualizer based on [lidar-visualizer](https://github.com/PRBonn/lidar-visualizer) to monitor the SLAM process. You can use `-v` flag to turn on it.
Expand Down
6 changes: 3 additions & 3 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def smooth_sdf_loss(pred, label, delta=20.0, weight=None, weighted=False):
final_loss = ((2.0 / delta) * final_loss * weight).mean()
return final_loss


# deprecated
def ray_estimation_loss(x, y, d_meas): # for each ray
# x as depth
# y as sdf prediction
Expand All @@ -120,7 +120,7 @@ def ray_estimation_loss(x, y, d_meas): # for each ray

return d_error


# deprecated
def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch]
# x as depth
# y as occ.prob. prediction
Expand All @@ -140,7 +140,7 @@ def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch]

return d_error


# deprecated
def batch_ray_rendering_loss(x, y, d_meas, neus_on=True): # for all rays in a batch
# x as depth [ray number * sample number]
# y as prediction (the alpha in volume rendering) [ray number * sample number]
Expand Down
38 changes: 33 additions & 5 deletions utils/mesher.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def get_query_from_bbx(self, bbx, voxel_size, pad_voxel=0, skip_top_voxel=0):
return coord, voxel_num_xyz, voxel_origin

def get_query_from_hor_slice(self, bbx, slice_z, voxel_size):
"""get grid query points inside a given bounding box (bbx) at slice height (slice_z)"""
"""
get grid query points inside a given bounding box (bbx) at slice height (slice_z)
"""
# bbx and voxel_size are all in the world coordinate system
min_bound = bbx.get_min_bound()
max_bound = bbx.get_max_bound()
Expand Down Expand Up @@ -246,7 +248,9 @@ def get_query_from_hor_slice(self, bbx, slice_z, voxel_size):
return coord, voxel_num_xyz, voxel_origin

def get_query_from_ver_slice(self, bbx, slice_x, voxel_size):
"""get grid query points inside a given bounding box (bbx) at slice position (slice_x)"""
"""
get grid query points inside a given bounding box (bbx) at slice position (slice_x)
"""
# bbx and voxel_size are all in the world coordinate system
min_bound = bbx.get_min_bound()
max_bound = bbx.get_max_bound()
Expand Down Expand Up @@ -279,6 +283,9 @@ def get_query_from_ver_slice(self, bbx, slice_x, voxel_size):
return coord, voxel_num_xyz, voxel_origin

def generate_sdf_map(self, coord, sdf_pred, mc_mask):
"""
Generate the SDF map for saving
"""
device = o3d.core.Device("CPU:0")
dtype = o3d.core.float32
sdf_map_pc = o3d.t.geometry.PointCloud(device)
Expand All @@ -305,7 +312,9 @@ def generate_sdf_map(self, coord, sdf_pred, mc_mask):
def generate_sdf_map_for_vis(
self, coord, sdf_pred, mc_mask, min_sdf=-1.0, max_sdf=1.0, cmap="bwr"
): # 'jet','bwr','viridis'

"""
Generate the SDF map for visualization
"""
# do the masking or not
if mc_mask is not None:
coord = coord[mc_mask > 0]
Expand Down Expand Up @@ -392,6 +401,9 @@ def mc_mesh(self, mc_sdf, mc_mask, voxel_size, mc_origin):
return verts, faces

def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True):
"""
Predict the semantic label of the vertices
"""
if len(verts) == 0:
return mesh

Expand All @@ -413,6 +425,9 @@ def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True):
return mesh

def estimate_vertices_color(self, mesh, verts):
"""
Predict the color of the vertices
"""
if len(verts) == 0:
return mesh

Expand All @@ -430,7 +445,9 @@ def estimate_vertices_color(self, mesh, verts):
return mesh

def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300):
# print("Cluster connected triangles")
"""
Cluster connected triangles and remove the small clusters
"""
triangle_clusters, cluster_n_triangles, _ = mesh.cluster_connected_triangles()
triangle_clusters = np.asarray(triangle_clusters)
cluster_n_triangles = np.asarray(cluster_n_triangles)
Expand All @@ -445,6 +462,9 @@ def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300):
def generate_bbx_sdf_hor_slice(
self, bbx, slice_z, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0
):
"""
Generate the SDF slice at height (slice_z)
"""
# print("Generate the SDF slice at heright %.2f (m)" % (slice_z))
coord, _, _ = self.get_query_from_hor_slice(bbx, slice_z, voxel_size)
sdf_pred, _, _, mc_mask = self.query_points(
Expand All @@ -466,6 +486,9 @@ def generate_bbx_sdf_hor_slice(
def generate_bbx_sdf_ver_slice(
self, bbx, slice_x, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0
):
"""
Generate the SDF slice at x position (slice_x)
"""
# print("Generate the SDF slice at x position %.2f (m)" % (slice_x))
coord, _, _ = self.get_query_from_ver_slice(bbx, slice_x, voxel_size)
sdf_pred, _, _, mc_mask = self.query_points(
Expand Down Expand Up @@ -499,6 +522,9 @@ def recon_aabb_collections_mesh(
mesh_min_nn=10,
use_torch_mc=False,
):
"""
Reconstruct the mesh from a collection of bounding boxes
"""
if not self.silence:
print("# Chunk for meshing: ", len(aabbs))

Expand Down Expand Up @@ -545,7 +571,9 @@ def recon_aabb_mesh(
mesh_min_nn=10,
use_torch_mc=False,
):

"""
Reconstruct the mesh from a given bounding box
"""
# reconstruct and save the (semantic) mesh from the feature octree the decoders within a
# given bounding box. bbx and voxel_size all with unit m, in world coordinate system
coord, voxel_num_xyz, voxel_origin = self.get_query_from_bbx(
Expand Down
60 changes: 48 additions & 12 deletions utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,10 @@ def step_lr_decay(
return learning_rate


# calculate the analytical gradient by pytorch auto diff
def get_gradient(inputs, outputs):
"""
Calculate the analytical gradient by pytorch auto diff
"""
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
points_grad = grad(
outputs=outputs,
Expand Down Expand Up @@ -387,8 +389,10 @@ def create_axis_aligned_bounding_box(center, size):


def apply_quaternion_rotation(quat: torch.tensor, points: torch.tensor) -> torch.tensor:
# apply passive rotation: coordinate system rotation w.r.t. the points
# p' = qpq^-1
"""
Apply passive rotation: coordinate system rotation w.r.t. the points
p' = qpq^-1
"""
quat_w = quat[..., 0].unsqueeze(-1)
quat_xyz = -quat[..., 1:]
t = 2 * torch.linalg.cross(quat_xyz, points)
Expand Down Expand Up @@ -416,6 +420,11 @@ def rotmat_to_quat(rot_matrix: torch.tensor):


def quat_to_rotmat(quaternions: torch.tensor):
"""
Convert a batch of quaternions to rotation matrices.
quaternions: N,4
return N,3,3
"""
# Ensure quaternions are normalized
quaternions /= torch.norm(quaternions, dim=1, keepdim=True)

Expand Down Expand Up @@ -469,19 +478,32 @@ def quat_multiply(q1: torch.tensor, q2: torch.tensor):


def torch2o3d(points_torch):
"""
Convert a batch of points from torch to o3d
"""
pc_o3d = o3d.geometry.PointCloud()
points_np = points_torch.cpu().detach().numpy().astype(np.float64)
pc_o3d.points = o3d.utility.Vector3dVector(points_np)
return pc_o3d


def o3d2torch(o3d, device="cpu", dtype=torch.float32):
"""
Convert a batch of points from o3d to torch
"""
return torch.tensor(np.asarray(o3d.points), dtype=dtype, device=device)


def transform_torch(points: torch.tensor, transformation: torch.tensor):
# points [N, 3]
# transformation [4, 4]
"""
Transform a batch of points by a transformation matrix
Args:
points: N,3 torch tensor, the coordinates of all N (axbxc) query points in the scaled
kaolin coordinate system [-1,1]
transformation: 4,4 torch tensor, the transformation matrix
Returns:
transformed_points: N,3 torch tensor, the transformed coordinates
"""
# Add a homogeneous coordinate to each point in the point cloud
points_homo = torch.cat([points, torch.ones(points.shape[0], 1).to(points)], dim=1)

Expand All @@ -495,9 +517,15 @@ def transform_torch(points: torch.tensor, transformation: torch.tensor):


def transform_batch_torch(points: torch.tensor, transformation: torch.tensor):
# points [N, 3]
# transformation [N, 4, 4]
# N,3,3 @ N,3,1 -> N,3,1 + N,3,1 -> N,3,1 -> N,3
"""
Transform a batch of points by a batch of transformation matrices
Args:
points: N,3 torch tensor, the coordinates of all N (axbxc) query points in the scaled
kaolin coordinate system [-1,1]
transformation: N,4,4 torch tensor, the transformation matrices
Returns:
transformed_points: N,3 torch tensor, the transformed coordinates
"""

# Extract rotation and translation components
rotation = transformation[:, :3, :3].to(points)
Expand Down Expand Up @@ -609,7 +637,9 @@ def split_chunks(
aabb: o3d.geometry.AxisAlignedBoundingBox(),
chunk_m: float = 100.0
):

"""
Split a large point cloud into bounding box chunks
"""
if not pc.has_points():
return None

Expand Down Expand Up @@ -680,7 +710,9 @@ def split_chunks(
def deskewing(
points: torch.tensor, ts: torch.tensor, pose: torch.tensor, ts_mid_pose=0.5
):

"""
Deskew a batch of points at timestamp ts by a relative transformation matrix
"""
if ts is None:
return points # no deskewing

Expand Down Expand Up @@ -711,7 +743,9 @@ def deskewing(


def tranmat_close_to_identity(mats: np.ndarray, rot_thre: float, tran_thre: float):

"""
Check if a batch of transformation matrices is close to identity
"""
rot_diff = np.abs(mats[:3, :3] - np.identity(3))

rot_close_to_identity = np.all(rot_diff < rot_thre)
Expand Down Expand Up @@ -781,7 +815,9 @@ def feature_pca_torch(data, principal_components = None,
return data_pca, principal_components

def plot_timing_detail(time_table: np.ndarray, saving_path: str, with_loop=False):

"""
Plot the timing detail for processing per frame
"""
frame_count = time_table.shape[0]
time_table_ms = time_table * 1e3

Expand Down
30 changes: 27 additions & 3 deletions utils/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,21 @@ def tracking(
loop_reg: bool = False,
vis_result: bool = False,
):

"""
Perform tracking
Args:
source_points: N,3 torch tensor, the coordinates of all N query points
init_pose: 4,4 torch tensor, the initial pose
source_colors: N,3 torch tensor, the colors of all N query points
source_normals: N,3 torch tensor, the normals of all N query points
source_sdf: N torch tensor, the SDF values of all N query points
cur_ts: float, the timestamp of the current frame
loop_reg: bool, whether this is a registration for loop closure
vis_result: bool, whether to visualize the result
Returns:
T: 4,4 torch tensor, the final pose
cov_mat: 6,6 torch tensor, the covariance matrix
"""
if init_pose is None:
T = torch.eye(4, dtype=torch.float64, device=self.device)
else:
Expand Down Expand Up @@ -365,7 +379,9 @@ def registration_step(
lm_lambda=0.0,
vis_weight_pc=False,
): # if lm_lambda = 0, then it's Gaussian Newton Optimization

"""
Perform one step of registration
"""
T0 = get_time()

colors_on = colors is not None and self.config.color_on
Expand Down Expand Up @@ -757,6 +773,9 @@ def ct_registration_step(

# math tools
def skew(v):
"""
Compute the skew-symmetric matrix of a 3D vector
"""
S = torch.zeros(3, 3, device=v.device, dtype=v.dtype)
S[0, 1] = -v[2]
S[0, 2] = v[1]
Expand All @@ -765,7 +784,9 @@ def skew(v):


def expmap(axis_angle: torch.Tensor):

"""
Convert an axis-angle representation to a rotation matrix
"""
angle = axis_angle.norm()
axis = axis_angle / angle
eye = torch.eye(3, device=axis_angle.device, dtype=axis_angle.dtype)
Expand All @@ -777,6 +798,9 @@ def expmap(axis_angle: torch.Tensor):


def rotation_matrix_to_axis_angle(R):
"""
Convert a rotation matrix to an axis-angle representation
"""
# epsilon = 1e-8 # A small value to handle numerical precision issues
# Ensure the input matrix is a valid rotation matrix
assert torch.is_tensor(R) and R.shape == (3, 3), "Invalid rotation matrix"
Expand Down

0 comments on commit 87b8a9c

Please sign in to comment.