diff --git a/Dockerfile.aws b/Dockerfile.aws index 85618c89..3e12c588 100644 --- a/Dockerfile.aws +++ b/Dockerfile.aws @@ -5,5 +5,6 @@ RUN pip install --upgrade pip && \ pip install --no-cache-dir -U taichi==1.6.0 matplotlib numpy pytorch_msssim dataclass-wizard pillow pyyaml pandas[parquet]==2.0.0 scipy argparse tensorboard COPY . /opt/ml/code WORKDIR /opt/ml/code +RUN pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly RUN pip install -r requirements.txt RUN pip install -e . diff --git a/requirements.txt b/requirements.txt index 1dfad486..920d92a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -taichi>=1.6.0 matplotlib numpy pytorch_msssim @@ -8,4 +7,4 @@ pyyaml pandas[parquet]>=2.0.0 scipy argparse -tensorboard \ No newline at end of file +tensorboard diff --git a/scratch/playground.py b/scratch/playground.py index 75c6960c..cec3f6dc 100644 --- a/scratch/playground.py +++ b/scratch/playground.py @@ -1,4 +1,9 @@ # %% +from scipy.stats import multivariate_normal +import open3d as o3d +from sympy import symbols, Matrix, diff, exp +import matplotlib.pyplot as plt +import pandas as pd from Camera import CameraInfo from utils import get_ray_origin_and_direction_from_camera, get_ray_origin_and_direction_from_camera_by_gpt from utils import torch_single_point_alpha_forward @@ -88,7 +93,7 @@ homogeneous_translation_camera[1, 0], homogeneous_translation_camera[2, 0]]) D_translation_camrea_D_q = translation_camera.jacobian(q) -D_translation_camrea_D_t = translation_camera.jacobian(translation) +D_translation_camrea_D_t = translation_camera.jacobian(t) print(latex(D_translation_camrea_D_q)) pprint(D_translation_camrea_D_q, use_unicode=True) print(latex(D_translation_camrea_D_t)) @@ -303,7 +308,6 @@ def rotation_matrix_from_quaternion(q: ti.math.vec4) -> ti.math.mat3: print(sympy.python(J)) # %% -import sympy xy = sympy.MatrixSymbol('xy', 2, 1) mu = sympy.Matrix(["mu_x", "mu_y"]) cov = sympy.MatrixSymbol('cov', 2, 2) @@ -325,7 +329,6 @@ def rotation_matrix_from_quaternion(q: ti.math.vec4) -> ti.math.mat3: print(J.shape) print(sympy.python(J)) # %% -import numpy as np xy = np.array([1, 2]) x = xy[0] y = xy[1] @@ -377,7 +380,6 @@ def gradient_cov(x, mean, cov): print(gradient_mean(xy, mu, cov)) print(gradient_cov(xy, mu, cov)) # %% -import torch xy = torch.tensor([1., 2.]) mu = torch.tensor([3., 1.], requires_grad=True) cov = torch.tensor([[0.8, 0.1], [0.1, 0.8]], requires_grad=True) @@ -585,14 +587,11 @@ def quaternion_to_rotation_matrix_torch(q): T_pointcloud_camera=T_pointcloud_camera) # %% -import pandas as pd -import numpy as np path = "logs/sigmoid_on_image_fix_bug/scene_66000.parquet" df = pd.read_parquet(path) # %% df.head() # %% -import matplotlib.pyplot as plt plt.hist(np.exp(df.cov_s0), bins=100) # %% np.exp(df.cov_s0).argmax() @@ -608,11 +607,11 @@ def quaternion_to_rotation_matrix_torch(q): print(col, df[col].isnull().sum()) # %% -df = df.dropna() +df = df.dropna() # %% -import numpy as np -from sympy import symbols, Matrix, diff, exp + + def compute_derivatives(mu, Sigma, x): # 定义符号变量 mu1, mu2, x1, x2 = symbols('mu1 mu2 x1 x2') @@ -637,13 +636,14 @@ def compute_derivatives(mu, Sigma, x): """ # 用实际值替换符号变量 - subs = {mu1: mu[0], mu2: mu[1], x1: x[0], x2: x[1], + subs = {mu1: mu[0], mu2: mu[1], x1: x[0], x2: x[1], s11: Sigma[0, 0], s12: Sigma[0, 1], s21: Sigma[1, 0], s22: Sigma[1, 1]} dp_dmu_val = dp_dmu.subs(subs) dp_dSigma_val = dp_dSigma.subs(subs) return dp_dmu_val, dp_dSigma_val + # 测试函数 mu = np.array([1, 2]) Sigma = np.array([[100., 0], [0, 100]]) @@ -653,7 +653,7 @@ def compute_derivatives(mu, Sigma, x): print("dp/dmu:", dp_dmu) print("dp/dSigma:", dp_dSigma) # %% -import torch + def compute_derivatives_torch(mu, Sigma, x): # 将输入转换为PyTorch张量,并设置requires_grad=True以启用自动微分 @@ -671,6 +671,7 @@ def compute_derivatives_torch(mu, Sigma, x): return mu_torch.grad, Sigma_torch.grad + def my_compute(mu, Sigma, x): gaussian_mean = mu xy = x @@ -680,7 +681,7 @@ def my_compute(mu, Sigma, x): det_cov = Sigma[0, 0] * Sigma[1, 1] - Sigma[0, 1] * Sigma[1, 0] inv_cov = (1. / det_cov) * \ np.array([[gaussian_covariance[1, 1], -gaussian_covariance[0, 1]], - [-gaussian_covariance[1, 0], gaussian_covariance[0, 0]]]) + [-gaussian_covariance[1, 0], gaussian_covariance[0, 0]]]) cov_inv_xy_mean = inv_cov @ xy_mean xy_mean_T_cov_inv_xy_mean = xy_mean @ cov_inv_xy_mean exponent = -0.5 * xy_mean_T_cov_inv_xy_mean @@ -689,9 +690,10 @@ def my_compute(mu, Sigma, x): xy_mean_outer_xy_mean = np.array([[xy_mean[0] * xy_mean[0], xy_mean[0] * xy_mean[1]], [xy_mean[1] * xy_mean[0], xy_mean[1] * xy_mean[1]]]) d_p_d_cov = 0.5 * p * (inv_cov @ - xy_mean_outer_xy_mean @ inv_cov) + xy_mean_outer_xy_mean @ inv_cov) return d_p_d_mean, d_p_d_cov + # 测试函数 mu = np.array([1.0, 2.0]) Sigma = np.array([[1.0, 0.0], [0.0, 1.0]]) @@ -711,7 +713,8 @@ def my_compute(mu, Sigma, x): tmp = np.random.rand(2, 2) Sigma = tmp @ tmp.T dp_dmu, dp_dSigma = compute_derivatives(mu, Sigma, x) - dp_dmu, dp_dSigma = np.array(dp_dmu, dtype=np.float32), np.array(dp_dSigma, dtype=np.float32) + dp_dmu, dp_dSigma = np.array(dp_dmu, dtype=np.float32), np.array( + dp_dSigma, dtype=np.float32) dp_dmu = dp_dmu.reshape(-1) dp_dSigma = dp_dSigma.reshape(2, 2) dp_dmu_torch, dp_dSigma_torch = compute_derivatives_torch(mu, Sigma, x) @@ -723,17 +726,16 @@ def my_compute(mu, Sigma, x): print("dp/dSigma:", dp_dSigma) print("dp/dSigma (my):", dp_dSigma_my) print("dp/dSigma (torch):", dp_dSigma_torch) - + # assert np.allclose(dp_dmu, dp_dmu_torch.detach().numpy(), rtol=1e-3), f"dp_dmu: {dp_dmu}, dp_dmu_torch: {dp_dmu_torch}" # assert np.allclose(dp_dSigma, dp_dSigma_torch.detach().numpy(), rtol=1e-3), f"dp_dSigma: {dp_dSigma}, dp_dSigma_torch: {dp_dSigma_torch}" - assert np.allclose(dp_dmu, dp_dmu_my, rtol=1e-3), f"dp_dmu: {dp_dmu}, dp_dmu_my: {dp_dmu_my}" - assert np.allclose(dp_dSigma, dp_dSigma_my, rtol=1e-3), f"dp_dSigma: {dp_dSigma}, dp_dSigma_my: {dp_dSigma_my}" - - + assert np.allclose(dp_dmu, dp_dmu_my, + rtol=1e-3), f"dp_dmu: {dp_dmu}, dp_dmu_my: {dp_dmu_my}" + assert np.allclose(dp_dSigma, dp_dSigma_my, + rtol=1e-3), f"dp_dSigma: {dp_dSigma}, dp_dSigma_my: {dp_dSigma_my}" + + # %% -import pandas as pd -import numpy as np -import open3d as o3d parquet_path = "/home/kuangyuan/hdd/Development/taichi_3d_gaussian_splatting/logs/tat_truck_experiment_more_val/scene_13750.parquet" df = pd.read_parquet(parquet_path) # %% @@ -741,7 +743,7 @@ def my_compute(mu, Sigma, x): # %% point_cloud = df[["x", "y", "z"]].values point_cloud_rgb = df[["r_sh0", "g_sh0", "b_sh0"]].values -# here rgb are actually sh coefficients (-inf, inf), +# here rgb are actually sh coefficients (-inf, inf), # need to apply sigmoid to get (0, 1) rgb point_cloud_rgb = 1.0 / (1.0 + np.exp(-point_cloud_rgb)) # %% @@ -789,6 +791,7 @@ def rotation_matrix_from_quaternion(q: ti.math.vec4) -> ti.math.mat3: ]) """ + def rotation_matrix_from_quaternion(q: np.ndarray) -> np.ndarray: xx = q[0] * q[0] yy = q[1] * q[1] @@ -805,6 +808,7 @@ def rotation_matrix_from_quaternion(q: np.ndarray) -> np.ndarray: [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)] ]) + S = np.exp(s) rotated_S = np.zeros((len(q), 3)) @@ -814,8 +818,7 @@ def rotation_matrix_from_quaternion(q: np.ndarray) -> np.ndarray: normal[i] = rotation_matrix_from_quaternion(q[i]) @ base_vector[i] normal[i] *= np.linalg.norm(rotated_S[i]) - - + # %% point_cloud_o3d = o3d.geometry.PointCloud() point_cloud_o3d.points = o3d.utility.Vector3dVector(point_cloud[mask]) @@ -824,8 +827,9 @@ def rotation_matrix_from_quaternion(q: np.ndarray) -> np.ndarray: o3d.visualization.draw_geometries([point_cloud_o3d]) # %% -import taichi as ti ti.init(arch=ti.cpu) + + @ti.kernel def test(): Cov = ti.Matrix([ @@ -835,11 +839,10 @@ def test(): eig, V = ti.sym_eig(Cov) print(eig) print(V) + + test() # %% -import numpy as np -import matplotlib.pyplot as plt -from scipy.stats import multivariate_normal # Define the mean and covariance matrix mean = np.array([0, 0]) @@ -875,8 +878,10 @@ def test(): plt.imshow(mask, extent=(-50, 50, -50, 50), origin='lower') # plt eigenvectors -plt.quiver(mean[0], mean[1], eigen_vectors[0, 0], eigen_vectors[1, 0], color='r', scale=10 / np.sqrt(eigen_values[0])) -plt.quiver(mean[0], mean[1], eigen_vectors[0, 1], eigen_vectors[1, 1], color='r', scale=10 / np.sqrt(eigen_values[1])) +plt.quiver(mean[0], mean[1], eigen_vectors[0, 0], eigen_vectors[1, + 0], color='r', scale=10 / np.sqrt(eigen_values[0])) +plt.quiver(mean[0], mean[1], eigen_vectors[0, 1], eigen_vectors[1, + 1], color='r', scale=10 / np.sqrt(eigen_values[1])) plt.colorbar() plt.show() @@ -885,13 +890,14 @@ def test(): print(np.sqrt(eigen_values[1]) * 4) # %% -import pandas as pd "/home/kuangyuan/hdd/Development/other/taichi_3d_gaussian_splatting/logs/tat_truck_every_8_experiment/camera_poses_6000.parquet" -df = pd.read_parquet("/home/kuangyuan/hdd/Development/other/taichi_3d_gaussian_splatting/logs/tat_truck_every_8_with_pose_noise_optimization/camera_poses_10000.parquet") +df = pd.read_parquet( + "/home/kuangyuan/hdd/Development/other/taichi_3d_gaussian_splatting/logs/tat_truck_every_8_with_pose_noise_optimization/camera_poses_10000.parquet") # %% df.head() # %% -df1 = pd.read_parquet("/home/kuangyuan/hdd/Development/other/taichi_3d_gaussian_splatting/logs/tat_truck_every_8_baseline/camera_poses_30000.parquet") +df1 = pd.read_parquet( + "/home/kuangyuan/hdd/Development/other/taichi_3d_gaussian_splatting/logs/tat_truck_every_8_baseline/camera_poses_30000.parquet") # %% df1.head() diff --git a/taichi_3d_gaussian_splatting/GaussianPoint3D.py b/taichi_3d_gaussian_splatting/GaussianPoint3D.py index 20eae4c0..521a58ad 100644 --- a/taichi_3d_gaussian_splatting/GaussianPoint3D.py +++ b/taichi_3d_gaussian_splatting/GaussianPoint3D.py @@ -211,7 +211,14 @@ def project_to_camera_position_by_q_t_jacobian( d_uv_d_q = d_uv_d_translation_camera @ d_translation_camera_d_q # d_uv_d_t = d_uv_d_translation_camera @ d_translation_camera_d_t d_uv_d_t = d_uv_d_translation_camera - return d_uv_d_translation, d_uv_d_q, d_uv_d_t + d_depth_dq = ti.math.vec4([ + d_translation_camera_d_q[2, 0], + d_translation_camera_d_q[2, 1], + d_translation_camera_d_q[2, 2], + d_translation_camera_d_q[2, 3] + ]) + d_depth_dt = ti.math.vec3([0, 0, 1]) + return d_uv_d_translation, d_uv_d_q, d_uv_d_t, d_depth_dq, d_depth_dt @ti.func def project_to_camera_covariance( diff --git a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py index cf4fdd3d..d1332cf6 100644 --- a/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py +++ b/taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py @@ -515,6 +515,10 @@ def gaussian_point_rasterisation_backward( in_camera_grad_uv_cov_buffer: ti.types.ndarray(ti.f32, ndim=2), in_camera_grad_color_buffer: ti.types.ndarray(ti.f32, ndim=2), # (M, 3) + enable_grad_depth_cov: ti.template(), + in_camera_grad_depth_buffer: ti.types.ndarray(ti.f32, ndim=1), # (M,) + depth_cov_loss_factor: ti.f32, + rasterized_depth: ti.types.ndarray(ti.f32, ndim=2), # (H, W) point_uv: ti.types.ndarray(ti.f32, ndim=2), # (M, 2) point_in_camera: ti.types.ndarray(ti.f32, ndim=2), # (M, 3) @@ -552,6 +556,8 @@ def gaussian_point_rasterisation_backward( (3, ti.static(TILE_HEIGHT * TILE_WIDTH)), dtype=ti.f32) # 3KB shared memory tile_point_alpha = ti.simt.block.SharedArray( (ti.static(TILE_HEIGHT * TILE_WIDTH),), dtype=ti.f32) # 1KB shared memory + tile_point_depth = ti.simt.block.SharedArray( + (ti.static(TILE_HEIGHT * TILE_WIDTH if enable_grad_depth_cov else 0),), dtype=ti.f32) # 1KB shared memory pixel_offset_in_tile = pixel_offset - \ tile_id * ti.static(TILE_HEIGHT * TILE_WIDTH) @@ -570,6 +576,10 @@ def gaussian_point_rasterisation_backward( pixel_rgb_grad = ti.math.vec3( rasterized_image_grad[pixel_v, pixel_u, 0], rasterized_image_grad[pixel_v, pixel_u, 1], rasterized_image_grad[pixel_v, pixel_u, 2]) + + pixel_depth: ti.f32 = 0.0 + if enable_grad_depth_cov: + pixel_depth = rasterized_depth[pixel_v, pixel_u] total_magnitude_grad_viewspace_on_image = ti.math.vec2(0.0, 0.0) # for inverse_point_offset in range(effective_point_count): @@ -602,6 +612,8 @@ def gaussian_point_rasterisation_backward( thread_id] = point_color[to_load_point_offset, i] tile_point_alpha[thread_id] = point_alpha_after_activation[to_load_point_offset] + if enable_grad_depth_cov: + tile_point_depth[thread_id] = point_in_camera[to_load_point_offset, 2] ti.simt.block.sync() max_inverse_point_offset_offset = ti.min( @@ -628,6 +640,10 @@ def gaussian_point_rasterisation_backward( point_alpha_after_activation_value = tile_point_alpha[ idx_point_offset_with_sort_key_in_block] + point_depth = 0.0 + if enable_grad_depth_cov: + point_depth = tile_point_depth[idx_point_offset_with_sort_key_in_block] + # d_p_d_mean is (2,), d_p_d_cov is (2, 2), needs to be flattened to (4,) gaussian_alpha, d_p_d_mean, d_p_d_cov = grad_point_probability_density_from_conic( xy=ti.math.vec2([pixel_u + 0.5, pixel_v + 0.5]), @@ -653,11 +669,18 @@ def gaussian_point_rasterisation_backward( # f"({pixel_v}, {pixel_u}, {point_offset}, {point_offset - start_offset}), accumulated_alpha: {accumulated_alpha}") d_pixel_rgb_d_color = alpha * T_i + point_grad_color = d_pixel_rgb_d_color * pixel_rgb_grad # \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} w_i alpha_grad_from_rgb = (color * T_i - w_i / (1. - alpha)) \ * pixel_rgb_grad + + d_var_depth_d_depth = 0.0 + if enable_grad_depth_cov: + d_var_depth_d_depth = alpha * T_i * \ + (point_depth - pixel_depth) * depth_cov_loss_factor + # w_{i-1} = w_i + c_i a_i T(i) w_i += color * alpha * T_i alpha_grad: ti.f32 = alpha_grad_from_rgb.sum() @@ -700,6 +723,9 @@ def gaussian_point_rasterisation_backward( magnitude_grad_viewspace[point_id], magnitude_point_grad_viewspace) ti.atomic_add( in_camera_num_affected_pixels[point_offset], 1) + if enable_grad_depth_cov: + ti.atomic_add( + in_camera_grad_depth_buffer[point_offset], d_var_depth_d_depth) # end of the TILE_WIDTH * TILE_HEIGHT block loop ti.simt.block.sync() # end of the backward traversal loop, from last point to first point @@ -737,6 +763,10 @@ def gaussian_point_rasterisation_backward( in_camera_grad_color_buffer[idx, 1], in_camera_grad_color_buffer[idx, 2], ) + point_grad_depth = 0.0 + if enable_grad_depth_cov: + point_grad_depth = in_camera_grad_depth_buffer[idx] + point_q_camera_pointcloud = ti.Vector( [q_camera_pointcloud[point_object_id[point_id], idx] for idx in ti.static(range(4))]) point_t_camera_pointcloud = ti.Vector( @@ -749,12 +779,12 @@ def gaussian_point_rasterisation_backward( ) translation_camera = ti.Vector([ point_in_camera[idx, j] for j in ti.static(range(3))]) - d_uv_d_translation, d_uv_d_q, d_uv_d_t = gaussian_point_3d.project_to_camera_position_by_q_t_jacobian( + d_uv_d_translation, d_uv_d_q, d_uv_d_t, d_depth_dq, d_depth_dt = gaussian_point_3d.project_to_camera_position_by_q_t_jacobian( q_camera_world=point_q_camera_pointcloud, t_camera_world=point_t_camera_pointcloud, projective_transform=camera_intrinsics_mat, ) # (2, 3), (2, 4), (2, 3) - + d_Sigma_prime_d_q, d_Sigma_prime_d_s = gaussian_point_3d.project_to_camera_covariance_jacobian( T_camera_world=T_camera_pointcloud_mat, projective_transform=camera_intrinsics_mat, @@ -777,6 +807,10 @@ def gaussian_point_rasterisation_backward( point_camera_pose_q_grad = point_grad_uv @ d_uv_d_q point_camera_pose_t_grad = point_grad_uv @ d_uv_d_t + if enable_grad_depth_cov: + point_camera_pose_q_grad += point_grad_depth * d_depth_dq + point_camera_pose_t_grad += point_grad_depth * d_depth_dt + for i in ti.static(range(3)): grad_pointcloud[point_id, i] = translation_grad[i] for i in ti.static(range(4)): @@ -856,6 +890,8 @@ class GaussianPointCloudRasterisationInput: # Kx3, x to the right, y down, z forward, K is the number of objects t_camera_pointcloud: torch.Tensor color_max_sh_band: int = 2 + enable_depth_cov_loss: bool = True + depth_cov_loss_factor: float = 0.2 @dataclass class BackwardValidPointHookInput: @@ -894,6 +930,8 @@ def forward(ctx, t_pointcloud_camera, camera_info, color_max_sh_band, + enable_depth_cov_loss, + depth_cov_loss_factor, ): point_in_camera_mask = torch.zeros( size=(pointcloud.shape[0],), dtype=torch.int8, device=pointcloud.device) @@ -1070,9 +1108,12 @@ def forward(ctx, point_uv_conic, point_alpha_after_activation, point_color, + rasterized_depth, ) ctx.camera_info = camera_info ctx.color_max_sh_band = color_max_sh_band + ctx.enable_depth_cov_loss = enable_depth_cov_loss + ctx.depth_cov_loss_factor = depth_cov_loss_factor # rasterized_image.requires_grad_(True) return rasterized_image, rasterized_depth, pixel_valid_point_count @@ -1097,9 +1138,12 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid point_in_camera, \ point_uv_conic, \ point_alpha_after_activation, \ - point_color = ctx.saved_tensors + point_color, \ + rasterized_depth = ctx.saved_tensors camera_info = ctx.camera_info color_max_sh_band = ctx.color_max_sh_band + enable_depth_cov_loss = ctx.enable_depth_cov_loss + depth_cov_loss_factor = ctx.depth_cov_loss_factor grad_rasterized_image = grad_rasterized_image.contiguous() grad_pointcloud = torch.zeros_like(pointcloud) grad_pointcloud_features = torch.zeros_like( @@ -1123,11 +1167,17 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid q_camera_pointcloud) grad_t_camera_pointcloud = torch.zeros_like( t_camera_pointcloud) - else: # torch will report error if we provide None as input, provide empty tensor instead + else: # torch will report error if we provide None as input, provide empty tensor instead grad_q_camera_pointcloud = torch.zeros( size=(0, 4), dtype=torch.float32, device=pointcloud.device) grad_t_camera_pointcloud = torch.zeros( size=(0, 3), dtype=torch.float32, device=pointcloud.device) + if enable_depth_cov_loss: + in_camera_grad_depth_buffer = torch.zeros( + size=(point_id_in_camera_list.shape[0],), dtype=torch.float32, device=pointcloud.device) + else: + in_camera_grad_depth_buffer = torch.zeros( + size=(0,), dtype=torch.float32, device=pointcloud.device) gaussian_point_rasterisation_backward( camera_height=camera_info.camera_height, @@ -1154,6 +1204,10 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid grad_uv=grad_viewspace.contiguous(), in_camera_grad_uv_cov_buffer=in_camera_grad_uv_cov_buffer.contiguous(), in_camera_grad_color_buffer=in_camera_grad_color_buffer.contiguous(), + enable_grad_depth_cov=enable_depth_cov_loss, + in_camera_grad_depth_buffer=in_camera_grad_depth_buffer.contiguous(), + depth_cov_loss_factor=depth_cov_loss_factor, + rasterized_depth=rasterized_depth.contiguous(), point_uv=point_uv.contiguous(), point_in_camera=point_in_camera.contiguous(), point_uv_conic=point_uv_conic.contiguous(), @@ -1232,7 +1286,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid grad_q_camera_pointcloud, \ grad_t_camera_pointcloud, \ None, \ - None, None + None, None, None, None self._module_function = _module_function @@ -1276,4 +1330,6 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput): t_pointcloud_camera, camera_info, color_max_sh_band, + input_data.enable_depth_cov_loss, + input_data.depth_cov_loss_factor, ) diff --git a/taichi_3d_gaussian_splatting/GaussianPointTrainer.py b/taichi_3d_gaussian_splatting/GaussianPointTrainer.py index d9e496b2..d7e316ab 100644 --- a/taichi_3d_gaussian_splatting/GaussianPointTrainer.py +++ b/taichi_3d_gaussian_splatting/GaussianPointTrainer.py @@ -41,6 +41,11 @@ class TrainConfig(YAMLWizard): val_interval: int = 1000 feature_learning_rate: float = 1e-3 iteration_start_camera_pose_optimization: int = 30000 + iteration_start_depth_cov_loss: int = 2000 + enable_depth_cov_loss: bool = True + initial_depth_cov_loss_factor: float = 0.4 + depth_cov_loss_factor_increase_interval: int = 1000 + depth_cov_loss_factor_increase_rate: float = 1.05511 camera_pose_optimization_batch_size: int = 500 position_learning_rate: float = 1e-5 position_learning_rate_decay_rate: float = 0.97 @@ -175,15 +180,20 @@ def train(self): camera_info.camera_intrinsics = camera_info.camera_intrinsics.cuda() camera_info.camera_width = int(camera_info.camera_width) camera_info.camera_height = int(camera_info.camera_height) + depth_cov_loss_factor = self.config.initial_depth_cov_loss_factor * \ + self.config.depth_cov_loss_factor_increase_rate ** ( + iteration // self.config.depth_cov_loss_factor_increase_interval) gaussian_point_cloud_rasterisation_input = GaussianPointCloudRasterisation.GaussianPointCloudRasterisationInput( point_cloud=self.scene.point_cloud, point_cloud_features=self.scene.point_cloud_features, point_object_id=self.scene.point_object_id, point_invalid_mask=self.scene.point_invalid_mask, camera_info=camera_info, - q_camera_pointcloud=trained_q_camera_pointcloud, + q_camera_pointcloud=trained_q_camera_pointcloud, t_camera_pointcloud=trained_t_camera_pointcloud, color_max_sh_band=iteration // self.config.increase_color_max_sh_band_interval, + enable_depth_cov_loss=self.config.enable_depth_cov_loss and iteration > self.config.iteration_start_depth_cov_loss, + depth_cov_loss_factor=depth_cov_loss_factor ) image_pred, image_depth, pixel_valid_point_count = self.rasterisation( gaussian_point_cloud_rasterisation_input) @@ -345,7 +355,7 @@ def _plot_grad_histogram(grad_input: GaussianPointCloudRasterisation.BackwardVal num_overlap_tiles, iteration) writer.add_histogram("value/num_affected_pixels", num_affected_pixels, iteration) - + @staticmethod def _plot_value_histogram(scene: GaussianPointCloudScene, writer, iteration): with torch.no_grad():