Skip to content

Commit

Permalink
from_flat and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Aug 23, 2024
1 parent 24fb348 commit d81f5fd
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 13 deletions.
80 changes: 67 additions & 13 deletions src/nested_dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def _rebuild(self, graph, func, args): # type: ignore
return collection


def _nested_meta_from_flat(flat, name):
"""construct meta for a packed series from a flat dataframe"""
pd_fields = flat.dtypes.to_dict() # grabbing pandas dtypes
pyarrow_fields = {} # grab underlying pyarrow dtypes
for field, dtype in pd_fields.items():
if hasattr(dtype, "pyarrow_dtype"):
pyarrow_fields[field] = dtype.pyarrow_dtype
else: # or convert from numpy types
pyarrow_fields[field] = pa.from_numpy_dtype(dtype)
return pd.Series(name=name, dtype=NestedDtype.from_fields(pyarrow_fields))


class NestedFrame(
_Frame, dd.DataFrame
): # can use dd.DataFrame instead of dx.DataFrame if the config is set true (default in >=2024.3.0)
Expand All @@ -70,17 +82,6 @@ def __getitem__(self, item):
else:
return super().__getitem__(item)

def _nested_meta_from_flat(self, flat, name):
"""construct meta for a packed series from a flat dataframe"""
pd_fields = flat.dtypes.to_dict() # grabbing pandas dtypes
pyarrow_fields = {} # grab underlying pyarrow dtypes
for field, dtype in pd_fields.items():
if hasattr(dtype, "pyarrow_dtype"):
pyarrow_fields[field] = dtype.pyarrow_dtype
else: # or convert from numpy types
pyarrow_fields[field] = pa.from_numpy_dtype(dtype)
return pd.Series(name=name, dtype=NestedDtype.from_fields(pyarrow_fields))

def __setitem__(self, key, value):
"""Adds custom __setitem__ behavior for nested columns"""

Expand All @@ -102,7 +103,7 @@ def __setitem__(self, key, value):
new_flat = new_flat.astype({col: pd.ArrowDtype(pa.string())})

# pack the modified df back into a nested column
meta = self._nested_meta_from_flat(new_flat, nested)
meta = _nested_meta_from_flat(new_flat, nested)
packed = new_flat.map_partitions(lambda x: pack(x), meta=meta)
return super().__setitem__(nested, packed)

Expand All @@ -114,7 +115,7 @@ def __setitem__(self, key, value):
value.name = col
value = value.to_frame()

meta = self._nested_meta_from_flat(value, new_nested)
meta = _nested_meta_from_flat(value, new_nested)
packed = value.map_partitions(lambda x: pack(x), meta=meta)
return super().__setitem__(new_nested, packed)

Expand Down Expand Up @@ -280,6 +281,59 @@ def from_map(
)
return NestedFrame.from_dask_dataframe(nf)

@classmethod
def from_flat(cls, df, base_columns, nested_columns=None, index=None, name="nested"):
"""Creates a NestedFrame with base and nested columns from a flat
dataframe.
Parameters
----------
df: pd.DataFrame or NestedFrame
A flat dataframe.
base_columns: list-like
The columns that should be used as base (flat) columns in the
output dataframe.
nested_columns: list-like, or None
The columns that should be packed into a nested column. All columns
in the list will attempt to be packed into a single nested column
with the name provided in `nested_name`. If None, is defined as all
columns not in `base_columns`.
index: str, or None
The name of a column to use as the new index. Typically, the index
should have a unique value per row for base columns, and should
repeat for nested columns. For example, a dataframe with two
columns; a=[1,1,1,2,2,2] and b=[5,10,15,20,25,30] would want an
index like [0,0,0,1,1,1] if a is chosen as a base column. If not
provided the current index will be used.
name:
The name of the output column the `nested_columns` are packed into.
Returns
-------
NestedFrame
A NestedFrame with the specified nesting structure.
"""
# Handle the meta
# Pathway 1: Some base columns and one nested column -> nestedframe
# Pathway 2: Only a single nested column -> nestedframe as defined in npd
# Pathway 3: Only a set of base columns, technically possible -> nestedframe

if nested_columns is None:
nested_columns = [col for col in df.columns if (col not in base_columns) and col != index]

meta = npd.NestedFrame(df[base_columns]._meta)

if len(nested_columns) > 0:
nested_meta = _nested_meta_from_flat(df[nested_columns], name)
meta = meta.join(nested_meta)

return df.map_partitions(
lambda x: npd.NestedFrame.from_flat(
df=x, base_columns=base_columns, nested_columns=nested_columns, index=index, name=name
),
meta=meta,
)

def compute(self, **kwargs):
"""Compute this Dask collection, returning the underlying dataframe or series."""
return npd.NestedFrame(super().compute(**kwargs))
Expand Down
70 changes: 70 additions & 0 deletions tests/nested_dask/test_nestedframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,76 @@ def test_add_nested(test_dataset_no_add_nested):
assert len(base_with_nested.compute()) == 50


def test_from_flat():
"""Test the from_flat wrapping, make sure meta is assigned correctly"""

nf = nd.NestedFrame.from_pandas(
npd.NestedFrame(
{
"a": [1, 1, 1, 2, 2, 2],
"b": [2, 2, 2, 4, 4, 4],
"c": [1, 2, 3, 4, 5, 6],
"d": [2, 4, 6, 8, 10, 12],
},
index=[0, 0, 0, 1, 1, 1],
)
)

# Check full inputs
ndf = nd.NestedFrame.from_flat(nf, base_columns=["a", "b"], nested_columns=["c", "d"])
assert list(ndf.columns) == ["a", "b", "nested"]
assert list(ndf["nested"].nest.fields) == ["c", "d"]
ndf_comp = ndf.compute()
assert list(ndf.columns) == list(ndf_comp.columns)
assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields)
assert len(ndf_comp) == 2

# Check omitting a base column
ndf = nd.NestedFrame.from_flat(nf, base_columns=["a"], nested_columns=["c", "d"])
assert list(ndf.columns) == ["a", "nested"]
assert list(ndf["nested"].nest.fields) == ["c", "d"]
ndf_comp = ndf.compute()
assert list(ndf.columns) == list(ndf_comp.columns)
assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields)
assert len(ndf_comp) == 2

# Check omitting a nested column
ndf = nd.NestedFrame.from_flat(nf, base_columns=["a", "b"], nested_columns=["d"])
assert list(ndf.columns) == ["a", "b", "nested"]
assert list(ndf["nested"].nest.fields) == ["d"]
ndf_comp = ndf.compute()
assert list(ndf.columns) == list(ndf_comp.columns)
assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields)
assert len(ndf_comp) == 2

# Check no base columns
ndf = nd.NestedFrame.from_flat(nf, base_columns=[], nested_columns=["c", "d"])
assert list(ndf.columns) == ["nested"]
assert list(ndf["nested"].nest.fields) == ["c", "d"]
ndf_comp = ndf.compute()
assert list(ndf.columns) == list(ndf_comp.columns)
assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields)
assert len(ndf_comp) == 2

# Check inferred nested columns
ndf = nd.NestedFrame.from_flat(nf, base_columns=["a", "b"])
assert list(ndf.columns) == ["a", "b", "nested"]
assert list(ndf["nested"].nest.fields) == ["c", "d"]
ndf_comp = ndf.compute()
assert list(ndf.columns) == list(ndf_comp.columns)
assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields)
assert len(ndf_comp) == 2

# Check using an index
ndf = nd.NestedFrame.from_flat(nf, base_columns=["b"], index="a")
assert list(ndf.columns) == ["b", "nested"]
assert list(ndf["nested"].nest.fields) == ["c", "d"]
ndf_comp = ndf.compute()
assert list(ndf.columns) == list(ndf_comp.columns)
assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields)
assert len(ndf_comp) == 2


def test_query_on_base(test_dataset):
"""test the query function on base columns"""

Expand Down

0 comments on commit d81f5fd

Please sign in to comment.