Skip to content

Commit

Permalink
improved eval in _FEMResults for serialized calls
Browse files Browse the repository at this point in the history
  • Loading branch information
louisreg committed Sep 25, 2024
1 parent 8e0767a commit 6ffd4dc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 56 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ All notable changes to NRV are sumed up in this file.

## [1.1.3] - XXXX-XX-XX
### Added

- improved `eval` in `_FEMResults` for serialized calls (added state variables)

### Fixed



### Removed
- `myelinated_results.find_central_node_index`-method replaced by `axon_results.find_central_index` with, for `myelinated_results`, the argument `node` to obtain former results

Expand Down
91 changes: 36 additions & 55 deletions nrv/fmod/FEM/fenics_utils/_FEMResults.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import gmsh
import numpy as np
import scipy
from time import perf_counter

from dolfinx.fem import Expression, Function, functionspace
from dolfinx.io.gmshio import model_to_mesh
from dolfinx.io.utils import XDMFFile, VTXWriter, VTKFile
Expand Down Expand Up @@ -197,6 +199,15 @@ def __init__(
self.elem = elem
self.vout = vout
self.comm = comm

#For eval() serialized calls acceleration
self.is_evaluated = False
self.tdim = None
self.n_entities_local = None
self.entities = None
self.midpoint_tree = None
self.tree = None


def set_sim_result(
self, mesh_file="", domain=None, V=None, elem=None, vout=None, comm=None
Expand Down Expand Up @@ -297,67 +308,37 @@ def eval(self, X, is_multi_proc=False):
N = len(X)
cells = []
points_on_proc = []
tdim = self.domain.geometry.dim
n_entities_local = (
self.domain.topology.index_map(tdim).size_local
+ self.domain.topology.index_map(tdim).num_ghosts
)
entities = np.arange(n_entities_local, dtype=np.int32)
midpoint_tree = create_midpoint_tree(self.domain, tdim, entities)
tree = bb_tree(self.domain, tdim)
if self.is_evaluated is False:
self.is_evaluated = True
self.tdim = self.domain.geometry.dim
self.n_entities_local = (
self.domain.topology.index_map(self.tdim).size_local
+ self.domain.topology.index_map(self.tdim).num_ghosts
)
self.entities = np.arange(self.n_entities_local, dtype=np.int32)
self.midpoint_tree = create_midpoint_tree(self.domain, self.tdim, self.entities)
self.tree = bb_tree(self.domain, self.tdim)
# Find cells whose bounding-box collide with the the points
cells_candidates = compute_collisions_points(tree, X)

cells_candidates = compute_collisions_points(self.tree, X)
# Choose one of the cells that contains the point
cells_colliding = compute_colliding_cells(self.domain, cells_candidates, X)
for i in range(N):
cell = cells_colliding.links(i)
if is_multi_proc:
if len(cell) > 0:
points_on_proc.append(X[i])
cells.append(cells_colliding.links(i)[0])
else:
# point not in the mesh
if len(cell) == 0:
cell, x_closest = closest_point_in_mesh(
self.domain, X[i], tree, tdim, midpoint_tree
)
rise_warning(
X[i], " not found in mesh, value of ", x_closest, " reused"
)
# compute_colliding_cells(self.domain, cells_candidates, X)
cells += [cell[0]]
else:
cells += [cell[0]]
if is_multi_proc:
points_on_proc = np.array(points_on_proc, dtype=np.float64)
s_values = self.vout.eval(points_on_proc, cells)
if len(points_on_proc) > 0:
m_values = np.concatenate((points_on_proc.T, s_values.T)).T
else:
m_values = np.array([])
m_values = self.comm.gather(m_values.tolist(), root=0)
if MCH.do_master_only_work():
val = []
for i in range(len(m_values)):
for j in range(len(m_values[i])):
val += [m_values[i][j]]
val = np.array(val)
values = np.empty((N), dtype="float64")
for i, p in enumerate(X):
i_p = np.where((np.isclose(val[:, :3], p)).all(axis=1))[0]
if len(i_p > 0):
values[i] = val[i_p[0], 3]
else:
values[i] = values[i - 1]
# point not in the mesh
if len(cell) == 0:
cell, x_closest = closest_point_in_mesh(
self.domain, X[i], self.tree, self.tdim, self.midpoint_tree
)
rise_warning(
X[i], " not found in mesh, value of ", x_closest, " reused"
)
# compute_colliding_cells(self.domain, cells_candidates, X)
cells += [cell[0]]
else:
values = np.empty((N), dtype="float64")
synchronize_processes()
self.comm.Bcast(values, root=0)
else:
values = self.vout.eval(X, cells)
if N > 1:
values = values[:, 0]
cells += [cell[0]]
values = self.vout.eval(X, cells)
if N > 1:
values = values[:, 0]
return values

#####################
Expand Down

0 comments on commit 6ffd4dc

Please sign in to comment.