diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 406a5881..35eb7ead 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -607,41 +607,68 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): """ if by_band: - band_counts = ( - self._source.groupby([self._id_col])[self._band_col] # group by each object - .value_counts() # count occurence of each band - .to_frame() # convert series to dataframe - .rename(columns={self._band_col: "counts"}) # rename column - .reset_index() # break up the multiindex - .categorize(columns=[self._band_col]) # retype the band labels as categories - .pivot_table(values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum") - ) # the pivot_table call makes each band_count a column of the id_col row - # repartition the result to align with object if self._object.known_divisions: - band_counts.divisions = self._source.divisions - band_counts = band_counts.repartition(divisions=self._object.divisions) + # Grab these up front to help out the task graph + id_col = self._id_col + band_col = self._band_col + + # Get the band metadata + unq_bands = np.unique(self._source[band_col]) + meta = {band: float for band in unq_bands} + + # Map the groupby to each partition + band_counts = self._source.map_partitions( + lambda x: x.groupby(id_col)[[band_col]] + .value_counts() + .to_frame() + .reset_index() + .pivot_table(values=band_col, index=id_col, columns=band_col, aggfunc="sum"), + meta=meta, + ).repartition(divisions=self._object.divisions) else: + band_counts = ( + self._source.groupby([self._id_col])[self._band_col] # group by each object + .value_counts() # count occurence of each band + .to_frame() # convert series to dataframe + .rename(columns={self._band_col: "counts"}) # rename column + .reset_index() # break up the multiindex + .categorize(columns=[self._band_col]) # retype the band labels as categories + .pivot_table( + values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum" + ) + ) # the pivot_table call makes each band_count a column of the id_col row + band_counts = band_counts.repartition(npartitions=self._object.npartitions) # short-hand for calculating nobs_total band_counts["total"] = band_counts[list(band_counts.columns)].sum(axis=1) bands = band_counts.columns.values - self._object = self._object.assign(**{label + "_" + band: band_counts[band] for band in bands}) + self._object = self._object.assign( + **{label + "_" + str(band): band_counts[band] for band in bands} + ) if temporary: - self._object_temp.extend(label + "_" + band for band in bands) + self._object_temp.extend(label + "_" + str(band) for band in bands) else: - counts = self._source.groupby([self._id_col])[[self._band_col]].aggregate("count") - - # repartition the result to align with object if self._object.known_divisions and self._source.known_divisions: - counts.divisions = self._source.divisions - counts = counts.repartition(divisions=self._object.divisions) + # Grab these up front to help out the task graph + id_col = self._id_col + band_col = self._band_col + + # Map the groupby to each partition + counts = self._source.map_partitions( + lambda x: x.groupby([id_col])[[band_col]].aggregate("count") + ).repartition(divisions=self._object.divisions) else: - counts = counts.repartition(npartitions=self._object.npartitions) + # Just do a groupby on all source + counts = ( + self._source.groupby([self._id_col])[[self._band_col]] + .aggregate("count") + .repartition(npartitions=self._object.npartitions) + ) self._object = self._object.assign(**{label + "_total": counts[self._band_col]}) diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 9fbf1d53..f4bb7df2 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -823,10 +823,14 @@ def test_keep_zeros(parquet_ensemble): ], ) @pytest.mark.parametrize("by_band", [True, False]) -def test_calc_nobs(data_fixture, request, by_band): +@pytest.mark.parametrize("multi_partition", [True, False]) +def test_calc_nobs(data_fixture, request, by_band, multi_partition): # Get the Ensemble from a fixture ens = request.getfixturevalue(data_fixture) + if multi_partition: + ens._source = ens._source.repartition(3) + # Drop the existing nobs columns ens._object = ens._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1)