Skip to content

Commit

Permalink
[geom] Fix Mesh.vertex_connectivity, simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Nov 25, 2024
1 parent 0084480 commit ba28002
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 60 deletions.
10 changes: 4 additions & 6 deletions phi/geom/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def sdf_and_grad(x: Tensor):
def surface_mesh(geo: Geometry,
rel_dx: float = None,
abs_dx: float = None,
method='auto',
build_vertex_connectivity=False,
build_normals=False) -> Mesh:
method='auto') -> Mesh:
"""
Create a surface `Mesh` from a Geometry.
Expand All @@ -101,7 +99,7 @@ def surface_mesh(geo: Geometry,
if geo.spatial_rank != 3:
raise NotImplementedError("Only 3D SDF currently supported")
if isinstance(geo, NoGeometry):
return mesh_from_numpy([], [], build_faces=False, element_rank=2, build_normals=False)
return mesh_from_numpy([], [], element_rank=2)
if method == 'auto' and isinstance(geo, BaseBox):
assert rel_dx is None and abs_dx is None, f"When method='auto', boxes will always use their corners as vertices. Leave rel_dx,abs_dx unspecified or pass 'lewiner' or 'lorensen' as method"
vertices = pack_dims(geo.corners, dual, instance('vertices'))
Expand All @@ -113,7 +111,7 @@ def surface_mesh(geo: Geometry,
instance_offset = math.range_tensor(instance(geo)) * corner_count
faces = wrap([v1, v2, v3], spatial('vertices'), instance('faces')) + instance_offset
faces = pack_dims(faces, instance, instance('faces'))
return mesh(vertices, faces, element_rank=2, build_faces=False, build_vertex_connectivity=build_vertex_connectivity, build_normals=build_normals)
return mesh(vertices, faces, element_rank=2)
elif method == 'auto' and isinstance(geo, Sphere):
pass # ToDo analytic solution
if isinstance(geo, SDFGrid):
Expand All @@ -139,5 +137,5 @@ def generate_mesh(sdf_grid: SDFGrid) -> Mesh:
vertices, faces, v_normals, _ = marching_cubes(sdf_numpy, level=0.0, spacing=dx, allow_degenerate=False, method=method)
vertices += sdf_grid.bounds.lower.numpy() + .5 * dx
with math.NUMPY:
return mesh_from_numpy(vertices, faces, element_rank=2, build_faces=False, build_vertex_connectivity=build_vertex_connectivity, build_normals=build_normals, cell_dim=instance('faces'))
return mesh_from_numpy(vertices, faces, element_rank=2, cell_dim=instance('faces'))
return math.map(generate_mesh, sdf_grid, dims=batch)
90 changes: 36 additions & 54 deletions phi/geom/_mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import warnings
from dataclasses import dataclass
from functools import cached_property
from numbers import Number
Expand All @@ -8,50 +7,26 @@
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix

from phiml import math
from phiml.math import to_format, is_sparse, non_channel, non_batch, batch, pack_dims, unstack, tensor, si2d, non_dual, nonzero, stored_indices, stored_values, scatter, \
find_closest, sqrt, where, vec_normalize, argmax, broadcast, to_int32, cross_product, zeros, random_normal, EMPTY_SHAPE, meshgrid, mean, reshaped_numpy, range_tensor, convolve, \
assert_close, shift, pad, extrapolation, NUMPY, sum as sum_, with_diagonal, flatten, ones_like, dim_mask, math
find_closest, sqrt, where, vec_normalize, argmax, broadcast, cross_product, zeros, EMPTY_SHAPE, meshgrid, mean, reshaped_numpy, range_tensor, convolve, \
assert_close, shift, pad, extrapolation, sum as sum_, flatten, dim_mask, math, cumulative_sum, arange
from phiml.math._magic_ops import getitem_dataclass
from phiml.math._sparse import CompactSparseTensor
from phiml.math.extrapolation import as_extrapolation, PERIODIC
from phiml.math.magic import slicing_dict
from . import bounding_box
from ._box import Box, BaseBox
from ._functions import plane_sgn_dist
from ._geom import Geometry, Point, NoGeometry
from ._transform import scale
from ._box import Box, BaseBox
from ._graph import Graph, graph
from ._transform import scale
from ..math import Tensor, Shape, channel, shape, instance, dual, rename_dims, expand, spatial, wrap, sparse_tensor, stack, vec_length, tensor_like, \
pairwise_distances, concat, Extrapolation


class _MeshType(type):
"""Metaclass containing the user-friendly (legacy) Mesh constructor."""
def __call__(cls,
vertices: Union[Geometry, Tensor],
elements: Tensor,
element_rank: int,
boundaries: Dict[str, Dict[str, slice]],
periodic: Sequence[str],
face_format: str = 'csc',
max_cell_walk: int = None,
variables=('vertices',),
values=()):
if spatial(elements):
assert elements.dtype.kind == int, f"elements listing vertices must be integer lists but got dtype {elements.dtype}"
else:
assert elements.dtype.kind == bool, f"element matrices must be of type bool but got {elements.dtype}"
if not isinstance(vertices, Geometry):
vertices = Point(vertices)
if max_cell_walk is None:
max_cell_walk = 2 if instance(elements).volume > 1 else 1
result = cls.__new__(cls, vertices, elements, element_rank, boundaries, periodic, face_format, max_cell_walk, variables, values)
result.__init__(vertices, elements, element_rank, boundaries, periodic, face_format, max_cell_walk, variables, values) # also calls __post_init__()
return result


@dataclass(frozen=True)
class Mesh(Geometry, metaclass=_MeshType):
class Mesh(Geometry):
"""
Unstructured mesh, consisting of vertices and elements.
Expand All @@ -73,9 +48,16 @@ class Mesh(Geometry, metaclass=_MeshType):
face_format: str = 'csc'
"""Sparse matrix format for storing quantities that depend on a pair of neighboring elements, e.g. `face_area`, `face_normal`, `face_center`."""
max_cell_walk: int = None
""" Maximum number of steps to walk along the element connectivity in order to find a cell, e.g. for sampling at an arbitrary point."""

variable_attrs: Tuple[str, ...] = ('vertices',) # PhiML keyword
value_attrs: Tuple[str, ...] = () # PhiML keyword

variable_attrs: Tuple[str, ...] = ('vertices',)
value_attrs: Tuple[str, ...] = ()
def __post_init__(self):
if spatial(self.elements):
assert self.elements.dtype.kind == int, f"elements listing vertices must be integer lists but got dtype {self.elements.dtype}"
else:
assert self.elements.dtype.kind == bool, f"element matrices must be of type bool but got {self.elements.dtype}"

@cached_property
def shape(self) -> Shape:
Expand Down Expand Up @@ -117,15 +99,14 @@ def face_normals(self) -> Tensor:
raise NotImplementedError

@cached_property
def _faces(self) -> Dict[str, Tensor]:
def _faces(self) -> Dict[str, Any]:
if self.element_rank == 2:
centers, normals, areas, boundary_slices, vertex_connectivity = build_faces_2d(self.vertices.center, self.elements, self.boundaries, self.periodic, self._vertex_mean, self.face_format)
centers, normals, areas, boundary_slices = build_faces_2d(self.vertices.center, self.elements, self.boundaries, self.periodic, self._vertex_mean, self.face_format)
return {
'center': centers,
'normal': normals,
'area': areas,
'boundary_slices': boundary_slices,
'vertex_connectivity': vertex_connectivity,
}
return None

Expand Down Expand Up @@ -283,24 +264,23 @@ def element_connectivity(self) -> Tensor:
def vertex_connectivity(self) -> Tensor:
if isinstance(self.vertices, Graph):
return self.vertices.connectivity
if self.element_rank == self.spatial_rank:
return self._faces['vertex_connectivity']
elif self.element_rank <= 2:
coo = to_format(self.elements, 'coo').numpy()
connected_points = coo.T @ coo # ToDo this also counts vertices not connected by a single line/face as long as they are part of the same element
if not np.all(connected_points.sum_(axis=1) > 0):
warnings.warn("some vertices have no element connection at all", RuntimeWarning)
connected_points.data = np.ones_like(connected_points.data)
vertex_connectivity = wrap(connected_points, instance(self.vertices), dual(self.elements))
return vertex_connectivity
def single_vertex_connectivity(elements: Tensor):
indices = stored_indices(elements).index[dual(elements).name]
idx1 = indices.numpy()
v_count = sum_(elements, dual).numpy()
ptr_end = np.cumsum(v_count)
roll = np.arange(idx1.size) + 1
roll[ptr_end-1] = ptr_end - v_count
idx2 = idx1[roll]
v_conn = coo_matrix((np.ones(idx1.size, dtype=bool), (idx1, idx2)), shape=(dual(elements).size,)*2).tocsr()
return wrap(v_conn, dual(elements).as_instance(), dual(elements))
return math.map(single_vertex_connectivity, self.elements, dims=batch)
raise NotImplementedError

@property
@cached_property
def vertex_graph(self) -> Graph:
if isinstance(self.vertices, Graph):
return self.vertices
assert self._vertex_connectivity is not None, f"vertex_graph not available because vertex_connectivity has not been computed"
return graph(self.vertices, self._vertex_connectivity)
return self.vertices if isinstance(self.vertices, Graph) else graph(self.vertices, self.vertex_connectivity)

def filter_unused_vertices(self) -> 'Mesh':
coo = to_format(self.elements, 'coo').numpy()
Expand Down Expand Up @@ -343,8 +323,9 @@ def volume(self) -> Tensor:
@property
def normals(self) -> Tensor:
"""Extrinsic element normal space. This is a 0D vector for solid elements and 1D for surface elements."""
if isinstance(self.elements, CompactSparseTensor) and self.element_rank == 2:
corners = self.vertices[self.elements._indices]
if self.element_rank == 2:
three_vertices = nonzero(self.elements, 3, list_dims=dual)
corners = self.vertices.center[{instance: three_vertices}]
assert dual(corners).size == 3, f"signed distance currently only supports triangles"
v1, v2, v3 = unstack(corners, dual)
return vec_normalize(cross_product(v2 - v1, v3 - v1))
Expand Down Expand Up @@ -651,6 +632,8 @@ def mesh(vertices: Geometry | Tensor,
element_rank = 2 if min_vertices <= 4 else 3 # assume tri or quad mesh
else:
raise ValueError(vertices.vector.size)
if max_cell_walk is None:
max_cell_walk = 2 if instance(elements).volume > 1 else 1
# --- build faces ---
periodic_dims = []
if periodic is not None:
Expand Down Expand Up @@ -751,8 +734,7 @@ def build_faces_2d(vertices: Tensor, # (vertices:i, vector)
edge_len = sparse_tensor(indices, edge_len, element_connectivity.shape, format='coo' if face_format == 'dense' else face_format, indices_constant=True)
normal = tensor_like(edge_len, normal, value_order='original')
edge_center = tensor_like(edge_len, edge_center, value_order='original')
vertex_connectivity = None
return edge_center, normal, edge_len, boundary_slices, vertex_connectivity
return edge_center, normal, edge_len, boundary_slices


def build_mesh(bounds: Box = None,
Expand Down

0 comments on commit ba28002

Please sign in to comment.