Skip to content

Commit

Permalink
CLN: Simplify advanced indexing broadcast calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 4, 2024
1 parent 4ac044f commit 48634df
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
18 changes: 8 additions & 10 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,8 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /):
return output
in_dim = self.shape
key, adv_inds = standardize_indexer(self, key)
if adv_inds:
adjacent, bcast_nd, bcast_start_axis = _determine_adv_broadcasting(adv_inds)
else:
adjacent, bcast_nd, bcast_start_axis = True, 0, 0
adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds)
# Handle moving around non-adjacent advanced axes
old_index = OldIndex(0)
pindexers: list[PartialReIndexer | list[PartialReIndexer]] = []
for key_ind, indexer in enumerate(key):
Expand Down Expand Up @@ -253,15 +251,15 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /):
adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa
bcast_nd = np.broadcast(*adv_indexers).nd
adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:]))
bcast_start_axis = 0 if not adjacent else min(adv_inds)
bcast_start_ax = 0 if not adjacent else min(adv_inds)
adv_map = {}

for idx_id, idxer in zip(adv_inds, adv_indexers):
base_idxer_ax_name = self._reverse_map[ # count non-None keys
len([id for id in range(idx_id) if key[id] is not None])
]
adv_map[base_idxer_ax_name] = [
bcast_start_axis + shp
bcast_start_ax + shp
for shp in _compare_bcast_shapes(bcast_nd, idxer.shape)
]

Expand All @@ -273,9 +271,9 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /):
[]
if len(ax_names) == 0:
if "ax_unk" not in adv_map.keys():
adv_map["ax_unk"] = [bcast_ax + bcast_start_axis]
adv_map["ax_unk"] = [bcast_ax + bcast_start_ax]
else:
adv_map["ax_unk"].append(bcast_ax + bcast_start_axis)
adv_map["ax_unk"].append(bcast_ax + bcast_start_ax)

for conflict_axis, conflict_names in conflicts.items():
new_name = "ax_"
Expand Down Expand Up @@ -493,11 +491,11 @@ def _move_idxs_to_front(li: list, idxs: Sequence) -> None:
def _determine_adv_broadcasting(
key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex]
) -> tuple:
"""Calculate the shape and location for the result of advanced indexing"""
"""Calculate the shape and location for the result of advanced indexing."""
adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:]))
adv_indexers = [np.array(key[i]) for i in adv_inds]
bcast_nd = np.broadcast(*adv_indexers).nd
bcast_start_axis = 0 if not adjacent else min(adv_inds)
bcast_start_axis = 0 if not adjacent else min(adv_inds) if adv_inds else None
return adjacent, bcast_nd, bcast_start_axis


Expand Down
11 changes: 8 additions & 3 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,14 +420,19 @@ def test_squeeze_to_sublist():


def test_determine_adv_broadcasting():
indexers = (np.ones(1), np.ones((4, 1)), np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [0, 1, 2])
indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3])
assert res_adj is True
assert res_nd == 2
assert res_start == 0
assert res_start == 1

indexers = (None, np.ones(1), 2, np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3])
assert res_adj is False
assert res_nd == 1
assert res_start == 0

res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [])
assert res_adj is True
assert res_nd == 0
assert res_start is None

0 comments on commit 48634df

Please sign in to comment.