From ea321fa1b6834d8611aa22736703760c9cd37b79 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Wed, 22 May 2024 16:27:34 +0200 Subject: [PATCH] [field] Fix Mesh padding with Field as boundary --- phi/field/_embed.py | 12 ++++++++++-- phi/geom/_mesh.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/phi/field/_embed.py b/phi/field/_embed.py index 653a6f5cf..b7ab04094 100644 --- a/phi/field/_embed.py +++ b/phi/field/_embed.py @@ -3,7 +3,8 @@ from phi.geom import UniformGrid, Box from phi.math import Tensor, spatial, Extrapolation, Shape, stack from phi.math.extrapolation import Undefined, ConstantExtrapolation, ZERO -from phiml.math import unstack +from phiml import math +from phiml.math import unstack, rename_dims, instance, dual from ._field import Field from ._resample import sample @@ -76,7 +77,14 @@ def is_flexible(self) -> bool: return False def sparse_pad_values(self, value: Tensor, connectivity: Tensor, dim: str, **kwargs) -> Tensor: - raise NotImplementedError + assert 'mesh' in kwargs, f"sparse padding with Field as boundary only supported for meshes" + from ..geom import Mesh + mesh: Mesh = kwargs['mesh'] + boundary_slice = mesh.boundary_faces[dim] + face_pos = mesh.face_centers[boundary_slice] + face_pos = math.stored_values(face_pos) # does this always preserve the order? + sampled = sample(self.field, face_pos) + return rename_dims(sampled, instance, dual(value)) def __eq__(self, other): if not isinstance(other, FieldEmbedding): diff --git a/phi/geom/_mesh.py b/phi/geom/_mesh.py index 5cd7332ca..dcf1857f3 100644 --- a/phi/geom/_mesh.py +++ b/phi/geom/_mesh.py @@ -141,7 +141,7 @@ def pad_boundary(self, value: Tensor, widths: Dict[str, Dict[str, slice]] = None for name, b_slice in widths.items(): if b_slice[dim].stop - b_slice[dim].start > 0: slices.append(b_slice[dim]) - values.append(mode.sparse_pad_values(value, connectivity[b_slice], name, **kwargs)) + values.append(mode.sparse_pad_values(value, connectivity[b_slice], name, mesh=self, **kwargs)) perm = np.argsort([s.start for s in slices]) ordered_pieces = [values[i] for i in perm] return concat(ordered_pieces, dim, expand_values=True)