Skip to content

Commit

Permalink
Polars version update, typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hexane360 committed Sep 21, 2024
1 parent 2ba42bc commit 994910b
Show file tree
Hide file tree
Showing 16 changed files with 427 additions and 92 deletions.
4 changes: 3 additions & 1 deletion atomlib/atomcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
import copy
import typing as t

from typing_extensions import Self
import numpy
from numpy.typing import NDArray, ArrayLike
import polars
import polars.dataframe.group_by

from typing_extensions import ParamSpec, Concatenate
from .bbox import BBox3D
from .types import VecLike, to_vec3, ParamSpec, Concatenate, Self
from .types import VecLike, to_vec3
from .transform import LinearTransform3D, AffineTransform3D, Transform3D, IntoTransform3D
from .cell import CoordinateFrame, HasCell, Cell
from .atoms import HasAtoms, Atoms, IntoAtoms, AtomSelection, AtomValues
Expand Down
62 changes: 32 additions & 30 deletions atomlib/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
from io import StringIO
import typing as t

from typing_extensions import Self, ParamSpec, Concatenate, TypeAlias
import numpy
from numpy.typing import ArrayLike, NDArray
import polars
import polars.dataframe.group_by
import polars.datatypes
import polars.interchange.dataframe
import polars.testing
import polars.type_aliases
import polars._typing
from polars.schema import Schema

from .types import to_vec3, VecLike, ParamSpec, Concatenate, TypeAlias, Self
from .types import to_vec3, VecLike
from .bbox import BBox3D
from .elem import get_elem, get_sym, get_mass
from .transform import Transform3D, IntoTransform3D, AffineTransform3D
Expand Down Expand Up @@ -59,6 +61,7 @@ def _is_abstract(cls: t.Type) -> bool:
return bool(getattr(cls, "__abstractmethods__", False))


"""
def _polars_to_numpy_dtype(dtype: t.Type[polars.DataType]) -> numpy.dtype:
from polars.datatypes import dtype_to_ctype
if dtype == polars.Boolean:
Expand All @@ -67,14 +70,15 @@ def _polars_to_numpy_dtype(dtype: t.Type[polars.DataType]) -> numpy.dtype:
return numpy.dtype(dtype_to_ctype(dtype))
except NotImplementedError:
return numpy.dtype(object)
"""


def _get_symbol_mapping(df: t.Union[polars.DataFrame, HasAtoms], mapping: t.Mapping[str, t.Any], ty: t.Type[polars.DataType]) -> polars.Expr:
syms = df['symbol'].unique()
if (missing := set(syms) - set(mapping.keys())):
raise ValueError(f"Could not remap symbols {', '.join(map(repr, missing))}")

return polars.col('symbol').replace(mapping, default=None, return_dtype=ty)
return polars.col('symbol').replace_strict(mapping, default=None, return_dtype=ty)


def _values_to_expr(df: t.Union[polars.DataFrame, HasAtoms], values: AtomValues, ty: t.Type[polars.DataType]) -> polars.Expr:
Expand All @@ -97,12 +101,8 @@ def _values_to_numpy(df: t.Union[polars.DataFrame, HasAtoms], values: AtomValues
# syms = df.select(polars.col('symbol').filter(values.is_null())).unique().to_series().to_list()
# raise ValueError(f"Could not remap symbols {', '.join(map(repr, syms))}")
if isinstance(values, polars.Series):
if ty == polars.Boolean:
# force conversion to numpy (unpacked) bool
return values.cast(polars.UInt8).to_numpy().astype(numpy.bool_)
return numpy.broadcast_to(values.cast(ty).to_numpy(), len(df))

return numpy.broadcast_to(values.to_numpy(), len(df))
values = values.cast(ty)
return numpy.broadcast_to(values, len(df))


def _selection_to_expr(df: t.Union[polars.DataFrame, HasAtoms], selection: t.Optional[AtomSelection] = None) -> polars.Expr:
Expand All @@ -126,7 +126,7 @@ def _select_schema(df: t.Union[polars.DataFrame, HasAtoms], schema: SchemaDict)
polars.col(col).cast(ty, strict=True)
for (col, ty) in schema.items()
])
except (polars.ComputeError, polars.ColumnNotFoundError):
except (polars.exceptions.ComputeError, polars.exceptions.ColumnNotFoundError):
raise TypeError(f"Failed to cast '{df.__class__.__name__}' with schema '{df.schema}' to schema '{schema}'.")


Expand All @@ -137,9 +137,10 @@ def _with_columns_stacked(df: polars.DataFrame, cols: t.Sequence[str], out_col:
i = df.get_column_index(cols[0])
dtype = df[cols[0]].dtype

arr = numpy.array(tuple(df[c].to_numpy() for c in cols)).T
# https://github.com/pola-rs/polars/issues/18369
arr = [] if len(df) == 0 else numpy.array(tuple(df[c].to_numpy() for c in cols)).T

return df.drop(cols).insert_column(i, polars.Series(out_col, arr, polars.Array(dtype, arr.shape[-1])))
return df.drop(cols).insert_column(i, polars.Series(out_col, arr, polars.Array(dtype, len(cols))))


HasAtomsT = t.TypeVar('HasAtomsT', bound='HasAtoms')
Expand Down Expand Up @@ -249,8 +250,8 @@ def dtypes(self) -> t.List[polars.DataType]:
...

@property
@_fwd_frame(lambda df: df.schema)
def schema(self) -> SchemaDict:
@_fwd_frame(lambda df: df.schema) # type: ignore
def schema(self) -> Schema:
"""
Return the schema of `self`.
Expand Down Expand Up @@ -287,7 +288,7 @@ def with_columns(self,
def insert_column(self, index: int, column: polars.Series) -> polars.DataFrame:
return self._get_frame().insert_column(index, column)

@_fwd_frame(polars.DataFrame.get_column)
@_fwd_frame(lambda df, name: df.get_column(name))
def get_column(self, name: str) -> polars.Series:
"""
Get the specified column from `self`, raising [`polars.ColumnNotFoundError`][polars.exceptions.ColumnNotFoundError] if it's not present.
Expand Down Expand Up @@ -327,9 +328,9 @@ def clone(self) -> polars.DataFrame:
"""Return a copy of `self`."""
return self._get_frame().clone()

def drop(self, *columns: t.Union[str, t.Iterable[str]]) -> polars.DataFrame:
def drop(self, *columns: t.Union[str, t.Iterable[str]], strict: bool = True) -> polars.DataFrame:
"""Return `self` with the specified columns removed."""
return self._get_frame().drop(*columns)
return self._get_frame().drop(*columns, strict=strict)

# row-wise operations

Expand All @@ -340,11 +341,10 @@ def filter(
) -> Self:
"""Filter `self`, removing rows which evaluate to `False`."""
# TODO clean up
preds_not_none: t.Tuple[t.Union[IntoExprColumn, t.Iterable[IntoExprColumn], bool, t.List[bool], numpy.ndarray], ...]
preds_not_none = tuple(filter(lambda p: p is not None, predicates)) # type: ignore
preds_not_none = tuple(filter(lambda p: p is not None, predicates))
if not len(preds_not_none) and not len(constraints):
return self
return self.with_atoms(Atoms(self._get_frame().filter(*preds_not_none, **constraints), _unchecked=True))
return self.with_atoms(Atoms(self._get_frame().filter(*preds_not_none, **constraints), _unchecked=True)) # type: ignore

@_fwd_frame_map
def sort(
Expand Down Expand Up @@ -497,7 +497,7 @@ def select_props(
A [`HasAtoms`][atomlib.atoms.HasAtoms] filtered to contain the
specified properties (as well as required columns).
"""
props = self._get_frame().lazy().select(*exprs, **named_exprs).drop(_REQUIRED_COLUMNS).collect(_eager=True)
props = self._get_frame().lazy().select(*exprs, **named_exprs).drop(_REQUIRED_COLUMNS, strict=False).collect(_eager=True)
return self.with_atoms(
Atoms(self._get_frame().select(_REQUIRED_COLUMNS).hstack(props), _unchecked=False)
)
Expand All @@ -524,7 +524,7 @@ def try_get_column(self, name: str) -> t.Optional[polars.Series]:
"""Try to get a column from `self`, returning `None` if it doesn't exist."""
try:
return self.get_column(name)
except polars.ColumnNotFoundError:
except polars.exceptions.ColumnNotFoundError:
return None

def assert_equal(self, other: t.Any):
Expand Down Expand Up @@ -554,7 +554,7 @@ def __radd__(self, other: IntoAtoms) -> HasAtoms:
def __getitem__(self, column: str) -> polars.Series:
try:
return self.get_column(column)
except polars.ColumnNotFoundError:
except polars.exceptions.ColumnNotFoundError:
if column in ('x', 'y', 'z'):
return self.select(_coord_expr(column)).to_series()
raise
Expand Down Expand Up @@ -937,7 +937,8 @@ def with_coords(self, pts: ArrayLike, selection: t.Optional[AtomSelection] = Non
new_pts[selection] = pts
pts = new_pts

pts = numpy.broadcast_to(pts, (len(self), 3))
# https://github.com/pola-rs/polars/issues/18369
pts = numpy.broadcast_to(pts, (len(self), 3)) if len(self) else []
return self.with_columns(polars.Series('coords', pts, polars.Array(polars.Float64, 3)))

def with_velocity(self, pts: t.Optional[ArrayLike] = None,
Expand All @@ -962,7 +963,8 @@ def with_velocity(self, pts: t.Optional[ArrayLike] = None,
assert pts.shape[-1] == 3
all_pts[selection] = pts

all_pts = numpy.broadcast_to(all_pts, (len(self), 3))
# https://github.com/pola-rs/polars/issues/18369
all_pts = numpy.broadcast_to(all_pts, (len(self), 3)) if len(self) else []
return self.with_columns(polars.Series('velocity', all_pts, polars.Array(polars.Float64, 3)))


Expand Down Expand Up @@ -1087,11 +1089,11 @@ def _repr_pretty_(self, p, cycle: bool) -> None:


SchemaDict: TypeAlias = OrderedDict[str, polars.DataType]
IntoExprColumn: TypeAlias = polars.type_aliases.IntoExprColumn
IntoExpr: TypeAlias = polars.type_aliases.IntoExpr
UniqueKeepStrategy: TypeAlias = polars.type_aliases.UniqueKeepStrategy
FillNullStrategy: TypeAlias = polars.type_aliases.FillNullStrategy
RollingInterpolationMethod: TypeAlias = polars.type_aliases.RollingInterpolationMethod
IntoExprColumn: TypeAlias = polars._typing.IntoExprColumn
IntoExpr: TypeAlias = polars._typing.IntoExpr
UniqueKeepStrategy: TypeAlias = polars._typing.UniqueKeepStrategy
FillNullStrategy: TypeAlias = polars._typing.FillNullStrategy
RollingInterpolationMethod: TypeAlias = polars._typing.RollingInterpolationMethod
ConcatMethod: TypeAlias = t.Literal['horizontal', 'vertical', 'diagonal', 'inner', 'align']

IntoAtoms = t.Union[t.Dict[str, t.Sequence[t.Any]], t.Sequence[t.Any], numpy.ndarray, polars.DataFrame, 'Atoms']
Expand Down
2 changes: 1 addition & 1 deletion atomlib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import typing as t
import logging

from typing_extensions import ParamSpec, Concatenate
import click

from . import CoordinateFrame, HasAtoms, Atoms, AtomCell, AtomSelection
from . import io
from .types import ParamSpec, Concatenate
from .transform import LinearTransform3D, AffineTransform3D


Expand Down
2 changes: 1 addition & 1 deletion atomlib/defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def ellip_pi(n: NDArray[numpy.float64], m: NDArray[numpy.float64]) -> NDArray[nu
[wolfram_ellip_pi]: https://mathworld.wolfram.com/EllipticIntegraloftheThirdKind.html
"""
from scipy.special import elliprf, elliprj
from scipy.special import elliprf, elliprj # type: ignore

y = 1 - m
assert numpy.all(y > 0)
Expand Down
8 changes: 4 additions & 4 deletions atomlib/elem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy

try:
from polars.exceptions import PanicException
from polars.exceptions import PanicException
except ImportError:
from polars.exceptions import PolarsPanicError as PanicException
from polars.exceptions import PolarsPanicError as PanicException # type: ignore

from .types import ElemLike

Expand Down Expand Up @@ -83,7 +83,7 @@ def get_elem(sym: t.Union[int, str, polars.Series]):

if isinstance(sym, polars.Series):
elem = sym.str.extract(_SYM_RE, 0).str.to_lowercase() \
.replace(ELEMENTS, return_dtype=polars.UInt8, default=255) \
.replace_strict(ELEMENTS, default=255, return_dtype=polars.UInt8) \
.alias('elem')

if (invalid := sym.filter(sym.is_not_null() & (elem > 118)).to_list()):
Expand Down Expand Up @@ -127,7 +127,7 @@ def get_sym(elem: t.Union[int, polars.Series]):
try:
return elem.map_elements(_get_sym, return_dtype=polars.Utf8, skip_nulls=True) \
.alias('symbol')
except PolarsPanicError:
except PanicException:
# attempt to recreate the error in Python
_ = [_get_sym(t.cast(int, e)) for e in elem.to_list() if e is not None]
raise
Expand Down
2 changes: 1 addition & 1 deletion atomlib/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def read_cif(f: t.Union[FileOrPath, CIF, CIFDataBlock], block: t.Union[int, str,
raise ValueError("No data present in CIF file.")
if block is None:
if len(cif) > 1:
logging.warn("Multiple blocks present in CIF file. Defaulting to reading first block.")
logging.warning("Multiple blocks present in CIF file. Defaulting to reading first block.")
cif = cif.data_blocks[0]
else:
cif = cif.get_block(block)
Expand Down
2 changes: 1 addition & 1 deletion atomlib/io/lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_atoms(self, type_map: t.Optional[t.Dict[int, t.Union[str, int]]] = None)
def _apply_type_labels(df: polars.DataFrame, section_name: str, labels: t.Optional[polars.DataFrame] = None) -> polars.DataFrame:
if labels is not None:
#df = df.with_columns(polars.col('type').replace(d, default=polars.col('type').cast(polars.Int32, strict=False), return_dtype=polars.Int32))
df = df.with_columns(polars.col('type').replace(labels['symbol'], labels['type'], default=polars.col('type').cast(polars.Int32, strict=False), return_dtype=polars.Int32))
df = df.with_columns(polars.col('type').replace_strict(labels['symbol'], labels['type'], default=polars.col('type').cast(polars.Int32, strict=False), return_dtype=polars.Int32))
if df['type'].is_null().any():
raise ValueError(f"While parsing section {section_name}: Unknown atom label or invalid atom type")
try:
Expand Down
7 changes: 4 additions & 3 deletions atomlib/io/mslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from warnings import warn
import typing as t

from typing_extensions import TypeAlias
from importlib_resources import files
import numpy
from numpy.typing import ArrayLike
Expand All @@ -25,9 +26,9 @@
from ..transform import AffineTransform3D, LinearTransform3D


ElementTree = et._ElementTree
Element = et._Element
MSliceFile = t.Union[ElementTree, FileOrPath]
ElementTree: TypeAlias = et._ElementTree
Element: TypeAlias = et._Element
MSliceFile: TypeAlias = t.Union[ElementTree, FileOrPath]


DEFAULT_TEMPLATE_PATH = files('atomlib.data') / 'template.mslice'
Expand Down
6 changes: 3 additions & 3 deletions atomlib/io/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing as t

import polars
from polars.type_aliases import SchemaDict, PolarsDataType
from polars._typing import SchemaDict, PolarsDataType


class LineBuffer:
Expand Down Expand Up @@ -144,15 +144,15 @@ def _parse_rows_whitespace_separated(
assert inner_ty is not None

elem_base_name = _ARRAY_ELEM_NAMES.get(col, col)
suffixes = ('x', 'y', 'z') if ty.width == 3 else range(ty.width)
suffixes = ('x', 'y', 'z') if ty.size == 3 else range(ty.size)
elem_cols = [f"{elem_base_name}_{s}".lstrip('_') for s in suffixes]

expanded_schema.update({elem_col: inner_ty for elem_col in elem_cols})

exprs.append(polars.concat_list(
polars.col('s').struct.field(elem_col)
for elem_col in elem_cols
).list.to_array(ty.width).alias(col))
).list.to_array(ty.size).alias(col))

regex = "".join((
"^",
Expand Down
4 changes: 2 additions & 2 deletions atomlib/io/xsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __post_init__(self):
raise ValueError("Error: No coordinates are specified (atoms, primitive, or conventional).")

if self.prim_coords is not None and self.conv_coords is not None:
logging.warn("Warning: Both 'primcoord' and 'convcoord' are specified. 'convcoord' will be ignored.")
logging.warning("Warning: Both 'primcoord' and 'convcoord' are specified. 'convcoord' will be ignored.")
elif self.conv_coords is not None and self.conventional_cell is None:
raise ValueError("If 'convcoord' is specified, 'convvec' must be specified as well.")

Expand Down Expand Up @@ -196,7 +196,7 @@ def parse_atoms(self, expected_length: t.Optional[int] = None) -> polars.DataFra

if expected_length is not None:
if not expected_length == len(zs):
logging.warn(f"Warning: List length {len(zs)} doesn't match declared length {expected_length}")
logging.warning(f"Warning: List length {len(zs)} doesn't match declared length {expected_length}")
elif len(zs) == 0:
raise ValueError(f"Expected atom list after keyword 'ATOMS'. Got '{line or 'EOF'}' instead.")

Expand Down
13 changes: 8 additions & 5 deletions atomlib/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
import typing as t

from typing_extensions import TypeAlias
import numpy
from numpy.typing import ArrayLike, NDArray

Expand All @@ -16,7 +17,7 @@
NumT = t.TypeVar('NumT', bound=t.Union[float, int])

Affine3DSelf = t.TypeVar('Affine3DSelf', bound='AffineTransform3D')
IntoTransform3D = t.Union['Transform3D', t.Callable[[NDArray[numpy.floating]], numpy.ndarray], numpy.ndarray]
IntoTransform3D: TypeAlias = t.Union['Transform3D', t.Callable[[NDArray[numpy.floating]], numpy.ndarray], numpy.ndarray]
"""
Type which is coercable into a [`Transform3D`][atomlib.transform.Transform3D].
Expand Down Expand Up @@ -249,7 +250,7 @@ def scale(self, x: t.Union[Num, VecLike] = 1., y: Num = 1., z: Num = 1., *,
Can be called as a classmethod or instance method.
"""
return self.compose(LinearTransform3D.scale(x, y, z, all=all))
return self.compose(LinearTransform3D.scale(x, y, z, all=all)) # type: ignore

@opt_classmethod
def rotate(self, v: VecLike, theta: Num) -> AffineTransform3D:
Expand Down Expand Up @@ -291,7 +292,7 @@ def mirror(self, a: t.Union[Num, VecLike],
Can be called as a classmethod or instance method.
"""
return self.compose(LinearTransform3D.mirror(a, b, c))
return self.compose(LinearTransform3D.mirror(a, b, c)) # type: ignore

@opt_classmethod
def strain(self, strain: float, v: VecLike = (0, 0, 1), poisson: float = 0.) -> AffineTransform3D:
Expand Down Expand Up @@ -607,11 +608,13 @@ def align(self, v1: VecLike, horz: t.Optional[VecLike] = None) -> LinearTransfor
return self.align_to(v1, [0., 0., 1.], horz, [1., 0., 0.])

@t.overload
def align_to(self, v1: VecLike, v2: VecLike, p1: t.Literal[None] = None, p2: t.Literal[None] = None) -> LinearTransform3D:
@classmethod
def align_to(cls, v1: VecLike, v2: VecLike, p1: t.Literal[None] = None, p2: t.Literal[None] = None) -> LinearTransform3D:
...

@t.overload
def align_to(self, v1: VecLike, v2: VecLike, p1: VecLike, p2: VecLike) -> LinearTransform3D:
@classmethod
def align_to(cls, v1: VecLike, v2: VecLike, p1: VecLike, p2: VecLike) -> LinearTransform3D:
...

@opt_classmethod
Expand Down
Loading

0 comments on commit 994910b

Please sign in to comment.