Skip to content

Commit

Permalink
method heuristics: Avoid dot product as much as possible (#347)
Browse files Browse the repository at this point in the history
* Another `method` detection optimization

* fix

* silence warnings

* silence one more warning

* Even better shortcut

* Update docs
  • Loading branch information
dcherian authored Mar 27, 2024
1 parent 307899a commit 4952fe9
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
15 changes: 12 additions & 3 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ERA5Dataset:
"""ERA5"""

def __init__(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="h"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 48))

Expand Down Expand Up @@ -143,7 +143,7 @@ class PerfectMonthly(Cohorts):
"""Perfectly chunked for a "cohorts" monthly mean climatology"""

def setup(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M"))
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="ME"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
self.by = self.time.dt.month.values - 1
Expand All @@ -164,7 +164,7 @@ def rechunk(self):
class ERA5Google(Cohorts):
def setup(self, *args, **kwargs):
TIME = 900 # 92044 in Google ARCO ERA5
self.time = pd.Series(pd.date_range("1959-01-01", freq="6H", periods=TIME))
self.time = pd.Series(pd.date_range("1959-01-01", freq="6h", periods=TIME))
self.axis = (2,)
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 1))
self.by = self.time.dt.day.values - 1
Expand Down Expand Up @@ -201,3 +201,12 @@ def setup(self, *args, **kwargs):
self.time = pd.Series(index)
self.by = self.time.dt.dayofyear.values - 1
self.expected = pd.RangeIndex(self.by.max() + 1)


class RandomBigArray(Cohorts):
def setup(self, *args, **kwargs):
M, N = 100_000, 20_000
self.array = dask.array.random.normal(size=(M, N), chunks=(10_000, N // 5)).T
self.by = np.random.choice(5_000, size=M)
self.expected = pd.RangeIndex(5000)
self.axis = (1,)
1 change: 1 addition & 0 deletions ci/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ channels:
- conda-forge
dependencies:
- asv
- build
- cachey
- dask-core
- numpy>=1.22
Expand Down
6 changes: 4 additions & 2 deletions docs/source/implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,10 @@ label overlaps with all other labels. The algorithm is as follows.
![cohorts-schematic](/../diagrams/containment.png)

1. To choose between `"map-reduce"` and `"cohorts"`, we need a summary measure of the degree to which the labels overlap with
each other. We use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
When sparsity > 0.6, we choose `"map-reduce"` since there is decent overlap between (any) cohorts. Otherwise we use `"cohorts"`.
each other. We can use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
We use sparsity(`S`) as an approximation for the sparsity(`C`) to avoid a potentially expensive sparse matrix dot product when `S`
isn't particularly sparse. When sparsity(`S`) > 0.4 (arbitrary), we choose `"map-reduce"` since there is decent overlap between
(any) cohorts. Otherwise we use `"cohorts"`.

Cool, isn't it?!

Expand Down
46 changes: 32 additions & 14 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,37 +363,55 @@ def invert(x) -> tuple[np.ndarray, ...]:
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
return "cohorts", chunks_cohorts

# Containment = |Q & S| / |Q|
# We'll use containment to measure degree of overlap between labels.
# Containment C = |Q & S| / |Q|
# - |X| is the cardinality of set X
# - Q is the query set being tested
# - S is the existing set
# We'll use containment to measure degree of overlap between labels. The bitmask
# matrix allows us to calculate this pretty efficiently.
asfloat = bitmask.astype(float)
# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
# makes it non-symmetric.
containment = csr_array((asfloat.T @ asfloat) / chunks_per_label)

# The containment matrix is a measure of how much the labels overlap
# with each other. We treat the sparsity = (nnz/size) as a summary measure of the net overlap.
# The bitmask matrix S allows us to calculate this pretty efficiently using a dot product.
# S.T @ S / chunks_per_label
#
# We treat the sparsity(C) = (nnz/size) as a summary measure of the net overlap.
# 1. For high enough sparsity, there is a lot of overlap and we should use "map-reduce".
# 2. When labels are uniformly distributed amongst all chunks
# (and number of labels < chunk size), sparsity is 1.
# 3. Time grouping cohorts (e.g. dayofyear) appear as lines in this matrix.
# 4. When there are no overlaps at all between labels, containment is a block diagonal matrix
# (approximately).
MAX_SPARSITY_FOR_COHORTS = 0.6 # arbitrary
sparsity = containment.nnz / math.prod(containment.shape)
#
# However computing S.T @ S can still be the slowest step, especially if S
# is not particularly sparse. Empirically the sparsity( S.T @ S ) > min(1, 2 x sparsity(S)).
# So we use sparsity(S) as a shortcut.
MAX_SPARSITY_FOR_COHORTS = 0.4 # arbitrary
sparsity = bitmask.nnz / math.prod(bitmask.shape)
preferred_method: Literal["map-reduce"] | Literal["cohorts"]
logger.debug(
"sparsity of bitmask is {}, threshold is {}".format( # noqa
sparsity, MAX_SPARSITY_FOR_COHORTS
)
)
if sparsity > MAX_SPARSITY_FOR_COHORTS:
logger.info("sparsity is {}".format(sparsity)) # noqa
if not merge:
logger.info("find_group_cohorts: merge=False, choosing 'map-reduce'")
logger.info(
"find_group_cohorts: bitmask sparsity={}, merge=False, choosing 'map-reduce'".format( # noqa
sparsity
)
)
return "map-reduce", {}
preferred_method = "map-reduce"
else:
preferred_method = "cohorts"

# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
# makes it non-symmetric.
asfloat = bitmask.astype(float)
containment = csr_array(asfloat.T @ asfloat / chunks_per_label)

logger.debug(
"sparsity of containment matrix is {}".format( # noqa
containment.nnz / math.prod(containment.shape)
)
)
# Use a threshold to force some merging. We do not use the filtered
# containment matrix for estimating "sparsity" because it is a bit
# hard to reason about.
Expand Down

0 comments on commit 4952fe9

Please sign in to comment.