From ee6ac291e152155de737d501e127ec9adfa568e7 Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Mon, 23 Dec 2024 13:50:11 +0530 Subject: [PATCH] Feat/add collect schema to interchange dfs (#1646) --- narwhals/_duckdb/dataframe.py | 8 ++++++++ narwhals/_ibis/dataframe.py | 6 ++++++ tests/frame/interchange_schema_test.py | 2 ++ 3 files changed, 16 insertions(+) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index fe4a5856e..339fca137 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -151,3 +151,11 @@ def _change_version(self: Self, version: Version) -> Self: def _from_native_frame(self: Self, df: Any) -> Self: return self.__class__(df, version=self._version) + + def collect_schema(self) -> dict[str, DType]: + return { + column_name: native_to_narwhals_dtype(str(duckdb_dtype), self._version) + for column_name, duckdb_dtype in zip( + self._native_frame.columns, self._native_frame.types + ) + } diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index fe9bb0349..f62a31e8b 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -129,3 +129,9 @@ def _change_version(self: Self, version: Version) -> Self: def _from_native_frame(self: Self, df: Any) -> Self: return self.__class__(df, version=self._version) + + def collect_schema(self) -> dict[str, DType]: + return { + column_name: native_to_narwhals_dtype(ibis_dtype, self._version) + for column_name, ibis_dtype in self._native_frame.schema().items() + } diff --git a/tests/frame/interchange_schema_test.py b/tests/frame/interchange_schema_test.py index 588c92597..e06a482db 100644 --- a/tests/frame/interchange_schema_test.py +++ b/tests/frame/interchange_schema_test.py @@ -156,6 +156,7 @@ def test_interchange_schema_ibis( assert result == expected assert df["a"].dtype == nw.Int64 assert df.columns == list(expected.keys()) + assert df.collect_schema() == expected def test_interchange_schema_duckdb() -> None: @@ -221,6 +222,7 @@ def test_interchange_schema_duckdb() -> None: assert result == expected assert df["a"].dtype == nw.Int64 assert df.columns == list(expected.keys()) + assert df.collect_schema() == expected def test_invalid() -> None: