Skip to content

Commit

Permalink
Use set containment instead of perfect subsets (#291)
Browse files Browse the repository at this point in the history
* Use set containment instead of perfect subsets

xref #180

Containment = |Q & S| / |Q|
where
- |X| is the cardinality of set X
- Q is the query set being tested
- S is the existing set

https://ekzhu.com/datasketch/lshensemble.html#containment
  • Loading branch information
dcherian authored Nov 30, 2023
1 parent 769db63 commit 666d45e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 27 deletions.
6 changes: 5 additions & 1 deletion asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def track_num_layers(self):
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_layers.unit = "layers" # type: ignore[attr-defined] # Lazy
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers]:
f.repeat = 1 # type: ignore[attr-defined] # Lazy
f.rounds = 1 # type: ignore[attr-defined] # Lazy
f.number = 1 # type: ignore[attr-defined] # Lazy


class NWMMidwest(Cohorts):
Expand Down Expand Up @@ -83,7 +87,7 @@ def setup(self, *args, **kwargs):
class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 24))


class ERA5MonthHour(ERA5Dataset, Cohorts):
Expand Down
22 changes: 11 additions & 11 deletions asv_bench/benchmarks/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ def time_reduce(self, func, expected_name, engine):
expected_groups=expected_groups[expected_name],
)

@skip_for_params(numbagg_skip)
@parameterize({"func": funcs, "expected_name": expected_names, "engine": engines})
def peakmem_reduce(self, func, expected_name, engine):
flox.groupby_reduce(
self.array,
self.labels,
func=func,
engine=engine,
axis=self.axis,
expected_groups=expected_groups[expected_name],
)
# @skip_for_params(numbagg_skip)
# @parameterize({"func": funcs, "expected_name": expected_names, "engine": engines})
# def peakmem_reduce(self, func, expected_name, engine):
# flox.groupby_reduce(
# self.array,
# self.labels,
# func=func,
# engine=engine,
# axis=self.axis,
# expected_groups=expected_groups[expected_name],
# )


class ChunkReduce1D(ChunkReduce):
Expand Down
35 changes: 21 additions & 14 deletions flox/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import copy
import itertools
import math
import operator
Expand Down Expand Up @@ -304,14 +303,15 @@ def invert(x) -> tuple[np.ndarray, ...]:
# If our dataset has chunksize one along the axis,
# then no merging is possible.
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)

if not single_chunks and merge:
one_group_per_chunk = (bitmask.sum(axis=1) == 1).all()
if not one_group_per_chunk and not single_chunks and merge:
# First sort by number of chunks occupied by cohort
sorted_chunks_cohorts = dict(
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
)

items = tuple((k, set(k), v) for k, v in sorted_chunks_cohorts.items() if k)
# precompute needed metrics for the quadratic loop below.
items = tuple((k, len(k), set(k), v) for k, v in sorted_chunks_cohorts.items() if k)

merged_cohorts = {}
merged_keys: set[tuple] = set()
Expand All @@ -320,21 +320,28 @@ def invert(x) -> tuple[np.ndarray, ...]:
# and then merge in cohorts that are present in a subset of those chunks
# I think this is suboptimal and must fail at some point.
# But it might work for most cases. There must be a better way...
for idx, (k1, set_k1, v1) in enumerate(items):
for idx, (k1, len_k1, set_k1, v1) in enumerate(items):
if k1 in merged_keys:
continue
merged_cohorts[k1] = copy.deepcopy(v1)
for k2, set_k2, v2 in items[idx + 1 :]:
if k2 not in merged_keys and set_k2.issubset(set_k1):
merged_cohorts[k1].extend(v2)
merged_keys.update((k2,))

# make sure each cohort is sorted after merging
sorted_merged_cohorts = {k: sorted(v) for k, v in merged_cohorts.items()}
new_key = set_k1
new_value = v1
# iterate in reverse since we expect small cohorts
# to be most likely merged in to larger ones
for k2, len_k2, set_k2, v2 in reversed(items[idx + 1 :]):
if k2 not in merged_keys:
if (len(set_k2 & new_key) / len_k2) > 0.75:
new_key |= set_k2
new_value += v2
merged_keys.update((k2,))
sorted_ = sorted(new_value)
merged_cohorts[tuple(sorted(new_key))] = sorted_
if idx == 0 and (len(sorted_) == nlabels) and (np.array(sorted_) == ilabels).all():
break

# sort by first label in cohort
# This will help when sort=True (default)
# and we have to resort the dask array
return dict(sorted(sorted_merged_cohorts.items(), key=lambda kv: kv[1][0]))
return dict(sorted(merged_cohorts.items(), key=lambda kv: kv[1][0]))

else:
return chunks_cohorts
Expand Down
17 changes: 16 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,11 +848,26 @@ def test_rechunk_for_blockwise(inchunks, expected):
],
],
)
def test_find_group_cohorts(expected, labels, chunks, merge):
def test_find_group_cohorts(expected, labels, chunks: tuple[int], merge: bool) -> None:
actual = list(find_group_cohorts(labels, (chunks,), merge).values())
assert actual == expected, (actual, expected)


@pytest.mark.parametrize("chunksize", [12, 13, 14, 24, 36, 48, 72, 71])
def test_verify_complex_cohorts(chunksize: int) -> None:
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
chunks = (chunksize,) * (len(time) // chunksize)
by = np.array(time.dt.dayofyear.values)

if len(by) != sum(chunks):
chunks += (len(by) - sum(chunks),)
chunk_cohorts = find_group_cohorts(by - 1, (chunks,))
chunks_ = np.sort(np.concatenate(tuple(chunk_cohorts.keys())))
groups = np.sort(np.concatenate(tuple(chunk_cohorts.values())))
assert_equal(np.unique(chunks_), np.arange(len(chunks), dtype=int))
assert_equal(groups, np.arange(366, dtype=int))


@requires_dask
@pytest.mark.parametrize(
"chunk_at,expected",
Expand Down

0 comments on commit 666d45e

Please sign in to comment.