From da0de55275f00f94eb84831f12e7dbe66c213734 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 30 Sep 2023 13:33:34 -0400 Subject: [PATCH] Fix encoding type inference for boolean columns when pyarrow is installed (#3210) * Work around https://github.com/pandas-dev/pandas/issues/55332 * Only test boolean column with pandas >= 1.0.0 --- altair/utils/core.py | 6 ++++-- tests/utils/test_core.py | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index fa83665f5..45033572a 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -588,8 +588,10 @@ def parse_shorthand( column = dfi.get_column_by_name(unescaped_field) try: attrs["type"] = infer_vegalite_type_for_dfi_column(column) - except NotImplementedError: - # Fall back to pandas-based inference + except (NotImplementedError, AttributeError): + # Fall back to pandas-based inference. + # Note: The AttributeError catch is a workaround for + # https://github.com/pandas-dev/pandas/issues/55332 if isinstance(data, pd.DataFrame): attrs["type"] = infer_vegalite_type(data[unescaped_field]) else: diff --git a/tests/utils/test_core.py b/tests/utils/test_core.py index 75db18769..27cd3b7ee 100644 --- a/tests/utils/test_core.py +++ b/tests/utils/test_core.py @@ -1,4 +1,6 @@ import types +from packaging.version import Version +from importlib.metadata import version as importlib_version import numpy as np import pandas as pd @@ -16,6 +18,8 @@ except ImportError: pa = None +PANDAS_VERSION = Version(importlib_version("pandas")) + FAKE_CHANNELS_MODULE = f''' """Fake channels module for utility tests.""" @@ -160,6 +164,10 @@ def check(s, data, **kwargs): check("month(z)", data, timeUnit="month", field="z", type="temporal") check("month(t)", data, timeUnit="month", field="t", type="temporal") + if PANDAS_VERSION >= Version("1.0.0"): + data["b"] = pd.Series([True, False, True, False, None], dtype="boolean") + check("b", data, field="b", type="nominal") + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_parse_shorthand_for_arrow_timestamp():