diff --git a/phi/field/_field.py b/phi/field/_field.py index 87f2799c9..fd48a87ff 100644 --- a/phi/field/_field.py +++ b/phi/field/_field.py @@ -205,7 +205,7 @@ def shape(self) -> Shape: * The batch dimensions match the batch dimensions of this Field * The channel dimensions match the channels of this Field """ - if self.is_staggered and self.is_grid: + if self.is_grid and '~vector' in self._values.shape: return batch(self._geometry) & self.resolution & non_dual(self._values).without(self.resolution) & self._geometry.shape['vector'] set_shape = self._geometry.sets[self.sampled_at] return batch(self._geometry) & (channel(self._geometry) - 'vector') & set_shape & self._values diff --git a/phi/geom/_grid.py b/phi/geom/_grid.py index fe503abc8..e97c60c78 100644 --- a/phi/geom/_grid.py +++ b/phi/geom/_grid.py @@ -34,6 +34,8 @@ def __init__(self, resolution: Shape = None, bounds: BaseBox = None, **resolutio self._resolution = resolution.only(bounds.vector.item_names, reorder=True) # reorder only self._bounds = bounds self._shape = self._resolution & bounds.shape.non_spatial + staggered_shapes = [self._shape.spatial.with_dim_size(dim, self._shape.get_size(dim) + 1) for dim in self.vector.item_names] + self._face_shape = shape_stack(dual(vector=self.vector.item_names), *staggered_shapes) @property def resolution(self): @@ -87,8 +89,7 @@ def face_areas(self) -> Tensor: @property def face_shape(self) -> Shape: - shapes = [self._shape.spatial.with_dim_size(dim, self._shape.get_size(dim) + 1) for dim in self.vector.item_names] - return shape_stack(dual(vector=self.vector.item_names), *shapes) + return self._face_shape def interior(self) -> 'Geometry': raise GeometryException("Regular grid does not have an interior")