Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble Methods to retrieve a data subset #361

Merged
merged 15 commits into from
Jan 31, 2024
65 changes: 63 additions & 2 deletions docs/tutorials/working_with_the_ensemble.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
" err_col=\"error\",\n",
" band_col=\"band\",\n",
" npartitions=1,\n",
" sort=True,\n",
")"
]
},
Expand Down Expand Up @@ -130,7 +131,9 @@
")\n",
"\n",
"# Pass the ColumnMapper along to from_pandas\n",
"ens.from_pandas(source_frame=source_table, object_frame=object_table, column_mapper=col_map, npartitions=1)"
"ens.from_pandas(\n",
" source_frame=source_table, object_frame=object_table, column_mapper=col_map, npartitions=1, sort=True\n",
")"
]
},
{
Expand Down Expand Up @@ -201,10 +204,11 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inspection, Filtering, and Selecting\n",
"## Inspection and Filtering\n",
"\n",
"The `Ensemble` contains an assortment of functions for inspecting and filtering your data."
]
Expand Down Expand Up @@ -290,6 +294,40 @@
"ens.source.compute()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Several methods exist to access individual lightcurves within the `Ensemble`. First of which is the `to_timeseries` function. This allows you to supply a given object ID, and returns a `TimeSeries` object (see <working_with_the_timeseries>)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ens.to_timeseries(8003).data"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, if you aren't interested in a particular lightcurve, you can draw a random one from the `Ensemble` using `Ensemble.select_random_timeseries`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ens.select_random_timeseries(seed=1).data"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -397,6 +435,29 @@
"In the above operations, we remove any rows that have at least 1 NaN value present. And then filter such that only lightcurves which have at least 50 measurements are retained."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sampling\n",
"\n",
"In addition to filtering by specific constraints, it's possible to select a subset of your data to work with. `Ensemble.sample` will randomly select a fraction of objects from the full object list. This will return a new\n",
"ensemble object to work with."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"subset_ens = ens.sample(frac=0.5) # select ~half of the objects\n",
"\n",
"print(\"Number of pre-sampled objects: \", len(ens.object))\n",
"print(\"Number of post-sampled objects: \", len(subset_ens.object))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
59 changes: 59 additions & 0 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,65 @@ def persist(self, **kwargs):
self.update_frame(self.object.persist(**kwargs))
self.update_frame(self.source.persist(**kwargs))

def sample(self, frac=None, replace=False, random_state=None):
"""Selects a random sample of objects (sampling each partition).

This sampling will be lazily applied to the SourceFrame as well. A new
Ensemble object is created, and no additional EnsembleFrames will be
carried into the new Ensemble object. Most of docstring copied from
https://docs.dask.org/en/latest/generated/dask.dataframe.DataFrame.sample.html.

Parameters
----------
frac: float, optional
Approximate fraction of objects to return. This sampling fraction
is applied to all partitions equally. Note that this is an
approximate fraction. You should not expect exactly len(df) * frac
items to be returned, as the exact number of elements selected will
depend on how your data is partitioned (but should be pretty close
in practice).
replace: boolean, optional
Sample with or without replacement. Default = False.
random_state: int or np.random.RandomState
If an int, we create a new RandomState with this as the seed;
Otherwise we draw from the passed RandomState.

Returns
----------
ensemble: `tape.ensemble.Ensemble`
A new ensemble with the subset of data selected

"""

# first do an object sync, ensure object table is up to date
self._lazy_sync_tables(table="object")

# sample on the object table
object_subset = self.object.sample(frac=frac, replace=replace, random_state=random_state)

# make a new ensemble
if self.client is not None:
new_ens = Ensemble(client=self.client)

# turn off cleanups -- in the case where multiple ensembles are
# using a client, an individual ensemble should not close the
# client during an __exit__ or __del__ event. This means that
# the client will not be closed without an explicit client.close()
# call, which is unfortunate... not sure of an alternative way
# forward.
Comment on lines +515 to +520
Copy link
Collaborator

@wilsonbb wilsonbb Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel there is a path for some ugly answers to this (implement our own client manager with reference counting, each ensemble keeps track of its parents/children in a tree-like structure that gets updated on-exit, etc) but not sure we need to solve this now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, agree that this is something to keep an eye on, but feels out of scope for this PR, I'll make an issue

self.cleanup_client = False
new_ens.cleanup_client = False
else:
new_ens = Ensemble(client=False)

new_ens.update_frame(object_subset)
new_ens.update_frame(self.source.copy())

# sync to source, removes all tied sources
new_ens._lazy_sync_tables(table="source")

return new_ens

def columns(self, table="object"):
"""Retrieve columns from dask dataframe"""
if table == "object":
Expand Down
55 changes: 54 additions & 1 deletion src/tape/ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dask
from dask.dataframe.dispatch import make_meta_dispatch
from dask.dataframe.backends import _nonempty_index, meta_nonempty, meta_nonempty_dataframe, _nonempty_series

from tape.utils import IndexCallable
from dask.dataframe.core import get_parallel_type
from dask.dataframe.extensions import make_array_nonempty

Expand Down Expand Up @@ -119,6 +119,29 @@ def _args(self):
# See https://github.com/geopandas/dask-geopandas/issues/237
return super()._args + (self.label, self.ensemble)

@property
def partitions(self):
"""Slice dataframe by partitions

This allows partitionwise slicing of a TAPE EnsembleFrame. You can perform normal
Numpy-style slicing, but now rather than slice elements of the array you
slice along partitions so, for example, ``df.partitions[:5]`` produces a new
Dask Dataframe of the first five partitions. Valid indexers are integers, sequences
of integers, slices, or boolean masks.

Examples
--------
>>> df.partitions[0] # doctest: +SKIP
>>> df.partitions[:3] # doctest: +SKIP
>>> df.partitions[::10] # doctest: +SKIP

Returns
-------
A TAPE EnsembleFrame Object
"""
self.set_dirty(True)
return IndexCallable(self._partitions, self.is_dirty(), self.ensemble)

def _propagate_metadata(self, new_frame):
"""Propagates any relevant metadata to a new frame.

Expand Down Expand Up @@ -208,6 +231,36 @@ def query(self, expr, **kwargs):
result.set_dirty(True)
return result

def sample(self, **kwargs):
"""Random sample of items from a Dataframe.

Doc string below derived from dask.dataframe.core

Parameters
----------
frac: float, optional
Approximate fraction of objects to return. This sampling fraction
is applied to all partitions equally. Note that this is an
approximate fraction. You should not expect exactly len(df) * frac
items to be returned, as the exact number of elements selected will
depend on how your data is partitioned (but should be pretty close
in practice).
replace: boolean, optional
Sample with or without replacement. Default = False.
random_state: int or np.random.RandomState
If an int, we create a new RandomState with this as the seed;
Otherwise we draw from the passed RandomState.

Returns
----------
result: `tape._Frame`
The modifed frame

"""
result = self._propagate_metadata(super().sample(**kwargs))
result.set_dirty(True)
return result

def merge(self, right, **kwargs):
"""Merge the Dataframe with another DataFrame

Expand Down
1 change: 1 addition & 0 deletions src/tape/ensemble_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
The following package-level methods can be used to create a new Ensemble object
by reading in the given data source.
"""

import requests

import dask.dataframe as dd
Expand Down
1 change: 1 addition & 0 deletions src/tape/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .column_mapper import ColumnMapper # noqa
from .utils import IndexCallable # noqa
23 changes: 23 additions & 0 deletions src/tape/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class IndexCallable:
"""Provide getitem syntax for functions

>>> def inc(x):
... return x + 1

>>> I = IndexCallable(inc)
>>> I[3]
4
"""

__slots__ = ("fn", "dirty", "ensemble")

def __init__(self, fn, dirty, ensemble):
self.fn = fn
self.dirty = dirty # propagate metadata
self.ensemble = ensemble # propagate ensemble metadata

def __getitem__(self, key):
result = self.fn(key)
result.dirty = self.dirty # propagate metadata
result.ensemble = self.ensemble # propagate ensemble metadata
return result
32 changes: 32 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,38 @@ def test_persist(dask_client):
assert new_graph_size < old_graph_size


@pytest.mark.parametrize(
"data_fixture",
[
"parquet_ensemble_with_divisions",
"parquet_ensemble_without_client",
dougbrn marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_sample(data_fixture, request):
"""
Test Ensemble.sample
"""

ens = request.getfixturevalue(data_fixture)
ens.source.repartition(npartitions=10).update_ensemble()
ens.object.repartition(npartitions=5).update_ensemble()

prior_obj_len = len(ens.object)
prior_src_len = len(ens.source)

new_ens = ens.sample(frac=0.3)

assert len(new_ens.object) < prior_obj_len # frac is not exact
assert len(new_ens.source) < prior_src_len # should affect source

# ens should not have been affected
assert len(ens.object) == prior_obj_len
assert len(ens.source) == prior_src_len

if data_fixture == "parquet_ensemble_with_divisions":
ens.client.close() # sample_objects disables client cleanup, must do manually


def test_update_column_map(dask_client):
"""
Test that we can update the column maps in an Ensemble.
Expand Down
30 changes: 30 additions & 0 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,33 @@ def test_coalesce(dask_client, drop_inputs):
# The input columns should still be present
for col in ["flux1", "flux2", "flux3"]:
assert col in ens.source.columns


def test_partition_slicing(parquet_ensemble_with_divisions):
"""
Test that partition slicing propagates EnsembleFrame metadata
"""
ens = parquet_ensemble_with_divisions

ens.source.repartition(npartitions=10).update_ensemble()
ens.object.repartition(npartitions=5).update_ensemble()

prior_obj_len = len(ens.object)
prior_src_len = len(ens.source)

# slice on object
ens.object.partitions[0:3].update_ensemble()
ens._lazy_sync_tables("all") # sync needed as len() won't trigger one

assert ens.object.npartitions == 3 # should return exactly 3 partitions
assert len(ens.source) < prior_src_len # should affect source

prior_obj_len = len(ens.object)
prior_src_len = len(ens.source)

# slice on source
ens.source.partitions[0:2].update_ensemble()
ens._lazy_sync_tables("all") # sync needed as len() won't trigger one

assert ens.source.npartitions == 2 # should return exactly 2 partitions
assert len(ens.object) < prior_src_len # should affect objects
19 changes: 18 additions & 1 deletion tests/tape_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
import pytest
from tape.utils import ColumnMapper
from tape.utils import ColumnMapper, IndexCallable


def test_index_callable(parquet_ensemble):
"""
Test the basic function of the IndexCallable object
"""

ens = parquet_ensemble

source_ic = IndexCallable(ens.source._partitions, True, "ensemble")

# grab the first (and only) source partition
sliced_source_frame = source_ic[0]

# ensure that the metadata propagates to the result
assert sliced_source_frame.dirty is True
assert sliced_source_frame.ensemble == "ensemble"


def test_column_mapper():
Expand Down
Loading