diff --git a/src/gpytoolbox/reach_for_the_spheres.py b/src/gpytoolbox/reach_for_the_spheres.py index e9e3fd08..d1ff5303 100644 --- a/src/gpytoolbox/reach_for_the_spheres.py +++ b/src/gpytoolbox/reach_for_the_spheres.py @@ -664,7 +664,19 @@ def reach_for_the_spheres_iteration(state, state.V = sp.sparse.linalg.spsolve(Q,b) # catching flow singularities so we fail gracefully - if np.any((np.isnan(state.V))) or np.any(doublearea(state.V, state.F)==0) or len(non_manifold_edges(state.F))>0 : + + there_are_non_manifold_edges = False + if dim==3: + there_are_non_manifold_edges = len(non_manifold_edges(state.F))>0 + elif dim==2: + he_nm = np.sort(state.F, axis=1) + # print(he) + he_u_nm = np.unique(he_nm, axis=0, return_counts=True) + # print(he) + ne_nm = he_u_nm[0][he_u_nm[1]>2] + there_are_non_manifold_edges = len(ne_nm)>0 + + if np.any((np.isnan(state.V))) or np.any(doublearea(state.V, state.F)==0) or there_are_non_manifold_edges : if verbose: print("we found a flow singularity. Returning the last converged solution.") diff --git a/test/test_reach_for_the_spheres.py b/test/test_reach_for_the_spheres.py index 50d0dd38..6806129f 100644 --- a/test/test_reach_for_the_spheres.py +++ b/test/test_reach_for_the_spheres.py @@ -1,7 +1,7 @@ from .context import gpytoolbox as gpy from .context import numpy as np from .context import unittest - +import matplotlib.pyplot as plt class TestReachForTheSpheres(unittest.TestCase): def test_beat_marching_cubes_low_res(self): @@ -26,6 +26,40 @@ def test_beat_marching_cubes_low_res(self): # print(f"reach_for_the_spheres h: {h_ours}, MC h: {h_mc} for {mesh} with n={n}") self.assertTrue(h_ours < h_mc) + def test_beat_marching_cubes_2d(self): + png_paths = ["test/unit_tests_data/switzerland.png"] + ns = [10, 20, 30, 50] + for png_path in png_paths: + vv = gpy.png2poly(png_path)[0] + vv = gpy.normalize_points(vv) + vv = 1.0*vv + ec = gpy.edge_indices(vv.shape[0], closed=True) + for n in ns: + gx, gy = np.meshgrid(np.linspace(-1.0, 1.0, n+1), np.linspace(-1.0, 1.0, n+1)) + GV = np.vstack((gx.flatten(), gy.flatten())).T + S = gpy.signed_distance(GV, vv, ec)[0] + plt.scatter(GV[:,0], GV[:,1], c=S) + plt.plot(vv[:,0], vv[:,1], 'r-') + plt.colorbar() + plt.show() + vv_mc, ee_mc = gpy.marching_squares(S, GV, n+1, n+1) + + # plot vv, ee edge by edge + # for i in range(ee_mc.shape[0]): + # plt.plot([vv_mc[ee_mc[i,0],0], vv_mc[ee_mc[i,1],0]], [vv_mc[ee_mc[i,0],1], vv_mc[ee_mc[i,1],1]], 'k-') + + # now run rfts + vv_rfts, ee_rfts = gpy.regular_circle_polyline(10) + sdf = lambda x: gpy.signed_distance(x, vv, ec)[0] + vv_rfts, ee_rfts = gpy.reach_for_the_spheres(GV, sdf, V=vv_rfts, F=ee_rfts, S=S) + # plot vv, ee edge by edge + # for i in range(ee_rfts.shape[0]): + # plt.plot([vv_rfts[ee_rfts[i,0],0], vv_rfts[ee_rfts[i,1],0]], [vv_rfts[ee_rfts[i,0],1], vv_rfts[ee_rfts[i,1],1]], 'g-') + + # plt.axis('equal') + # plt.show() + + def test_noop(self): meshes = ["bunny_oded.obj", "spot.obj", "teddy.obj"] for mesh in meshes: diff --git a/test/unit_tests_data/switzerland.png b/test/unit_tests_data/switzerland.png new file mode 100644 index 00000000..65438c4f Binary files /dev/null and b/test/unit_tests_data/switzerland.png differ