Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make NullableCore public #69

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Changelog
0.9.0 (unreleased)
------------------

**New feature**

- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx.

**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
Expand Down
2 changes: 2 additions & 0 deletions ndonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Floating,
Integral,
Nullable,
NullableCore,
NullableFloating,
NullableIntegral,
NullableNumerical,
Expand Down Expand Up @@ -323,6 +324,7 @@
"Floating",
"NullableIntegral",
"Nullable",
"NullableCore",
"Integral",
"CoreType",
"CastError",
Expand Down
4 changes: 2 additions & 2 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def can_cast(self, from_, to) -> bool:

@validate_core
def all(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, True, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(True, dtype=ndx.bool)
Expand All @@ -110,7 +110,7 @@ def all(self, x, *, axis=None, keepdims: bool = False):

@validate_core
def any(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, False, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(False, dtype=ndx.bool)
Expand Down
10 changes: 5 additions & 5 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def clip(
and isinstance(x.dtype, dtypes.Numerical)
):
x, min, max = promote(x, min, max)
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
out_null = x.null
x_values = x.values._core()
clipped = from_corearray(opx.clip(x_values, min._core(), max._core()))
Expand Down Expand Up @@ -856,7 +856,7 @@ def can_cast(self, from_, to) -> bool:

@validate_core
def all(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, True, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(True, dtype=ndx.bool)
Expand All @@ -866,7 +866,7 @@ def all(self, x, *, axis=None, keepdims: bool = False):

@validate_core
def any(self, x, *, axis=None, keepdims: bool = False):
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
x = ndx.where(x.null, False, x.values)
if functools.reduce(operator.mul, x._static_shape, 1) == 0:
return ndx.asarray(False, dtype=ndx.bool)
Expand Down Expand Up @@ -898,7 +898,7 @@ def arange(self, start, stop=None, step=None, dtype=None, device=None) -> ndx.Ar

@validate_core
def tril(self, x, k=0) -> ndx.Array:
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
# NumPy appears to just ignore the mask so we do the same
x = x.values
return x._transmute(
Expand All @@ -909,7 +909,7 @@ def tril(self, x, k=0) -> ndx.Array:

@validate_core
def triu(self, x, k=0) -> ndx.Array:
if isinstance(x.dtype, dtypes._NullableCore):
if isinstance(x.dtype, dtypes.NullableCore):
# NumPy appears to just ignore the mask so we do the same
x = x.values
return x._transmute(
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
if dtype is not None and not isinstance(
dtype, (dtypes.CoreType, dtypes._NullableCore)
dtype, (dtypes.CoreType, dtypes.NullableCore)
):
raise TypeError("'dtype' must be a CoreType or NullableCoreType")
if dtype in (None, dtypes.utf8, dtypes.nutf8):
Expand Down
6 changes: 3 additions & 3 deletions ndonnx/_core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def variadic_op(
):
args = promote(*args)
out_dtype = args[0].dtype
if not isinstance(out_dtype, (dtypes.CoreType, dtypes._NullableCore)):
if not isinstance(out_dtype, (dtypes.CoreType, dtypes.NullableCore)):
raise TypeError(
f"Expected ndx.Array with CoreType or NullableCoreType, got {args[0].dtype}"
)
Expand Down Expand Up @@ -100,7 +100,7 @@ def _via_dtype(
promoted = promote(*arrays)
out_dtype = promoted[0].dtype

if isinstance(out_dtype, dtypes._NullableCore) and out_dtype.values == dtype:
if isinstance(out_dtype, dtypes.NullableCore) and out_dtype.values == dtype:
dtype = out_dtype

values, nulls = split_nulls_and_values(
Expand Down Expand Up @@ -203,7 +203,7 @@ def validate_core(func):
def wrapper(*args, **kwargs):
for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, ndx.Array) and not isinstance(
arg.dtype, (dtypes.CoreType, dtypes._NullableCore)
arg.dtype, (dtypes.CoreType, dtypes.NullableCore)
):
return NotImplemented
return func(*args, **kwargs)
Expand Down
12 changes: 6 additions & 6 deletions ndonnx/_data_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
NullableUnsigned,
Numerical,
Unsigned,
_NullableCore,
NullableCore,
from_numpy_dtype,
get_finfo,
get_iinfo,
Expand All @@ -51,7 +51,7 @@
from .structtype import StructType


def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
def into_nullable(dtype: StructType | CoreType) -> NullableCore:
"""Return nullable counterpart, if present.

Parameters
Expand All @@ -61,7 +61,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore:

Returns
-------
out : _NullableCore
out : NullableCore
The nullable counterpart of the input type.

Raises
Expand Down Expand Up @@ -93,7 +93,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
return nuint64
elif dtype == utf8:
return nutf8
elif isinstance(dtype, _NullableCore):
elif isinstance(dtype, NullableCore):
return dtype
else:
raise ValueError(f"Cannot promote {dtype} to nullable")
Expand All @@ -103,14 +103,14 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
"Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. "
"To create nullable array, use 'ndonnx.additional.make_nullable' instead."
)
def promote_nullable(dtype: StructType | CoreType) -> _NullableCore:
def promote_nullable(dtype: StructType | CoreType) -> NullableCore:
return into_nullable(dtype)


__all__ = [
"CoreType",
"StructType",
"_NullableCore",
"NullableCore",
"NullableFloating",
"NullableIntegral",
"NullableUnsigned",
Expand Down
20 changes: 10 additions & 10 deletions ndonnx/_data_types/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _fields(self) -> dict[str, StructType | CoreType]:
}


class _NullableCore(Nullable[CoreType], CastMixin):
class NullableCore(Nullable[CoreType], CastMixin):
def copy(self) -> Self:
return self

Expand All @@ -213,7 +213,7 @@ def _schema(self) -> Schema:
return Schema(type_name=type(self).__name__, author="ndonnx")

def _cast_to(self, array: Array, dtype: CoreType | StructType) -> Array:
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
return ndx.Array._from_fields(
dtype,
values=self.values._cast_to(array.values, dtype.values),
Expand All @@ -230,7 +230,7 @@ def _cast_from(self, array: Array) -> Array:
values=self.values._cast_from(array),
null=ndx.zeros_like(array, dtype=Boolean()),
)
elif isinstance(array.dtype, _NullableCore):
elif isinstance(array.dtype, NullableCore):
return ndx.Array._from_fields(
self,
values=self.values._cast_from(array.values),
Expand All @@ -240,7 +240,7 @@ def _cast_from(self, array: Array) -> Array:
raise CastError(f"Cannot cast from {array.dtype} to {self}")


class NullableNumerical(_NullableCore):
class NullableNumerical(NullableCore):
"""Base class for nullable numerical data types."""

_ops: OperationsBlock = NullableNumericOperationsImpl()
Expand Down Expand Up @@ -312,14 +312,14 @@ class NFloat64(NullableFloating):
null = Boolean()


class NBoolean(_NullableCore):
class NBoolean(NullableCore):
values = Boolean()
null = Boolean()

_ops: OperationsBlock = NullableBooleanOperationsImpl()


class NUtf8(_NullableCore):
class NUtf8(NullableCore):
values = Utf8()
null = Boolean()

Expand Down Expand Up @@ -405,18 +405,18 @@ def _from_dtype(cls, dtype: CoreType) -> Finfo:
)


def get_finfo(dtype: _NullableCore | CoreType) -> Finfo:
def get_finfo(dtype: NullableCore | CoreType) -> Finfo:
try:
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
dtype = dtype.values
return Finfo._from_dtype(dtype)
except KeyError:
raise TypeError(f"'{dtype}' is not a floating point data type.")


def get_iinfo(dtype: _NullableCore | CoreType) -> Iinfo:
def get_iinfo(dtype: NullableCore | CoreType) -> Iinfo:
try:
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
dtype = dtype.values
return Iinfo._from_dtype(dtype)
except KeyError:
Expand Down
4 changes: 2 additions & 2 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy.typing as npt

import ndonnx._data_types as dtypes
from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore
from ndonnx._data_types import CastError, CastMixin, CoreType, NullableCore
from ndonnx._data_types.structtype import StructType
from ndonnx.additional import shape

Expand Down Expand Up @@ -291,7 +291,7 @@ def result_type(
np_dtypes = []
for dtype in observed_dtypes:
if isinstance(dtype, dtypes.StructType):
if isinstance(dtype, _NullableCore):
if isinstance(dtype, NullableCore):
nullable = True
np_dtypes.append(dtype.values.to_numpy_dtype())
else:
Expand Down