Skip to content

Commit

Permalink
Typing & CI fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hexane360 committed Oct 11, 2024
1 parent e6ce3c1 commit fc023f2
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 61 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: false
matrix:
version: ["3.9", "3.10", "3.11"]
version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest]
arch: [x64]
experimental: [false]
Expand Down Expand Up @@ -47,9 +47,9 @@ jobs:
- name: Install python
uses: actions/setup-python@v4
with:
python-version: 3.11
python-version: 3.12
cache: 'pip'
cache-dependency-path: setup.cfg
cache-dependency-path: pyproject.toml
- name: Install dependencies
id: deps
run: |
Expand Down
12 changes: 6 additions & 6 deletions atomlib/atomcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@

def _fwd_atoms_get(f: t.Callable[P, T]) -> t.Callable[P, T]:
"""Forward getter method on `HasAtomCell` to method on `HasAtoms`"""
def inner(self, *args, frame: t.Optional[CoordinateFrame] = None, **kwargs):
def inner(self, *args, frame: t.Optional[CoordinateFrame] = None, **kwargs): # type: ignore
return getattr(self.get_atoms(frame), f.__name__)(*args, **kwargs)

return t.cast(t.Callable[P, T], inner)


def _fwd_atoms_transform(f: t.Callable[P, T]) -> t.Callable[P, T]:
"""Forward transformation method on `HasAtomCell` to method on `HasAtoms`"""
def inner(self, *args, frame: t.Optional[CoordinateFrame] = None, **kwargs):
def inner(self, *args, frame: t.Optional[CoordinateFrame] = None, **kwargs): # type: ignore
return self.with_atoms(self._transform_atoms_in_frame(frame, lambda atoms: getattr(atoms, f.__name__)(*args, **kwargs)))

return t.cast(t.Callable[P, T], inner)
Expand Down Expand Up @@ -186,7 +186,7 @@ def repeat(self, n: t.Union[int, VecLike]) -> Self:
"""Tile the cell"""
ns = numpy.broadcast_to(n, 3)
if not numpy.issubdtype(ns.dtype, numpy.integer):
raise ValueError(f"repeat() argument must be an integer or integer array.")
raise ValueError("repeat() argument must be an integer or integer array.")

cells = numpy.stack(numpy.meshgrid(*map(numpy.arange, ns))) \
.reshape(3, -1).T.astype(float)
Expand Down Expand Up @@ -660,10 +660,10 @@ def _combine_metadata(cls: t.Type[AtomCellT], *atoms: HasAtoms, n: t.Optional[in
else:
atom_cells = [a for a in atoms if isinstance(a, AtomCell)]
if len(atom_cells) == 0:
raise TypeError(f"No AtomCells to combine")
raise TypeError("No AtomCells to combine")
rep = atom_cells[0]
if not all(a.cell == rep.cell for a in atom_cells[1:]):
raise TypeError(f"Can't combine AtomCells with different cells")
raise TypeError("Can't combine AtomCells with different cells")

return cls(Atoms.empty(), frame=rep.frame, cell=rep.cell)

Expand Down Expand Up @@ -740,7 +740,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.atoms!r}, cell={self.cell!r}, frame={self.frame})"

def _repr_pretty_(self, p, cycle: bool) -> None:
def _repr_pretty_(self, p: t.Any, cycle: bool) -> None:
p.text(f'{self.__class__.__name__}(...)') if cycle else p.text(str(self))


Expand Down
8 changes: 4 additions & 4 deletions atomlib/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ def __repr__(self) -> str:
buf.write("])\n")
return buf.getvalue()

def _repr_pretty_(self, p, cycle: bool) -> None:
def _repr_pretty_(self, p: t.Any, cycle: bool) -> None:
p.text('Atoms(...)') if cycle else p.text(str(self))


Expand All @@ -1093,18 +1093,18 @@ def _repr_pretty_(self, p, cycle: bool) -> None:
RollingInterpolationMethod: TypeAlias = polars._typing.RollingInterpolationMethod
ConcatMethod: TypeAlias = t.Literal['horizontal', 'vertical', 'diagonal', 'inner', 'align']

IntoAtoms = t.Union[t.Dict[str, t.Sequence[t.Any]], t.Sequence[t.Any], numpy.ndarray, polars.DataFrame, 'Atoms']
IntoAtoms: TypeAlias = t.Union[t.Dict[str, t.Sequence[t.Any]], t.Sequence[t.Any], numpy.ndarray, polars.DataFrame, 'Atoms']
"""
A type convertible into an [`Atoms`][atomlib.atoms.Atoms].
"""

AtomSelection = t.Union[IntoExprColumn, NDArray[numpy.bool_], ArrayLike, t.Mapping[str, t.Any]]
AtomSelection: TypeAlias = t.Union[IntoExprColumn, NDArray[numpy.bool_], ArrayLike, t.Mapping[str, t.Any]]
"""
Polars expression selecting a subset of atoms.
Can be used with many [`Atoms`][atomlib.atoms.Atoms] methods.
"""

AtomValues = t.Union[IntoExprColumn, NDArray[numpy.generic], ArrayLike, t.Mapping[str, t.Any]]
AtomValues: TypeAlias = t.Union[IntoExprColumn, NDArray[numpy.generic], ArrayLike, t.Mapping[str, t.Any]]
"""
Array, value, or polars expression mapping atom symbols to values.
Can be used with `with_*` methods on Atoms
Expand Down
12 changes: 7 additions & 5 deletions atomlib/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import abc
import typing as t

from typing_extensions import TypeAlias
import numpy
from numpy.typing import NDArray

Expand All @@ -23,7 +24,7 @@
from .bbox import BBox3D


CoordinateFrame = t.Literal[
CoordinateFrame: TypeAlias = t.Literal[
'cell', 'cell_frac', 'cell_box',
'ortho', 'ortho_frac', 'ortho_box',
'linear', 'local', 'global',
Expand Down Expand Up @@ -248,7 +249,7 @@ def repeat(self: HasCellT, n: t.Union[int, VecLike]) -> HasCellT:
"""Tile the cell by `n` in each dimension."""
ns = numpy.broadcast_to(n, 3)
if not numpy.issubdtype(ns.dtype, numpy.integer):
raise ValueError(f"repeat() argument must be an integer or integer array.")
raise ValueError("repeat() argument must be an integer or integer array.")
return self.with_cell(Cell(
affine=self.affine,
ortho=self.ortho,
Expand Down Expand Up @@ -330,7 +331,7 @@ def change_transform(self, transform: Transform3D,
return coord_change @ transform @ coord_change.inverse()

def assert_equal(self, other: t.Any):
assert isinstance(other, HasCell) and type(self) == type(other)
assert isinstance(other, HasCell) and type(self) is type(other)
numpy.testing.assert_array_almost_equal(self.affine.inner, other.affine.inner, 6)
numpy.testing.assert_array_almost_equal(self.ortho.inner, other.ortho.inner, 6)
numpy.testing.assert_array_almost_equal(self.cell_size, other.cell_size, 6)
Expand Down Expand Up @@ -410,7 +411,8 @@ def from_ortho(ortho: AffineTransform3D, n_cells: t.Optional[VecLike] = None, pb
# flip QR decomposition so R has positive diagonals
signs = numpy.sign(numpy.diagonal(r))
# multiply flips to columns of Q, rows of R
q = q * signs; r = r * signs[:, None]
q = q * signs
r = r * signs[:, None]
#numpy.testing.assert_allclose(q @ r, lin.inner)
if numpy.linalg.det(q) < 0:
warn("Crystal is left-handed. This is currently unsupported, and may cause errors.")
Expand Down Expand Up @@ -441,7 +443,7 @@ def __repr__(self) -> str:
f"cell_angle={self.cell_angle}, n_cells={self.n_cells}, pbc={self.pbc})"
)

def _repr_pretty_(self, p, cycle: bool) -> None:
def _repr_pretty_(self, p: t.Any, cycle: bool) -> None:
p.text(f"{self.__class__.__name__}(...)") if cycle else p.text(str(self))


Expand Down
47 changes: 23 additions & 24 deletions atomlib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import logging

from typing_extensions import ParamSpec, Concatenate
import numpy
import click

from . import CoordinateFrame, HasAtoms, Atoms, AtomCell, AtomSelection
from . import CoordinateFrame, HasAtoms, HasAtomCell, AtomSelection
from . import io
from .transform import LinearTransform3D, AffineTransform3D
from .mixins import AtomsIOMixin, AtomCellIOMixin


frame_type = click.Choice(('global', 'local', 'frac'), case_sensitive=False)
Expand Down Expand Up @@ -171,10 +173,10 @@ def input_cfg(file: t.Optional[Path] = None):
yield io.read_cfg(file or sys.stdin)


@cli.command('in_cfg')
@cli.command('in_mslice')
@click.argument('file', type=file_type, required=False)
@lazy_append
def input_cfg(file: t.Optional[Path] = None):
def input_mslice(file: t.Optional[Path] = None):
"""Input a pyMultislicer mslice file. If `file` is not specified, use stdin."""
yield io.read_mslice(file or sys.stdin)

Expand All @@ -200,7 +202,7 @@ def loop(state: State, n: int) -> t.Iterable[State]:
@lazy_map
def show(state: State,
zone: t.Optional[t.Tuple[float, float, float]] = None,
plane: t.Optional[t.Tuple[float, float, float]] = None) -> State:
plane: t.Optional[t.Tuple[float, float, float]] = None) -> t.Iterable[State]:
"""Show the current structure. Doesn't affect the stream of structures."""
from matplotlib import pyplot
from .visualize import show_atoms_mpl_3d
Expand All @@ -215,7 +217,7 @@ def show(state: State,
@lazy_map
def show_2d(state: State,
zone: t.Optional[t.Tuple[float, float, float]] = None,
plane: t.Optional[t.Tuple[float, float, float]] = None) -> State:
plane: t.Optional[t.Tuple[float, float, float]] = None) -> t.Iterable[State]:
"""Show the current structure. Doesn't affect the stream of structures."""
from matplotlib import pyplot
from .visualize import show_atoms_mpl_2d
Expand All @@ -224,24 +226,24 @@ def show_2d(state: State,
yield state


@cli.command('union')
@cli.command('concat')
@lazy
def union(states: t.Iterable[State]) -> t.Iterable[State]:
"""Combine structures. Symmetry is discarded, but """
def concat(states: t.Iterable[State]) -> t.Iterable[State]:
"""Combine structures. Symmetry is discarded"""
last_index = None
collect: t.List[HasAtoms] = []
state = None
for state in states:
if last_index is None:
last_index = state.indices[-1]
elif last_index != state.indices[-1]:
state.structure = HasAtoms.union(collect)
state.structure = HasAtoms.concat(collect)
state.indices.pop()
yield state
last_index = state.indices[-1]
collect.append(state.structure)
if state is not None:
state.structure = HasAtoms.union(collect)
state.structure = HasAtoms.concat(collect)
state.indices.pop()
yield state

Expand All @@ -265,7 +267,12 @@ def crop(state: State,
If none are specified, refers to the global (cartesian) coordinate system.
Currently, does not update the structure box (but probably should)
"""
state.structure = state.structure.crop(x_min, x_max, y_min, y_max, z_min, z_max, frame=frame)
state.structure = t.cast(HasAtomCell, state.structure).crop(
x_min or -numpy.inf, x_max or numpy.inf,
y_min or -numpy.inf, y_max or numpy.inf,
z_min or -numpy.inf, z_max or numpy.inf,
frame=frame
)
yield state


Expand All @@ -285,7 +292,7 @@ def rotate(state: State,
transform = LinearTransform3D().rotate_euler(x, y, z)
else:
transform = LinearTransform3D().rotate([x, y, z], theta)
state.structure = state.structure.transform_atoms(transform, frame=frame)
state.structure = t.cast(HasAtomCell, state.structure).transform_atoms(transform, frame=frame)
yield state


Expand Down Expand Up @@ -315,29 +322,21 @@ def print_(state: State):
@lazy_map
def out(state: State, file: Path):
path = state.deduplicated_output_path(file)
state.structure.write(path)
t.cast(AtomsIOMixin, state.structure).write(path)
yield state


@cli.command('out_mslice')
@click.argument('file', type=out_file_type, required=False)
@lazy_map
def out_mslice(state: State, file: t.Optional[Path] = None):
state.structure.write_mslice(sys.stdout if file is None or file == '-' else file)
t.cast(AtomCellIOMixin, state.structure).write_mslice(sys.stdout if file is None or str(file) == '-' else file)
yield state


@cli.command('out_xyz')
@click.argument('file', type=out_file_type, required=False)
@click.option('-f', '--frame', type=frame_type, default='global',
help="Frame of reference to output coordinates in.")
@click.option('--ext/--no-ext', default=True, help="Write extended format")
@click.option('-c', '--comment', type=str)
@lazy_map
def out_xyz(state: State, file: t.Optional[Path] = None, frame: CoordinateFrame = 'global', ext: bool = True, comment: t.Optional[str] = None):
if file is None or file == '-':
state.structure.write_xyz(sys.stdout, frame=frame, ext=ext, comment=comment)
else:
with open(state.deduplicated_output_path(file), 'w') as f:
state.structure.write_xyz(f, frame=frame, ext=ext, comment=comment)
def out_xyz(state: State, file: t.Optional[Path] = None):
t.cast(AtomsIOMixin, state.structure).write_xyz(sys.stdout if file is None or str(file) == '-' else file)
yield state
4 changes: 2 additions & 2 deletions atomlib/io/mslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def default_template() -> ElementTree:
return deepcopy(DEFAULT_TEMPLATE)


def convert_xml_value(val, ty):
def convert_xml_value(val: str, ty: str):
"""Convert an XML value `val` to a Python type determined by the XML type name `ty`."""
if ty == 'string':
ty = 'str'
elif ty == 'int16' or ty == 'int32':
val = val.split('.')[0]
ty = 'int'
elif ty == 'bool':
val = int(val)
return bool(int(val))

return getattr(builtins, ty)(val)

Expand Down
2 changes: 1 addition & 1 deletion atomlib/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _cast_atoms(atoms: _HasAtoms, ty: t.Type[HasAtomsT]) -> HasAtomsT:
if isinstance(atoms, ty):
return atoms
if issubclass(ty, HasAtomCell) and not isinstance(atoms, HasAtomCell):
raise TypeError(f"File contains no cell information.")
raise TypeError("File contains no cell information.")

if ty is AtomCell and isinstance(atoms, HasAtomCell):
return atoms.get_atomcell() # type: ignore
Expand Down
7 changes: 3 additions & 4 deletions atomlib/test_alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from .testing import check_equals_structure

from . import make
from . import alter
from . import make, alter, AtomCell
from .transform import AffineTransform3D


Expand All @@ -21,8 +20,8 @@ def test_unbunch():


@check_equals_structure('ZnSe_contaminated.xsf')
def test_contaminated(znse_supercell):
def test_contaminated(znse_supercell: AtomCell):
cell = alter.contaminate(znse_supercell, (20., 10.), seed='test_znse_cont')
assert_array_equal(cell.n_cells, [6, 6, 1])
assert_array_equal(cell.affine, AffineTransform3D.translate([0., 0., -10.]).inner)
assert_array_equal(cell.affine.inner, AffineTransform3D.translate([0., 0., -10.]).inner)
return cell.explode().transform(cell.get_cell().affine.inverse())
6 changes: 3 additions & 3 deletions atomlib/test_elem.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_get_elems_fail():
(numpy.array([1, 47, 82]), numpy.array([1.008, 107.8682, 207.2])),
(polars.Series([1, 47, 82]), polars.Series([1.008, 107.8682, 207.2])),
))
def test_get_mass(elem, mass):
def test_get_mass(elem: t.Any, mass: t.Any):
result = get_mass(elem)

if isinstance(mass, polars.Series):
Expand All @@ -135,7 +135,7 @@ def test_get_mass(elem, mass):
(55, 2.98),
(70, 2.22),
))
def test_get_radius(elem, radius):
def test_get_radius(elem: int, radius: float):
assert get_radius(elem) == pytest.approx(radius)


Expand All @@ -144,5 +144,5 @@ def test_get_radius(elem, radius):
(1, -1, 2.08),
(34, +6, 0.56),
))
def test_get_ionic_radius(elem, charge, radius):
def test_get_ionic_radius(elem: int, charge: int, radius: float):
assert get_ionic_radius(elem, charge) == pytest.approx(radius)
2 changes: 1 addition & 1 deletion atomlib/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def wrapper(*args, **kwargs): # type: ignore


def check_figure_draw(name: t.Union[str, Path, t.Sequence[t.Union[str, Path]]],
savefig_kwarg=None) -> t.Callable[[t.Callable[..., None]], t.Callable[..., None]]:
savefig_kwarg: t.Optional[t.Dict[str, t.Any]] = None) -> t.Callable[[t.Callable[..., None]], t.Callable[..., None]]:
"""Test that the wrapped function draws an identical figure to `name` in `baseline_images`."""

if isinstance(name, (str, Path)):
Expand Down
Loading

0 comments on commit fc023f2

Please sign in to comment.