From 0ae8b87b2c9b0a505d58fc3d47de2d7f1e538b97 Mon Sep 17 00:00:00 2001 From: mhostetter Date: Sat, 22 Jun 2024 12:04:32 -0400 Subject: [PATCH] Ensure integer dtypes are returned from `log()` --- src/galois/_domains/_ufunc.py | 2 ++ src/galois/_fields/_array.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/galois/_domains/_ufunc.py b/src/galois/_domains/_ufunc.py index 6f85b633d..72168b6da 100644 --- a/src/galois/_domains/_ufunc.py +++ b/src/galois/_domains/_ufunc.py @@ -516,6 +516,8 @@ def __call__(self, ufunc, method, inputs, kwargs, meta): inputs = list(inputs) + [int(self.field.primitive_element)] inputs, kwargs = self._view_inputs_as_ndarray(inputs, kwargs) output = getattr(self.ufunc, method)(*inputs, **kwargs) + if output.dtype == np.object_: + output = output.astype(int) return output diff --git a/src/galois/_fields/_array.py b/src/galois/_fields/_array.py index 61b590727..5cf847d38 100644 --- a/src/galois/_fields/_array.py +++ b/src/galois/_fields/_array.py @@ -1730,6 +1730,8 @@ def log(self, base: ElementLike | ArrayLike | None = None) -> int | np.ndarray: if np.isscalar(output): output = int(output) + if output.dtype == np.object_: + output = output.astype(int) return output