diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index 0bd8ec5e3..dba1e1f81 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -2,6 +2,7 @@ infer_vegalite_type, infer_encoding_types, sanitize_dataframe, + sanitize_arrow_table, parse_shorthand, use_signature, update_nested, @@ -18,6 +19,7 @@ "infer_vegalite_type", "infer_encoding_types", "sanitize_dataframe", + "sanitize_arrow_table", "spec_to_html", "parse_shorthand", "use_signature", diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 8b46bab78..ce30e8d6d 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -45,7 +45,7 @@ def vegafusion_data_transformer( # Use default transformer for geo interface objects # # (e.g. a geopandas GeoDataFrame) return default_data_transformer(data) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): table_name = f"table_{uuid.uuid4()}".replace("-", "_") extracted_inline_tables[table_name] = data return {"url": VEGAFUSION_PREFIX + table_name} diff --git a/altair/utils/core.py b/altair/utils/core.py index a382ac787..baf1013f7 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -37,7 +37,7 @@ else: from typing_extensions import ParamSpec -from typing import Literal, Protocol, TYPE_CHECKING +from typing import Literal, Protocol, TYPE_CHECKING, runtime_checkable if TYPE_CHECKING: from pandas.core.interchange.dataframe_protocol import Column as PandasColumn @@ -46,6 +46,7 @@ P = ParamSpec("P") +@runtime_checkable class DataFrameLike(Protocol): def __dataframe__( self, nan_as_null: bool = False, allow_copy: bool = True @@ -429,15 +430,15 @@ def sanitize_arrow_table(pa_table): schema = pa_table.schema for name in schema.names: array = pa_table[name] - dtype = schema.field(name).type - if str(dtype).startswith("timestamp"): + dtype_name = str(schema.field(name).type) + if dtype_name.startswith("timestamp") or dtype_name.startswith("date"): arrays.append(pc.strftime(array)) - elif str(dtype).startswith("duration"): + elif dtype_name.startswith("duration"): raise ValueError( 'Field "{col_name}" has type "{dtype}" which is ' "not supported by Altair. Please convert to " "either a timestamp or a numerical value." - "".format(col_name=name, dtype=dtype) + "".format(col_name=name, dtype=dtype_name) ) else: arrays.append(array) @@ -588,7 +589,7 @@ def parse_shorthand( # if data is specified and type is not, infer type from data if "type" not in attrs: - if pyarrow_available() and data is not None and hasattr(data, "__dataframe__"): + if pyarrow_available() and data is not None and isinstance(data, DataFrameLike): dfi = data.__dataframe__() if "field" in attrs: unescaped_field = attrs["field"].replace("\\", "") diff --git a/altair/utils/data.py b/altair/utils/data.py index 7fba7adaa..871b43092 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -105,13 +105,12 @@ def raise_max_rows_error(): # mypy gets confused as it doesn't see Dict[Any, Any] # as equivalent to TDataType return data # type: ignore[return-value] - elif hasattr(data, "__dataframe__"): - pi = import_pyarrow_interchange() - pa_table = pi.from_dataframe(data) + elif isinstance(data, DataFrameLike): + pa_table = arrow_table_from_dfi_dataframe(data) if max_rows is not None and pa_table.num_rows > max_rows: raise_max_rows_error() # Return pyarrow Table instead of input since the - # `from_dataframe` call may be expensive + # `arrow_table_from_dfi_dataframe` call above may be expensive return pa_table if max_rows is not None and len(values) > max_rows: @@ -142,10 +141,8 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - elif hasattr(data, "__dataframe__"): - # experimental interchange dataframe support - pi = import_pyarrow_interchange() - pa_table = pi.from_dataframe(data) + elif isinstance(data, DataFrameLike): + pa_table = arrow_table_from_dfi_dataframe(data) if not n: if frac is None: raise ValueError( @@ -232,10 +229,8 @@ def to_values(data: DataType) -> ToValuesReturnType: if "values" not in data: raise KeyError("values expected in data dict, but not present.") return data - elif hasattr(data, "__dataframe__"): - # experimental interchange dataframe support - pi = import_pyarrow_interchange() - pa_table = sanitize_arrow_table(pi.from_dataframe(data)) + elif isinstance(data, DataFrameLike): + pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data)) return {"values": pa_table.to_pylist()} else: # Should never reach this state as tested by check_data_type @@ -243,8 +238,8 @@ def to_values(data: DataType) -> ToValuesReturnType: def check_data_type(data: DataType) -> None: - if not isinstance(data, (dict, pd.DataFrame)) and not any( - hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] + if not isinstance(data, (dict, pd.DataFrame, DataFrameLike)) and not any( + hasattr(data, attr) for attr in ["__geo_interface__"] ): raise TypeError( "Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format( @@ -277,10 +272,8 @@ def _data_to_json_string(data: DataType) -> str: if "values" not in data: raise KeyError("values expected in data dict, but not present.") return json.dumps(data["values"], sort_keys=True) - elif hasattr(data, "__dataframe__"): - # experimental interchange dataframe support - pi = import_pyarrow_interchange() - pa_table = pi.from_dataframe(data) + elif isinstance(data, DataFrameLike): + pa_table = arrow_table_from_dfi_dataframe(data) return json.dumps(pa_table.to_pylist()) else: raise NotImplementedError( @@ -303,13 +296,12 @@ def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str: if "values" not in data: raise KeyError("values expected in data dict, but not present") return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): # experimental interchange dataframe support - pi = import_pyarrow_interchange() import pyarrow as pa import pyarrow.csv as pa_csv - pa_table = pi.from_dataframe(data) + pa_table = arrow_table_from_dfi_dataframe(data) csv_buffer = pa.BufferOutputStream() pa_csv.write_csv(pa_table, csv_buffer) return csv_buffer.getvalue().to_pybytes().decode() @@ -346,3 +338,23 @@ def curry(*args, **kwargs): stacklevel=1, ) return curried.curry(*args, **kwargs) + + +def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> "pyarrow.lib.Table": + """Convert a DataFrame Interchange Protocol compatible object to an Arrow Table""" + import pyarrow as pa + + # First check if the dataframe object has a method to convert to arrow. + # Give this preference over the pyarrow from_dataframe function since the object + # has more control over the conversion, and may have broader compatibility. + # This is the case for Polars, which supports Date32 columns in direct conversion + # while pyarrow does not yet support this type in from_dataframe + for convert_method_name in ("arrow", "to_arrow", "to_arrow_table"): + convert_method = getattr(dfi_df, convert_method_name, None) + if callable(convert_method): + result = convert_method() + if isinstance(result, pa.Table): + return result + + pi = import_pyarrow_interchange() + return pi.from_dataframe(dfi_df) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index c697fccc1..647ee3fea 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -57,7 +57,7 @@ def _dataset_name(values: Union[dict, list, core.InlineDataset]) -> str: values = values.to_dict() if values == [{}]: return "empty" - values_json = json.dumps(values, sort_keys=True) + values_json = json.dumps(values, sort_keys=True, default=str) hsh = hashlib.sha256(values_json.encode()).hexdigest()[:32] return "data-" + hsh @@ -115,7 +115,7 @@ def _prepare_data(data, context=None): elif isinstance(data, str): data = core.UrlData(data) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): data = _pipe(data, data_transformers.get()) # consolidate inline data to top-level datasets diff --git a/doc/releases/changes.rst b/doc/releases/changes.rst index 816901550..32b9f25c7 100644 --- a/doc/releases/changes.rst +++ b/doc/releases/changes.rst @@ -29,6 +29,7 @@ Bug Fixes - Fix error when embed_options are None (#3376) - Fix type hints for libraries such as Polars where Altair uses the dataframe interchange protocol (#3297) - Fix anywidget deprecation warning (#3364) +- Fix handling of Date32 columns in arrow tables and Polars DataFrames (#3377) Backward-Incompatible Changes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index 7cc710894..f03b1b721 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ all = [ "vega_datasets>=0.9.0", "vl-convert-python>=1.3.0", "pyarrow>=11", - "vegafusion[embed]>=1.5.0", + "vegafusion[embed]>=1.6.6", "anywidget>=0.9.0", "altair_tiles>=0.3.0" ] diff --git a/tests/utils/test_mimebundle.py b/tests/utils/test_mimebundle.py index 541ac483f..97c353c56 100644 --- a/tests/utils/test_mimebundle.py +++ b/tests/utils/test_mimebundle.py @@ -241,7 +241,7 @@ def check_pre_transformed_vega_spec(vega_spec): # Check that the bin transform has been applied row0 = data_0["values"][0] - assert row0 == {"a": "A", "b": 28, "b_end": 28.0, "b_start": 0.0} + assert row0 == {"a": "A", "b_end": 28.0, "b_start": 0.0} # And no transforms remain assert len(data_0.get("transform", [])) == 0 diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index c0334533a..9cf5bda37 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from altair.utils import infer_vegalite_type, sanitize_dataframe +from altair.utils import infer_vegalite_type, sanitize_dataframe, sanitize_arrow_table try: import pyarrow as pa @@ -120,6 +120,53 @@ def test_sanitize_dataframe_arrow_columns(): json.dumps(records) +@pytest.mark.skipif(pa is None, reason="pyarrow not installed") +def test_sanitize_pyarrow_table_columns(): + # create a dataframe with various types + df = pd.DataFrame( + { + "s": list("abcde"), + "f": np.arange(5, dtype=float), + "i": np.arange(5, dtype=int), + "b": np.array([True, False, True, True, False]), + "d": pd.date_range("2012-01-01", periods=5, freq="H"), + "c": pd.Series(list("ababc"), dtype="category"), + "p": pd.date_range("2012-01-01", periods=5, freq="H").tz_localize("UTC"), + } + ) + + # Create pyarrow table with explicit schema so that date32 type is preserved + pa_table = pa.Table.from_pandas( + df, + pa.schema( + [ + pa.field("s", pa.string()), + pa.field("f", pa.float64()), + pa.field("i", pa.int64()), + pa.field("b", pa.bool_()), + pa.field("d", pa.date32()), + pa.field("c", pa.dictionary(pa.int8(), pa.string())), + pa.field("p", pa.timestamp("ns", tz="UTC")), + ] + ), + ) + sanitized = sanitize_arrow_table(pa_table) + values = sanitized.to_pylist() + + assert values[0] == { + "s": "a", + "f": 0.0, + "i": 0, + "b": True, + "d": "2012-01-01T00:00:00", + "c": "a", + "p": "2012-01-01T00:00:00.000000000", + } + + # Make sure we can serialize to JSON without error + json.dumps(values) + + def test_sanitize_dataframe_colnames(): df = pd.DataFrame(np.arange(12).reshape(4, 3))