Skip to content

Commit

Permalink
batch_by_band fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Mar 18, 2024
1 parent 9a2e7ef commit da39b86
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 39 deletions.
23 changes: 18 additions & 5 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,23 +1212,36 @@ def _standardize_batch(self, batch, on, by_band):

# Need to overwrite the meta manually as the multiindex will be
# interpretted by dask as a single "index" column
batch._meta = TapeFrame(columns=on + res_cols)

# [expr] added map_partitions meta assignment
#batch._meta = TapeFrame(columns=on + res_cols)
batch = batch.map_partitions(TapeFrame, meta = TapeFrame(columns=on + res_cols))

# Further reformatting for per-band results
# Pivots on the band column to generate a result column for each
# photometric band.
if by_band:
batch = batch.categorize(self._band_col)
batch = batch.pivot_table(index=on[0], columns=self._band_col, aggfunc="sum")

# [expr] added values
#import pdb;pdb.set_trace()
col_values = [col for col in batch.columns if col not in [on[0], self._band_col]]
batch = batch.pivot_table(index=on[0], columns=self._band_col, values=col_values, aggfunc="sum")

# Need to once again reestablish meta for the pivot
band_labels = batch.columns.values
out_cols = []
# To align with pandas pivot_table results, the columns should be generated in reverse order
for col in res_cols[::-1]:
for band in band_labels:
out_cols += [(str(col), str(band))]
batch._meta = TapeFrame(columns=out_cols) # apply new meta
# [expr] adjusted labeling
#out_cols += [(str(col), str(band))]
out_cols += [(str(band[0]), str(band[1]))]

#import pdb; pdb.set_trace()
# [expr] added map_partitions meta assignment
#batch._meta = TapeFrame(columns=out_cols) # apply new meta
#apply new meta
batch = batch.map_partitions(TapeFrame, meta = TapeFrame(columns=band_labels))

# Flatten the columns to a new column per band
batch.columns = ["_".join(col) for col in batch.columns.values]
Expand Down
File renamed without changes.
64 changes: 33 additions & 31 deletions src/tape/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
elemwise,
from_graph,
get_collection_type,
from_dict,
#from_dict,
)
from dask_expr._collection import new_collection
from dask_expr._collection import new_collection, from_dict
from dask_expr._expr import _emulate, ApplyConcatApply

from .ensemble_frame import TapeFrame, TapeSeries
#from .ensemble_frame import TapeFrame, TapeSeries

SOURCE_FRAME_LABEL = "source" # Reserved label for source table
OBJECT_FRAME_LABEL = "object" # Reserved label for object table.
Expand All @@ -53,6 +53,35 @@
ArrowDatasetEngine as DaskArrowDatasetEngine,
)

class TapeSeries(pd.Series):
"""A barebones extension of a Pandas series to be used for underlying Ensemble data.
See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures
"""

@property
def _constructor(self):
return TapeSeries

@property
def _constructor_sliced(self):
return TapeSeries


class TapeFrame(pd.DataFrame):
"""A barebones extension of a Pandas frame to be used for underlying Ensemble data.
See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures
"""

@property
def _constructor(self):
return TapeFrame

@property
def _constructor_expanddim(self):
return TapeFrame


class TapeArrowEngine(DaskArrowDatasetEngine):
"""
Expand Down Expand Up @@ -840,34 +869,7 @@ def repartition(
return self._propagate_metadata(result)


class TapeSeries(pd.Series):
"""A barebones extension of a Pandas series to be used for underlying Ensemble data.
See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures
"""

@property
def _constructor(self):
return TapeSeries

@property
def _constructor_sliced(self):
return TapeSeries


class TapeFrame(pd.DataFrame):
"""A barebones extension of a Pandas frame to be used for underlying Ensemble data.
See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures
"""

@property
def _constructor(self):
return TapeFrame

@property
def _constructor_expanddim(self):
return TapeFrame


class EnsembleSeries(_Frame, dd.Series):
Expand Down Expand Up @@ -976,7 +978,7 @@ def from_dict(
cl, data, npartitions, orient="columns", dtype=None, columns=None, label=None, ensemble=None
):
""""""
result = from_dict(data, npartitions=npartitions, orient=orient, dtype=dtype, columns=columns)
result = from_dict(data, npartitions=npartitions, orient=orient, dtype=dtype, columns=columns, constructor=TapeFrame)
result.label = label
result.ensemble = ensemble
return result
Expand Down
9 changes: 9 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,10 +1986,15 @@ def my_mean(flux):
# An EnsembleFrame should be returned
assert isinstance(res, EnsembleFrame)


#import pdb; pdb.set_trace()

# Make sure we get all the expected columns
assert all([col in res.columns for col in ["result_g", "result_r"]])

# These should be equivalent
# [expr] need this TODO: investigate typing issue
filter_res.index = filter_res.index.astype("int")
assert (
res.loc[88472935274829959]["result_g"]
.compute()
Expand Down Expand Up @@ -2018,6 +2023,10 @@ def my_bounds(flux):
assert all([col in res.columns for col in ["max_g", "max_r", "min_g", "min_r"]])

# These should be equivalent

# [expr] need this TODO: investigate typing issue
filter_res.index = filter_res.index.astype("int")

assert (
res.loc[88472935274829959]["max_g"]
.compute()
Expand Down
3 changes: 0 additions & 3 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,6 @@ def test_convert_flux_to_mag(data_fixture, request, err_col, zp_form, out_col_na
ens_frame.label = TEST_LABEL
ens_frame.ensemble = ens

print(type(ens_frame))
assert False

if zp_form == "flux":
ens_frame = ens_frame.convert_flux_to_mag("flux", "zp_flux", err_col, zp_form, out_col_name)

Expand Down

0 comments on commit da39b86

Please sign in to comment.