Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A minimal Dask Dataframe subclass for the Ensemble #209

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tape/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .analysis import * # noqa
from .ensemble import * # noqa
from .ensemble_frame import * # noqa
from .timeseries import * # noqa
173 changes: 173 additions & 0 deletions src/tape/ensemble_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import dask.dataframe as dd

import dask
from dask.dataframe.dispatch import make_meta_dispatch
from dask.dataframe.backends import _nonempty_index, meta_nonempty, meta_nonempty_dataframe

from dask.dataframe.core import get_parallel_type
from dask.dataframe.extensions import make_array_nonempty

import pandas as pd

class _Frame(dd.core._Frame):
"""Base class for extensions of Dask Dataframes that track additional Ensemble-related metadata."""

def __init__(self, dsk, name, meta, divisions, label=None, ensemble=None):
super().__init__(dsk, name, meta, divisions)
self.label = label # A label used by the Ensemble to identify this frame.
self.ensemble = ensemble # The Ensemble object containing this frame.

@property
def _args(self):
# Ensure our Dask extension can correctly be used by pickle.
# See https://github.com/geopandas/dask-geopandas/issues/237
return super()._args + (self.label, self.ensemble)

def _propagate_metadata(self, new_frame):
"""Propagatees any relevant metadata to a new frame.

Parameters
----------
new_frame: `_Frame`
| A frame to propage metadata to

Returns
----------
new_frame: `_Frame`
The modifed frame
"""
new_frame.label = self.label
new_frame.ensemble = self.ensemble
return new_frame

def copy(self):
self_copy = super().copy()
return self._propagate_metadata(self_copy)

class TapeSeries(pd.Series):
"""A barebones extension of a Pandas series to be used for underlying Ensmeble data.

See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures
"""
@property
def _constructor(self):
return TapeSeries

@property
def _constructor_sliced(self):
return TapeSeries

class TapeFrame(pd.DataFrame):
"""A barebones extension of a Pandas frame to be used for underlying Ensmeble data.

See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures
"""
@property
def _constructor(self):
return TapeFrame

@property
def _constructor_expanddim(self):
return TapeFrame


class EnsembleSeries(_Frame, dd.core.Series):
"""A barebones extension of a Dask Series for Ensemble data.
"""
_partition_type = TapeSeries # Tracks the underlying data type

class EnsembleFrame(_Frame, dd.core.DataFrame):
"""An extension for a Dask Dataframe for data used by a lightcurve Ensemble.

The underlying non-parallel dataframes are TapeFrames and TapeSeries which extend Pandas frames.

Example
----------
import tape
ens = tape.Ensemble()
data = {...} # Some data you want tracked by the Ensemble
ensemble_frame = tape.EnsembleFrame.from_dict(data, label="my_frame", ensemble=ens)
"""
_partition_type = TapeFrame # Tracks the underlying data type

def __getitem__(self, key):
result = super().__getitem__(key)
if isinstance(result, _Frame):
# Ensures that we have any
result = self._propagate_metadata(result)
return result

@classmethod
def from_tapeframe(
cls, data, npartitions=None, chunksize=None, sort=True, label=None, ensemble=None
):
""" Returns an EnsembleFrame constructed from a TapeFrame.
Parameters
----------
data: `TapeFrame`
Frame containing the underlying data fro the EnsembleFram
npartitions: `int`, optional
The number of partitions of the index to create. Note that depending on
the size and index of the dataframe, the output may have fewer
partitions than requested.
chunksize: `int`, optional
Size of the individual chunks of data in non-parallel objects that make up Dask frames.
sort: `bool`, optional
Whether to sort the frame by a default index.
label: `str`, optional
| The label used to by the Ensemble to identify the frame.
ensemble: `tape.Ensemble`, optional
| A linnk to the Ensmeble object that owns this frame.
Returns
result: `tape.EnsembleFrame`
The constructed EnsembleFrame object.
"""
result = dd.from_pandas(data, npartitions=npartitions, chunksize=chunksize, sort=sort)
result.label = label
result.ensemble = ensemble
return result
"""
Dask Dataframes are constructed indirectly using method dispatching and inference on the
underlying data. So to ensure our subclasses behave correctly, we register the methods
below.

For more information, see https://docs.dask.org/en/latest/dataframe-extend.html

The following should ensure that any Dask Dataframes which use TapeSeries or TapeFrames as their
underlying data will be resolved as EnsembleFrames or EnsembleSeries as their parrallel
counterparts. The underlying Dask Dataframe _meta will be a TapeSeries or TapeFrame.
"""
get_parallel_type.register(TapeSeries, lambda _: EnsembleSeries)
get_parallel_type.register(TapeFrame, lambda _: EnsembleFrame)

@make_meta_dispatch.register(TapeSeries)
def make_meta_series(x, index=None):
# Create an empty TapeSeries to use as Dask's underlying object meta.
result = x.head(0)
# Re-index if requested
if index is not None:
result = result.reindex(index[:0])
return result

@make_meta_dispatch.register(TapeFrame)
def make_meta_frame(x, index=None):
# Create an empty TapeFrame to use as Dask's underlying object meta.
result = x.head(0)
# Re-index if requested
if index is not None:
result = result.reindex(index[:0])
return result

@meta_nonempty.register(TapeSeries)
def _nonempty_tapeseries(x, index=None):
# Construct a new TapeSeries with the same underlying data.
if index is None:
index = _nonempty_index(x.index)
data = make_array_nonempty(x.dtype)
return TapeSeries(data, name=x.name, crs=x.crs)

@meta_nonempty.register(TapeFrame)
def _nonempty_tapeseries(x, index=None):
# Construct a new TapeFrame with the same underlying data.
df = meta_nonempty_dataframe(x)
return TapeFrame(df)
20 changes: 20 additions & 0 deletions tests/tape_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,23 @@ def parquet_ensemble_from_hipscat(dask_client):
)

return ens

# pylint: disable=redefined-outer-name
@pytest.fixture
def ensemble_from_source_dict(dask_client):
"""Create an Ensemble from a source dict, returning the ensemble and the source dict."""
ens = Ensemble(client=dask_client)

# Create some fake data with two IDs (8001, 8002), two bands ["g", "b"]
# a few time steps, and flux.
source_dict = {
"id": [8001, 8001, 8001, 8001, 8002, 8002, 8002, 8002, 8002],
"time": [10.1, 10.2, 10.2, 11.1, 11.2, 11.3, 11.4, 15.0, 15.1],
"band": ["g", "g", "b", "g", "b", "g", "g", "g", "g"],
"err": [1.0, 2.0, 1.0, 3.0, 2.0, 3.0, 4.0, 5.0, 6.0],
"flux": [1.0, 2.0, 5.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0],
}
cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band")
ens.from_source_dict(source_dict, column_mapper=cmap)

return ens, source_dict
105 changes: 105 additions & 0 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
""" Test EnsembleFrame (inherited from Dask.DataFrame) creation and manipulations. """
import pandas as pd
from tape import Ensemble, EnsembleFrame, TapeFrame

import pytest

TEST_LABEL = "test_frame"

# pylint: disable=protected-access
@pytest.mark.parametrize(
"data_fixture",
[
"ensemble_from_source_dict",
],
)
def test_from_dict(data_fixture, request):
"""
Test creating an EnsembleFrame from a dictionary and verify that dask lazy evaluation was appropriately inherited.
"""
_, data = request.getfixturevalue(data_fixture)
ens_frame = EnsembleFrame.from_dict(data,
npartitions=1)

assert isinstance(ens_frame, EnsembleFrame)
assert isinstance(ens_frame._meta, TapeFrame)

# The calculation for finding the max flux from the data. Note that the
# inherited dask compute method must be called to obtain the result.
assert ens_frame.flux.max().compute() == 5.0

@pytest.mark.parametrize(
"data_fixture",
[
"ensemble_from_source_dict",
],
)
def test_from_pandas(data_fixture, request):
"""
Test creating an EnsembleFrame from a Pandas dataframe and verify that dask lazy evaluation was appropriately inherited.
"""
ens, data = request.getfixturevalue(data_fixture)
frame = TapeFrame(data)
ens_frame = EnsembleFrame.from_tapeframe(frame,
label=TEST_LABEL,
ensemble=ens,
npartitions=1)

assert isinstance(ens_frame, EnsembleFrame)
assert isinstance(ens_frame._meta, TapeFrame)
assert ens_frame.label == TEST_LABEL
assert ens_frame.ensemble is ens

# The calculation for finding the max flux from the data. Note that the
# inherited dask compute method must be called to obtain the result.
assert ens_frame.flux.max().compute() == 5.0


@pytest.mark.parametrize(
"data_fixture",
[
"ensemble_from_source_dict",
],
)
def test_frame_propagation(data_fixture, request):
"""
Test ensuring that slices and copies of an EnsembleFrame or still the same class.
"""
ens, data = request.getfixturevalue(data_fixture)
ens_frame = EnsembleFrame.from_dict(data,
npartitions=1)
# Set a label and ensemble for the frame and copies/transformations retain them.
ens_frame.label = TEST_LABEL
ens_frame.ensemble=ens

# Create a copy of an EnsembleFrame and verify that it's still a proper
# EnsembleFrame with appropriate metadata propagated.
copied_frame = ens_frame.copy()
assert isinstance(copied_frame, EnsembleFrame)
assert isinstance(copied_frame._meta, TapeFrame)
assert copied_frame.label == TEST_LABEL
assert copied_frame.ensemble == ens

# Test that a filtered EnsembleFrame is still an EnsembleFrame.
filtered_frame = ens_frame[["id", "time"]]
assert isinstance(filtered_frame, EnsembleFrame)
assert isinstance(filtered_frame._meta, TapeFrame)
assert filtered_frame.label == TEST_LABEL
assert filtered_frame.ensemble == ens

# Test that the output of an EnsembleFrame query is still an EnsembleFrame
queried_rows = ens_frame.query("flux > 3.0")
assert isinstance(queried_rows, EnsembleFrame)
assert isinstance(filtered_frame._meta, TapeFrame)
assert filtered_frame.label == TEST_LABEL
assert filtered_frame.ensemble == ens

# Test that head returns a subset of the underlying TapeFrame.
h = ens_frame.head(5)
assert isinstance(h, TapeFrame)
assert len(h) == 5

# Test that the inherited dask.DataFrame.compute method returns
# the underlying TapeFrame.
assert isinstance(ens_frame.compute(), TapeFrame)
assert len(ens_frame) == len(ens_frame.compute())