Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate that int data is used for ElementIdentifiers on init, append, extend #1009

Merged
merged 2 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
### Bug fixes
- Fixed issue with custom class generation when a spec has a `name`. @rly [#1006](https://github.com/hdmf-dev/hdmf/pull/1006)

- Fixed issue where `ElementIdentifiers` data could be set to non-integer values. @rly [#1009](https://github.com/hdmf-dev/hdmf/pull/1009)

## HDMF 3.11.0 (October 30, 2023)

### Enhancements
Expand Down
20 changes: 17 additions & 3 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from . import register_class, EXP_NAMESPACE
from ..container import Container, Data
from ..data_utils import DataIO, AbstractDataChunkIterator
from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional
from ..utils import docval, getargs, ExtenderMeta, popargs, pystr, AllowPositional, check_type
from ..term_set import TermSetWrapper


Expand Down Expand Up @@ -211,8 +211,8 @@ class ElementIdentifiers(Data):
"""

@docval({'name': 'name', 'type': str, 'doc': 'the name of this ElementIdentifiers'},
{'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing identifiers',
'default': list()},
{'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing integer identifiers',
'default': list(), 'shape': (None,)},
allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -237,6 +237,20 @@ def __eq__(self, other):
# Find all matching locations
return np.in1d(self.data, search_ids).nonzero()[0]

def _validate_new_data(self, data):
# NOTE this may not cover all the many AbstractDataChunkIterator edge cases
if (isinstance(data, AbstractDataChunkIterator) or
(hasattr(data, "data") and isinstance(data.data, AbstractDataChunkIterator))):
if not np.issubdtype(data.dtype, np.integer):
raise ValueError("ElementIdentifiers must contain integers")
elif hasattr(data, "__len__") and len(data):
self._validate_new_data_element(data[0])

def _validate_new_data_element(self, arg):
if not check_type(arg, int):
raise ValueError("ElementIdentifiers must contain integers")
super()._validate_new_data_element(arg)


@register_class('DynamicTable')
class DynamicTable(Container):
Expand Down
18 changes: 18 additions & 0 deletions src/hdmf/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ class Data(AbstractContainer):
def __init__(self, **kwargs):
data = popargs('data', kwargs)
super().__init__(**kwargs)

self._validate_new_data(data)
self.__data = data

@property
Expand Down Expand Up @@ -822,6 +824,7 @@ def get(self, args):
return self.data[args]

def append(self, arg):
self._validate_new_data_element(arg)
self.__data = append_data(self.__data, arg)

def extend(self, arg):
Expand All @@ -831,8 +834,23 @@ def extend(self, arg):

:param arg: The iterable to add to the end of this VectorData
"""
self._validate_new_data(arg)
self.__data = extend_data(self.__data, arg)

def _validate_new_data(self, data):
"""Function to validate a new array that will be set or added to data. Raises an error if the data is invalid.

Subclasses should override this function to perform class-specific validation.
"""
pass

def _validate_new_data_element(self, arg):
"""Function to validate a new value that will be added to the data. Raises an error if the data is invalid.

Subclasses should override this function to perform class-specific validation.
"""
pass


class DataRegion(Data):

Expand Down
2 changes: 2 additions & 0 deletions src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,8 @@
return self.__shape[0]
if not self.valid:
raise InvalidDataIOError("Cannot get length of data. Data is not valid.")
if isinstance(self.data, AbstractDataChunkIterator):
return self.data.maxshape[0]

Check warning on line 1065 in src/hdmf/data_utils.py

View check run for this annotation

Codecov / codecov/patch

src/hdmf/data_utils.py#L1065

Added line #L1065 was not covered by tests
return len(self.data)

def __bool__(self):
Expand Down
12 changes: 6 additions & 6 deletions src/hdmf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
return tuple(__macros[key])


def __type_okay(value, argtype, allow_none=False):
def check_type(value, argtype, allow_none=False):
"""Check a value against a type

The difference between this function and :py:func:`isinstance` is that
Expand All @@ -87,7 +87,7 @@
return allow_none
if isinstance(argtype, str):
if argtype in __macros:
return __type_okay(value, __macros[argtype], allow_none=allow_none)
return check_type(value, __macros[argtype], allow_none=allow_none)

Check warning on line 90 in src/hdmf/utils.py

View check run for this annotation

Codecov / codecov/patch

src/hdmf/utils.py#L90

Added line #L90 was not covered by tests
elif argtype == 'uint':
return __is_uint(value)
elif argtype == 'int':
Expand All @@ -106,7 +106,7 @@
return __is_bool(value)
return isinstance(value, argtype)
elif isinstance(argtype, tuple) or isinstance(argtype, list):
return any(__type_okay(value, i) for i in argtype)
return any(check_type(value, i) for i in argtype)
else: # argtype is None
return True

Expand Down Expand Up @@ -279,7 +279,7 @@
# we can use this to unwrap the dataset/attribute to use the "item" for docval to validate the type.
argval = argval.value
if enforce_type:
if not __type_okay(argval, arg['type']):
if not check_type(argval, arg['type']):
if argval is None:
fmt_val = (argname, __format_type(arg['type']))
type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val)
Expand Down Expand Up @@ -336,7 +336,7 @@
# we can use this to unwrap the dataset/attribute to use the "item" for docval to validate the type.
argval = argval.value
if enforce_type:
if not __type_okay(argval, arg['type'], arg['default'] is None or arg.get('allow_none', False)):
if not check_type(argval, arg['type'], arg['default'] is None or arg.get('allow_none', False)):
if argval is None and arg['default'] is None:
fmt_val = (argname, __format_type(arg['type']))
type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val)
Expand Down Expand Up @@ -613,7 +613,7 @@
msg = 'docval for {}: enum checking cannot be used with arg type {}'.format(a['name'], a['type'])
raise Exception(msg)
# check that enum allowed values are allowed by arg type
if any([not __type_okay(x, a['type']) for x in a['enum']]):
if any([not check_type(x, a['type']) for x in a['enum']]):
msg = ('docval for {}: enum values are of types not allowed by arg type (got {}, '
'expected {})'.format(a['name'], [type(x) for x in a['enum']], a['type']))
raise Exception(msg)
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/common/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,38 @@ def test_identifier_search_with_bad_ids(self):
_ = (self.e == 'test')


class TestBadElementIdentifiers(TestCase):

def test_bad_dtype(self):
with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=["1", "2"])

with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=np.array(["1", "2"]))

with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=[1.0, 2.0])

def test_dci_int_ok(self):
a = np.arange(30)
dci = DataChunkIterator(data=a, buffer_size=1)
e = ElementIdentifiers(name='ids', data=dci) # test that no error is raised
self.assertIs(e.data, dci)

def test_dci_float_bad(self):
a = np.arange(30.0)
dci = DataChunkIterator(data=a, buffer_size=1)
with self.assertRaisesWith(ValueError, "ElementIdentifiers must contain integers"):
ElementIdentifiers(name='ids', data=dci)

def test_dataio_dci_ok(self):
a = np.arange(30)
dci = DataChunkIterator(data=a, buffer_size=1)
dio = H5DataIO(dci)
e = ElementIdentifiers(name='ids', data=dio) # test that no error is raised
self.assertIs(e.data, dio)


class SubTable(DynamicTable):

__columns__ = (
Expand Down
6 changes: 0 additions & 6 deletions tests/unit/utils_test/test_core_DataIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@

class DataIOTests(TestCase):

def setUp(self):
pass

def tearDown(self):
pass

def test_copy(self):
obj = DataIO(data=[1., 2., 3.])
obj_copy = copy(obj)
Expand Down