From e1290d4384d4926c24f22a3a23f103e284cfbe1e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 11 Nov 2024 13:50:54 +0000 Subject: [PATCH] refactor: Tidy up changes from last 4 commits - Rename and properly document "file-like object" handling - Also made a bit clearer what is being called and when - Use a more granular approach to skipping in `@backends` - Previously, everything was skipped regardless of whether it required `pyarrow` - Now, `polars`, `pandas` **always** run - with `pandas` expected to fail - I had to clean up `skip_requires_pyarrow` to make it compatible with `pytest.param` - It has a runtime check for if `MarkDecorator`, instead of just a callable https://github.com/vega/altair/pull/3631/commits/bb7bc171a7005fd63f39b3d949902f4d553801f0, https://github.com/vega/altair/pull/3631/commits/ebc1bfaa0b35e554da15bab7dd7d7e2a95f17e63, https://github.com/vega/altair/pull/3631/commits/fe0ae88201cc699b32ee1e9c07b602d9d7a8d439, https://github.com/vega/altair/pull/3631/commits/7089f2af693c6db2025ee265f31ec4ef228dd8c3 --- altair/datasets/_readers.py | 33 ++++++++++++++++++++++----------- tests/__init__.py | 31 +++++++++++++++++++------------ tests/test_datasets.py | 26 ++++++++++++++++++++++---- 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/altair/datasets/_readers.py b/altair/datasets/_readers.py index eea9f18db..a3435d231 100644 --- a/altair/datasets/_readers.py +++ b/altair/datasets/_readers.py @@ -12,6 +12,7 @@ import os import urllib.request from functools import partial +from http.client import HTTPResponse from importlib import import_module from importlib.util import find_spec from itertools import chain, islice @@ -76,6 +77,10 @@ __all__ = ["get_backend"] +def _identity(_: _T, /) -> _T: + return _ + + class _Reader(Generic[IntoDataFrameT, IntoFrameT], Protocol): """ Common functionality between backends. @@ -88,6 +93,18 @@ class _Reader(Generic[IntoDataFrameT, IntoFrameT], Protocol): _name: LiteralString _ENV_VAR: ClassVar[LiteralString] = "ALTAIR_DATASETS_DIR" _opener: ClassVar[OpenerDirector] = urllib.request.build_opener() + _response: ClassVar[staticmethod[[HTTPResponse], Any]] = staticmethod(_identity) + """ + Backends that do not support `file-like objects`_, must override with conversion. + + Used only for **remote** files, as *cached* files use a `pathlib.Path`_. + + .. _file-like objects: + https://docs.python.org/3/glossary.html#term-file-object + .. _pathlib.Path: + https://docs.python.org/3/library/pathlib.html#pathlib.Path + """ + _metadata: Path = Path(__file__).parent / "_metadata" / "metadata.parquet" def read_fn(self, source: StrPath, /) -> Callable[..., IntoDataFrameT]: @@ -98,10 +115,6 @@ def scan_fn(self, source: StrPath, /) -> Callable[..., IntoFrameT]: suffix = validate_suffix(source, is_ext_scan) return self._scan_fn[suffix] - def _response_hook(self, f): - # HACK: `pyarrow` + `pandas` wants the file obj - return f - def dataset( self, name: DatasetName | LiteralString, @@ -137,7 +150,7 @@ def dataset( return fn(fp, **kwds) else: with self._opener.open(url) as f: - return fn(self._response_hook(f), **kwds) + return fn(self._response(f), **kwds) def url( self, @@ -261,6 +274,8 @@ def __init__(self, name: Literal["pandas[pyarrow]"], /) -> None: class _PolarsReader(_Reader["pl.DataFrame", "pl.LazyFrame"]): + _response = staticmethod(HTTPResponse.read) + def __init__(self, name: _Polars, /) -> None: self._name = _requirements(name) if not TYPE_CHECKING: @@ -273,11 +288,10 @@ def __init__(self, name: _Polars, /) -> None: } self._scan_fn = {".parquet": pl.scan_parquet} - def _response_hook(self, f): - return f.read() - class _PolarsPyArrowReader(_Reader["pl.DataFrame", "pl.LazyFrame"]): + _response = staticmethod(HTTPResponse.read) + def __init__(self, name: Literal["polars[pyarrow]"], /) -> None: _pl, _pa = _requirements(name) self._name = name @@ -292,9 +306,6 @@ def __init__(self, name: Literal["polars[pyarrow]"], /) -> None: } self._scan_fn = {".parquet": pl.scan_parquet} - def _response_hook(self, f): - return f.read() - class _PyArrowReader(_Reader["pa.Table", "pa.Table"]): """ diff --git a/tests/__init__.py b/tests/__init__.py index 617cfca80..17a33e91e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,14 +5,14 @@ import sys from importlib.util import find_spec from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, overload import pytest from tests import examples_arguments_syntax, examples_methods_syntax if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterator, Mapping + from collections.abc import Collection, Iterator, Mapping from re import Pattern if sys.version_info >= (3, 11): @@ -20,6 +20,7 @@ else: from typing_extensions import TypeAlias from _pytest.mark import ParameterSet + from _pytest.mark.structures import Markable MarksType: TypeAlias = ( "pytest.MarkDecorator | Collection[pytest.MarkDecorator | pytest.Mark]" @@ -96,9 +97,21 @@ def windows_has_tzdata() -> bool: """ +@overload def skip_requires_pyarrow( - fn: Callable[..., Any] | None = None, /, *, requires_tzdata: bool = False -) -> Callable[..., Any]: + fn: None = ..., /, *, requires_tzdata: bool = ... +) -> pytest.MarkDecorator: ... + + +@overload +def skip_requires_pyarrow( + fn: Markable, /, *, requires_tzdata: bool = ... +) -> Markable: ... + + +def skip_requires_pyarrow( + fn: Markable | None = None, /, *, requires_tzdata: bool = False +) -> pytest.MarkDecorator | Markable: """ ``pytest.mark.skipif`` decorator. @@ -109,7 +122,7 @@ def skip_requires_pyarrow( https://github.com/vega/altair/issues/3050 .. _pyarrow: - https://pypi.org/project/pyarrow/ + https://pypi.org/project/pyarrow/ """ composed = pytest.mark.skipif( find_spec("pyarrow") is None, reason="`pyarrow` not installed." @@ -120,13 +133,7 @@ def skip_requires_pyarrow( reason="Timezone database is not installed on Windows", )(composed) - def wrap(test_fn: Callable[..., Any], /) -> Callable[..., Any]: - return composed(test_fn) - - if fn is None: - return wrap - else: - return wrap(fn) + return composed if fn is None else composed(fn) def id_func_str_only(val) -> str: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ec2f9014f..7a4ab51f1 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from importlib.util import find_spec from typing import TYPE_CHECKING import pytest @@ -14,10 +15,27 @@ if TYPE_CHECKING: from altair.datasets._readers import _Backend -backends = skip_requires_pyarrow( - pytest.mark.parametrize( - "backend", ["polars", "polars[pyarrow]", "pandas", "pandas[pyarrow]", "pyarrow"] - ) + +requires_pyarrow = skip_requires_pyarrow() + +backends = pytest.mark.parametrize( + "backend", + [ + "polars", + pytest.param( + "pandas", + marks=pytest.mark.xfail( + find_spec("pyarrow") is None, + reason=( + "`pandas` supports backends other than `pyarrow` for `.parquet`.\n" + "However, none of these are currently an `altair` dependency." + ), + ), + ), + pytest.param("polars[pyarrow]", marks=requires_pyarrow), + pytest.param("pandas[pyarrow]", marks=requires_pyarrow), + pytest.param("pyarrow", marks=requires_pyarrow), + ], )