Skip to content

Commit

Permalink
More typing/CI fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hexane360 committed Oct 11, 2024
1 parent afd6a8e commit 888af9d
Show file tree
Hide file tree
Showing 13 changed files with 56 additions and 52 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,16 @@ jobs:
python -m pip install --upgrade pip
pip install -e '.[dev]'
- name: Type check 3.9
if: ${{ steps.deps.outcome == 'success' && success() || failure() }}
uses: jakebailey/pyright-action@v1
with:
python-version: "3.9"
- name: Type check 3.10
if: ${{ steps.deps.outcome == 'success' && success() || failure() }}
uses: jakebailey/pyright-action@v1
with:
python-version: "3.10"
- name: Type check 3.11
- name: Type check 3.12
if: ${{ steps.deps.outcome == 'success' && success() || failure() }}
uses: jakebailey/pyright-action@v1
with:
python-version: "3.11"
python-version: "3.12"

success:
name: Success
Expand Down
2 changes: 1 addition & 1 deletion atomlib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def cli(verbose: int = 0):
def run_chain(cmds: t.Sequence[CmdType], verbose: int = 0):
states: t.Iterable[State] = ()
for cmd in cmds:
if cmd is None:
if cmd is None: # type: ignore reportUnnecessaryComparison
raise RuntimeError("'cmd' is None. Did a command forget to return a wrapper function?")
states = cmd(states)
for _ in states:
Expand Down
3 changes: 2 additions & 1 deletion atomlib/defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import typing as t
from typing import cast

from typing_extensions import TypeAlias
import numpy
from numpy.typing import NDArray, ArrayLike
import polars
Expand Down Expand Up @@ -36,7 +37,7 @@ def ellip_pi(n: NDArray[numpy.float64], m: NDArray[numpy.float64]) -> NDArray[nu
return rf + rj * n / 3


CutType = t.Literal['shift', 'add', 'rm']
CutType: TypeAlias = t.Literal['shift', 'add', 'rm']
"""Cut plane to use when creating a (non-screw) dislocation."""


Expand Down
4 changes: 2 additions & 2 deletions atomlib/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def parse_nary(self, lhs: Expr[T_co, V], level: t.Optional[int] = None) -> Expr[
logging.debug(f"parse_nary({lhs}, level={level})")
token = self.peek()
logging.debug(f"token: '{token!r}'")

while token is not None:
if not isinstance(token, OpToken) or \
not isinstance(token.op, (NaryOp, BinaryOp)):
Expand All @@ -435,7 +435,7 @@ def parse_nary(self, lhs: Expr[T_co, V], level: t.Optional[int] = None) -> Expr[
logging.debug(f"rhs: '{rhs}'")

inner = self.peek()
if not inner is None and isinstance(inner, OpToken):
if inner is not None and isinstance(inner, OpToken):
inner_op = inner.op
if isinstance(inner_op, (NaryOp, BinaryOp)) and \
inner_op.precedes(token.op):
Expand Down
3 changes: 2 additions & 1 deletion atomlib/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import typing as t

from typing_extensions import TypeAlias
import numpy
import polars

Expand All @@ -23,7 +24,7 @@
from ..elem import get_sym, get_elem
from ..util import FileOrPath

FileType = t.Literal['cif', 'xyz', 'xsf', 'cfg', 'lmp', 'mslice', 'qe']
FileType: TypeAlias = t.Literal['cif', 'xyz', 'xsf', 'cfg', 'lmp', 'mslice', 'qe']


def read_cif(f: t.Union[FileOrPath, CIF, CIFDataBlock], block: t.Union[int, str, None] = None) -> HasAtoms:
Expand Down
33 changes: 17 additions & 16 deletions atomlib/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import typing as t

from typing_extensions import TypeAlias
import numpy
import polars
from numpy.typing import NDArray
Expand All @@ -26,7 +27,7 @@
from ..atomcell import HasAtomCell


Value = t.Union[int, float, str, None]
Value: TypeAlias = t.Union[int, float, str, None]
_INT_RE = re.compile(r'[-+]?\d+')
# float regex with uncertainty (e.g. '3.14159(3)')
_FLOAT_RE = re.compile(r'([-+]?\d*(\.\d*)?(e[-+]?\d+)?)(\(\d+\))?', re.I)
Expand Down Expand Up @@ -59,7 +60,7 @@ def get_block(self, block: t.Union[int, str]) -> CIFDataBlock:

def write(self, file: FileOrPath):
with open_file(file, 'w') as f:
print(f"# generated by atomlib", file=f, end=None)
print("# generated by atomlib", file=f, end=None)
for data_block in self.data_blocks:
print(file=f)
data_block._write(f)
Expand Down Expand Up @@ -195,7 +196,7 @@ def stack_tags(self, *tags: str, dtype: t.Union[str, numpy.dtype, t.Iterable[t.U
else:
dtypes = tuple(map(lambda ty: numpy.dtype(ty), dtype))
if len(dtypes) != len(tags):
raise ValueError(f"dtype list of invalid length")
raise ValueError("dtype list of invalid length")

if isinstance(required, bool):
required = repeat(required)
Expand All @@ -218,8 +219,8 @@ def stack_tags(self, *tags: str, dtype: t.Union[str, numpy.dtype, t.Iterable[t.U
if len(d) == 0:
return polars.DataFrame({})

l = len(next(iter(d.values())))
if any(len(arr) != l for arr in d.values()):
tag_len = len(next(iter(d.values())))
if any(len(arr) != tag_len for arr in d.values()):
raise ValueError(f"Tags of mismatching lengths: {tuple(map(len, d.values()))}")

return polars.DataFrame(d)
Expand Down Expand Up @@ -264,7 +265,7 @@ class CIFTable:
data: t.Dict[str, t.List[Value]]

def _write(self, f: TextIOBase):
print(f"\nloop_", file=f)
print("\nloop_", file=f)
for tag in self.data.keys():
print(f" _{tag}", file=f)

Expand Down Expand Up @@ -535,7 +536,7 @@ def parse_loop(self) -> CIFTable:
return CIFTable(dict(zip(tags, vals)))

def parse_value(self) -> Value:
logging.debug(f"parse_value")
logging.debug("parse_value")
w = self.peek_word()
assert w is not None
if w in ('.', '?'):
Expand All @@ -551,17 +552,17 @@ def parse_value(self) -> Value:
return self.parse_bare()

def parse_text_field(self) -> str:
line = self.line
l = self.next_line()
assert l is not None
s = l.lstrip().removeprefix(';').lstrip()
start_line = self.line
line = self.next_line()
assert line is not None
s = line.lstrip().removeprefix(';').lstrip()
while True:
l = self.next_line()
if l is None:
raise ValueError(f"While parsing text field at line {line}: Unexpected EOF")
if l.strip() == ';':
line = self.next_line()
if line is None:
raise ValueError(f"While parsing text field at line {start_line}: Unexpected EOF")
if line.strip() == ';':
break
s += l
s += line
return s.rstrip()

def parse_quoted(self) -> str:
Expand Down
21 changes: 12 additions & 9 deletions atomlib/io/lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ def write(self, file: FileOrPath):

# print sections
for section in self.sections:
l = section.name
line = section.name
if section.style is not None:
l += f' # {section.style}'
print(f"\n{l}\n", file=f)
line += f' # {section.style}'
print(f"\n{line}\n", file=f)

f.writelines(section.body)

Expand Down Expand Up @@ -361,9 +361,9 @@ def collect_lines(self, n: int) -> t.Optional[t.List[str]]:
try:
for _ in range(n):
while True:
l = next(self._file)
if not l.isspace():
lines.append(l)
line = next(self._file)
if not line.isspace():
lines.append(line)
break
except StopIteration:
return None
Expand All @@ -385,11 +385,14 @@ def inner(s: str) -> t.Tuple[t.Any, ...]:

return inner


_parse_2float = _parse_seq(float, 2)
_parse_3float = _parse_seq(float, 3)
_fmt_2float = lambda vals: f"{vals[0]:16.7f} {vals[1]:14.7f}"
_fmt_3float = lambda vals: f"{vals[0]:16.7f} {vals[1]:14.7f} {vals[2]:14.7f}"

def _fmt_2float(vals: t.Sequence[float]):
return f"{vals[0]:16.7f} {vals[1]:14.7f}"

def _fmt_3float(vals: t.Sequence[float]):
return f"{vals[0]:16.7f} {vals[1]:14.7f} {vals[2]:14.7f}"


_HEADER_KWS = {
Expand Down
5 changes: 3 additions & 2 deletions atomlib/io/xsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import logging
import typing as t

from typing_extensions import TypeAlias
import numpy
from numpy.typing import NDArray
import polars

from ..transform import LinearTransform3D
from ..util import open_file, FileOrPath

Periodicity = t.Literal['crystal', 'slab', 'polymer', 'molecule']
Periodicity: TypeAlias = t.Literal['crystal', 'slab', 'polymer', 'molecule']

if t.TYPE_CHECKING:
from ..atoms import HasAtoms
Expand Down Expand Up @@ -204,7 +205,7 @@ def parse_atoms(self, expected_length: t.Optional[int] = None) -> polars.DataFra
return polars.DataFrame({}, schema=['elem', 'x', 'y', 'z']) # type: ignore

coord_lens = list(map(len, coords))
if not all(l == coord_lens[0] for l in coord_lens[1:]):
if not all(coord_len == coord_lens[0] for coord_len in coord_lens[1:]):
raise ValueError("Mismatched atom dimensions.")
if coord_lens[0] < 3:
raise ValueError("Expected at least 3 coordinates per atom.")
Expand Down
5 changes: 3 additions & 2 deletions atomlib/io/xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import typing as t

from typing_extensions import TypeAlias
import numpy
from numpy.typing import NDArray
import polars
Expand Down Expand Up @@ -46,7 +47,7 @@
}


XYZFormat = t.Literal['xyz', 'exyz']
XYZFormat: TypeAlias = t.Literal['xyz', 'exyz']

T = t.TypeVar('T')

Expand Down Expand Up @@ -90,7 +91,7 @@ def from_file(file: FileOrPath) -> XYZ:
# TODO be more gracious about whitespace here
length = int(f.readline())
except ValueError:
raise ValueError(f"Error parsing XYZ file: Invalid length") from None
raise ValueError("Error parsing XYZ file: Invalid length") from None
except IOError as e:
raise IOError(f"Error parsing XYZ file: {e}") from None

Expand Down
3 changes: 2 additions & 1 deletion atomlib/make/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import string
import typing as t

from typing_extensions import TypeAlias
import numpy

from ..atomcell import Atoms, AtomCell, HasAtomCellT, IntoAtoms
Expand All @@ -20,7 +21,7 @@
from ..util import proc_seed


CellType = t.Literal['conv', 'prim', 'ortho']
CellType: TypeAlias = t.Literal['conv', 'prim', 'ortho']


def fcc(elem: ElemLike, a: Num, *, cell: CellType = 'conv', additional: t.Optional[IntoAtoms] = None) -> AtomCell:
Expand Down
6 changes: 3 additions & 3 deletions atomlib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import typing as t

from typing_extensions import ParamSpec, Concatenate
from typing_extensions import ParamSpec, Concatenate, TypeAlias
import numpy
from numpy.typing import NDArray
import polars
Expand All @@ -27,9 +27,9 @@ def map_some(f: t.Callable[[T], U], val: t.Optional[T]) -> t.Optional[U]:
return None if val is None else f(val)


FileOrPath = t.Union[str, Path, TextIOBase, t.TextIO]
FileOrPath: TypeAlias = t.Union[str, Path, TextIOBase, t.TextIO]
"""Open text file or path to a file. Use with [open_file][atomlib.util.open_file]."""
BinaryFileOrPath = t.Union[str, Path, t.TextIO, t.BinaryIO, IOBase]
BinaryFileOrPath: TypeAlias = t.Union[str, Path, t.TextIO, t.BinaryIO, IOBase]
"""Open binary file or path to a file. Use with [open_file_binary][atomlib.util.open_file_binary]."""


Expand Down
6 changes: 3 additions & 3 deletions atomlib/vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def miller_4_to_3_plane(a: NDArray[numpy.number], reduce: bool = True, max_denom
"""Convert a plane in 4-axis Miller-Bravais notation to 3-axis Miller notation."""
a = numpy.atleast_1d(a)
assert a.shape[-1] == 4
h, k, i, l = numpy.split(a, 4, axis=-1)
h, k, i, l = numpy.split(a, 4, axis=-1) # noqa: E741
assert numpy.allclose(-i, h + k, equal_nan=True)
out = numpy.concatenate((h, k, l), axis=-1)
return reduce_vec(out, max_denom) if reduce else out
Expand All @@ -235,8 +235,8 @@ def miller_3_to_4_plane(a: NDArray[numpy.number], reduce: bool = True, max_denom
"""Convert a plane in 3-axis Miller notation to 4-axis Miller-Bravais notation."""
a = numpy.atleast_1d(a)
assert a.shape[-1] == 3
h, k, l = numpy.split(a, 3, axis=-1)
out = numpy.concatenate((h, k, -(h + k), l), axis=-1) # type: ignore
h, k, l = numpy.split(a, 3, axis=-1) # noqa: E741
out = numpy.concatenate((h, k, -(h + k), l), axis=-1)
return reduce_vec(out, max_denom) if reduce else out


Expand Down
8 changes: 4 additions & 4 deletions atomlib/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import abstractmethod, ABC
import typing as t

from typing_extensions import TypeAlias
import numpy
from numpy.typing import NDArray
from matplotlib import pyplot
Expand All @@ -25,8 +26,8 @@
from ..elem import get_radius


BackendName = t.Literal['mpl', 'ase']
AtomStyle = t.Literal['spacefill', 'ballstick', 'small']
BackendName: TypeAlias = t.Literal['mpl', 'ase']
AtomStyle: TypeAlias = t.Literal['spacefill', 'ballstick', 'small']


class AtomImage(ABC):
Expand Down Expand Up @@ -140,9 +141,8 @@ def get_plot_radii(atoms: HasAtoms, min_r: t.Optional[float] = 1.0, style: AtomS

def get_azim_elev(zone: VecLike) -> t.Tuple[float, float]:
(a, b, c) = -to_vec3(zone) # look down zone
l = numpy.sqrt(a**2 + b**2)
# todo: aren't these just arctan2s?
return (numpy.angle(a + b*1.j, deg=True), numpy.angle(l + c*1.j, deg=True)) # type: ignore
return (numpy.angle(a + b*1.j, deg=True), numpy.angle(numpy.sqrt(a**2 + b**2) + c*1.j, deg=True)) # type: ignore


def show_atoms_mpl_3d(atoms: HasAtoms, *, fig: t.Optional[Figure] = None,
Expand Down

0 comments on commit 888af9d

Please sign in to comment.