Skip to content

Commit

Permalink
User driven getitem and construction (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Aug 29, 2024
1 parent 6d74551 commit 60c3bd2
Show file tree
Hide file tree
Showing 13 changed files with 352 additions and 74 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ Changelog
0.9.0 (unreleased)
------------------

**New feature**
**New features**

- User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function.
- User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function.
- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx.

**Bug fixes**
Expand Down
19 changes: 7 additions & 12 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx.additional import shape
from ndonnx.additional._additional import _getitem as getitem
from ndonnx.additional._additional import _static_shape as static_shape

from ._corearray import _CoreArray
Expand Down Expand Up @@ -47,7 +48,11 @@ def array(
out : Array
The new array. This represents an ONNX model input.
"""
return Array._construct(shape=shape, dtype=dtype)
if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented:
return out
raise ndx.UnsupportedOperationError(
f"No implementation of `make_array` for {dtype}"
)


def from_spox_var(
Expand Down Expand Up @@ -154,17 +159,7 @@ def astype(self, to: CoreType | StructType) -> Array:
return ndx.astype(self, to)

def __getitem__(self, index: IndexType) -> Array:
if isinstance(index, Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
raise TypeError(
f"Index must be an integral or boolean 'Array', not `{index.dtype}`"
)

if isinstance(index, Array):
index = index._core()

return self._transmute(lambda corearray: corearray[index])
return getitem(self, index)

def __setitem__(
self, index: IndexType | Self, updates: int | bool | float | Array
Expand Down
24 changes: 10 additions & 14 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, unary_op, validate_core
Expand All @@ -22,7 +23,7 @@
from ndonnx import Array


class BooleanOperationsImpl(UniformShapeOperations):
class _BooleanOperationsImpl(OperationsBlock):
@validate_core
def equal(self, x, y) -> Array:
return binary_op(x, y, opx.equal)
Expand Down Expand Up @@ -163,17 +164,12 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

class BooleanOperationsImpl(
CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...

class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented

class NullableBooleanOperationsImpl(
NullableOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...
50 changes: 50 additions & 0 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from spox import Tensor, argument

import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx.additional as nda
from ndonnx._corearray import _CoreArray

from ._interface import OperationsBlock
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import Dtype


class CoreOperationsImpl(OperationsBlock):
def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> Array:
if not isinstance(dtype, dtypes.CoreType):
return NotImplemented
return ndx.Array._from_fields(
dtype,
data=_CoreArray(
dtype._parse_input(eager_value)["data"]
if eager_value is not None
else argument(Tensor(dtype.to_numpy_dtype(), shape))
),
)

@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
if null.dtype != ndx.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)
39 changes: 29 additions & 10 deletions ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@

from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np

import ndonnx as ndx
import ndonnx._data_types as dtypes

if TYPE_CHECKING:
from ndonnx._array import IndexType
from ndonnx._data_types import Dtype


class OperationsBlock:
"""Interface for data types to implement top-level functions exported by ndonnx."""
Expand Down Expand Up @@ -251,7 +257,7 @@ def cumulative_sum(
x,
*,
axis: int | None = None,
dtype: ndx.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
include_initial: bool = False,
):
return NotImplemented
Expand All @@ -270,7 +276,7 @@ def prod(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -293,7 +299,7 @@ def sum(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -305,7 +311,7 @@ def var(
axis=None,
keepdims: bool = False,
correction=0.0,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
) -> ndx.Array:
return NotImplemented

Expand Down Expand Up @@ -352,7 +358,7 @@ def full_like(self, x, fill_value, dtype=None, device=None) -> ndx.Array:
def ones(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented
Expand All @@ -365,14 +371,12 @@ def ones_like(
def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
def zeros_like(self, x, dtype: Dtype | None = None, device=None):
return NotImplemented

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
Expand Down Expand Up @@ -413,3 +417,18 @@ def can_cast(self, from_, to) -> bool:

def static_shape(self, x) -> tuple[int | None, ...]:
return NotImplemented

def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> ndx.Array:
return NotImplemented

def getitem(
self,
x: ndx.Array,
index: IndexType,
) -> ndx.Array:
return NotImplemented
15 changes: 14 additions & 1 deletion ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING, Union

import ndonnx as ndx

from ._interface import OperationsBlock
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import CoreType, StructType

Dtype = Union[CoreType, StructType]


class NullableOperationsImpl(OperationsBlock):
@validate_core
def fill_null(self, x, value):
def fill_null(self, x: Array, value) -> Array:
value = ndx.asarray(value)
if value.dtype != x.values.dtype:
value = value.astype(x.values.dtype)
return ndx.where(x.null, value, x.values)

@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
return NotImplemented
26 changes: 11 additions & 15 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import ndonnx.additional as nda
from ndonnx._utility import promote

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import (
Expand All @@ -36,7 +38,7 @@
from ndonnx._corearray import _CoreArray


class NumericOperationsImpl(UniformShapeOperations):
class _NumericOperationsImpl(OperationsBlock):
# elementwise.py

@validate_core
Expand Down Expand Up @@ -837,17 +839,6 @@ def var(
- correction
)

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

@validate_core
def can_cast(self, from_, to) -> bool:
if isinstance(from_, dtypes.CoreType) and isinstance(to, ndx.CoreType):
Expand Down Expand Up @@ -980,9 +971,14 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)


class NullableNumericOperationsImpl(NumericOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented
class NumericOperationsImpl(
CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


class NullableNumericOperationsImpl(
NullableOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


def _via_i64_f64(
Expand Down
Loading

0 comments on commit 60c3bd2

Please sign in to comment.