diff --git a/src/nested_pandas/nestedframe/core.py b/src/nested_pandas/nestedframe/core.py index c7f2b38..3f28a5e 100644 --- a/src/nested_pandas/nestedframe/core.py +++ b/src/nested_pandas/nestedframe/core.py @@ -210,6 +210,10 @@ def from_flat(cls, df, base_columns, nested_columns=None, index=None, name="nest # drop duplicates on index out_df = df[base_columns][~df.index.duplicated(keep="first")] + # Convert df to NestedFrame if needed + if not isinstance(out_df, NestedFrame): + out_df = NestedFrame(out_df) + # add nested if nested_columns is None: nested_columns = [col for col in df.columns if col not in base_columns] diff --git a/tests/nested_pandas/nestedframe/test_nestedframe.py b/tests/nested_pandas/nestedframe/test_nestedframe.py index 097ce1b..95421b3 100644 --- a/tests/nested_pandas/nestedframe/test_nestedframe.py +++ b/tests/nested_pandas/nestedframe/test_nestedframe.py @@ -285,13 +285,21 @@ def test_add_nested_for_empty_df(): assert_frame_equal(new_base.nested.nest.to_flat(), nested.astype(pd.ArrowDtype(pa.float64()))) +@pytest.mark.parametrize("pandas", [False, True]) @pytest.mark.parametrize("index", [None, "a", "c"]) -def test_from_flat(index): +def test_from_flat(index, pandas): """Test the NestedFrame.from_flat functionality""" - nf = NestedFrame( - {"a": [1, 1, 1, 2, 2], "b": [2, 2, 2, 4, 4], "c": [1, 2, 3, 4, 5], "d": [2, 4, 6, 8, 10]}, - index=[0, 0, 0, 1, 1], - ) + + if pandas: + nf = pd.DataFrame( + {"a": [1, 1, 1, 2, 2], "b": [2, 2, 2, 4, 4], "c": [1, 2, 3, 4, 5], "d": [2, 4, 6, 8, 10]}, + index=[0, 0, 0, 1, 1], + ) + else: + nf = NestedFrame( + {"a": [1, 1, 1, 2, 2], "b": [2, 2, 2, 4, 4], "c": [1, 2, 3, 4, 5], "d": [2, 4, 6, 8, 10]}, + index=[0, 0, 0, 1, 1], + ) out_nf = NestedFrame.from_flat(nf, base_columns=["a", "b"], index=index, name="new_nested")