From 6ffd4dcec37f2f2c4e898d8b0d07a1708f0a1f87 Mon Sep 17 00:00:00 2001 From: louisreg Date: Wed, 25 Sep 2024 15:10:13 +0200 Subject: [PATCH] improved `eval` in `_FEMResults` for serialized calls --- CHANGELOG.md | 3 +- nrv/fmod/FEM/fenics_utils/_FEMResults.py | 91 ++++++++++-------------- 2 files changed, 38 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 659a17b4..2958a130 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,12 @@ All notable changes to NRV are sumed up in this file. ## [1.1.3] - XXXX-XX-XX ### Added - +- improved `eval` in `_FEMResults` for serialized calls (added state variables) ### Fixed + ### Removed - `myelinated_results.find_central_node_index`-method replaced by `axon_results.find_central_index` with, for `myelinated_results`, the argument `node` to obtain former results diff --git a/nrv/fmod/FEM/fenics_utils/_FEMResults.py b/nrv/fmod/FEM/fenics_utils/_FEMResults.py index 87d1e13a..7213aeb7 100644 --- a/nrv/fmod/FEM/fenics_utils/_FEMResults.py +++ b/nrv/fmod/FEM/fenics_utils/_FEMResults.py @@ -5,6 +5,8 @@ import gmsh import numpy as np import scipy +from time import perf_counter + from dolfinx.fem import Expression, Function, functionspace from dolfinx.io.gmshio import model_to_mesh from dolfinx.io.utils import XDMFFile, VTXWriter, VTKFile @@ -197,6 +199,15 @@ def __init__( self.elem = elem self.vout = vout self.comm = comm + + #For eval() serialized calls acceleration + self.is_evaluated = False + self.tdim = None + self.n_entities_local = None + self.entities = None + self.midpoint_tree = None + self.tree = None + def set_sim_result( self, mesh_file="", domain=None, V=None, elem=None, vout=None, comm=None @@ -297,67 +308,37 @@ def eval(self, X, is_multi_proc=False): N = len(X) cells = [] points_on_proc = [] - tdim = self.domain.geometry.dim - n_entities_local = ( - self.domain.topology.index_map(tdim).size_local - + self.domain.topology.index_map(tdim).num_ghosts - ) - entities = np.arange(n_entities_local, dtype=np.int32) - midpoint_tree = create_midpoint_tree(self.domain, tdim, entities) - tree = bb_tree(self.domain, tdim) + if self.is_evaluated is False: + self.is_evaluated = True + self.tdim = self.domain.geometry.dim + self.n_entities_local = ( + self.domain.topology.index_map(self.tdim).size_local + + self.domain.topology.index_map(self.tdim).num_ghosts + ) + self.entities = np.arange(self.n_entities_local, dtype=np.int32) + self.midpoint_tree = create_midpoint_tree(self.domain, self.tdim, self.entities) + self.tree = bb_tree(self.domain, self.tdim) # Find cells whose bounding-box collide with the the points - cells_candidates = compute_collisions_points(tree, X) - + cells_candidates = compute_collisions_points(self.tree, X) # Choose one of the cells that contains the point cells_colliding = compute_colliding_cells(self.domain, cells_candidates, X) for i in range(N): cell = cells_colliding.links(i) - if is_multi_proc: - if len(cell) > 0: - points_on_proc.append(X[i]) - cells.append(cells_colliding.links(i)[0]) - else: - # point not in the mesh - if len(cell) == 0: - cell, x_closest = closest_point_in_mesh( - self.domain, X[i], tree, tdim, midpoint_tree - ) - rise_warning( - X[i], " not found in mesh, value of ", x_closest, " reused" - ) - # compute_colliding_cells(self.domain, cells_candidates, X) - cells += [cell[0]] - else: - cells += [cell[0]] - if is_multi_proc: - points_on_proc = np.array(points_on_proc, dtype=np.float64) - s_values = self.vout.eval(points_on_proc, cells) - if len(points_on_proc) > 0: - m_values = np.concatenate((points_on_proc.T, s_values.T)).T - else: - m_values = np.array([]) - m_values = self.comm.gather(m_values.tolist(), root=0) - if MCH.do_master_only_work(): - val = [] - for i in range(len(m_values)): - for j in range(len(m_values[i])): - val += [m_values[i][j]] - val = np.array(val) - values = np.empty((N), dtype="float64") - for i, p in enumerate(X): - i_p = np.where((np.isclose(val[:, :3], p)).all(axis=1))[0] - if len(i_p > 0): - values[i] = val[i_p[0], 3] - else: - values[i] = values[i - 1] + # point not in the mesh + if len(cell) == 0: + cell, x_closest = closest_point_in_mesh( + self.domain, X[i], self.tree, self.tdim, self.midpoint_tree + ) + rise_warning( + X[i], " not found in mesh, value of ", x_closest, " reused" + ) + # compute_colliding_cells(self.domain, cells_candidates, X) + cells += [cell[0]] else: - values = np.empty((N), dtype="float64") - synchronize_processes() - self.comm.Bcast(values, root=0) - else: - values = self.vout.eval(X, cells) - if N > 1: - values = values[:, 0] + cells += [cell[0]] + values = self.vout.eval(X, cells) + if N > 1: + values = values[:, 0] return values #####################