diff --git a/altair/datasets/_readers.py b/altair/datasets/_readers.py index a3435d231..b2f41af89 100644 --- a/altair/datasets/_readers.py +++ b/altair/datasets/_readers.py @@ -22,6 +22,7 @@ Any, Callable, ClassVar, + Final, Generic, Literal, Protocol, @@ -76,6 +77,8 @@ __all__ = ["get_backend"] +_METADATA: Final[Path] = Path(__file__).parent / "_metadata" / "metadata.parquet" + def _identity(_: _T, /) -> _T: return _ @@ -105,8 +108,6 @@ class _Reader(Generic[IntoDataFrameT, IntoFrameT], Protocol): https://docs.python.org/3/library/pathlib.html#pathlib.Path """ - _metadata: Path = Path(__file__).parent / "_metadata" / "metadata.parquet" - def read_fn(self, source: StrPath, /) -> Callable[..., IntoDataFrameT]: suffix = validate_suffix(source, is_ext_read) return self._read_fn[suffix] @@ -159,20 +160,13 @@ def url( /, tag: VersionTag | None = None, ) -> str: - df = self.query(**validate_constraints(name, suffix, tag)) - url = df.item(0, "url_npm") + frame = self.query(**validate_constraints(name, suffix, tag)) + url = nw.to_py_scalar(frame.item(0, "url_npm")) if isinstance(url, str): return url else: - converted = nw.to_py_scalar(url) - if isinstance(converted, str): - return converted - else: - msg = ( - f"Expected 'str' but got {type(converted).__name__!r}\n" - f"from {converted!r}." - ) - raise TypeError(msg) + msg = f"Expected 'str' but got {type(url).__name__!r}\n" f"from {url!r}." + raise TypeError(msg) def query( self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata] @@ -188,15 +182,14 @@ def query( .. _pl.LazyFrame.filter: https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.filter.html """ - source = self._metadata - fn = self.scan_fn(source) - frame = nw.from_native(fn(source)) - result = frame.filter(_filter_reduce(predicates, constraints)) - df: nw.DataFrame[Any] = ( - result.collect() if isinstance(result, nw.LazyFrame) else result + frame = ( + nw.from_native(self.scan_fn(_METADATA)(_METADATA)) + .filter(_parse_predicates_constraints(predicates, constraints)) + .lazy() + .collect() ) - if not df.is_empty(): - return df + if not frame.is_empty(): + return frame else: terms = "\n".join(f"{t!r}" for t in (predicates, constraints) if t) msg = f"Found no results for:\n{terms}" @@ -208,12 +201,12 @@ def _read_metadata(self) -> IntoDataFrameT: Effectively an eager read, no filters. """ - fn = self.scan_fn(self._metadata) - frame = nw.from_native(fn(self._metadata)) - df: nw.DataFrame[Any] = ( - frame.collect() if isinstance(frame, nw.LazyFrame) else frame + return ( + nw.from_native(self.scan_fn(_METADATA)(_METADATA)) + .lazy() + .collect() + .to_native() ) - return df.to_native() @property def _cache(self) -> Path | None: # type: ignore[return] @@ -351,11 +344,15 @@ def __init__(self, name: _PyArrow, /) -> None: self._scan_fn = {".parquet": pa_read_parquet} -def _filter_reduce(predicates: tuple[Any, ...], constraints: Metadata, /) -> nw.Expr: +def _parse_predicates_constraints( + predicates: tuple[Any, ...], constraints: Metadata, / +) -> nw.Expr: """ - ``narwhals`` only accepts ``filter(*predicates)`. + ``narwhals`` only accepts ``filter(*predicates)``. + + So we convert each item in ``**constraints`` here as:: - Manually converts the constraints into ``==`` + col("column_name") == literal_value """ return nw.all_horizontal( chain(predicates, (nw.col(name) == v for name, v in constraints.items()))