Skip to content

Commit

Permalink
Fix various types with cast()
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Nov 4, 2023
1 parent 3d4e314 commit d3aaf2a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
5 changes: 4 additions & 1 deletion src/galois/_codes/_cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from __future__ import annotations

from typing import Any, overload
from typing import Any, cast, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -133,6 +133,7 @@ def _convert_codeword_to_message(self, codeword: FieldArray) -> FieldArray:
message = codeword[..., 0:ks]
else:
message, _ = divmod_jit(self.field)(codeword, self.generator_poly.coeffs)
message = cast(FieldArray, message)

return message

Expand All @@ -141,6 +142,7 @@ def _convert_codeword_to_parity(self, codeword: FieldArray) -> FieldArray:
parity = codeword[..., -(self.n - self.k) :]
else:
_, parity = divmod_jit(self.field)(codeword, self.generator_poly.coeffs)
parity = cast(FieldArray, parity)

return parity

Expand Down Expand Up @@ -218,6 +220,7 @@ def _poly_to_generator_matrix(n: int, generator_poly: Poly, systematic: bool) ->
G = GF.Zeros((k, n))
for i in range(k):
G[i, i : i + generator_poly.degree + 1] = generator_poly.coeffs
G = cast(FieldArray, G)

return G

Expand Down
6 changes: 3 additions & 3 deletions src/galois/_codes/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from __future__ import annotations

from typing import Any, Type, cast, overload
from typing import Any, cast, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -216,11 +216,11 @@ def _check_and_convert_message(self, message: ArrayLike) -> tuple[FieldArray, bo

# Record if the original message was 1-D and then convert to 2-D
is_message_1d = message.ndim == 1
message = np.atleast_2d(message)
message = cast(FieldArray, np.atleast_2d(message))

return message, is_message_1d

def _check_and_convert_codeword(self, codeword: FieldArray) -> FieldArray:
def _check_and_convert_codeword(self, codeword: ArrayLike) -> tuple[FieldArray, bool]:
"""
Converts the array-like codeword into a 2-D FieldArray with shape (N, ns).
"""
Expand Down
15 changes: 9 additions & 6 deletions src/galois/_lfsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
from __future__ import annotations

from typing import Any, Callable, Type, cast, overload
from typing import Any, Callable, cast, overload

import numba
import numpy as np
Expand Down Expand Up @@ -44,19 +44,20 @@ def __init__(
if not feedback_poly.coeffs[-1] == 1:
raise ValueError(f"Argument 'feedback_poly' must have a 0-th degree term of 1, not {feedback_poly}.")

self._field = feedback_poly.field
self._field = cast(type[FieldArray], feedback_poly.field)
self._feedback_poly = feedback_poly
self._characteristic_poly = feedback_poly.reverse()
self._order = feedback_poly.degree

if self._type == "fibonacci":
# T = [c_n-1, c_n-2, ..., c_1, c_0]
# c(x) = x^{n} - c_{n-1}x^{n-1} - c_{n-2}x^{n-2} - \dots - c_{1}x - c_{0}
self._taps = -self.characteristic_poly.coeffs[1:]
taps = -self.characteristic_poly.coeffs[1:]
else:
# T = [c_0, c_1, ..., c_n-2, c_n-1]
# c(x) = x^{n} - c_{n-1}x^{n-1} - c_{n-2}x^{n-2} - \dots - c_{1}x - c_{0}
self._taps = -self.characteristic_poly.coeffs[1:][::-1]
taps = -self.characteristic_poly.coeffs[1:][::-1]
self._taps = cast(FieldArray, taps)

if state is None:
state = self.field.Ones(self.order)
Expand Down Expand Up @@ -113,21 +114,22 @@ def step(self, steps: int = 1) -> FieldArray:

return y

def _step_forward(self, steps):
def _step_forward(self, steps: int) -> FieldArray:
assert steps > 0

if self._type == "fibonacci":
y, state = fibonacci_lfsr_step_forward_jit(self.field)(self.taps, self.state, steps)
else:
y, state = galois_lfsr_step_forward_jit(self.field)(self.taps, self.state, steps)
y = cast(FieldArray, y)

self._state[:] = state[:]
if y.size == 1:
y = y[0]

return y

def _step_backward(self, steps):
def _step_backward(self, steps: int) -> FieldArray:
assert steps > 0

if not self.characteristic_poly.coeffs[-1] > 0:
Expand All @@ -140,6 +142,7 @@ def _step_backward(self, steps):
y, state = fibonacci_lfsr_step_backward_jit(self.field)(self.taps, self.state, steps)
else:
y, state = galois_lfsr_step_backward_jit(self.field)(self.taps, self.state, steps)
y = cast(FieldArray, y)

self._state[:] = state[:]
if y.size == 1:
Expand Down
12 changes: 11 additions & 1 deletion src/galois/_ntt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
from __future__ import annotations

from typing import cast

import numpy as np

from ._fields import Field, FieldArray
Expand Down Expand Up @@ -235,7 +237,14 @@ def intt(
return _ntt(X, size=size, modulus=modulus, forward=False, scaled=scaled)


def _ntt(x, size=None, modulus=None, forward=True, scaled=True):
def _ntt(
x: ArrayLike,
size: int | None = None,
modulus: int | None = None,
forward: bool = True,
scaled: bool = True,
) -> FieldArray:
x = np.asarray(x)
verify_isinstance(size, int, optional=True)
verify_isinstance(modulus, int, optional=True)
verify_isinstance(forward, bool)
Expand Down Expand Up @@ -273,5 +282,6 @@ def _ntt(x, size=None, modulus=None, forward=True, scaled=True):
else:
norm = "backward" if scaled else "forward"
y = np.fft.ifft(x, n=size, norm=norm)
y = cast(FieldArray, y)

return y

0 comments on commit d3aaf2a

Please sign in to comment.