diff --git a/src/nested_dask/core.py b/src/nested_dask/core.py index 886700c..df4f032 100644 --- a/src/nested_dask/core.py +++ b/src/nested_dask/core.py @@ -46,6 +46,18 @@ def _rebuild(self, graph, func, args): # type: ignore return collection +def _nested_meta_from_flat(flat, name): + """construct meta for a packed series from a flat dataframe""" + pd_fields = flat.dtypes.to_dict() # grabbing pandas dtypes + pyarrow_fields = {} # grab underlying pyarrow dtypes + for field, dtype in pd_fields.items(): + if hasattr(dtype, "pyarrow_dtype"): + pyarrow_fields[field] = dtype.pyarrow_dtype + else: # or convert from numpy types + pyarrow_fields[field] = pa.from_numpy_dtype(dtype) + return pd.Series(name=name, dtype=NestedDtype.from_fields(pyarrow_fields)) + + class NestedFrame( _Frame, dd.DataFrame ): # can use dd.DataFrame instead of dx.DataFrame if the config is set true (default in >=2024.3.0) @@ -70,17 +82,6 @@ def __getitem__(self, item): else: return super().__getitem__(item) - def _nested_meta_from_flat(self, flat, name): - """construct meta for a packed series from a flat dataframe""" - pd_fields = flat.dtypes.to_dict() # grabbing pandas dtypes - pyarrow_fields = {} # grab underlying pyarrow dtypes - for field, dtype in pd_fields.items(): - if hasattr(dtype, "pyarrow_dtype"): - pyarrow_fields[field] = dtype.pyarrow_dtype - else: # or convert from numpy types - pyarrow_fields[field] = pa.from_numpy_dtype(dtype) - return pd.Series(name=name, dtype=NestedDtype.from_fields(pyarrow_fields)) - def __setitem__(self, key, value): """Adds custom __setitem__ behavior for nested columns""" @@ -102,7 +103,7 @@ def __setitem__(self, key, value): new_flat = new_flat.astype({col: pd.ArrowDtype(pa.string())}) # pack the modified df back into a nested column - meta = self._nested_meta_from_flat(new_flat, nested) + meta = _nested_meta_from_flat(new_flat, nested) packed = new_flat.map_partitions(lambda x: pack(x), meta=meta) return super().__setitem__(nested, packed) @@ -114,7 +115,7 @@ def __setitem__(self, key, value): value.name = col value = value.to_frame() - meta = self._nested_meta_from_flat(value, new_nested) + meta = _nested_meta_from_flat(value, new_nested) packed = value.map_partitions(lambda x: pack(x), meta=meta) return super().__setitem__(new_nested, packed) @@ -280,6 +281,59 @@ def from_map( ) return NestedFrame.from_dask_dataframe(nf) + @classmethod + def from_flat(cls, df, base_columns, nested_columns=None, index=None, name="nested"): + """Creates a NestedFrame with base and nested columns from a flat + dataframe. + + Parameters + ---------- + df: pd.DataFrame or NestedFrame + A flat dataframe. + base_columns: list-like + The columns that should be used as base (flat) columns in the + output dataframe. + nested_columns: list-like, or None + The columns that should be packed into a nested column. All columns + in the list will attempt to be packed into a single nested column + with the name provided in `nested_name`. If None, is defined as all + columns not in `base_columns`. + index: str, or None + The name of a column to use as the new index. Typically, the index + should have a unique value per row for base columns, and should + repeat for nested columns. For example, a dataframe with two + columns; a=[1,1,1,2,2,2] and b=[5,10,15,20,25,30] would want an + index like [0,0,0,1,1,1] if a is chosen as a base column. If not + provided the current index will be used. + name: + The name of the output column the `nested_columns` are packed into. + + Returns + ------- + NestedFrame + A NestedFrame with the specified nesting structure. + """ + # Handle the meta + # Pathway 1: Some base columns and one nested column -> nestedframe + # Pathway 2: Only a single nested column -> nestedframe as defined in npd + # Pathway 3: Only a set of base columns, technically possible -> nestedframe + + if nested_columns is None: + nested_columns = [col for col in df.columns if (col not in base_columns) and col != index] + + meta = npd.NestedFrame(df[base_columns]._meta) + + if len(nested_columns) > 0: + nested_meta = _nested_meta_from_flat(df[nested_columns], name) + meta = meta.join(nested_meta) + + return df.map_partitions( + lambda x: npd.NestedFrame.from_flat( + df=x, base_columns=base_columns, nested_columns=nested_columns, index=index, name=name + ), + meta=meta, + ) + def compute(self, **kwargs): """Compute this Dask collection, returning the underlying dataframe or series.""" return npd.NestedFrame(super().compute(**kwargs)) diff --git a/tests/nested_dask/test_nestedframe.py b/tests/nested_dask/test_nestedframe.py index 5dead59..0546a08 100644 --- a/tests/nested_dask/test_nestedframe.py +++ b/tests/nested_dask/test_nestedframe.py @@ -108,6 +108,76 @@ def test_add_nested(test_dataset_no_add_nested): assert len(base_with_nested.compute()) == 50 +def test_from_flat(): + """Test the from_flat wrapping, make sure meta is assigned correctly""" + + nf = nd.NestedFrame.from_pandas( + npd.NestedFrame( + { + "a": [1, 1, 1, 2, 2, 2], + "b": [2, 2, 2, 4, 4, 4], + "c": [1, 2, 3, 4, 5, 6], + "d": [2, 4, 6, 8, 10, 12], + }, + index=[0, 0, 0, 1, 1, 1], + ) + ) + + # Check full inputs + ndf = nd.NestedFrame.from_flat(nf, base_columns=["a", "b"], nested_columns=["c", "d"]) + assert list(ndf.columns) == ["a", "b", "nested"] + assert list(ndf["nested"].nest.fields) == ["c", "d"] + ndf_comp = ndf.compute() + assert list(ndf.columns) == list(ndf_comp.columns) + assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields) + assert len(ndf_comp) == 2 + + # Check omitting a base column + ndf = nd.NestedFrame.from_flat(nf, base_columns=["a"], nested_columns=["c", "d"]) + assert list(ndf.columns) == ["a", "nested"] + assert list(ndf["nested"].nest.fields) == ["c", "d"] + ndf_comp = ndf.compute() + assert list(ndf.columns) == list(ndf_comp.columns) + assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields) + assert len(ndf_comp) == 2 + + # Check omitting a nested column + ndf = nd.NestedFrame.from_flat(nf, base_columns=["a", "b"], nested_columns=["d"]) + assert list(ndf.columns) == ["a", "b", "nested"] + assert list(ndf["nested"].nest.fields) == ["d"] + ndf_comp = ndf.compute() + assert list(ndf.columns) == list(ndf_comp.columns) + assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields) + assert len(ndf_comp) == 2 + + # Check no base columns + ndf = nd.NestedFrame.from_flat(nf, base_columns=[], nested_columns=["c", "d"]) + assert list(ndf.columns) == ["nested"] + assert list(ndf["nested"].nest.fields) == ["c", "d"] + ndf_comp = ndf.compute() + assert list(ndf.columns) == list(ndf_comp.columns) + assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields) + assert len(ndf_comp) == 2 + + # Check inferred nested columns + ndf = nd.NestedFrame.from_flat(nf, base_columns=["a", "b"]) + assert list(ndf.columns) == ["a", "b", "nested"] + assert list(ndf["nested"].nest.fields) == ["c", "d"] + ndf_comp = ndf.compute() + assert list(ndf.columns) == list(ndf_comp.columns) + assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields) + assert len(ndf_comp) == 2 + + # Check using an index + ndf = nd.NestedFrame.from_flat(nf, base_columns=["b"], index="a") + assert list(ndf.columns) == ["b", "nested"] + assert list(ndf["nested"].nest.fields) == ["c", "d"] + ndf_comp = ndf.compute() + assert list(ndf.columns) == list(ndf_comp.columns) + assert list(ndf["nested"].nest.fields) == list(ndf["nested"].nest.fields) + assert len(ndf_comp) == 2 + + def test_query_on_base(test_dataset): """test the query function on base columns"""