diff --git a/phi/field/_resample.py b/phi/field/_resample.py index 3d12d0501..ef05aff45 100644 --- a/phi/field/_resample.py +++ b/phi/field/_resample.py @@ -196,16 +196,16 @@ def scatter_to_centers(self: Field, geometry: Geometry, soft=False, scatter=Fals assert not soft, "Cannot soft-sample when scatter=True" return grid_scatter(self, geometry.bounds, geometry.resolution, outside_handling) else: - assert not isinstance(self._geometry, Point), "Cannot sample Point-like elements with scatter=False" - if may_vary_along(self._values, instance(self._values) & spatial(self._values)): + assert not isinstance(self.geometry, Point), "Cannot sample Point-like elements with scatter=False" + if may_vary_along(self.values, instance(self.values) & spatial(self.values)): raise NotImplementedError("Non-scatter resampling not yet supported for varying values") - idx0 = (instance(self._values) & spatial(self._values)).first_index() + idx0 = (instance(self.values) & spatial(self.values)).first_index() outside = self.boundary.value if isinstance(self.boundary, ConstantExtrapolation) else 0 if soft: frac_inside = self.geometry.approximate_fraction_inside(geometry, balance) - return frac_inside * self._values[idx0] + (1 - frac_inside) * outside + return frac_inside * self.values[idx0] + (1 - frac_inside) * outside else: - return math.where(self.geometry.lies_inside(geometry.center), self._values[idx0], outside) + return math.where(self.geometry.lies_inside(geometry.center), self.values[idx0], outside) def scatter_to_faces(field: Field, geometry: Geometry, extrapolation: Extrapolation, **kwargs) -> Tensor: diff --git a/phi/flow.py b/phi/flow.py index 1ed7476a9..eedd2d388 100644 --- a/phi/flow.py +++ b/phi/flow.py @@ -40,12 +40,12 @@ dsum, isum, ssum, csum, mean, dmean, imean, smean, cmean, median, sign, round, ceil, floor, sqrt, exp, erf, log, log2, log10, sigmoid, soft_plus, sin, cos, tan, sinh, cosh, tanh, arcsin, arccos, arctan, arcsinh, arccosh, arctanh, log_gamma, factorial, incomplete_gamma, scatter, gather, where, nonzero, - rotate_vector as rotate, cross_product as cross, dot, convolve, vec_normalize as normalize, length, maximum, minimum, clip, # vector math + cross_product as cross, dot, convolve, vec_normalize as normalize, length, maximum, minimum, clip, # vector math safe_div, length, is_finite, is_nan, is_inf, # Basic functions jit_compile, jit_compile_linear, minimize, gradient as functional_gradient, gradient, solve_linear, solve_nonlinear, iterate, identity, # jacobian, hessian, custom_gradient # Functional magic assert_close, always_close, equal, close ) -from .geom import union +from .geom import union, rotate, scale from .vis import show, control, plot # Exceptions diff --git a/phi/geom/__init__.py b/phi/geom/__init__.py index 0b32be2f1..574ccafe0 100644 --- a/phi/geom/__init__.py +++ b/phi/geom/__init__.py @@ -10,18 +10,25 @@ See the `phi.geom` module documentation at https://tum-pbs.github.io/PhiFlow/Geometry.html """ from ..math import stack, concat, pack_dims # for compatibility + +# --- Low-level functions --- +from ._geom import Geometry, GeometryException, Point, assert_same_rank, invert, sample_function from ._functions import normal_from_slope -from ._geom import Geometry, GeometryException, Point, assert_same_rank, invert, rotate, sample_function +from ._transform import scale, rotate, rotation_matrix, rotation_angles, rotation_matrix_from_axis_and_angle, rotation_matrix_from_directions + +# --- Geometry types --- from ._box import Box, BaseBox, Cuboid, bounding_box from ._sphere import Sphere from ._cylinder import Cylinder, cylinder from ._grid import UniformGrid, enclosing_grid from ._graph import Graph, graph from ._mesh import Mesh, mesh, load_su2, load_gmsh, load_stl, mesh_from_numpy, build_mesh -from ._transform import embed, infinite_cylinder from ._heightmap import Heightmap from ._sdf_grid import SDFGrid, sample_sdf from ._sdf import SDF, numpy_sdf +from ._embed import embed, infinite_cylinder + +# --- Top-level functions --- from ._geom_ops import union, intersection from ._convert import surface_mesh, as_sdf from ._geom_functions import line_trace diff --git a/phi/geom/_box.py b/phi/geom/_box.py index 84cc55f94..6f21cc6c5 100644 --- a/phi/geom/_box.py +++ b/phi/geom/_box.py @@ -7,6 +7,7 @@ from phi.math import DimFilter from phiml.math import rename_dims, vec, stack, expand, instance from phiml.math._shape import parse_dim_order, dual, non_channel, non_batch +from . import rotate, rotation_matrix from ._geom import Geometry, _keep_vector from ..math import wrap, INF, Shape, channel, Tensor from ..math.magic import slicing_dict @@ -74,7 +75,7 @@ def global_to_local(self, global_position: Tensor, scale=True, origin='lower') - assert origin in ['lower', 'center', 'upper'] origin_loc = getattr(self, origin) pos = global_position if math.always_close(origin_loc, 0) else global_position - origin_loc - pos = math.rotate_vector(pos, self.rotation_matrix, invert=True) + pos = rotate(pos, self.rotation_matrix, invert=True) if scale: pos /= (self.half_size if origin == 'center' else self.size) return pos @@ -83,7 +84,7 @@ def local_to_global(self, local_position, scale=True, origin='lower'): assert origin in ['lower', 'center', 'upper'] origin_loc = getattr(self, origin) pos = local_position * (self.half_size if origin == 'center' else self.size) if scale else local_position - return math.rotate_vector(pos, self.rotation_matrix) + origin_loc + return rotate(pos, self.rotation_matrix) + origin_loc def largest(self, dim: DimFilter) -> 'BaseBox': dim = self.shape.without('vector').only(dim) @@ -137,7 +138,7 @@ def push(self, positions: Tensor, outward: bool = True, shift_amount: float = 0) if instance(self): shift, loc_to_center, rotation_matrix = math.at_min((shift, loc_to_center, rotation_matrix), key=math.vec_length(shift), dim=instance) shift = math.where(abs(shift) > abs(loc_to_center), abs(loc_to_center), shift) # ensure inward shift ends at center - shift = math.rotate_vector(shift, rotation_matrix) + shift = rotate(shift, rotation_matrix) return positions + math.where(loc_to_center < 0, 1, -1) * shift def approximate_closest_surface(self, location: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: @@ -221,7 +222,7 @@ def face_centers(self) -> Tensor: @property def face_normals(self) -> Tensor: unit_vectors = math.to_float(math.range(self.shape['vector']) == math.range(dual(**self.shape['vector'].untyped_dict))) - vectors = math.rotate_vector(unit_vectors, self.rotation_matrix) + vectors = rotate(unit_vectors, self.rotation_matrix) return vectors * math.vec(dual('side'), lower=-1, upper=1) @property @@ -442,7 +443,7 @@ def __init__(self, if 'vector' not in center.shape or center.shape.get_item_names('vector') is None: center = math.expand(center, channel(self._half_size)) self._center = center - self._rotation_matrix = None if rotation is None else math.rotation_matrix(rotation) + self._rotation_matrix = None if rotation is None else rotation_matrix(rotation) self._size_variable = size_variable def __repr__(self): diff --git a/phi/geom/_cylinder.py b/phi/geom/_cylinder.py index f519b36bb..f9c5fbaeb 100644 --- a/phi/geom/_cylinder.py +++ b/phi/geom/_cylinder.py @@ -3,11 +3,10 @@ from typing import Union, Dict, Tuple, Optional, Sequence from phiml import math -from phiml.math import Shape, dual, wrap, Tensor, expand, vec, where, ncat, clip, length, normalize, rotate_vector, minimum, vec_squared, rotation_matrix, channel, instance, stack, maximum, PI, linspace, sin, cos, \ - rotation_matrix_from_directions, sqrt, batch +from phiml.math import (Shape, dual, wrap, Tensor, expand, vec, where, ncat, clip, length, normalize, minimum, vec_squared, channel, instance, stack, maximum, PI, linspace, sin, cos, sqrt, batch) from phiml.math._magic_ops import all_attributes, getitem_dataclass -from phiml.math.magic import slicing_dict -from ._geom import Geometry, _keep_vector +from ._geom import Geometry +from ._transform import rotate, rotation_matrix, rotation_matrix_from_directions from ._sphere import Sphere @@ -47,7 +46,7 @@ def volume(self) -> math.Tensor: @cached_property def up(self): - return math.rotate_vector(vec(**{d: 1 if d == self.axis else 0 for d in self._center.vector.item_names}), self.rotation) + return math.rotate(vec(**{d: 1 if d == self.axis else 0 for d in self._center.vector.item_names}), self.rotation) def with_radius(self, radius: Tensor) -> 'Cylinder': return Cylinder(self._center, wrap(radius), self.depth, self.rotation, self.axis, self.variable_attrs, self.value_attrs) @@ -56,14 +55,14 @@ def with_depth(self, depth: Tensor) -> 'Cylinder': return Cylinder(self._center, self.radius, wrap(depth), self.rotation, self.axis, self.variable_attrs, self.value_attrs) def lies_inside(self, location): - pos = rotate_vector(location - self._center, self.rotation, invert=True) + pos = rotate(location - self._center, self.rotation, invert=True) r = pos.vector[self.radial_axes] h = pos.vector[self.axis] inside = (vec_squared(r) <= self.radius**2) & (h >= -.5*self.depth) & (h <= .5*self.depth) return math.any(inside, instance(self)) # union for instance dimensions def approximate_signed_distance(self, location: Union[Tensor, tuple]): - location = math.rotate_vector(location - self._center, self.rotation, invert=True) + location = math.rotate(location - self._center, self.rotation, invert=True) r = location.vector[self.radial_axes] h = location.vector[self.axis] top_h = .5*self.depth @@ -83,7 +82,7 @@ def approximate_signed_distance(self, location: Union[Tensor, tuple]): return math.min(sgn_dist, instance(self)) def approximate_closest_surface(self, location: Tensor): - location = math.rotate_vector(location - self._center, self.rotation, invert=True) + location = math.rotate(location - self._center, self.rotation, invert=True) r = location.vector[self.radial_axes] h = location.vector[self.axis] top_h = .5*self.depth @@ -112,8 +111,8 @@ def approximate_closest_surface(self, location: Tensor): sgn_dist = minimum(d_flat, d_cyl) * where(inside, -1, 1) delta = surf_point - location normal = where(flat_closer, normal_flat, normal_cyl) - delta = rotate_vector(delta, self.rotation) - normal = rotate_vector(normal, self.rotation) + delta = rotate(delta, self.rotation) + normal = rotate(normal, self.rotation) idx = None if instance(self): sgn_dist, delta, normal, idx = math.min((sgn_dist, delta, normal, range), instance(self), key=sgn_dist) @@ -123,7 +122,7 @@ def sample_uniform(self, *shape: math.Shape): r = Sphere(self._center[self.radial_axes], self.radius).sample_uniform(*shape) h = math.random_uniform(*shape, -.5*self.depth, .5*self.depth) rh = ncat([r, h], self._center.shape['vector']) - return rotate_vector(rh, self.rotation) + return rotate(rh, self.rotation) def bounding_radius(self): return length(vec(rad=self.radius, dep=.5*self.depth), 'vector') @@ -215,7 +214,7 @@ def vertex_rings(self, count: Shape) -> Tensor: c = cos(angle) * self.radius r = stack([s, c], channel(vector=self.radial_axes)) x = ncat([h, r], self._center.shape['vector'], expand_values=True) - return math.rotate_vector(x, self.rotation) + self._center + return math.rotate(x, self.rotation) + self._center raise NotImplementedError diff --git a/phi/geom/_embed.py b/phi/geom/_embed.py new file mode 100644 index 000000000..675cb20dc --- /dev/null +++ b/phi/geom/_embed.py @@ -0,0 +1,158 @@ +from numbers import Number +from typing import Tuple, Union, Dict, Any + +from phiml.math import spatial, channel, stack, expand, INF + +from phi import math +from phi.math import Tensor, Shape +from phiml.math.magic import slicing_dict +from . import BaseBox, Box, Cuboid +from ._geom import Geometry +from ._sphere import Sphere +from phiml.math._shape import parse_dim_order + + +class _EmbeddedGeometry(Geometry): + + def __init__(self, geometry, axes: Tuple[str]): + self.geometry = geometry + self.axes = axes # spatial axis order + + @property + def spatial_rank(self) -> int: + return len(self.axes) + + @property + def center(self) -> Tensor: + g_cen = dict(**self.geometry.bounding_half_extent().vector) + return stack({dim: g_cen.get(dim, 0) for dim in self.vector.item_names}, channel('vector')) + + @property + def shape(self) -> Shape: + return self.geometry.shape.with_dim_size('vector', self.axes) + + @property + def volume(self) -> Tensor: + raise NotImplementedError() + + def unstack(self, dimension: str) -> tuple: + raise NotImplementedError() + + def _down_project(self, location: Tensor): + item_names = list(location.shape.get_item_names('vector')) + for dim in self.axes: + if dim not in self.geometry.shape.get_item_names('vector'): + item_names.remove(dim) + projected_loc = location.vector[item_names] + return projected_loc + + def __getitem__(self, item): + item = slicing_dict(self, item) + if 'vector' in item: + axes = channel(vector=self.axes).after_gather(item).item_names[0] + if all(a in self.geometry.vector.item_names for a in axes): + return self.geometry[item] + item['vector'] = [a for a in axes if a in self.geometry.vector.item_names] + else: + axes = self.axes + projected = self.geometry[item] + if projected.spatial_rank == 0: + return Box(**{a: None for a in axes}) + assert not isinstance(projected, BaseBox), f"_EmbeddedGeometry reduced to a Box but should already have been a box. Was {self.geometry}" + if isinstance(projected, Sphere) and projected.spatial_rank: # 1D spheres are just boxes + box1d = Cuboid(projected.center, expand(projected.radius, projected.center.shape['vector'])) + emb = _EmbeddedGeometry(box1d, axes) + return Cuboid(emb.center, emb.bounding_half_extent()) + return _EmbeddedGeometry(projected, axes) + + def lies_inside(self, location: Tensor) -> Tensor: + return self.geometry.lies_inside(self._down_project(location)) + + def approximate_signed_distance(self, location: Tensor) -> Tensor: + return self.geometry.approximate_signed_distance(self._down_project(location)) + + def sample_uniform(self, *shape: math.Shape) -> Tensor: + raise NotImplementedError() + + def bounding_radius(self) -> Tensor: + raise NotImplementedError() + + def bounding_half_extent(self) -> Tensor: + g_ext = dict(**self.geometry.bounding_half_extent().vector) + return stack({dim: g_ext.get(dim, INF) for dim in self.vector.item_names}, channel('vector')) + + def shifted(self, delta: Tensor) -> 'Geometry': + raise NotImplementedError() + + def at(self, center: Tensor) -> 'Geometry': + raise NotImplementedError() + + def rotated(self, angle: Union[float, Tensor]) -> 'Geometry': + raise NotImplementedError() + + def scaled(self, factor: Union[float, Tensor]) -> 'Geometry': + raise NotImplementedError() + + def __hash__(self): + return hash(self.geometry) + hash(self.axes) + + @property + def boundary_elements(self) -> Dict[Any, Dict[str, slice]]: + return self.geometry.boundary_elements + + @property + def boundary_faces(self) -> Dict[Any, Dict[str, slice]]: + return self.geometry.boundary_faces + + +def embed(geometry: Geometry, projected_dims: Union[math.Shape, str, tuple, list, None]) -> Geometry: + """ + Adds fake spatial dimensions to a geometry. + The geometry value will be constant along the added dimensions, as if it had infinite length in these directions. + + Dimensions that are already present with `geometry` are ignored. + + Args: + geometry: `Geometry` + projected_dims: Additional dimensions + + Returns: + `Geometry` with spatial rank `geometry.spatial_rank + projected_dims.rank`. + """ + if projected_dims is None: + return geometry + axes = parse_dim_order(projected_dims) + embedded_axes = [a for a in axes if a not in geometry.shape.get_item_names('vector')] + if not embedded_axes: + return geometry[axes] + # --- add dims from geometry to axes --- + for name in reversed(geometry.shape.get_item_names('vector')): + if name not in projected_dims: + axes = (name,) + axes + if isinstance(geometry, BaseBox): + box = geometry.corner_representation() + embedded = box * Box(**{dim: None for dim in embedded_axes}) + return embedded[axes] + return _EmbeddedGeometry(geometry, axes) + + +def infinite_cylinder(center=None, radius=None, inf_dim: Union[str, Shape, tuple, list] = None, **center_) -> Geometry: + """ + Creates an infinite cylinder. + This is equal to embedding an `n`-dimensional `Sphere` in `n+1` dimensions. + + See Also: + `Sphere`, `embed` + + Args: + center: Center coordinates without `inf_dim`. Alternatively use keyword arguments. + radius: Cylinder radius. + inf_dim: Dimension along which the cylinder is infinite. + Use `Geometry.rotated()` if the direction does not align with an axis. + **center_: Alternatively specify center coordinates without `inf_dim` as keyword arguments. + + Returns: + `Geometry` + """ + sphere = Sphere(center, radius, **center_) + return embed(sphere, inf_dim) diff --git a/phi/geom/_functions.py b/phi/geom/_functions.py index 2bbf68bca..f4265ae76 100644 --- a/phi/geom/_functions.py +++ b/phi/geom/_functions.py @@ -137,3 +137,5 @@ def distance_line_point(line_offset: Tensor, line_direction: Tensor, point: Tens if not is_direction_normalized: c /= vec_length(line_direction) return c + + diff --git a/phi/geom/_geom.py b/phi/geom/_geom.py index b26faa639..54811e077 100644 --- a/phi/geom/_geom.py +++ b/phi/geom/_geom.py @@ -1,11 +1,10 @@ import warnings from numbers import Number -from typing import Union, Dict, Any, Tuple, Callable +from typing import Union, Dict, Any, Tuple, Callable, TypeVar -from phi import math -from phi.math import Tensor, Shape, non_channel, wrap, shape, Extrapolation -from phi.math.magic import BoundDim, slicing_dict -from phiml.math import non_batch, tensor_like +from phiml import math +from phiml.math import Tensor, Shape, non_channel, wrap, shape, Extrapolation, non_batch, tensor_like +from phiml.math.magic import BoundDim, slicing_dict from phiml.math._magic_ops import variable_attributes, expand, find_differences @@ -747,6 +746,9 @@ def __getitem__(self, item): return Point(self._location[_keep_vector(slicing_dict(self, item))]) +GeometricType = TypeVar("GeometricType", Tensor, Geometry) + + class GeometryException(BaseException): """ Raised when an operation is fundamentally not possible for a `Geometry`. @@ -789,42 +791,6 @@ def _keep_vector(dim_selection: dict) -> dict: return item -def rotate(geometry: Geometry, rot: Union[float, Tensor], pivot: Tensor = None) -> Geometry: - """ - Rotate a `Geometry` about an axis given by `rot` and `pivot`. - - Args: - geometry: `Geometry` to rotate - rot: Rotation, either as Euler angles or rotation matrix. - pivot: Any point lying on the rotation axis. Defaults to the bounding box center. - - Returns: - Rotated `Geometry` - """ - if pivot is None: - pivot = geometry.bounding_box().center - center = pivot + math.rotate_vector(geometry.center - pivot, rot) - return geometry.rotated(rot).at(center) - - -def scale(geometry: Geometry, scale: float | Tensor, pivot: Tensor = None) -> Geometry: - """ - Scale a `Geometry` about a pivot point. - - Args: - geometry: `Geometry` to scale. - scale: Scaling factor. - pivot: Point that stays fixed under the scaling operation. Defaults to the bounding box center. - - Returns: - Rotated `Geometry` - """ - if pivot is None: - pivot = geometry.bounding_box().center - center = pivot + scale * (geometry.center - pivot) - return geometry.scaled(scale).at(center) - - def slice_off_constant_faces(obj, boundary_slices: Dict[Any, Dict[str, slice]], boundary: Extrapolation): """ Removes slices of `obj` where the boundary conditions fully determine the values. diff --git a/phi/geom/_geom_functions.py b/phi/geom/_geom_functions.py index 6b9519134..04e255438 100644 --- a/phi/geom/_geom_functions.py +++ b/phi/geom/_geom_functions.py @@ -2,11 +2,32 @@ from typing import Tuple, Optional from phiml import math -from phiml.math import Tensor, stack, instance, wrap +from phiml.math import Tensor, stack, instance, wrap, shape +from . import Cylinder from ._geom import Geometry +def length(obj: Geometry | Tensor, eps=1e-5) -> Tensor: + """ + Returns the length of a vector `Tensor` or geometric object with a length-like property. + + Args: + obj: + eps: Minimum valid vector length. Use to avoid `inf` gradients for zero-length vectors. + Lengths shorter than `eps` are set to 0. + + Returns: + Length as `Tensor` + """ + if isinstance(obj, Tensor): + assert 'vector' in obj.shape, f"length() requires 'vector' dim but got {type(obj)} with shape {shape(obj)}." + return math.length(obj, 'vector', eps) + elif isinstance(obj, Cylinder): + return obj.depth + raise ValueError(obj) + + def line_trace(geo: Geometry, origin: Tensor, direction: Tensor, side='both', tolerance=None, max_iter=64, step_size=.9, max_line_length=None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]: """ Trace a line until it hits the surface of `geo`. diff --git a/phi/geom/_geom_ops.py b/phi/geom/_geom_ops.py index e00fd71d4..0ff2f22c1 100644 --- a/phi/geom/_geom_ops.py +++ b/phi/geom/_geom_ops.py @@ -11,7 +11,8 @@ from phiml.math.magic import PhiTreeNode from ._box import bounding_box, Box -from ._geom import Geometry, NoGeometry, rotate +from ._geom import Geometry, NoGeometry +from ._transform import rotate from ._geom import InvertedGeometry from ..math import Tensor, instance from ..math.magic import slicing_dict diff --git a/phi/geom/_mesh.py b/phi/geom/_mesh.py index f1a233a66..93b890012 100644 --- a/phi/geom/_mesh.py +++ b/phi/geom/_mesh.py @@ -10,14 +10,15 @@ from phiml.math import to_format, is_sparse, non_channel, non_batch, batch, pack_dims, unstack, tensor, si2d, non_dual, nonzero, stored_indices, stored_values, scatter, \ find_closest, sqrt, where, vec_normalize, argmax, broadcast, to_int32, cross_product, zeros, random_normal, EMPTY_SHAPE, meshgrid, mean, reshaped_numpy, range_tensor, convolve, \ - assert_close, shift, pad, extrapolation, NUMPY, sum as sum_, with_diagonal, flatten, ones_like, dim_mask + assert_close, shift, pad, extrapolation, NUMPY, sum as sum_, with_diagonal, flatten, ones_like, dim_mask, math from phiml.math._magic_ops import getitem_dataclass from phiml.math._sparse import CompactSparseTensor from phiml.math.extrapolation import as_extrapolation, PERIODIC from phiml.math.magic import slicing_dict from . import bounding_box from ._functions import plane_sgn_dist -from ._geom import Geometry, Point, scale, NoGeometry +from ._geom import Geometry, Point, NoGeometry +from ._transform import scale from ._box import Box, BaseBox from ._graph import Graph, graph from ..math import Tensor, Shape, channel, shape, instance, dual, rename_dims, expand, spatial, wrap, sparse_tensor, stack, vec_length, tensor_like, \ @@ -31,6 +32,8 @@ def __call__(cls, elements: Tensor, element_rank: int, boundaries: Dict[str, Dict[str, slice]], + periodic: Sequence[str], + face_format: str = 'csc', max_cell_walk: int = None, variables=('vertices',), values=()): @@ -42,8 +45,8 @@ def __call__(cls, vertices = Point(vertices) if max_cell_walk is None: max_cell_walk = 2 if instance(elements).volume > 1 else 1 - result = cls.__new__(cls, vertices, elements, element_rank, boundaries, max_cell_walk, variables, values) - result.__init__(vertices, elements, element_rank, boundaries, max_cell_walk, variables, values) # also calls __post_init__() + result = cls.__new__(cls, vertices, elements, element_rank, boundaries, periodic, face_format, max_cell_walk, variables, values) + result.__init__(vertices, elements, element_rank, boundaries, periodic, face_format, max_cell_walk, variables, values) # also calls __post_init__() return result @@ -319,7 +322,7 @@ def filter_unused_vertices(self) -> 'Mesh': else: filtered_coo = coo_matrix((coo.data, (coo.row, new_index)), shape=(instance(self.elements).volume, instance(vertices).volume)) # ToDo keep sparse format elements = wrap(filtered_coo, self.elements.shape.without_sizes()) - return Mesh(vertices, elements, self.element_rank, self.boundaries, self._center, self._volume, self._normals, self.face_centers, self.face_normals, self.face_areas, None, v_normals, vertex_connectivity, self._element_connectivity, self._max_cell_walk) + return Mesh(vertices, elements, self.element_rank, self.boundaries, self.center, self._volume, self.normals, self.face_centers, self.face_normals, self.face_areas, None, v_normals, vertex_connectivity, self._element_connectivity, self.max_cell_walk) @property def volume(self) -> Tensor: @@ -358,37 +361,35 @@ def vertex_positions(self) -> Tensor: return si2d(self.vertices.center) def lies_inside(self, location: Tensor) -> Tensor: - idx = find_closest(self._center, location) - for i in range(self._max_cell_walk): - idx, leaves_mesh, is_outside, *_ = self.cell_walk_towards(location, idx, allow_exit=i == self._max_cell_walk - 1) + idx = find_closest(self.center, location) + for i in range(self.max_cell_walk): + idx, leaves_mesh, is_outside, *_ = self.cell_walk_towards(location, idx, allow_exit=i == self.max_cell_walk - 1) return ~(leaves_mesh & is_outside) def approximate_signed_distance(self, location: Union[Tensor, tuple]) -> Tensor: if self.element_rank == 2 and self.spatial_rank == 3: - closest_elem = find_closest(self._center, location) - center = self._center[closest_elem] - normal = self._normals[closest_elem] + closest_elem = find_closest(self.center, location) + center = self.center[closest_elem] + normal = self.normals[closest_elem] return plane_sgn_dist(center, normal, location) - if self._center is None: - raise NotImplementedError("Mesh.approximate_signed_distance only available when faces are built.") - idx = find_closest(self._center, location) - for i in range(self._max_cell_walk): + idx = find_closest(self.center, location) + for i in range(self.max_cell_walk): idx, leaves_mesh, is_outside, distances, nb_idx = self.cell_walk_towards(location, idx, allow_exit=False) - return max(distances, dual) + return math.max(distances, dual) def approximate_closest_surface(self, location: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: if self.element_rank == 2 and self.spatial_rank == 3: - closest_elem = find_closest(self._center, location) - center = self._center[closest_elem] - normal = self._normals[closest_elem] + closest_elem = find_closest(self.center, location) + center = self.center[closest_elem] + normal = self.normals[closest_elem] face_size = sqrt(self._volume) * 4 size = face_size[closest_elem] sgn_dist = plane_sgn_dist(center, normal, location) delta = center - location # this is not accurate... outward = where(abs(sgn_dist) < size, normal, vec_normalize(delta)) return sgn_dist, delta, outward, None, closest_elem - # idx = find_closest(self._center, location) - # for i in range(self._max_cell_walk): + # idx = find_closest(self.center, location) + # for i in range(self.max_cell_walk): # idx, leaves_mesh, is_outside, distances, nb_idx = self.cell_walk_towards(location, idx, allow_exit=False) # sgn_dist = max(distances, dual) # cell_normals = self.face_normals[idx] @@ -414,7 +415,7 @@ def cell_walk_towards(self, location: Tensor, start_cell_idx: Tensor, allow_exit closest_face_centers = self.face_centers[start_cell_idx] offsets = closest_normals.vector @ closest_face_centers.vector # this dot product could be cashed in the mesh distances = closest_normals.vector @ location.vector - offsets - is_outside = any(distances > 0, dual) + is_outside = math.any(distances > 0, dual) nb_idx = argmax(distances, dual).index[0] # cell index or boundary face index leaves_mesh = nb_idx >= instance(self).volume next_idx = where(is_outside & (~leaves_mesh | allow_exit), nb_idx, start_cell_idx) @@ -427,13 +428,13 @@ def bounding_radius(self) -> Tensor: center = self.elements * self.center vert_pos = rename_dims(self.vertices.center, instance, dual) dist_to_vert = vec_length(vert_pos - center) - max_dist = max(dist_to_vert, dual) + max_dist = math.max(dist_to_vert, dual) return max_dist def bounding_half_extent(self) -> Tensor: center = self.elements * self.center vert_pos = rename_dims(self.vertices.center, instance, dual) - max_delta = max(abs(vert_pos - center), dual) + max_delta = math.max(abs(vert_pos - center), dual) return max_delta def bounding_box(self) -> 'BaseBox': @@ -441,7 +442,7 @@ def bounding_box(self) -> 'BaseBox': @property def bounds(self): - return Box(min(self.vertices.center, instance), max(self.vertices.center, instance)) + return Box(math.min(self.vertices.center, instance), math.max(self.vertices.center, instance)) def at(self, center: Tensor) -> 'Mesh': if instance(self.elements) in center.shape: @@ -466,8 +467,8 @@ def shifted(self, delta: Tensor) -> 'Mesh': else: # shift everything # ToDo transfer cached properties vertices = self.vertices.shifted(delta) - center = self._center + delta - return Mesh(vertices, self.elements, self.element_rank, self.boundaries, center, self._volume, self._normals, self.face_centers, self.face_normals, self.face_areas, self.face_vertices, self._vertex_normals, self._vertex_connectivity, self._element_connectivity, self._max_cell_walk) + center = self.center + delta + return Mesh(vertices, self.elements, self.element_rank, self.boundaries, center, self._volume, self.normals, self.face_centers, self.face_normals, self.face_areas, self.face_vertices, self._vertex_normals, self._vertex_connectivity, self._element_connectivity, self.max_cell_walk) def rotated(self, angle: Union[float, Tensor]) -> 'Geometry': raise NotImplementedError @@ -475,10 +476,10 @@ def rotated(self, angle: Union[float, Tensor]) -> 'Geometry': def scaled(self, factor: float | Tensor) -> 'Geometry': pivot = self.bounds.center vertices = scale(self.vertices, factor, pivot) - center = scale(Point(self._center), factor, pivot).center + center = scale(Point(self.center), factor, pivot).center volume = self._volume * factor**self.element_rank if self._volume is not None else None face_areas = None - return Mesh(vertices, self.elements, self.element_rank, self.boundaries, center, volume, self._normals, self.face_centers, self.face_normals, face_areas, self.face_vertices, self._vertex_normals, self._vertex_connectivity, self._element_connectivity, self._max_cell_walk) + return Mesh(vertices, self.elements, self.element_rank, self.boundaries, center, volume, self.normals, self.face_centers, self.face_normals, face_areas, self.face_vertices, self._vertex_normals, self._vertex_connectivity, self._element_connectivity, self.max_cell_walk) def __getitem__(self, item): item: dict = slicing_dict(self, item) @@ -657,7 +658,7 @@ def mesh(vertices: Geometry | Tensor, assert all(p in vertices.vector.item_names for p in periodic_dims), f"Periodic boundaries must be named after axes, e.g. {vertices.vector.item_names} but got {periodic}" for base in periodic_dims: assert base+'+' in boundaries and base+'-' in boundaries, f"Missing boundaries for periodicity '{base}'. Make sure '{base}+' and '{base}-' are keys in boundaries dict, got {tuple(boundaries)}" - return Mesh(vertices, elements, element_rank, boundaries, periodic_dims, face_format, max_cell_walk) + return Mesh(vertices, elements, element_rank, boundaries, periodic_dims, face_format=face_format, max_cell_walk=max_cell_walk) def build_faces_2d(vertices: Tensor, # (vertices:i, vector) @@ -717,12 +718,11 @@ def build_faces_2d(vertices: Tensor, # (vertices:i, vector) bnd_el_coo_v_idx = coo_matrix((bnd_coo_vert+1, (bnd_coo_idx, bnd_coo_vert)), shape=(end, instance(vertices).size)) ptr = np.cumsum(np.asarray(el_coo.sum(1))) first_ptr = np.pad(ptr, (1, 0))[:-1] - last_ptr = ptr - 1 alt1 = np.arange(el_coo.data.size) % 2 alt2 = (1 - alt1) alt2[first_ptr] = alt1[first_ptr] alt3 = (1 - alt1) - alt3[last_ptr] = alt1[last_ptr] + alt3[ptr - 1] = alt1[ptr - 1] v_indices = [] for alt in [alt1, (1-alt1), alt2, alt3]: el_coo.data = alt + 1e-10 @@ -799,21 +799,21 @@ def build_mesh(bounds: Box = None, vert_pos = meshgrid(resolution + 1) / resolution * bounds.size + bounds.lower # centroids = UniformGrid(resolution, bounds).center dx = bounds.size / resolution - regular_size = min(dx, channel) + regular_size = math.min(dx, channel) vert_pos, polygons, boundaries = build_quadrilaterals(vert_pos, resolution, obstacles, bounds, regular_size * max_squish) if max_squish is not None: lin_vert_pos = pack_dims(vert_pos, spatial, instance('polygon')) corner_pos = lin_vert_pos[polygons] - min_pos = min(corner_pos, '~polygon') - max_pos = max(corner_pos, '~polygon') - cell_sizes = min(max_pos - min_pos, 'vector') + min_pos = math.min(corner_pos, '~polygon') + max_pos = math.max(corner_pos, '~polygon') + cell_sizes = math.min(max_pos - min_pos, 'vector') too_small = cell_sizes < regular_size * max_squish # --- remove too small cells --- removed = polygons[too_small] removed_centers = mean(lin_vert_pos[removed], '~polygon') kept_vert = removed[{'~polygon': 0}] vert_pos = scatter(lin_vert_pos, kept_vert, removed_centers) - vertex_map = range(non_channel(lin_vert_pos)) + vertex_map = math.range(non_channel(lin_vert_pos)) vertex_map = scatter(vertex_map, rename_dims(removed, '~polygon', instance('poly_list')), expand(kept_vert, instance(poly_list=4))) polygons = polygons[~too_small] polygons = vertex_map[polygons] @@ -825,7 +825,7 @@ def build_single_mesh(vert_pos, polygons, boundaries): polygon_list = reshaped_numpy(polygons, [..., dual]) boundaries = {b: edges.numpy('edges,~vert') for b, edges in boundaries.items()} return mesh_from_numpy(points_np, polygon_list, boundaries, cell_dim=cell_dim, face_format=face_format) - return map(build_single_mesh, vert_pos, polygons, boundaries, dims=batch) + return math.map(build_single_mesh, vert_pos, polygons, boundaries, dims=batch) def build_quadrilaterals(vert_pos, resolution: Shape, obstacles: Dict[str, Geometry], bounds: Box, min_size) -> Tuple[Tensor, Tensor, dict]: @@ -904,7 +904,7 @@ def face_curvature(mesh: Mesh): curvature_tensor = .5 / mesh.volume * (e1 * dn1 + e2 * dn2 + e3 * dn3) scalar_curvature = sum_([curvature_tensor[{'vector': d, '~vector': d}] for d in mesh.vector.item_names], '0') return curvature_tensor, scalar_curvature - # vec_curvature = max(v_normals, dual) - min(v_normals, dual) # positive / negative + # vec_curvature = math.max(v_normals, dual) - math.min(v_normals, dual) # positive / negative def save_tri_mesh(file: str, mesh: Mesh, **extra_data): diff --git a/phi/geom/_transform.py b/phi/geom/_transform.py index 675cb20dc..c70b42a52 100644 --- a/phi/geom/_transform.py +++ b/phi/geom/_transform.py @@ -1,158 +1,188 @@ -from numbers import Number -from typing import Tuple, Union, Dict, Any - -from phiml.math import spatial, channel, stack, expand, INF - -from phi import math -from phi.math import Tensor, Shape -from phiml.math.magic import slicing_dict -from . import BaseBox, Box, Cuboid -from ._geom import Geometry -from ._sphere import Sphere -from phiml.math._shape import parse_dim_order - - -class _EmbeddedGeometry(Geometry): - - def __init__(self, geometry, axes: Tuple[str]): - self.geometry = geometry - self.axes = axes # spatial axis order - - @property - def spatial_rank(self) -> int: - return len(self.axes) - - @property - def center(self) -> Tensor: - g_cen = dict(**self.geometry.bounding_half_extent().vector) - return stack({dim: g_cen.get(dim, 0) for dim in self.vector.item_names}, channel('vector')) - - @property - def shape(self) -> Shape: - return self.geometry.shape.with_dim_size('vector', self.axes) - - @property - def volume(self) -> Tensor: - raise NotImplementedError() - - def unstack(self, dimension: str) -> tuple: - raise NotImplementedError() - - def _down_project(self, location: Tensor): - item_names = list(location.shape.get_item_names('vector')) - for dim in self.axes: - if dim not in self.geometry.shape.get_item_names('vector'): - item_names.remove(dim) - projected_loc = location.vector[item_names] - return projected_loc - - def __getitem__(self, item): - item = slicing_dict(self, item) - if 'vector' in item: - axes = channel(vector=self.axes).after_gather(item).item_names[0] - if all(a in self.geometry.vector.item_names for a in axes): - return self.geometry[item] - item['vector'] = [a for a in axes if a in self.geometry.vector.item_names] - else: - axes = self.axes - projected = self.geometry[item] - if projected.spatial_rank == 0: - return Box(**{a: None for a in axes}) - assert not isinstance(projected, BaseBox), f"_EmbeddedGeometry reduced to a Box but should already have been a box. Was {self.geometry}" - if isinstance(projected, Sphere) and projected.spatial_rank: # 1D spheres are just boxes - box1d = Cuboid(projected.center, expand(projected.radius, projected.center.shape['vector'])) - emb = _EmbeddedGeometry(box1d, axes) - return Cuboid(emb.center, emb.bounding_half_extent()) - return _EmbeddedGeometry(projected, axes) - - def lies_inside(self, location: Tensor) -> Tensor: - return self.geometry.lies_inside(self._down_project(location)) - - def approximate_signed_distance(self, location: Tensor) -> Tensor: - return self.geometry.approximate_signed_distance(self._down_project(location)) - - def sample_uniform(self, *shape: math.Shape) -> Tensor: - raise NotImplementedError() - - def bounding_radius(self) -> Tensor: - raise NotImplementedError() - - def bounding_half_extent(self) -> Tensor: - g_ext = dict(**self.geometry.bounding_half_extent().vector) - return stack({dim: g_ext.get(dim, INF) for dim in self.vector.item_names}, channel('vector')) - - def shifted(self, delta: Tensor) -> 'Geometry': - raise NotImplementedError() - - def at(self, center: Tensor) -> 'Geometry': - raise NotImplementedError() - - def rotated(self, angle: Union[float, Tensor]) -> 'Geometry': - raise NotImplementedError() - - def scaled(self, factor: Union[float, Tensor]) -> 'Geometry': - raise NotImplementedError() - - def __hash__(self): - return hash(self.geometry) + hash(self.axes) - - @property - def boundary_elements(self) -> Dict[Any, Dict[str, slice]]: - return self.geometry.boundary_elements +from typing import Optional - @property - def boundary_faces(self) -> Dict[Any, Dict[str, slice]]: - return self.geometry.boundary_faces - - -def embed(geometry: Geometry, projected_dims: Union[math.Shape, str, tuple, list, None]) -> Geometry: +from phiml import math +from phiml.math import Tensor, channel, rename_dims, wrap, shape, normalize, cross_product, dual, stack, length + +from ._geom import Geometry, GeometricType + + +def scale(obj: GeometricType, scale: float | Tensor, pivot: Tensor = None, dim='vector') -> GeometricType: + """ + Scale a `Geometry` or vector `Tensor` about a pivot point. + + Args: + obj: `Geometry` to scale. + scale: Scaling factor. + pivot: Point that stays fixed under the scaling operation. Defaults to the bounding box center. + + Returns: + Rotated `Geometry` + """ + if scale is None: + return obj + if isinstance(obj, Geometry): + if pivot is None: + pivot = obj.bounding_box().center + center = pivot + scale * (obj.center - pivot) + return obj.scaled(scale).at(center) + elif isinstance(obj, Tensor): + assert 'vector' in obj.shape, f"vector must have exactly a channel dimension named 'vector'" + if pivot is None: + return obj * scale + raise NotImplementedError + raise ValueError(obj) + + +def rotate(obj: GeometricType, rot: float | Tensor | None, invert=False, pivot: Tensor | str = 'bounds') -> GeometricType: """ - Adds fake spatial dimensions to a geometry. - The geometry value will be constant along the added dimensions, as if it had infinite length in these directions. + Rotate a vector or `Geometry` about the `pivot`. + + Args: + obj: n-dimensional vector `Tensor` or `Geometry`. + rot: Euler angle(s) or rotation matrix. + `None` is interpreted as no rotation. + invert: Whether to apply the inverse rotation. + pivot: Either a point (`Tensor`) lying on the rotation axis or one of the following strings: 'bounds', 'individual'. + Vector tensors are rotated about the origin if `pivot` is not given as a `Tensor`. - Dimensions that are already present with `geometry` are ignored. + Returns: + Rotated vector as `Tensor` + """ + if rot is None: + return obj + if isinstance(obj, Geometry): + if pivot is None: + pivot = obj.bounding_box().center + center = pivot + rotate(obj.center - pivot, rot) + return obj.rotated(rot).at(center) + elif isinstance(obj, Tensor): + assert 'vector' in obj.shape, f"vector must have exactly a channel dimension named 'vector'" + matrix = rotation_matrix(rot) + if invert: + matrix = rename_dims(matrix, '~vector,vector', matrix.shape['vector'] + matrix.shape['~vector']) + assert matrix.vector.dual.size == obj.vector.size, f"Rotation matrix from {rot.shape} is {matrix.vector.dual.size}D but vector {obj.shape} is {obj.vector.size}D." + return math.dot(matrix, '~vector', obj, 'vector') + + +def rotation_matrix(x: float | math.Tensor | None, matrix_dim=channel('vector')) -> Optional[Tensor]: + """ + Create a 2D or 3D rotation matrix from the corresponding angle(s). Args: - geometry: `Geometry` - projected_dims: Additional dimensions + x: + 2D: scalar angle + 3D: Either vector pointing along the rotation axis with rotation angle as length or Euler angles. + Euler angles need to be laid out along a `angle` channel dimension with dimension names listing the spatial dimensions. + E.g. a 90° rotation about the z-axis is represented by `vec('angles', x=0, y=0, z=PI/2)`. + If a rotation matrix is passed for `angle`, it is returned without modification. + matrix_dim: Matrix dimension for 2D rotations. In 3D, the channel dimension of angle is used. Returns: - `Geometry` with spatial rank `geometry.spatial_rank + projected_dims.rank`. + Matrix containing `matrix_dim` in primal and dual form as well as all non-channel dimensions of `x`. """ - if projected_dims is None: - return geometry - axes = parse_dim_order(projected_dims) - embedded_axes = [a for a in axes if a not in geometry.shape.get_item_names('vector')] - if not embedded_axes: - return geometry[axes] - # --- add dims from geometry to axes --- - for name in reversed(geometry.shape.get_item_names('vector')): - if name not in projected_dims: - axes = (name,) + axes - if isinstance(geometry, BaseBox): - box = geometry.corner_representation() - embedded = box * Box(**{dim: None for dim in embedded_axes}) - return embedded[axes] - return _EmbeddedGeometry(geometry, axes) - - -def infinite_cylinder(center=None, radius=None, inf_dim: Union[str, Shape, tuple, list] = None, **center_) -> Geometry: + if x is None: + return None + if isinstance(x, Tensor) and '~vector' in x.shape and 'vector' in x.shape.channel and x.shape.get_size('~vector') == x.shape.get_size('vector'): + return x # already a rotation matrix + elif 'angle' in shape(x) and shape(x).get_size('angle') == 3: # 3D Euler angles + assert channel(x).rank == 1 and channel(x).size == 3, f"x for 3D rotations needs to be a 3-vector but got {x}" + s1, s2, s3 = math.sin(x).angle # x, y, z + c1, c2, c3 = math.cos(x).angle + matrix_dim = matrix_dim.with_size(shape(x).get_item_names('angle')) + return wrap([[c3 * c2, c3 * s2 * s1 - s3 * c1, c3 * s2 * c1 + s3 * s1], + [s3 * c2, s3 * s2 * s1 + c3 * c1, s3 * s2 * c1 - c3 * s1], + [-s2, c2 * s1, c2 * c1]], matrix_dim, matrix_dim.as_dual()) # Rz * Ry * Rx (1. rotate about X by first angle) + elif 'vector' in shape(x) and shape(x).get_size('vector') == 3: # 3D axis + x + angle = length(x) + s, c = math.sin(angle), math.cos(angle) + t = 1 - c + k1, k2, k3 = normalize(x, epsilon=1e-12).vector + matrix_dim = matrix_dim.with_size(shape(x).get_item_names('vector')) + return wrap([[c + k1**2 * t, k1 * k2 * t - k3 * s, k1 * k3 * t + k2 * s], + [k2 * k1 * t + k3 * s, c + k2**2 * t, k2 * k3 * t - k1 * s], + [k3 * k1 * t - k2 * s, k3 * k2 * t + k1 * s, c + k3**2 * t]], matrix_dim, matrix_dim.as_dual()) + else: # 2D rotation + sin = wrap(math.sin(x)) + cos = wrap(math.cos(x)) + return wrap([[cos, -sin], [sin, cos]], matrix_dim, matrix_dim.as_dual()) + + +def rotation_angles(rot: Tensor): """ - Creates an infinite cylinder. - This is equal to embedding an `n`-dimensional `Sphere` in `n+1` dimensions. + Compute the scalar x in 2D or the Euler angles in 3D from a given rotation matrix. + This function returns one valid solution but often, there are multiple solutions. - See Also: - `Sphere`, `embed` + Args: + rot: Rotation matrix as created by `phi.math.rotation_matrix()`. + Must have exactly one channel and one dual dimension with equally-ordered elements. + + Returns: + Scalar x in 2D, Euler angles + """ + assert channel(rot).rank == 1 and dual(rot).rank == 1, f"Rotation matrix must have one channel and one dual dimension but got {rot.shape}" + if channel(rot).size == 2: + cos = rot[{channel: 0, dual: 0}] + sin = rot[{channel: 1, dual: 0}] + return math.arctan(sin, divide_by=cos) + elif channel(rot).size == 3: + a2 = -math.arcsin(rot[{channel: 2, dual: 0}]) # ToDo handle [2, 0] == 1 (i.e. cos_theta == 0) + cos2 = math.cos(a2) + a1 = math.arctan(rot[{channel: 2, dual: 1}] / cos2, divide_by=rot[{channel: 2, dual: 2}] / cos2) + a3 = math.arctan(rot[{channel: 1, dual: 0}] / cos2, divide_by=rot[{channel: 0, dual: 0}] / cos2) + regular_sol = stack([a1, a2, a3], channel(angle=channel(rot).item_names[0])) + # --- pole case cos(theta) == 1 --- + a3_pole = 0 # unconstrained + bottom_pole = rot[{channel: 2, dual: 0}] < 0 + a2_pole = math.where(bottom_pole, 1.57079632679, -1.57079632679) + a1_pole = math.where(bottom_pole, math.arctan(rot[{channel: 0, dual: 1}], divide_by=rot[{channel: 0, dual: 2}]), math.arctan(-rot[{channel: 0, dual: 1}], divide_by=-rot[{channel: 0, dual: 2}])) + pole_sol = stack([a1_pole, a2_pole, a3_pole], channel(regular_sol)) + return math.where(abs(rot[{channel: 2, dual: 0}]) >= 1, pole_sol, regular_sol) + else: + raise ValueError(f"") + + +def rotation_matrix_from_directions(source_dir: Tensor, target_dir: Tensor, vec_dim: str = 'vector', epsilon=None) -> Tensor: + """ + Computes a rotation matrix A, such that `target_dir = A @ source_dir` + + Args: + source_dir: Two or three-dimensional vector. `Tensor` with channel dim called 'vector'. + target_dir: Two or three-dimensional vector. `Tensor` with channel dim called 'vector'. + + Returns: + Rotation matrix as `Tensor` with 'vector' dim and its dual counterpart. + """ + if source_dir.vector.size == 3: + source_dir = normalize(source_dir, vec_dim, epsilon=epsilon) + target_dir = normalize(target_dir, vec_dim, epsilon=epsilon) + axis = cross_product(source_dir, target_dir) + lim = 1-epsilon if epsilon is not None else 1 + angle = math.arccos(math.clip(source_dir.vector @ target_dir.vector, -lim, lim)) + return rotation_matrix_from_axis_and_angle(axis, angle, is_axis_normalized=False, epsilon=epsilon) + raise NotImplementedError + + +def rotation_matrix_from_axis_and_angle(axis: Tensor, angle: float | Tensor, vec_dim='vector', is_axis_normalized=False, epsilon=1e-5) -> Tensor: + """ + Computes a rotation matrix that rotates by `angle` around `axis`. Args: - center: Center coordinates without `inf_dim`. Alternatively use keyword arguments. - radius: Cylinder radius. - inf_dim: Dimension along which the cylinder is infinite. - Use `Geometry.rotated()` if the direction does not align with an axis. - **center_: Alternatively specify center coordinates without `inf_dim` as keyword arguments. + axis: 3D vector. `Tensor` with channel dim called 'vector'. + angle: Rotation angle. + is_axis_normalized: Whether `axis` has length 1. + epsilon: Minimum axis length. For shorter axes, the unit matrix is returned. Returns: - `Geometry` + Rotation matrix as `Tensor` with 'vector' dim and its dual counterpart. """ - sphere = Sphere(center, radius, **center_) - return embed(sphere, inf_dim) + if axis.vector.size == 3: # Rodrigues' rotation formula + axis = normalize(axis, vec_dim, epsilon=epsilon, allow_zero=False) if not is_axis_normalized else axis + kx, ky, kz = axis.vector + s = math.sin(angle) + c = 1 - math.cos(angle) + return wrap([ + (1 - c*(ky*ky+kz*kz), -kz*s + c*(kx*ky), ky*s + c*(kx*kz)), + ( kz*s + c*(kx*ky), 1 - c*(kx*kx+kz*kz), -kx*s + c*(ky * kz)), + ( -ky*s + c*(kx*kz), kx*s + c*(ky * kz), 1 - c*(kx*kx+ky*ky)), + ], axis.shape['vector'], axis.shape['vector'].as_dual()) + raise NotImplementedError \ No newline at end of file diff --git a/phi/vis/_matplotlib/_matplotlib_plots.py b/phi/vis/_matplotlib/_matplotlib_plots.py index 9bd5a80b4..4be31fc28 100644 --- a/phi/vis/_matplotlib/_matplotlib_plots.py +++ b/phi/vis/_matplotlib/_matplotlib_plots.py @@ -16,10 +16,10 @@ from phi import math from phi.field import StaggeredGrid, Field, CenteredGrid -from phi.geom import Sphere, BaseBox, Point, Box, Mesh, Graph, SDFGrid, SDF, UniformGrid +from phi.geom import Sphere, BaseBox, Point, Box, Mesh, Graph, SDFGrid, SDF, UniformGrid, rotate from phi.geom._heightmap import Heightmap from phi.geom._geom_ops import GeometryStack -from phi.geom._transform import _EmbeddedGeometry +from phi.geom._embed import _EmbeddedGeometry from phi.math import Tensor, channel, spatial, instance, non_channel, Shape, reshaped_numpy, shape from phi.vis._vis_base import display_name, PlottingLibrary, Recipe, index_label, only_stored_elements, to_field from phiml.math import wrap @@ -662,7 +662,7 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten lower_y = y - h2 else: angles = reshaped_numpy(math.rotation_angles(data.geometry.rotation_matrix), [data.shape.non_channel]) - lower_x, lower_y = reshaped_numpy(data.geometry.center - math.rotate_vector(data.geometry.half_size, data.geometry.rotation_matrix), ['vector', data.shape.non_channel]) + lower_x, lower_y = reshaped_numpy(data.geometry.center - rotate(data.geometry.half_size, data.geometry.rotation_matrix), ['vector', data.shape.non_channel]) shapes = [plt.Rectangle((lxi, lyi), w2i * 2, h2i * 2, angle=ang*180/np.pi, linewidth=1, edgecolor='white', alpha=a, facecolor=ci) for lxi, lyi, w2i, h2i, ang, ci, a in zip(lower_x, lower_y, w2, h2, angles, mpl_colors, alphas)] axis.add_collection(matplotlib.collections.PatchCollection(shapes, match_original=True)) elif isinstance(data.geometry, Mesh):