From 48634dff44ea0bde1d27b96418e2a21f27eef8c7 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 10:30:05 +0000 Subject: [PATCH] CLN: Simplify advanced indexing broadcast calculation --- pysindy/utils/axes.py | 18 ++++++++---------- test/utils/test_axes.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index b30b20d45..f5347146e 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -196,10 +196,8 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): return output in_dim = self.shape key, adv_inds = standardize_indexer(self, key) - if adv_inds: - adjacent, bcast_nd, bcast_start_axis = _determine_adv_broadcasting(adv_inds) - else: - adjacent, bcast_nd, bcast_start_axis = True, 0, 0 + adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) + # Handle moving around non-adjacent advanced axes old_index = OldIndex(0) pindexers: list[PartialReIndexer | list[PartialReIndexer]] = [] for key_ind, indexer in enumerate(key): @@ -253,7 +251,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa bcast_nd = np.broadcast(*adv_indexers).nd adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) - bcast_start_axis = 0 if not adjacent else min(adv_inds) + bcast_start_ax = 0 if not adjacent else min(adv_inds) adv_map = {} for idx_id, idxer in zip(adv_inds, adv_indexers): @@ -261,7 +259,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): len([id for id in range(idx_id) if key[id] is not None]) ] adv_map[base_idxer_ax_name] = [ - bcast_start_axis + shp + bcast_start_ax + shp for shp in _compare_bcast_shapes(bcast_nd, idxer.shape) ] @@ -273,9 +271,9 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): [] if len(ax_names) == 0: if "ax_unk" not in adv_map.keys(): - adv_map["ax_unk"] = [bcast_ax + bcast_start_axis] + adv_map["ax_unk"] = [bcast_ax + bcast_start_ax] else: - adv_map["ax_unk"].append(bcast_ax + bcast_start_axis) + adv_map["ax_unk"].append(bcast_ax + bcast_start_ax) for conflict_axis, conflict_names in conflicts.items(): new_name = "ax_" @@ -493,11 +491,11 @@ def _move_idxs_to_front(li: list, idxs: Sequence) -> None: def _determine_adv_broadcasting( key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] ) -> tuple: - """Calculate the shape and location for the result of advanced indexing""" + """Calculate the shape and location for the result of advanced indexing.""" adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) adv_indexers = [np.array(key[i]) for i in adv_inds] bcast_nd = np.broadcast(*adv_indexers).nd - bcast_start_axis = 0 if not adjacent else min(adv_inds) + bcast_start_axis = 0 if not adjacent else min(adv_inds) if adv_inds else None return adjacent, bcast_nd, bcast_start_axis diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index a32879d97..b37b20b83 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -420,14 +420,19 @@ def test_squeeze_to_sublist(): def test_determine_adv_broadcasting(): - indexers = (np.ones(1), np.ones((4, 1)), np.ones(3)) - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [0, 1, 2]) + indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3)) + res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3]) assert res_adj is True assert res_nd == 2 - assert res_start == 0 + assert res_start == 1 indexers = (None, np.ones(1), 2, np.ones(3)) res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3]) assert res_adj is False assert res_nd == 1 assert res_start == 0 + + res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) + assert res_adj is True + assert res_nd == 0 + assert res_start is None