Skip to content

Commit

Permalink
Merge pull request #288 from lincc-frameworks/calc_nobs_divisions
Browse files Browse the repository at this point in the history
Calc nobs divisions
  • Loading branch information
dougbrn authored Nov 15, 2023
2 parents 1082730 + 55ee6b7 commit 1d97c55
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
67 changes: 47 additions & 20 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,41 +607,68 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True):
"""

if by_band:
band_counts = (
self._source.groupby([self._id_col])[self._band_col] # group by each object
.value_counts() # count occurence of each band
.to_frame() # convert series to dataframe
.rename(columns={self._band_col: "counts"}) # rename column
.reset_index() # break up the multiindex
.categorize(columns=[self._band_col]) # retype the band labels as categories
.pivot_table(values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum")
) # the pivot_table call makes each band_count a column of the id_col row

# repartition the result to align with object
if self._object.known_divisions:
band_counts.divisions = self._source.divisions
band_counts = band_counts.repartition(divisions=self._object.divisions)
# Grab these up front to help out the task graph
id_col = self._id_col
band_col = self._band_col

# Get the band metadata
unq_bands = np.unique(self._source[band_col])
meta = {band: float for band in unq_bands}

# Map the groupby to each partition
band_counts = self._source.map_partitions(
lambda x: x.groupby(id_col)[[band_col]]
.value_counts()
.to_frame()
.reset_index()
.pivot_table(values=band_col, index=id_col, columns=band_col, aggfunc="sum"),
meta=meta,
).repartition(divisions=self._object.divisions)
else:
band_counts = (
self._source.groupby([self._id_col])[self._band_col] # group by each object
.value_counts() # count occurence of each band
.to_frame() # convert series to dataframe
.rename(columns={self._band_col: "counts"}) # rename column
.reset_index() # break up the multiindex
.categorize(columns=[self._band_col]) # retype the band labels as categories
.pivot_table(
values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum"
)
) # the pivot_table call makes each band_count a column of the id_col row

band_counts = band_counts.repartition(npartitions=self._object.npartitions)

# short-hand for calculating nobs_total
band_counts["total"] = band_counts[list(band_counts.columns)].sum(axis=1)

bands = band_counts.columns.values
self._object = self._object.assign(**{label + "_" + band: band_counts[band] for band in bands})
self._object = self._object.assign(
**{label + "_" + str(band): band_counts[band] for band in bands}
)

if temporary:
self._object_temp.extend(label + "_" + band for band in bands)
self._object_temp.extend(label + "_" + str(band) for band in bands)

else:
counts = self._source.groupby([self._id_col])[[self._band_col]].aggregate("count")

# repartition the result to align with object
if self._object.known_divisions and self._source.known_divisions:
counts.divisions = self._source.divisions
counts = counts.repartition(divisions=self._object.divisions)
# Grab these up front to help out the task graph
id_col = self._id_col
band_col = self._band_col

# Map the groupby to each partition
counts = self._source.map_partitions(
lambda x: x.groupby([id_col])[[band_col]].aggregate("count")
).repartition(divisions=self._object.divisions)
else:
counts = counts.repartition(npartitions=self._object.npartitions)
# Just do a groupby on all source
counts = (
self._source.groupby([self._id_col])[[self._band_col]]
.aggregate("count")
.repartition(npartitions=self._object.npartitions)
)

self._object = self._object.assign(**{label + "_total": counts[self._band_col]})

Expand Down
6 changes: 5 additions & 1 deletion tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,14 @@ def test_keep_zeros(parquet_ensemble):
],
)
@pytest.mark.parametrize("by_band", [True, False])
def test_calc_nobs(data_fixture, request, by_band):
@pytest.mark.parametrize("multi_partition", [True, False])
def test_calc_nobs(data_fixture, request, by_band, multi_partition):
# Get the Ensemble from a fixture
ens = request.getfixturevalue(data_fixture)

if multi_partition:
ens._source = ens._source.repartition(3)

# Drop the existing nobs columns
ens._object = ens._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1)

Expand Down

0 comments on commit 1d97c55

Please sign in to comment.