Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Mar 27, 2024
1 parent dbd5b4c commit 4ee4856
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 14 deletions.
1 change: 0 additions & 1 deletion src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,6 @@ def bin_sources(

# Group the columns by id, band, and time bucket and aggregate.
result = self.source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs)

# Fix the indices and remove the temporary column.
result = self.source._propagate_metadata(
result.reset_index().set_index(self._id_col).drop(columns=[tmp_time_col])
Expand Down
49 changes: 44 additions & 5 deletions src/tape/ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def map_partitions(self, func, *args, **kwargs):
return self._propagate_metadata(result)
elif isinstance(result, ObjectFrame):
result = self._propagate_metadata(result)
result.label = "object" # override the label
result.label = OBJECT_FRAME_LABEL # override the label
return result
elif isinstance(result, SourceFrame):
return self._propagate_metadata(result)

Check warning on line 741 in src/tape/ensemble_frame.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble_frame.py#L741

Added line #L741 was not covered by tests
Expand Down Expand Up @@ -949,11 +949,47 @@ def update_ensemble(self):

@classmethod
def from_dict(
cl, data, npartitions, orient="columns", dtype=None, columns=None, label=None, ensemble=None
cls, data, npartitions, orient="columns", dtype=None, columns=None, label=None, ensemble=None
):
""""""
"""
Construct a Tape EnsembleFrame from a Python Dictionary
Parameters
----------
data : dict
Of the form {field : array-like} or {field : dict}.
npartitions : int
The number of partitions of the index to create. Note that depending on
the size and index of the dataframe, the output may have fewer
partitions than requested.
orient : {'columns', 'index', 'tight'}, default 'columns'
The "orientation" of the data. If the keys of the passed dict
should be the columns of the resulting DataFrame, pass 'columns'
(default). Otherwise if the keys should be rows, pass 'index'.
If 'tight', assume a dict with keys
['index', 'columns', 'data', 'index_names', 'column_names'].
dtype: bool
Data type to force, otherwise infer.
columns: string, optional
Column labels to use when ``orient='index'``. Raises a ValueError
if used with ``orient='columns'`` or ``orient='tight'``.
label: `str`, optional
The label used to by the Ensemble to identify the frame.
ensemble: `tape.ensemble.Ensemble`, optional
A link to the Ensemble object that owns this frame.
Returns
----------
result: `tape.EnsembleFrame`
The constructed EnsembleFrame object.
"""
result = from_dict(
data, npartitions=npartitions, orient=orient, dtype=dtype, columns=columns, constructor=TapeFrame
data,
npartitions=npartitions,
orient=orient,
dtype=dtype,
columns=columns,
constructor=cls._partition_type,
)
result.label = label
result.ensemble = ensemble
Expand Down Expand Up @@ -1298,6 +1334,9 @@ def from_dask_dataframe(cl, df, ensemble=None):
# The following should ensure that any Dask Dataframes which use TapeSeries or TapeFrames as their
# underlying data will be resolved as EnsembleFrames or EnsembleSeries as their parrallel
# counterparts. The underlying Dask Dataframe _meta will be a TapeSeries or TapeFrame.
#
# Note that with the change to the dask-expr backend, the `get_collection_type` method
# is used to register instead of the previously used `get_parallel_type`.


get_collection_type.register(TapeSeries, lambda _: EnsembleSeries)
Expand Down Expand Up @@ -1342,7 +1381,7 @@ def make_meta_frame(x, index=None):


@meta_nonempty.register(TapeObjectFrame)
def _nonempty_tapesourceframe(x, index=None):
def _nonempty_tapeobjectframe(x, index=None):
# Construct a new TapeObjectFrame with the same underlying data.
df = meta_nonempty_dataframe(x)
return TapeObjectFrame(df)
Expand Down
2 changes: 1 addition & 1 deletion tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def test_insert(parquet_ensemble):
assert new_source.shape[0] == old_size + 10


def test_insert_paritioned():
def test_insert_partitioned():
ens = Ensemble()

# Create all fake source data with known divisions.
Expand Down
32 changes: 25 additions & 7 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,8 @@
from tape import (
Ensemble,
ColumnMapper,
# EnsembleFrame,
# ObjectFrame,
# SourceFrame,
# TapeObjectFrame,
# TapeSourceFrame,
# TapeFrame,
)

# from .ensemble_frame import (
from tape.ensemble_frame import (
EnsembleFrame,
EnsembleSeries,
Expand Down Expand Up @@ -53,6 +46,17 @@ def test_from_dict(data_fixture, request):
# inherited dask compute method must be called to obtain the result.
assert ens_frame.flux.max().compute() == 80.6

# Test SourceFrame Meta/Partition Typing
src_frame = SourceFrame.from_dict(data, npartitions=1)
assert isinstance(src_frame, SourceFrame)
assert isinstance(src_frame._meta, TapeSourceFrame)

# Test ObjectFrame Meta/Partition Typing
# use a dummy dict as the above is source data
obj_frame = ObjectFrame.from_dict({"a": [1]}, npartitions=1)
assert isinstance(obj_frame, ObjectFrame)
assert isinstance(obj_frame._meta, TapeObjectFrame)


@pytest.mark.parametrize(
"data_fixture",
Expand Down Expand Up @@ -344,6 +348,20 @@ def test_object_and_source_frame_propagation(data_fixture, request):
assert merged_frame.is_dirty()


def test_map_partitions_metadata(parquet_ensemble):
# Get Source and object frames to test joins on.
source_frame, object_frame = parquet_ensemble.source.copy(), parquet_ensemble.object.copy()

src_res = source_frame.map_partitions(lambda x: x)
obj_res = object_frame.map_partitions(lambda x: x)

assert src_res.label == SOURCE_LABEL
assert obj_res.label == OBJECT_LABEL

assert src_res.ensemble == parquet_ensemble
assert obj_res.ensemble == parquet_ensemble


def test_object_and_source_joins(parquet_ensemble):
"""
Test that SourceFrame and ObjectFrame metadata and class type are correctly propagated across
Expand Down

0 comments on commit 4ee4856

Please sign in to comment.