diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 7fda04f49..731c5080b 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -215,7 +215,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): base_indexer = key output = super().__getitem__(base_indexer) if not isinstance(output, AxesArray): - return output # why? + return output # return an element from the array in_dim = self.shape key, adv_inds = standardize_indexer(self, key) bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) @@ -386,17 +386,10 @@ def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]: """Replace ellipsis in indexers with the appropriate amount of slice(None)""" # [...].index errors if list contains numpy array ellind = [ind for ind, val in enumerate(key) if val is ...][0] - new_key = [] n_new_dims = sum(ax_key is None or isinstance(ax_key, str) for ax_key in key) n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1) - new_key = ( - key[:ellind] - + n_ellipsis_dims - * [ - slice(None), - ] - + key[ellind + 1 + n_ellipsis_dims :] - ) + new_key = key[:ellind] + key[ellind + 1 :] + new_key = new_key[:ellind] + (n_ellipsis_dims * [slice(None)]) + new_key[ellind:] return new_key diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index f5576f48b..e3910e29e 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -174,6 +174,15 @@ def test_simple_slice(): assert arr[0] == 1 +# @pytest.mark.skip # TODO: make this pass +def test_0d_indexer(): + arr = AxesArray(np.ones(2), {"ax_coord": 0}) + arr_out = arr[1, ...] + assert arr_out.ndim == 0 + assert arr_out.axes == {} + assert arr_out[()] == 1 + + def test_basic_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) @@ -428,3 +437,22 @@ def test_determine_adv_broadcasting(): res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) assert res_nd == 0 assert res_start is None + + +def test_replace_ellipsis(): + key = [..., 0] + result = axes._expand_indexer_ellipsis(key, 2) + expected = [slice(None), 0] + assert result == expected + + +def test_strip_ellipsis(): + key = [1, ...] + result = axes._expand_indexer_ellipsis(key, 1) + expected = [1] + assert result == expected + + key = [..., 1] + result = axes._expand_indexer_ellipsis(key, 1) + expected = [1] + assert result == expected