Skip to content
This repository has been archived by the owner on Jan 14, 2025. It is now read-only.

Commit

Permalink
WIP: Ensemble.sample, _Frame.partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Jan 29, 2024
1 parent 6a694c4 commit 2bcb5a3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
51 changes: 51 additions & 0 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,57 @@ def persist(self, **kwargs):
self.update_frame(self.object.persist(**kwargs))
self.update_frame(self.source.persist(**kwargs))

def sample(self, overwrite=False, **kwargs):
"""Selects a sample of objects.
Parameters
----------
overwrite: boolean, optional
Indicates whether to overwrite the current ensemble (set True), or
create a new ensemble for the subset of objects (set False).
**kwargs:
keyword arguments passed along to
`dask.dataframe.DataFrame.sample`
Returns
----------
ensemble: `tape.ensemble.Ensemble`
A new ensemble with the subset of data selected
"""

# first do an object sync, ensure object table is up to date
self._lazy_sync_tables(table="object")

# sample on the object table
object_subset = self.object.sample(**kwargs)
object_subset.set_dirty(True)

if overwrite:
self.update_frame(object_subset)

# sync to source, removes all tied sources
self._lazy_sync_tables(table="source")

return self # current in-place implementation
else:
# make a new ensemble
# TODO: Investigate shared client warning
if self.client is not None:
new_ens = Ensemble(client=self.client)
else:
new_ens = Ensemble(client=False)

new_ens.update_frame(object_subset)
new_ens.update_frame(self.source)

# TODO: Add other frames? Sync against them?

# sync to source, removes all tied sources
new_ens._lazy_sync_tables(table="source")

return new_ens

def columns(self, table="object"):
"""Retrieve columns from dask dataframe"""
if table == "object":
Expand Down
26 changes: 25 additions & 1 deletion src/tape/ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dask
from dask.dataframe.dispatch import make_meta_dispatch
from dask.dataframe.backends import _nonempty_index, meta_nonempty, meta_nonempty_dataframe, _nonempty_series

from tape.utils import IndexCallable
from dask.dataframe.core import get_parallel_type
from dask.dataframe.extensions import make_array_nonempty

Expand Down Expand Up @@ -119,6 +119,29 @@ def _args(self):
# See https://github.com/geopandas/dask-geopandas/issues/237
return super()._args + (self.label, self.ensemble)

@property
def partitions(self):
"""Slice dataframe by partitions
This allows partitionwise slicing of a TAPE EnsembleFrame. You can perform normal
Numpy-style slicing, but now rather than slice elements of the array you
slice along partitions so, for example, ``df.partitions[:5]`` produces a new
Dask Dataframe of the first five partitions. Valid indexers are integers, sequences
of integers, slices, or boolean masks.
Examples
--------
>>> df.partitions[0] # doctest: +SKIP
>>> df.partitions[:3] # doctest: +SKIP
>>> df.partitions[::10] # doctest: +SKIP
Returns
-------
A TAPE EnsembleFrame Object
"""
self.set_dirty(True)
return IndexCallable(self._partitions, self.is_dirty(), self.ensemble)

def _propagate_metadata(self, new_frame):
"""Propagates any relevant metadata to a new frame.
Expand Down Expand Up @@ -1173,6 +1196,7 @@ def __init__(self, dsk, name, meta, divisions, ensemble=None):
self.label = OBJECT_FRAME_LABEL # A label used by the Ensemble to identify this frame.
self.ensemble = ensemble # The Ensemble object containing this frame.


@classmethod
def from_parquet(
cl,
Expand Down

0 comments on commit 2bcb5a3

Please sign in to comment.