Skip to content
This repository has been archived by the owner on Jan 14, 2025. It is now read-only.

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Dec 15, 2023
1 parent 35fc80f commit 1a56e64
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 22 deletions.
34 changes: 12 additions & 22 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,6 @@ def s2n_inter_quartile_range(flux, err):
```
"""

print("In Batch")
self._lazy_sync_tables(table="all")

# Convert light-curve package feature into analysis function
Expand Down Expand Up @@ -1163,52 +1162,44 @@ def s2n_inter_quartile_range(flux, err):

# Output standardization
if isinstance(batch, EnsembleSeries):
print("EnsembleSeries")
if batch.name == self._id_col:
batch = batch.rename("result")
batch = EnsembleFrame.from_dask_dataframe(batch.to_frame())
print(type(batch))
if len(on) > 1:
batch = batch.reset_index()
# Need to overwrite the meta manually as the multiindex will be
# interpretted by dask as a single "index" column
batch._meta = TapeFrame(columns=on + ["result"])

if by_band:
batch = EnsembleFrame.from_dask_dataframe(
batch.categorize("band").pivot_table(
index=on[0], columns=self._band_col, aggfunc="sum"
)
batch = batch.categorize(self._band_col).pivot_table(
index=on[0], columns=self._band_col, aggfunc="sum"
)

# Need to once again reestablish meta for the pivot
band_labels = batch.columns.values

out_cols = []
for col in ["result"]:
for band in band_labels:
out_cols += [(str(col), str(band))]
batch._meta = TapeFrame(columns=out_cols)

# Flatten the columns to a new column per band
batch.columns = ["_".join(col).strip() for col in batch.columns.values]

# The pivot returns a dask dataframe, need to convert back
batch = EnsembleFrame.from_dask_dataframe(batch)

else:
batch = batch.set_index(on[0], sort=False)

elif isinstance(batch, EnsembleFrame):
print("EnsembleFrame")
if len(on) > 1:
res_cols = list(batch._meta.columns)
print(isinstance(batch, EnsembleFrame))
batch = batch.reset_index()
batch._meta = TapeFrame(columns=on + res_cols)
print(isinstance(batch, EnsembleFrame))
if by_band:
batch = batch.categorize("band")
print(isinstance(batch, EnsembleFrame))
batch = EnsembleFrame.from_dask_dataframe(
batch.pivot_table(index=on[0], columns=self._band_col, aggfunc="sum")
)
print(isinstance(batch, EnsembleFrame))
batch = batch.categorize(self._band_col)
batch = batch.pivot_table(index=on[0], columns=self._band_col, aggfunc="sum")

# Need to once again reestablish meta for the pivot
band_labels = batch.columns.values
Expand All @@ -1218,15 +1209,15 @@ def s2n_inter_quartile_range(flux, err):
for band in band_labels:
out_cols += [(str(col), str(band))]
batch._meta = TapeFrame(columns=out_cols)
print(isinstance(batch, EnsembleFrame))

# Flatten the columns to a new column per band
batch.columns = ["_".join(col).strip() for col in batch.columns.values]
print(isinstance(batch, EnsembleFrame))

# The pivot returns a dask dataframe, need to convert back
batch = EnsembleFrame.from_dask_dataframe(batch)

else:
batch = batch.set_index(on[0], sort=False)
print(isinstance(batch, EnsembleFrame))

# Inherit divisions if known from source and the resulting index is the id
# Groupby on index should always return a subset that adheres to the same divisions criteria
Expand All @@ -1240,7 +1231,6 @@ def s2n_inter_quartile_range(flux, err):
# Track the result frame under the provided label
self.add_frame(batch, label)

print(isinstance(batch, EnsembleFrame))
return batch

def from_pandas(
Expand Down
69 changes: 69 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,75 @@ def test_batch(data_fixture, request, use_map, on):
assert pytest.approx(result.values[1]["r"], 0.001) == -0.49639028


@pytest.mark.parametrize("on", [None, ["ps1_objid", "filterName"]])
@pytest.mark.parametrize("func_label", ["mean", "bounds"])
def test_batch_by_band(parquet_ensemble, func_label, on):
"""
Test that ensemble.batch(by_band=True) works as intended.
"""

if func_label == "mean":

def my_mean(flux):
"""returns a single value"""
return np.mean(flux)

res = parquet_ensemble.batch(my_mean, parquet_ensemble._flux_col, on=on, by_band=True)

parquet_ensemble.source.query(f"{parquet_ensemble._band_col}=='g'").update_ensemble()
filter_res = parquet_ensemble.batch(my_mean, parquet_ensemble._flux_col, on=on, by_band=False)

# An EnsembleFrame should be returned
assert isinstance(res, EnsembleFrame)

# Make sure we get all the expected columns
assert all([col in res.columns for col in ["result_g", "result_r"]])

# These should be equivalent
assert (
res.loc[88472935274829959]["result_g"]
.compute()
.equals(filter_res.loc[88472935274829959]["result"].compute())
)

elif func_label == "bounds":

def my_bounds(flux):
"""returns a series"""
return pd.Series({"min": np.min(flux), "max": np.max(flux)})

res = parquet_ensemble.batch(
my_bounds, "psFlux", on=on, by_band=True, meta={"min": float, "max": float}
)

parquet_ensemble.source.query(f"{parquet_ensemble._band_col}=='g'").update_ensemble()
filter_res = parquet_ensemble.batch(
my_bounds, "psFlux", on=on, by_band=False, meta={"min": float, "max": float}
)

# An EnsembleFrame should be returned
assert isinstance(res, EnsembleFrame)

# Make sure we get all the expected columns
assert all([col in res.columns for col in ["max_g", "max_r", "min_g", "min_r"]])

# These should be equivalent
assert (
res.loc[88472935274829959]["max_g"]
.compute()
.equals(filter_res.loc[88472935274829959]["max"].compute())
)
assert (
res.loc[88472935274829959]["min_g"]
.compute()
.equals(filter_res.loc[88472935274829959]["min"].compute())
)

# Meta should reflect the actual columns, this can get out of sync
# whenever multi-indexes are involved, which batch tries to handle
assert all([col in res.columns for col in res.compute().columns])


def test_batch_labels(parquet_ensemble):
"""
Test that ensemble.batch() generates unique labels for result frames when none are provided.
Expand Down

0 comments on commit 1a56e64

Please sign in to comment.