Skip to content

Commit

Permalink
refactor: Tidy up changes from last 4 commits
Browse files Browse the repository at this point in the history
- Rename and properly document "file-like object" handling
  - Also made a bit clearer what is being called and when
- Use a more granular approach to skipping in `@backends`
  - Previously, everything was skipped regardless of whether it required `pyarrow`
  - Now, `polars`, `pandas` **always** run - with `pandas` expected to fail
- I had to clean up `skip_requires_pyarrow` to make it compatible with `pytest.param`
  - It has a runtime check for if `MarkDecorator`, instead of just a callable

bb7bc17, ebc1bfa, fe0ae88,
7089f2a
  • Loading branch information
dangotbanned committed Nov 11, 2024
1 parent 7089f2a commit e1290d4
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 27 deletions.
33 changes: 22 additions & 11 deletions altair/datasets/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import urllib.request
from functools import partial
from http.client import HTTPResponse
from importlib import import_module
from importlib.util import find_spec
from itertools import chain, islice
Expand Down Expand Up @@ -76,6 +77,10 @@
__all__ = ["get_backend"]


def _identity(_: _T, /) -> _T:
return _


class _Reader(Generic[IntoDataFrameT, IntoFrameT], Protocol):
"""
Common functionality between backends.
Expand All @@ -88,6 +93,18 @@ class _Reader(Generic[IntoDataFrameT, IntoFrameT], Protocol):
_name: LiteralString
_ENV_VAR: ClassVar[LiteralString] = "ALTAIR_DATASETS_DIR"
_opener: ClassVar[OpenerDirector] = urllib.request.build_opener()
_response: ClassVar[staticmethod[[HTTPResponse], Any]] = staticmethod(_identity)
"""
Backends that do not support `file-like objects`_, must override with conversion.
Used only for **remote** files, as *cached* files use a `pathlib.Path`_.
.. _file-like objects:
https://docs.python.org/3/glossary.html#term-file-object
.. _pathlib.Path:
https://docs.python.org/3/library/pathlib.html#pathlib.Path
"""

_metadata: Path = Path(__file__).parent / "_metadata" / "metadata.parquet"

def read_fn(self, source: StrPath, /) -> Callable[..., IntoDataFrameT]:
Expand All @@ -98,10 +115,6 @@ def scan_fn(self, source: StrPath, /) -> Callable[..., IntoFrameT]:
suffix = validate_suffix(source, is_ext_scan)
return self._scan_fn[suffix]

def _response_hook(self, f):
# HACK: `pyarrow` + `pandas` wants the file obj
return f

def dataset(
self,
name: DatasetName | LiteralString,
Expand Down Expand Up @@ -137,7 +150,7 @@ def dataset(
return fn(fp, **kwds)
else:
with self._opener.open(url) as f:
return fn(self._response_hook(f), **kwds)
return fn(self._response(f), **kwds)

def url(
self,
Expand Down Expand Up @@ -261,6 +274,8 @@ def __init__(self, name: Literal["pandas[pyarrow]"], /) -> None:


class _PolarsReader(_Reader["pl.DataFrame", "pl.LazyFrame"]):
_response = staticmethod(HTTPResponse.read)

def __init__(self, name: _Polars, /) -> None:
self._name = _requirements(name)
if not TYPE_CHECKING:
Expand All @@ -273,11 +288,10 @@ def __init__(self, name: _Polars, /) -> None:
}
self._scan_fn = {".parquet": pl.scan_parquet}

def _response_hook(self, f):
return f.read()


class _PolarsPyArrowReader(_Reader["pl.DataFrame", "pl.LazyFrame"]):
_response = staticmethod(HTTPResponse.read)

def __init__(self, name: Literal["polars[pyarrow]"], /) -> None:
_pl, _pa = _requirements(name)
self._name = name
Expand All @@ -292,9 +306,6 @@ def __init__(self, name: Literal["polars[pyarrow]"], /) -> None:
}
self._scan_fn = {".parquet": pl.scan_parquet}

def _response_hook(self, f):
return f.read()


class _PyArrowReader(_Reader["pa.Table", "pa.Table"]):
"""
Expand Down
31 changes: 19 additions & 12 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
import sys
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, overload

import pytest

from tests import examples_arguments_syntax, examples_methods_syntax

if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterator, Mapping
from collections.abc import Collection, Iterator, Mapping
from re import Pattern

if sys.version_info >= (3, 11):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from _pytest.mark import ParameterSet
from _pytest.mark.structures import Markable

MarksType: TypeAlias = (
"pytest.MarkDecorator | Collection[pytest.MarkDecorator | pytest.Mark]"
Expand Down Expand Up @@ -96,9 +97,21 @@ def windows_has_tzdata() -> bool:
"""


@overload
def skip_requires_pyarrow(
fn: Callable[..., Any] | None = None, /, *, requires_tzdata: bool = False
) -> Callable[..., Any]:
fn: None = ..., /, *, requires_tzdata: bool = ...
) -> pytest.MarkDecorator: ...


@overload
def skip_requires_pyarrow(
fn: Markable, /, *, requires_tzdata: bool = ...
) -> Markable: ...


def skip_requires_pyarrow(
fn: Markable | None = None, /, *, requires_tzdata: bool = False
) -> pytest.MarkDecorator | Markable:
"""
``pytest.mark.skipif`` decorator.
Expand All @@ -109,7 +122,7 @@ def skip_requires_pyarrow(
https://github.com/vega/altair/issues/3050
.. _pyarrow:
https://pypi.org/project/pyarrow/
https://pypi.org/project/pyarrow/
"""
composed = pytest.mark.skipif(
find_spec("pyarrow") is None, reason="`pyarrow` not installed."
Expand All @@ -120,13 +133,7 @@ def skip_requires_pyarrow(
reason="Timezone database is not installed on Windows",
)(composed)

def wrap(test_fn: Callable[..., Any], /) -> Callable[..., Any]:
return composed(test_fn)

if fn is None:
return wrap
else:
return wrap(fn)
return composed if fn is None else composed(fn)


def id_func_str_only(val) -> str:
Expand Down
26 changes: 22 additions & 4 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
from importlib.util import find_spec
from typing import TYPE_CHECKING

import pytest
Expand All @@ -14,10 +15,27 @@
if TYPE_CHECKING:
from altair.datasets._readers import _Backend

backends = skip_requires_pyarrow(
pytest.mark.parametrize(
"backend", ["polars", "polars[pyarrow]", "pandas", "pandas[pyarrow]", "pyarrow"]
)

requires_pyarrow = skip_requires_pyarrow()

backends = pytest.mark.parametrize(
"backend",
[
"polars",
pytest.param(
"pandas",
marks=pytest.mark.xfail(
find_spec("pyarrow") is None,
reason=(
"`pandas` supports backends other than `pyarrow` for `.parquet`.\n"
"However, none of these are currently an `altair` dependency."
),
),
),
pytest.param("polars[pyarrow]", marks=requires_pyarrow),
pytest.param("pandas[pyarrow]", marks=requires_pyarrow),
pytest.param("pyarrow", marks=requires_pyarrow),
],
)


Expand Down

0 comments on commit e1290d4

Please sign in to comment.