diff --git a/src/gpytoolbox/particle_swarm.py b/src/gpytoolbox/particle_swarm.py index d4e46f96..a767c2cc 100644 --- a/src/gpytoolbox/particle_swarm.py +++ b/src/gpytoolbox/particle_swarm.py @@ -39,7 +39,7 @@ def particle_swarm(fun,lb,ub,n_particles=100,max_iter=100,momentum=0.9,phi=0.1,v f = fun(xi) # print(xi) best_xi[i,:] = xi.copy() - best_fi[i] = f.copy() + best_fi[i] = np.squeeze(f.copy()) # if verbose: # print("Particle %d: f = %f" % (i,f)) if f < current_best_f: @@ -77,7 +77,7 @@ def particle_swarm(fun,lb,ub,n_particles=100,max_iter=100,momentum=0.9,phi=0.1,v # Update best position if f < best_fi[i]: best_xi[i,:] = x[i,:].copy() - best_fi[i] = f.copy() + best_fi[i] = np.squeeze(f.copy()) if f < current_best_f: current_best_x = x[i,:].copy() current_best_f = f.copy() diff --git a/src/gpytoolbox/ray_mesh_intersect.py b/src/gpytoolbox/ray_mesh_intersect.py index 2da79fda..0c9e2ad4 100644 --- a/src/gpytoolbox/ray_mesh_intersect.py +++ b/src/gpytoolbox/ray_mesh_intersect.py @@ -114,7 +114,7 @@ def ray_mesh_intersect(cam_pos,cam_dir,V,F,use_embree=True,C=None,W=None,CH=None add_to_queue_fun = trav.add_to_queue _ = traverse_aabbtree(C,W,CH,tri_ind,None,traverse_fun,add_to_queue=add_to_queue_fun) ts[i] = trav.t - ids[i] = trav.id + ids[i] = np.squeeze(trav.id) lambdas[i,:] = trav.lmbd # print("computed distances") return ts, ids, lambdas diff --git a/src/gpytoolbox/squared_distance.py b/src/gpytoolbox/squared_distance.py index 13658bcd..f9f6ef21 100644 --- a/src/gpytoolbox/squared_distance.py +++ b/src/gpytoolbox/squared_distance.py @@ -176,8 +176,8 @@ def squared_distance(P,V,F=None,use_cpp=False,use_aabb=False,C=None,W=None,CH=No # print(tri_ind) _ = traverse_aabbtree(C,W,CH,tri_ind,split_dir,traverse_fun,add_to_queue=add_to_queue_fun) # print(t.num_traversal) - indices[j] = t.current_best_element - squared_distances[j] = t.current_best_guess + indices[j] = np.squeeze(t.current_best_element) + squared_distances[j] = np.squeeze(t.current_best_guess) lmbs[j,:] = t.current_best_lmb else: # Loop over every element diff --git a/test/test_particle_swarm.py b/test/test_particle_swarm.py index d5b76d37..4ca32a3b 100644 --- a/test/test_particle_swarm.py +++ b/test/test_particle_swarm.py @@ -68,8 +68,8 @@ def dropwave_function(x): return -(1 + np.cos(12*np.sqrt(np.sum(x**2))))/(0.5*np.sum(x**2) + 2) lb = np.array([-5,-5]) ub = np.array([5,5]) - x,f = gpy.particle_swarm(dropwave_function,lb,ub,verbose=True,max_iter=100,topology='full') - xring,fring = gpy.particle_swarm(dropwave_function,lb,ub,verbose=True,max_iter=100,topology='ring') + x,f = gpy.particle_swarm(dropwave_function,lb,ub,verbose=False,max_iter=100,topology='full') + xring,fring = gpy.particle_swarm(dropwave_function,lb,ub,verbose=False,max_iter=100,topology='ring') # print(x) self.assertTrue(np.isclose(x,random_center,atol=1e-3).all())