diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 0bf83b91..5506e676 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1095,7 +1095,6 @@ def s2n_inter_quartile_range(flux, err): ``` """ - print("In Batch") self._lazy_sync_tables(table="all") # Convert light-curve package feature into analysis function @@ -1163,52 +1162,44 @@ def s2n_inter_quartile_range(flux, err): # Output standardization if isinstance(batch, EnsembleSeries): - print("EnsembleSeries") if batch.name == self._id_col: batch = batch.rename("result") batch = EnsembleFrame.from_dask_dataframe(batch.to_frame()) - print(type(batch)) if len(on) > 1: batch = batch.reset_index() # Need to overwrite the meta manually as the multiindex will be # interpretted by dask as a single "index" column batch._meta = TapeFrame(columns=on + ["result"]) + if by_band: - batch = EnsembleFrame.from_dask_dataframe( - batch.categorize("band").pivot_table( - index=on[0], columns=self._band_col, aggfunc="sum" - ) + batch = batch.categorize(self._band_col).pivot_table( + index=on[0], columns=self._band_col, aggfunc="sum" ) # Need to once again reestablish meta for the pivot band_labels = batch.columns.values - out_cols = [] for col in ["result"]: for band in band_labels: out_cols += [(str(col), str(band))] batch._meta = TapeFrame(columns=out_cols) - # Flatten the columns to a new column per band batch.columns = ["_".join(col).strip() for col in batch.columns.values] + + # The pivot returns a dask dataframe, need to convert back + batch = EnsembleFrame.from_dask_dataframe(batch) + else: batch = batch.set_index(on[0], sort=False) elif isinstance(batch, EnsembleFrame): - print("EnsembleFrame") if len(on) > 1: res_cols = list(batch._meta.columns) - print(isinstance(batch, EnsembleFrame)) batch = batch.reset_index() batch._meta = TapeFrame(columns=on + res_cols) - print(isinstance(batch, EnsembleFrame)) if by_band: - batch = batch.categorize("band") - print(isinstance(batch, EnsembleFrame)) - batch = EnsembleFrame.from_dask_dataframe( - batch.pivot_table(index=on[0], columns=self._band_col, aggfunc="sum") - ) - print(isinstance(batch, EnsembleFrame)) + batch = batch.categorize(self._band_col) + batch = batch.pivot_table(index=on[0], columns=self._band_col, aggfunc="sum") # Need to once again reestablish meta for the pivot band_labels = batch.columns.values @@ -1218,15 +1209,15 @@ def s2n_inter_quartile_range(flux, err): for band in band_labels: out_cols += [(str(col), str(band))] batch._meta = TapeFrame(columns=out_cols) - print(isinstance(batch, EnsembleFrame)) # Flatten the columns to a new column per band batch.columns = ["_".join(col).strip() for col in batch.columns.values] - print(isinstance(batch, EnsembleFrame)) + + # The pivot returns a dask dataframe, need to convert back + batch = EnsembleFrame.from_dask_dataframe(batch) else: batch = batch.set_index(on[0], sort=False) - print(isinstance(batch, EnsembleFrame)) # Inherit divisions if known from source and the resulting index is the id # Groupby on index should always return a subset that adheres to the same divisions criteria @@ -1240,7 +1231,6 @@ def s2n_inter_quartile_range(flux, err): # Track the result frame under the provided label self.add_frame(batch, label) - print(isinstance(batch, EnsembleFrame)) return batch def from_pandas( diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index ef151271..a576828f 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -1653,6 +1653,75 @@ def test_batch(data_fixture, request, use_map, on): assert pytest.approx(result.values[1]["r"], 0.001) == -0.49639028 +@pytest.mark.parametrize("on", [None, ["ps1_objid", "filterName"]]) +@pytest.mark.parametrize("func_label", ["mean", "bounds"]) +def test_batch_by_band(parquet_ensemble, func_label, on): + """ + Test that ensemble.batch(by_band=True) works as intended. + """ + + if func_label == "mean": + + def my_mean(flux): + """returns a single value""" + return np.mean(flux) + + res = parquet_ensemble.batch(my_mean, parquet_ensemble._flux_col, on=on, by_band=True) + + parquet_ensemble.source.query(f"{parquet_ensemble._band_col}=='g'").update_ensemble() + filter_res = parquet_ensemble.batch(my_mean, parquet_ensemble._flux_col, on=on, by_band=False) + + # An EnsembleFrame should be returned + assert isinstance(res, EnsembleFrame) + + # Make sure we get all the expected columns + assert all([col in res.columns for col in ["result_g", "result_r"]]) + + # These should be equivalent + assert ( + res.loc[88472935274829959]["result_g"] + .compute() + .equals(filter_res.loc[88472935274829959]["result"].compute()) + ) + + elif func_label == "bounds": + + def my_bounds(flux): + """returns a series""" + return pd.Series({"min": np.min(flux), "max": np.max(flux)}) + + res = parquet_ensemble.batch( + my_bounds, "psFlux", on=on, by_band=True, meta={"min": float, "max": float} + ) + + parquet_ensemble.source.query(f"{parquet_ensemble._band_col}=='g'").update_ensemble() + filter_res = parquet_ensemble.batch( + my_bounds, "psFlux", on=on, by_band=False, meta={"min": float, "max": float} + ) + + # An EnsembleFrame should be returned + assert isinstance(res, EnsembleFrame) + + # Make sure we get all the expected columns + assert all([col in res.columns for col in ["max_g", "max_r", "min_g", "min_r"]]) + + # These should be equivalent + assert ( + res.loc[88472935274829959]["max_g"] + .compute() + .equals(filter_res.loc[88472935274829959]["max"].compute()) + ) + assert ( + res.loc[88472935274829959]["min_g"] + .compute() + .equals(filter_res.loc[88472935274829959]["min"].compute()) + ) + + # Meta should reflect the actual columns, this can get out of sync + # whenever multi-indexes are involved, which batch tries to handle + assert all([col in res.columns for col in res.compute().columns]) + + def test_batch_labels(parquet_ensemble): """ Test that ensemble.batch() generates unique labels for result frames when none are provided.