Skip to content

Commit

Permalink
[field] Fix Mesh padding with Field as boundary
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed May 22, 2024
1 parent edca661 commit ea321fa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
12 changes: 10 additions & 2 deletions phi/field/_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from phi.geom import UniformGrid, Box
from phi.math import Tensor, spatial, Extrapolation, Shape, stack
from phi.math.extrapolation import Undefined, ConstantExtrapolation, ZERO
from phiml.math import unstack
from phiml import math
from phiml.math import unstack, rename_dims, instance, dual
from ._field import Field
from ._resample import sample

Expand Down Expand Up @@ -76,7 +77,14 @@ def is_flexible(self) -> bool:
return False

def sparse_pad_values(self, value: Tensor, connectivity: Tensor, dim: str, **kwargs) -> Tensor:
raise NotImplementedError
assert 'mesh' in kwargs, f"sparse padding with Field as boundary only supported for meshes"
from ..geom import Mesh
mesh: Mesh = kwargs['mesh']
boundary_slice = mesh.boundary_faces[dim]
face_pos = mesh.face_centers[boundary_slice]
face_pos = math.stored_values(face_pos) # does this always preserve the order?
sampled = sample(self.field, face_pos)
return rename_dims(sampled, instance, dual(value))

def __eq__(self, other):
if not isinstance(other, FieldEmbedding):
Expand Down
2 changes: 1 addition & 1 deletion phi/geom/_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def pad_boundary(self, value: Tensor, widths: Dict[str, Dict[str, slice]] = None
for name, b_slice in widths.items():
if b_slice[dim].stop - b_slice[dim].start > 0:
slices.append(b_slice[dim])
values.append(mode.sparse_pad_values(value, connectivity[b_slice], name, **kwargs))
values.append(mode.sparse_pad_values(value, connectivity[b_slice], name, mesh=self, **kwargs))
perm = np.argsort([s.start for s in slices])
ordered_pieces = [values[i] for i in perm]
return concat(ordered_pieces, dim, expand_values=True)
Expand Down

0 comments on commit ea321fa

Please sign in to comment.