Skip to content

Commit

Permalink
add accessor to_lists and to_flat + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed May 13, 2024
1 parent fe39cd4 commit f43551a
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 2 deletions.
32 changes: 31 additions & 1 deletion src/dask_nested/accessor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask.dataframe as dd
import nested_pandas as npd
from dask.dataframe.extensions import register_series_accessor
from nested_pandas import NestedDtype
Expand All @@ -13,7 +14,6 @@ class DaskNestSeriesAccessor(npd.NestSeriesAccessor):
----------
series: dd.series
A series to tie to the accessor
"""

def __init__(self, series):
Expand All @@ -33,3 +33,33 @@ def fields(self) -> list[str]:
"""Names of the nested columns"""

return self._series.head(0).nest.fields # hacky

def to_lists(self, fields: list[str] | None = None) -> dd.DataFrame:
"""Convert nested series into dataframe of list-array columns
Parameters
----------
fields : list[str] or None, optional
Names of the fields to include. Default is None, which means all fields.
Returns
-------
dd.DataFrame
Dataframe of list-arrays.
"""
return self._series.map_partitions(lambda x: x.nest.to_lists(fields=fields))

def to_flat(self, fields: list[str] | None = None) -> dd.DataFrame:
"""Convert nested series into dataframe of flat arrays
Parameters
----------
fields : list[str] or None, optional
Names of the fields to include. Default is None, which means all fields.
Returns
-------
dd.DataFrame
Dataframe of flat arrays.
"""
return self._series.map_partitions(lambda x: x.nest.to_flat(fields=fields))
2 changes: 1 addition & 1 deletion tests/dask_nested/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_dataset():
"band": randomstate.choice(["r", "g"], size=layer_size * n_base),
"index": np.arange(layer_size * n_base) % n_base,
}
layer_nf = npd.NestedFrame(data=layer_data).set_index("index")
layer_nf = npd.NestedFrame(data=layer_data).set_index("index").sort_index()

base_dn = dn.NestedFrame.from_nestedpandas(base_nf, npartitions=5)
layer_dn = dn.NestedFrame.from_nestedpandas(layer_nf, npartitions=10)
Expand Down
93 changes: 93 additions & 0 deletions tests/dask_nested/test_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pandas as pd
import pyarrow as pa
import pytest


def test_nest_accessor(test_dataset):
"""test that the nest accessor is correctly tied to columns"""

# Make sure that nested columns have the accessor available
assert hasattr(test_dataset.nested, "nest")

# Make sure we get an attribute error when trying to use the wrong column
with pytest.raises(AttributeError):
test_dataset.ra.nest


def test_fields(test_dataset):
"""test the fields accessor property"""
assert test_dataset.nested.nest.fields == ["t", "flux", "band"]


def test_to_flat(test_dataset):
"""test the to_flat function"""
flat_ztf = test_dataset.nested.nest.to_flat()

# check dtypes
assert flat_ztf.dtypes["t"] == pd.ArrowDtype(pa.float64())
assert flat_ztf.dtypes["flux"] == pd.ArrowDtype(pa.float64())
assert flat_ztf.dtypes["band"] == pd.ArrowDtype(pa.large_string())

# Make sure we retain all rows
assert len(flat_ztf.loc[1]) == 500

one_row = flat_ztf.loc[1].compute().iloc[1]
assert pytest.approx(one_row["t"], 0.01) == 5.4584
assert pytest.approx(one_row["flux"], 0.01) == 84.1573
assert one_row["band"] == "r"


def test_to_flat_with_fields(test_dataset):
"""test the to_flat function"""
flat_ztf = test_dataset.nested.nest.to_flat(fields=["t", "flux"])

# check dtypes
assert flat_ztf.dtypes["t"] == pd.ArrowDtype(pa.float64())
assert flat_ztf.dtypes["flux"] == pd.ArrowDtype(pa.float64())

# Make sure we retain all rows
assert len(flat_ztf.loc[1]) == 500

one_row = flat_ztf.loc[1].compute().iloc[1]
assert pytest.approx(one_row["t"], 0.01) == 5.4584
assert pytest.approx(one_row["flux"], 0.01) == 84.1573


def test_to_lists(test_dataset):
"""test the to_lists function"""
list_ztf = test_dataset.nested.nest.to_lists()

# check dtypes
assert list_ztf.dtypes["t"] == pd.ArrowDtype(pa.list_(pa.float64()))
assert list_ztf.dtypes["flux"] == pd.ArrowDtype(pa.list_(pa.float64()))
assert list_ztf.dtypes["band"] == pd.ArrowDtype(pa.list_(pa.large_string()))

# Make sure we have a single row for an id
assert len(list_ztf.loc[1]) == 1

# Make sure we retain all rows -- double loc for speed and pandas get_item
assert len(list_ztf.loc[1].compute().loc[1]["t"]) == 500

# spot-check values
assert pytest.approx(list_ztf.loc[1].compute().loc[1]["t"][0], 0.01) == 7.5690279
assert pytest.approx(list_ztf.loc[1].compute().loc[1]["flux"][0], 0.01) == 79.6886
assert list_ztf.loc[1].compute().loc[1]["band"][0] == "g"


def test_to_lists_with_fields(test_dataset):
"""test the to_lists function"""
list_ztf = test_dataset.nested.nest.to_lists(fields=["t", "flux"])

# check dtypes
assert list_ztf.dtypes["t"] == pd.ArrowDtype(pa.list_(pa.float64()))
assert list_ztf.dtypes["flux"] == pd.ArrowDtype(pa.list_(pa.float64()))

# Make sure we have a single row for an id
assert len(list_ztf.loc[1]) == 1

# Make sure we retain all rows -- double loc for speed and pandas get_item
assert len(list_ztf.loc[1].compute().loc[1]["t"]) == 500

# spot-check values
assert pytest.approx(list_ztf.loc[1].compute().loc[1]["t"][0], 0.01) == 7.5690279
assert pytest.approx(list_ztf.loc[1].compute().loc[1]["flux"][0], 0.01) == 79.6886

0 comments on commit f43551a

Please sign in to comment.