From fc9ca415d8ff454dcaff189b67dafb2dd2214411 Mon Sep 17 00:00:00 2001 From: Doug Branton Date: Mon, 2 Oct 2023 14:09:20 -0700 Subject: [PATCH 1/3] add calc_nobs --- src/tape/ensemble.py | 53 +++++++++++++++++++++++++++++++ tests/tape_tests/test_ensemble.py | 21 +++++++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 0ba059b8..8b5269af 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -510,6 +510,59 @@ def coalesce_partition(df, input_cols, output_col): return self + def calc_nobs(self, by_band=False, label="nobs"): + """Calculates the number of observations per lightcurve. + + Parameters + ---------- + by_band: `bool`, optional + If True, also calculates the number of observations for each band + in addition to providing the number of observations in total + label: `str`, optional + The label used to generate output columns. "_total" and the band + labels (e.g. "_g") are appended. + + Returns + ------- + ensemble: `tape.ensemble.Ensemble` + The ensemble object with nobs columns added to the object table. + """ + + obj_npartitions = self._object.npartitions # to repartition output columns + + if by_band: + counts = self._source.groupby([self._id_col, self._band_col])[self._time_col].aggregate("count") + + band_counts = ( + self._source.groupby([self._id_col])[self._band_col] # group by each object + .value_counts() # count occurence of each band + .to_frame() # convert series to dataframe + .reset_index() # break up the multiindex + .categorize(columns=[self._band_col]) # retype the band labels as categories + .pivot_table(values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum") + .repartition(obj_npartitions) # counts inherits the source partitions + ) # the pivot_table call makes each band_count a column of the id_col row + + # short-hand for calculating nobs_total + band_counts["total"] = band_counts[list(band_counts.columns)].sum(axis=1) + + bands = band_counts.columns.values + cols_to_add = {label + "_" + band: band_counts[band] for band in bands} + + else: + counts = self._source.groupby([self._id_col])[self._band_col].aggregate("count") + counts = counts.repartition(obj_npartitions) # counts inherits the source partitions + cols_to_add = {label + "_total": counts} + + self._object = self._object.assign(**cols_to_add) # assign new columns + + # remove the dataframes from memory + if by_band: + del band_counts + del counts + + return self + def prune(self, threshold=50, col_name=None): """remove objects with less observations than a given threshold diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 5ea9d225..30932c09 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -74,7 +74,7 @@ def test_from_parquet(data_fixture, request): "dask_dataframe_ensemble", "dask_dataframe_with_object_ensemble", "pandas_ensemble", - "pandas_with_object_ensemble" + "pandas_with_object_ensemble", ], ) def test_from_dataframe(data_fixture, request): @@ -109,6 +109,7 @@ def test_from_dataframe(data_fixture, request): amplitude = ens.batch(calc_stetson_J) assert len(amplitude) == 5 + def test_available_datasets(dask_client): """ Test that the ensemble is able to successfully read in the list of available TAPE datasets @@ -573,6 +574,24 @@ def test_keep_zeros(parquet_ensemble): assert new_objects_pdf.loc[i, c] == old_objects_pdf.loc[i, c] +@pytest.mark.parametrize("by_band", [True, False]) +def test_calc_nobs(parquet_ensemble, by_band): + ens = parquet_ensemble + ens._object = ens._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) + + ens.calc_nobs(by_band) + + lc = ens._object.loc[88472935274829959].compute() + + if by_band: + assert np.all([col in ens._object.columns for col in ["nobs_g", "nobs_r"]]) + assert lc["nobs_g"].values[0] == 98 + assert lc["nobs_r"].values[0] == 401 + + assert "nobs_total" in ens._object.columns + assert lc["nobs_total"].values[0] == 499 + + def test_prune(parquet_ensemble): """ Test that ensemble.prune() appropriately filters the dataframe From 8b399cf781cb93c224e85e2c068b2071bd058856 Mon Sep 17 00:00:00 2001 From: Doug Branton Date: Mon, 2 Oct 2023 14:30:25 -0700 Subject: [PATCH 2/3] add calc_nobs --- src/tape/ensemble.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 8b5269af..8aae3c8b 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -547,19 +547,12 @@ def calc_nobs(self, by_band=False, label="nobs"): band_counts["total"] = band_counts[list(band_counts.columns)].sum(axis=1) bands = band_counts.columns.values - cols_to_add = {label + "_" + band: band_counts[band] for band in bands} + self._object = self._object.assign(**{label + "_" + band: band_counts[band] for band in bands}) else: counts = self._source.groupby([self._id_col])[self._band_col].aggregate("count") counts = counts.repartition(obj_npartitions) # counts inherits the source partitions - cols_to_add = {label + "_total": counts} - - self._object = self._object.assign(**cols_to_add) # assign new columns - - # remove the dataframes from memory - if by_band: - del band_counts - del counts + self._object = self._object.assign(**{label + "_total": counts}) # assign new columns return self From 1e51dac15aea962f2b2440b446635a0e8eefab28 Mon Sep 17 00:00:00 2001 From: Doug Branton Date: Mon, 2 Oct 2023 14:34:46 -0700 Subject: [PATCH 3/3] add calc_nobs --- src/tape/ensemble.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 8aae3c8b..5d871853 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -531,8 +531,6 @@ def calc_nobs(self, by_band=False, label="nobs"): obj_npartitions = self._object.npartitions # to repartition output columns if by_band: - counts = self._source.groupby([self._id_col, self._band_col])[self._time_col].aggregate("count") - band_counts = ( self._source.groupby([self._id_col])[self._band_col] # group by each object .value_counts() # count occurence of each band