From 4f78a264ef3d6bb555e5a11659bcb6a822f17886 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Fri, 15 Nov 2024 18:17:23 -0500 Subject: [PATCH 1/3] TST (string dtype): resolve xfails in interchange --- pandas/core/interchange/from_dataframe.py | 12 ++++++++---- pandas/tests/interchange/test_impl.py | 7 +------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pandas/core/interchange/from_dataframe.py b/pandas/core/interchange/from_dataframe.py index 0e5776ae8cdd9..3511506278d46 100644 --- a/pandas/core/interchange/from_dataframe.py +++ b/pandas/core/interchange/from_dataframe.py @@ -9,6 +9,8 @@ import numpy as np +from pandas._config import using_string_dtype + from pandas.compat._optional import import_optional_dependency import pandas as pd @@ -147,8 +149,6 @@ def protocol_df_chunk_to_pandas(df: DataFrameXchg) -> pd.DataFrame: ------- pd.DataFrame """ - # We need a dict of columns here, with each column being a NumPy array (at - # least for now, deal with non-NumPy dtypes later). columns: dict[str, Any] = {} buffers = [] # hold on to buffers, keeps memory alive for name in df.column_names(): @@ -347,8 +347,12 @@ def string_column_to_ndarray(col: Column) -> tuple[np.ndarray, Any]: # Add to our list of strings str_list[i] = string - # Convert the string list to a NumPy array - return np.asarray(str_list, dtype="object"), buffers + if using_string_dtype(): + res = pd.Series(str_list, dtype="str") + else: + res = np.asarray(str_list, dtype="object") + + return res, buffers def parse_datetime_format_str(format_str, data) -> pd.Series | np.ndarray: diff --git a/pandas/tests/interchange/test_impl.py b/pandas/tests/interchange/test_impl.py index 29ce9d0c03111..10824e388e073 100644 --- a/pandas/tests/interchange/test_impl.py +++ b/pandas/tests/interchange/test_impl.py @@ -6,8 +6,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs.tslibs import iNaT from pandas.compat import ( is_ci_environment, @@ -401,7 +399,6 @@ def test_interchange_from_corrected_buffer_dtypes(monkeypatch) -> None: pd.api.interchange.from_dataframe(df) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_empty_string_column(): # https://github.com/pandas-dev/pandas/issues/56703 df = pd.DataFrame({"a": []}, dtype=str) @@ -410,7 +407,6 @@ def test_empty_string_column(): tm.assert_frame_equal(df, result) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_large_string(): # GH#56702 pytest.importorskip("pyarrow") @@ -427,7 +423,6 @@ def test_non_str_names(): assert names == ["0"] -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_non_str_names_w_duplicates(): # https://github.com/pandas-dev/pandas/issues/56701 df = pd.DataFrame({"0": [1, 2, 3], 0: [4, 5, 6]}) @@ -438,7 +433,7 @@ def test_non_str_names_w_duplicates(): "Expected a Series, got a DataFrame. This likely happened because you " "called __dataframe__ on a DataFrame which, after converting column " r"names to string, resulted in duplicated names: Index\(\['0', '0'\], " - r"dtype='object'\). Please rename these columns before using the " + r"dtype='(str|object)'\). Please rename these columns before using the " "interchange protocol." ), ): From 11b226fb049a6adb9be5419c4085123a9def3cb4 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Fri, 15 Nov 2024 18:54:52 -0500 Subject: [PATCH 2/3] Hack mypy --- pandas/core/interchange/from_dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/interchange/from_dataframe.py b/pandas/core/interchange/from_dataframe.py index 3511506278d46..5c9b8ac8ea085 100644 --- a/pandas/core/interchange/from_dataframe.py +++ b/pandas/core/interchange/from_dataframe.py @@ -350,9 +350,9 @@ def string_column_to_ndarray(col: Column) -> tuple[np.ndarray, Any]: if using_string_dtype(): res = pd.Series(str_list, dtype="str") else: - res = np.asarray(str_list, dtype="object") + res = np.asarray(str_list, dtype="object") # type: ignore[assignment] - return res, buffers + return res, buffers # type: ignore[return-value] def parse_datetime_format_str(format_str, data) -> pd.Series | np.ndarray: From ad45ebc359b3c3f1d91e55500ebd236bcc820a54 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Sat, 16 Nov 2024 09:02:22 -0500 Subject: [PATCH 3/3] Fix test failure --- pandas/tests/interchange/test_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/interchange/test_impl.py b/pandas/tests/interchange/test_impl.py index 10824e388e073..b80b4b923c247 100644 --- a/pandas/tests/interchange/test_impl.py +++ b/pandas/tests/interchange/test_impl.py @@ -412,7 +412,7 @@ def test_large_string(): pytest.importorskip("pyarrow") df = pd.DataFrame({"a": ["x"]}, dtype="large_string[pyarrow]") result = pd.api.interchange.from_dataframe(df.__dataframe__()) - expected = pd.DataFrame({"a": ["x"]}, dtype="object") + expected = pd.DataFrame({"a": ["x"]}, dtype="str") tm.assert_frame_equal(result, expected)