Skip to content

Commit

Permalink
Merge branch 'main' into jonmmease/new_show
Browse files Browse the repository at this point in the history
  • Loading branch information
mattijn authored Mar 24, 2024
2 parents 3783dfc + c7c4149 commit 9dfb9c6
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 33 deletions.
2 changes: 2 additions & 0 deletions altair/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
infer_vegalite_type,
infer_encoding_types,
sanitize_dataframe,
sanitize_arrow_table,
parse_shorthand,
use_signature,
update_nested,
Expand All @@ -18,6 +19,7 @@
"infer_vegalite_type",
"infer_encoding_types",
"sanitize_dataframe",
"sanitize_arrow_table",
"spec_to_html",
"parse_shorthand",
"use_signature",
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/_vegafusion_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def vegafusion_data_transformer(
# Use default transformer for geo interface objects
# # (e.g. a geopandas GeoDataFrame)
return default_data_transformer(data)
elif hasattr(data, "__dataframe__"):
elif isinstance(data, DataFrameLike):
table_name = f"table_{uuid.uuid4()}".replace("-", "_")
extracted_inline_tables[table_name] = data
return {"url": VEGAFUSION_PREFIX + table_name}
Expand Down
13 changes: 7 additions & 6 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
else:
from typing_extensions import ParamSpec

from typing import Literal, Protocol, TYPE_CHECKING
from typing import Literal, Protocol, TYPE_CHECKING, runtime_checkable

if TYPE_CHECKING:
from pandas.core.interchange.dataframe_protocol import Column as PandasColumn
Expand All @@ -46,6 +46,7 @@
P = ParamSpec("P")


@runtime_checkable
class DataFrameLike(Protocol):
def __dataframe__(
self, nan_as_null: bool = False, allow_copy: bool = True
Expand Down Expand Up @@ -429,15 +430,15 @@ def sanitize_arrow_table(pa_table):
schema = pa_table.schema
for name in schema.names:
array = pa_table[name]
dtype = schema.field(name).type
if str(dtype).startswith("timestamp"):
dtype_name = str(schema.field(name).type)
if dtype_name.startswith("timestamp") or dtype_name.startswith("date"):
arrays.append(pc.strftime(array))
elif str(dtype).startswith("duration"):
elif dtype_name.startswith("duration"):
raise ValueError(
'Field "{col_name}" has type "{dtype}" which is '
"not supported by Altair. Please convert to "
"either a timestamp or a numerical value."
"".format(col_name=name, dtype=dtype)
"".format(col_name=name, dtype=dtype_name)
)
else:
arrays.append(array)
Expand Down Expand Up @@ -588,7 +589,7 @@ def parse_shorthand(

# if data is specified and type is not, infer type from data
if "type" not in attrs:
if pyarrow_available() and data is not None and hasattr(data, "__dataframe__"):
if pyarrow_available() and data is not None and isinstance(data, DataFrameLike):
dfi = data.__dataframe__()
if "field" in attrs:
unescaped_field = attrs["field"].replace("\\", "")
Expand Down
54 changes: 33 additions & 21 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,12 @@ def raise_max_rows_error():
# mypy gets confused as it doesn't see Dict[Any, Any]
# as equivalent to TDataType
return data # type: ignore[return-value]
elif hasattr(data, "__dataframe__"):
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
elif isinstance(data, DataFrameLike):
pa_table = arrow_table_from_dfi_dataframe(data)
if max_rows is not None and pa_table.num_rows > max_rows:
raise_max_rows_error()
# Return pyarrow Table instead of input since the
# `from_dataframe` call may be expensive
# `arrow_table_from_dfi_dataframe` call above may be expensive
return pa_table

if max_rows is not None and len(values) > max_rows:
Expand Down Expand Up @@ -142,10 +141,8 @@ def sample(
else:
# Maybe this should raise an error or return something useful?
return None
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
elif isinstance(data, DataFrameLike):
pa_table = arrow_table_from_dfi_dataframe(data)
if not n:
if frac is None:
raise ValueError(
Expand Down Expand Up @@ -232,19 +229,17 @@ def to_values(data: DataType) -> ToValuesReturnType:
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return data
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = sanitize_arrow_table(pi.from_dataframe(data))
elif isinstance(data, DataFrameLike):
pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data))
return {"values": pa_table.to_pylist()}
else:
# Should never reach this state as tested by check_data_type
raise ValueError("Unrecognized data type: {}".format(type(data)))


def check_data_type(data: DataType) -> None:
if not isinstance(data, (dict, pd.DataFrame)) and not any(
hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"]
if not isinstance(data, (dict, pd.DataFrame, DataFrameLike)) and not any(
hasattr(data, attr) for attr in ["__geo_interface__"]
):
raise TypeError(
"Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format(
Expand Down Expand Up @@ -277,10 +272,8 @@ def _data_to_json_string(data: DataType) -> str:
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return json.dumps(data["values"], sort_keys=True)
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
elif isinstance(data, DataFrameLike):
pa_table = arrow_table_from_dfi_dataframe(data)
return json.dumps(pa_table.to_pylist())
else:
raise NotImplementedError(
Expand All @@ -303,13 +296,12 @@ def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str:
if "values" not in data:
raise KeyError("values expected in data dict, but not present")
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
elif hasattr(data, "__dataframe__"):
elif isinstance(data, DataFrameLike):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
import pyarrow as pa
import pyarrow.csv as pa_csv

pa_table = pi.from_dataframe(data)
pa_table = arrow_table_from_dfi_dataframe(data)
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(pa_table, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
Expand Down Expand Up @@ -346,3 +338,23 @@ def curry(*args, **kwargs):
stacklevel=1,
)
return curried.curry(*args, **kwargs)


def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> "pyarrow.lib.Table":
"""Convert a DataFrame Interchange Protocol compatible object to an Arrow Table"""
import pyarrow as pa

# First check if the dataframe object has a method to convert to arrow.
# Give this preference over the pyarrow from_dataframe function since the object
# has more control over the conversion, and may have broader compatibility.
# This is the case for Polars, which supports Date32 columns in direct conversion
# while pyarrow does not yet support this type in from_dataframe
for convert_method_name in ("arrow", "to_arrow", "to_arrow_table"):
convert_method = getattr(dfi_df, convert_method_name, None)
if callable(convert_method):
result = convert_method()
if isinstance(result, pa.Table):
return result

pi = import_pyarrow_interchange()
return pi.from_dataframe(dfi_df)
4 changes: 2 additions & 2 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _dataset_name(values: Union[dict, list, core.InlineDataset]) -> str:
values = values.to_dict()
if values == [{}]:
return "empty"
values_json = json.dumps(values, sort_keys=True)
values_json = json.dumps(values, sort_keys=True, default=str)
hsh = hashlib.sha256(values_json.encode()).hexdigest()[:32]
return "data-" + hsh

Expand Down Expand Up @@ -115,7 +115,7 @@ def _prepare_data(data, context=None):
elif isinstance(data, str):
data = core.UrlData(data)

elif hasattr(data, "__dataframe__"):
elif isinstance(data, DataFrameLike):
data = _pipe(data, data_transformers.get())

# consolidate inline data to top-level datasets
Expand Down
1 change: 1 addition & 0 deletions doc/releases/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Bug Fixes
- Fix error when embed_options are None (#3376)
- Fix type hints for libraries such as Polars where Altair uses the dataframe interchange protocol (#3297)
- Fix anywidget deprecation warning (#3364)
- Fix handling of Date32 columns in arrow tables and Polars DataFrames (#3377)

Backward-Incompatible Changes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ all = [
"vega_datasets>=0.9.0",
"vl-convert-python>=1.3.0",
"pyarrow>=11",
"vegafusion[embed]>=1.5.0",
"vegafusion[embed]>=1.6.6",
"anywidget>=0.9.0",
"altair_tiles>=0.3.0"
]
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_mimebundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def check_pre_transformed_vega_spec(vega_spec):

# Check that the bin transform has been applied
row0 = data_0["values"][0]
assert row0 == {"a": "A", "b": 28, "b_end": 28.0, "b_start": 0.0}
assert row0 == {"a": "A", "b_end": 28.0, "b_start": 0.0}

# And no transforms remain
assert len(data_0.get("transform", [])) == 0
Expand Down
49 changes: 48 additions & 1 deletion tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import pytest

from altair.utils import infer_vegalite_type, sanitize_dataframe
from altair.utils import infer_vegalite_type, sanitize_dataframe, sanitize_arrow_table

try:
import pyarrow as pa
Expand Down Expand Up @@ -120,6 +120,53 @@ def test_sanitize_dataframe_arrow_columns():
json.dumps(records)


@pytest.mark.skipif(pa is None, reason="pyarrow not installed")
def test_sanitize_pyarrow_table_columns():
# create a dataframe with various types
df = pd.DataFrame(
{
"s": list("abcde"),
"f": np.arange(5, dtype=float),
"i": np.arange(5, dtype=int),
"b": np.array([True, False, True, True, False]),
"d": pd.date_range("2012-01-01", periods=5, freq="H"),
"c": pd.Series(list("ababc"), dtype="category"),
"p": pd.date_range("2012-01-01", periods=5, freq="H").tz_localize("UTC"),
}
)

# Create pyarrow table with explicit schema so that date32 type is preserved
pa_table = pa.Table.from_pandas(
df,
pa.schema(
[
pa.field("s", pa.string()),
pa.field("f", pa.float64()),
pa.field("i", pa.int64()),
pa.field("b", pa.bool_()),
pa.field("d", pa.date32()),
pa.field("c", pa.dictionary(pa.int8(), pa.string())),
pa.field("p", pa.timestamp("ns", tz="UTC")),
]
),
)
sanitized = sanitize_arrow_table(pa_table)
values = sanitized.to_pylist()

assert values[0] == {
"s": "a",
"f": 0.0,
"i": 0,
"b": True,
"d": "2012-01-01T00:00:00",
"c": "a",
"p": "2012-01-01T00:00:00.000000000",
}

# Make sure we can serialize to JSON without error
json.dumps(values)


def test_sanitize_dataframe_colnames():
df = pd.DataFrame(np.arange(12).reshape(4, 3))

Expand Down

0 comments on commit 9dfb9c6

Please sign in to comment.