From 03f090af112dd119c23d738c1f9dd045b73123f5 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sat, 19 Oct 2024 14:40:44 +0200 Subject: [PATCH] [geom] Refactor Cylinder as @dataclass --- phi/flow.py | 2 +- phi/geom/__init__.py | 2 +- phi/geom/_cylinder.py | 203 ++++++++++++++++++++---------------------- 3 files changed, 99 insertions(+), 108 deletions(-) diff --git a/phi/flow.py b/phi/flow.py index 6cfea35f3..040c6899b 100644 --- a/phi/flow.py +++ b/phi/flow.py @@ -22,7 +22,7 @@ # Classes from phiml.math import Shape, Tensor, DType, Solve -from .geom import Geometry, Point, Sphere, Box, Cuboid, Cylinder, UniformGrid, Mesh, Graph +from .geom import Geometry, Point, Sphere, Box, Cuboid, cylinder, UniformGrid, Mesh, Graph from .field import Field, Grid, CenteredGrid, StaggeredGrid, mask, Noise, PointCloud, Scene, resample, GeometryMask, SoftGeometryMask, HardGeometryMask from .physics.fluid import Obstacle diff --git a/phi/geom/__init__.py b/phi/geom/__init__.py index d409731cc..0b32be2f1 100644 --- a/phi/geom/__init__.py +++ b/phi/geom/__init__.py @@ -14,7 +14,7 @@ from ._geom import Geometry, GeometryException, Point, assert_same_rank, invert, rotate, sample_function from ._box import Box, BaseBox, Cuboid, bounding_box from ._sphere import Sphere -from ._cylinder import Cylinder +from ._cylinder import Cylinder, cylinder from ._grid import UniformGrid, enclosing_grid from ._graph import Graph, graph from ._mesh import Mesh, mesh, load_su2, load_gmsh, load_stl, mesh_from_numpy, build_mesh diff --git a/phi/geom/_cylinder.py b/phi/geom/_cylinder.py index 665af8618..d76dc3593 100644 --- a/phi/geom/_cylinder.py +++ b/phi/geom/_cylinder.py @@ -1,13 +1,16 @@ +from dataclasses import dataclass +from functools import cached_property from typing import Union, Dict, Tuple, Optional, Sequence from phiml import math -from phiml.math import Shape, dual, wrap, Tensor, expand, vec, where, ccat, clip, length, normalize, rotate_vector, minimum, vec_squared, rotation_matrix, channel, instance, stack, \ - maximum +from phiml.math import Shape, dual, wrap, Tensor, expand, vec, where, ccat, clip, length, normalize, rotate_vector, minimum, vec_squared, rotation_matrix, channel, instance, stack, maximum +from phiml.math._magic_ops import all_attributes from phiml.math.magic import slicing_dict from ._geom import Geometry, _keep_vector from ._sphere import Sphere +@dataclass(frozen=True) class Cylinder(Geometry): """ N-dimensional cylinder. @@ -16,127 +19,73 @@ class Cylinder(Geometry): For cylinders whose bottom and top lie outside the domain or are otherwise not needed, you may use `infinite_cylinder` instead, which simplifies computations. """ - def __init__(self, - center: Tensor = None, - radius: Union[float, Tensor] = None, - depth: Union[float, Tensor] = None, - rotation: Optional[Tensor] = None, - axis=-1, - variables=('center', 'radius', 'depth', 'rotation'), - **center_: Union[float, Tensor]): - """ - Args: - center: Cylinder center as `Tensor` with `vector` dimension. - The spatial dimension order should be specified in the `vector` dimension via item names. - Can be left empty to specify dimensions via kwargs. - radius: Cylinder radius as `float` or `Tensor`. - depth: Cylinder length as `float` or `Tensor`. - rotation: Rotation angle(s) or rotation matrix. - axis: The cylinder is aligned along this axis, perturbed by `rotation`. - variables: Which properties of the cylinder are variable, i.e. traced and optimizable. All by default. - **center_: Specifies center when the `center` argument is not given. Center position by dimension, e.g. `x=0.5, y=0.2`. - """ - if center is not None: - assert isinstance(center, Tensor), f"center must be a Tensor but got {type(center).__name__}" - assert 'vector' in center.shape, f"Sphere center must have a 'vector' dimension." - assert center.shape.get_item_names('vector') is not None, f"Vector dimension must list spatial dimensions as item names. Use the syntax Sphere(x=x, y=y) to assign names." - self._center = center - else: - self._center = wrap(tuple(center_.values()), channel(vector=tuple(center_.keys()))) - self._radius = wrap(radius) - self._depth = wrap(depth) - self._rotation = None if rotation is None else rotation_matrix(rotation) - self._variables = tuple([v if v.startswith('_') else '_' + v for v in variables]) - self._axis = self._center.vector.item_names[axis] if isinstance(axis, int) else axis - assert 'vector' not in self._radius.shape, f"Cylinder radius must not vary along vector but got {radius}" - assert set(self._variables).issubset(set(self.__all_attrs__())), f"Invalid variables: {self._variables}" - assert self._axis in self._center.vector.item_names, f"Cylinder axis {self._axis} not part of vector dim {self._center.vector}" - - def __all_attrs__(self) -> tuple: - return '_center', '_radius', '_depth', '_rotation' - - def __variable_attrs__(self) -> tuple: - return self._variables + _center: Tensor + radius: Tensor + depth: Tensor + rotation: Tensor # rotation matrix + axis: str - def __value_attrs__(self) -> tuple: - return () - - @property - def shape(self) -> Shape: - if self._center is None or self._radius is None or self._depth is None: - raise RuntimeError - return self._center.shape & self._radius.shape & self._depth.shape - - @property - def radius(self) -> Tensor: - return self._radius + variables: Tuple[str, ...] = ('_center', 'radius', 'depth', 'rotation') + values: Tuple[str, ...] = () @property def center(self) -> Tensor: return self._center - @property - def depth(self) -> Tensor: - return self._depth - - @property - def axis(self) -> str: - return self._axis + @cached_property + def shape(self) -> Shape: + return self._center.shape & self.radius.shape & self.depth.shape - @property + @cached_property def radial_axes(self) -> Sequence[str]: - return [d for d in self._center.vector.item_names if d != self._axis] - - @property - def rotation_matrix(self): - return self._rotation + return [d for d in self._center.vector.item_names if d != self.axis] - @property + @cached_property def volume(self) -> math.Tensor: - return Sphere.volume_from_radius(self._radius, self.spatial_rank - 1) * self._depth + return Sphere.volume_from_radius(self.radius, self.spatial_rank - 1) * self.depth - @property + @cached_property def up(self): - return math.rotate_vector(vec(**{d: 1 if d == self._axis else 0 for d in self._center.vector.item_names}), self._rotation) + return math.rotate_vector(vec(**{d: 1 if d == self.axis else 0 for d in self._center.vector.item_names}), self.rotation) def lies_inside(self, location): - pos = rotate_vector(location - self._center, self._rotation, invert=True) + pos = rotate_vector(location - self._center, self.rotation, invert=True) r = pos.vector[self.radial_axes] - h = pos.vector[self._axis] - inside = (vec_squared(r) <= self._radius**2) & (h >= -.5*self._depth) & (h <= .5*self._depth) + h = pos.vector[self.axis] + inside = (vec_squared(r) <= self.radius**2) & (h >= -.5*self.depth) & (h <= .5*self.depth) return math.any(inside, instance(self)) # union for instance dimensions def approximate_signed_distance(self, location: Union[Tensor, tuple]): - location = math.rotate_vector(location - self._center, self._rotation, invert=True) + location = math.rotate_vector(location - self._center, self.rotation, invert=True) r = location.vector[self.radial_axes] - h = location.vector[self._axis] - top_h = .5*self._depth - bot_h = -.5*self._depth + h = location.vector[self.axis] + top_h = .5*self.depth + bot_h = -.5*self.depth # --- Compute distances --- radial_outward = normalize(r, epsilon=1e-5) - surf_r = radial_outward * self._radius + surf_r = radial_outward * self.radius radial_dist2 = vec_squared(r) - inside_cyl = radial_dist2 <= self._radius**2 + inside_cyl = radial_dist2 <= self.radius**2 clamped_r = where(inside_cyl, r, surf_r) # --- Closest point on bottom / top --- sgn_dist_side = abs(h) - top_h # --- Closest point on cylinder --- - sgn_dist_cyl = length(r) - self._radius + sgn_dist_cyl = length(r) - self.radius # inside (all <= 0) -> largest SDF, outside (any > 0) -> largest positive SDF sgn_dist = maximum(sgn_dist_cyl, sgn_dist_side) return math.min(sgn_dist, instance(self)) def approximate_closest_surface(self, location: Tensor): - location = math.rotate_vector(location - self._center, self._rotation, invert=True) + location = math.rotate_vector(location - self._center, self.rotation, invert=True) r = location.vector[self.radial_axes] - h = location.vector[self._axis] - top_h = .5*self._depth - bot_h = -.5*self._depth + h = location.vector[self.axis] + top_h = .5*self.depth + bot_h = -.5*self.depth # --- Compute distances --- radial_outward = normalize(r, epsilon=1e-5) - surf_r = radial_outward * self._radius + surf_r = radial_outward * self.radius radial_dist2 = vec_squared(r) - inside_cyl = radial_dist2 <= self._radius**2 + inside_cyl = radial_dist2 <= self.radius**2 clamped_r = where(inside_cyl, r, surf_r) # --- Closest point on bottom / top --- above = h >= 0 @@ -156,8 +105,8 @@ def approximate_closest_surface(self, location: Tensor): sgn_dist = minimum(d_flat, d_cyl) * where(inside, -1, 1) delta = surf_point - location normal = where(flat_closer, normal_flat, normal_cyl) - delta = rotate_vector(delta, self._rotation) - normal = rotate_vector(normal, self._rotation) + delta = rotate_vector(delta, self.rotation) + normal = rotate_vector(normal, self.rotation) if instance(self): sgn_dist, delta, normal = math.at_min((sgn_dist, delta, normal), key=sgn_dist, dim=instance(self)) return sgn_dist, delta, normal, None, None @@ -166,37 +115,39 @@ def sample_uniform(self, *shape: math.Shape): raise NotImplementedError def bounding_radius(self): - return math.length(vec(rad=self._radius, dep=.5*self._depth)) + return math.length(vec(rad=self.radius, dep=.5*self.depth)) def bounding_half_extent(self): - if self._rotation is not None: + if self.rotation is not None: return expand(self.bounding_radius(), self._center.shape.only('vector')) - return ccat([.5*self._depth, expand(self._radius, channel(vector=self.radial_axes))], self._center.shape['vector']) + return ccat([.5*self.depth, expand(self.radius, channel(vector=self.radial_axes))], self._center.shape['vector']) def at(self, center: Tensor) -> 'Geometry': - return Cylinder(center, self._radius, self._depth, self._rotation, self._axis, self._variables) + return Cylinder(center, self.radius, self.depth, self.rotation, self.axis, self.variables, self.values) def rotated(self, angle): - if self._rotation is None: - return Cylinder(self._center, self._radius, self._depth, angle, self._axis, self._variables) + if self.rotation is None: + return Cylinder(self._center, self.radius, self.depth, angle, self.axis, self.variables, self.values) else: - matrix = self._rotation @ (angle if dual(angle) else math.rotation_matrix(angle)) - return Cylinder(self._center, self._radius, self._depth, matrix, self._axis, self._variables) + matrix = self.rotation @ (angle if dual(angle) else math.rotation_matrix(angle)) + return Cylinder(self._center, self.radius, self.depth, matrix, self.axis, self.variables, self.values) def scaled(self, factor: Union[float, Tensor]) -> 'Geometry': - return Cylinder(self._center, self._radius * factor, self._depth * factor, self._rotation, self._axis, self._variables) + return Cylinder(self._center, self.radius * factor, self.depth * factor, self.rotation, self.axis, self.variables, self.values) def __getitem__(self, item): item = slicing_dict(self, item) - return Cylinder(self._center[_keep_vector(item)], self._radius[item], self._depth[item], math.slice(self._rotation, item), self._axis, self._variables) + return Cylinder(self._center[_keep_vector(item)], self.radius[item], self.depth[item], math.slice(self.rotation, item), self.axis, self._variables, self.values) @staticmethod def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry': - if all(isinstance(v, Cylinder) for v in values) and all(v._axis == values[0]._axis for v in values): - variables = set() - variables.update(*[set(v._variables) for v in values]) - if any(v._rotation is not None for v in values): - matrices = [v._rotation for v in values] + if all(isinstance(v, Cylinder) for v in values) and all(v.axis == values[0].axis for v in values): + var_attrs = set() + var_attrs.update(*[set(v.variables) for v in values]) + val_attrs = set() + val_attrs.update(*[set(v.values) for v in values]) + if any(v.rotation is not None for v in values): + matrices = [v.rotation for v in values] if any(m is None for m in matrices): any_angle = math.rotation_angles([m for m in matrices if m is not None][0]) unit_matrix = math.rotation_matrix(any_angle * 0) @@ -207,7 +158,7 @@ def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry': center = stack([v.center for v in values], dim, simplify=True, **kwargs) radius = stack([v.radius for v in values], dim, simplify=True, **kwargs) depth = stack([v.depth for v in values], dim, simplify=True, **kwargs) - return Cylinder(center, radius, depth, rotation, values[0]._axis, variables) + return Cylinder(center, radius, depth, rotation, values[0].axis, tuple(var_attrs), tuple(val_attrs)) else: return Geometry.__stack__(values, dim, **kwargs) @@ -242,3 +193,43 @@ def face_shape(self) -> Shape: @property def corners(self) -> Tensor: return math.zeros(self.shape & dual(corners=0)) + + def __eq__(self, other): + return Geometry.__eq__(self, other) + + +def cylinder(center: Tensor = None, + radius: Union[float, Tensor] = None, + depth: Union[float, Tensor] = None, + rotation: Optional[Tensor] = None, + axis=-1, + variables=('center', 'radius', 'depth', 'rotation'), + **center_: Union[float, Tensor]): + """ + Args: + center: Cylinder center as `Tensor` with `vector` dimension. + The spatial dimension order should be specified in the `vector` dimension via item names. + Can be left empty to specify dimensions via kwargs. + radius: Cylinder radius as `float` or `Tensor`. + depth: Cylinder length as `float` or `Tensor`. + rotation: Rotation angle(s) or rotation matrix. + axis: The cylinder is aligned along this axis, perturbed by `rotation`. + variables: Which properties of the cylinder are variable, i.e. traced and optimizable. All by default. + **center_: Specifies center when the `center` argument is not given. Center position by dimension, e.g. `x=0.5, y=0.2`. + """ + if center is not None: + assert isinstance(center, Tensor), f"center must be a Tensor but got {type(center).__name__}" + assert 'vector' in center.shape, f"Sphere center must have a 'vector' dimension." + assert center.shape.get_item_names('vector') is not None, f"Vector dimension must list spatial dimensions as item names. Use the syntax Sphere(x=x, y=y) to assign names." + center = center + else: + center = wrap(tuple(center_.values()), channel(vector=tuple(center_.keys()))) + radius = wrap(radius) + depth = wrap(depth) + rotation = rotation_matrix(rotation) + axis = center.vector.item_names[axis] if isinstance(axis, int) else axis + variables = [{'center': '_center'}.get(v, v) for v in variables] + assert 'vector' not in radius.shape, f"Cylinder radius must not vary along vector but got {radius}" + assert set(variables).issubset(set(all_attributes(Cylinder))), f"Invalid variables: {variables}" + assert axis in center.vector.item_names, f"Cylinder axis {axis} not part of vector dim {center.vector}" + return Cylinder(center, radius, depth, rotation, axis, tuple(variables), ())