From 3785a48eb81be23b44b895624f21acbfc1a828c5 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Tue, 3 Dec 2024 11:17:04 -0600 Subject: [PATCH 1/2] Add multi-partition `DataFrameScan` support to cuDF-Polars (#17441) Follow-up to https://github.com/rapidsai/cudf/pull/17262 Adds support for parallel `DataFrameScan` operations. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/17441 --- python/cudf_polars/cudf_polars/callback.py | 14 +- python/cudf_polars/cudf_polars/dsl/ir.py | 17 +- .../cudf_polars/cudf_polars/dsl/translate.py | 1 + .../cudf_polars/experimental/base.py | 43 +++ .../cudf_polars/experimental/dispatch.py | 84 ++++++ .../cudf_polars/experimental/io.py | 49 ++++ .../cudf_polars/experimental/parallel.py | 245 +++++++++--------- .../cudf_polars/tests/dsl/test_traversal.py | 12 +- .../tests/experimental/test_dataframescan.py | 53 ++++ python/cudf_polars/tests/test_executors.py | 16 ++ 10 files changed, 411 insertions(+), 123 deletions(-) create mode 100644 python/cudf_polars/cudf_polars/experimental/base.py create mode 100644 python/cudf_polars/cudf_polars/experimental/dispatch.py create mode 100644 python/cudf_polars/cudf_polars/experimental/io.py create mode 100644 python/cudf_polars/tests/experimental/test_dataframescan.py diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 95527028aa9..29d3dc4ae79 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -217,7 +217,8 @@ def validate_config_options(config: dict) -> None: If the configuration contains unsupported options. """ if unsupported := ( - config.keys() - {"raise_on_fail", "parquet_options", "executor"} + config.keys() + - {"raise_on_fail", "parquet_options", "executor", "executor_options"} ): raise ValueError( f"Engine configuration contains unsupported settings: {unsupported}" @@ -226,6 +227,17 @@ def validate_config_options(config: dict) -> None: config.get("parquet_options", {}) ) + # Validate executor_options + executor = config.get("executor", "pylibcudf") + if executor == "dask-experimental": + unsupported = config.get("executor_options", {}).keys() - { + "max_rows_per_partition" + } + else: + unsupported = config.get("executor_options", {}).keys() + if unsupported: + raise ValueError(f"Unsupported executor_options for {executor}: {unsupported}") + def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None: """ diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index a28b4cf25b2..1faa778ccf6 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -688,14 +688,16 @@ class DataFrameScan(IR): This typically arises from ``q.collect().lazy()`` """ - __slots__ = ("df", "predicate", "projection") - _non_child = ("schema", "df", "projection", "predicate") + __slots__ = ("config_options", "df", "predicate", "projection") + _non_child = ("schema", "df", "projection", "predicate", "config_options") df: Any """Polars LazyFrame object.""" projection: tuple[str, ...] | None """List of columns to project out.""" predicate: expr.NamedExpr | None """Mask to apply.""" + config_options: dict[str, Any] + """GPU-specific configuration options""" def __init__( self, @@ -703,11 +705,13 @@ def __init__( df: Any, projection: Sequence[str] | None, predicate: expr.NamedExpr | None, + config_options: dict[str, Any], ): self.schema = schema self.df = df self.projection = tuple(projection) if projection is not None else None self.predicate = predicate + self.config_options = config_options self._non_child_args = (schema, df, self.projection, predicate) self.children = () @@ -719,7 +723,14 @@ def get_hashable(self) -> Hashable: not stable across runs, or repeat instances of the same equal dataframes. """ schema_hash = tuple(self.schema.items()) - return (type(self), schema_hash, id(self.df), self.projection, self.predicate) + return ( + type(self), + schema_hash, + id(self.df), + self.projection, + self.predicate, + json.dumps(self.config_options), + ) @classmethod def do_evaluate( diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index b1e2de63ba6..37cf36dc4dd 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -263,6 +263,7 @@ def _( translate_named_expr(translator, n=node.selection) if node.selection is not None else None, + translator.config.config.copy(), ) diff --git a/python/cudf_polars/cudf_polars/experimental/base.py b/python/cudf_polars/cudf_polars/experimental/base.py new file mode 100644 index 00000000000..8f660632df2 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/base.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-partition base classes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from cudf_polars.dsl.ir import Union + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from cudf_polars.containers import DataFrame + from cudf_polars.dsl.nodebase import Node + + +class PartitionInfo: + """ + Partitioning information. + + This class only tracks the partition count (for now). + """ + + __slots__ = ("count",) + + def __init__(self, count: int): + self.count = count + + def keys(self, node: Node) -> Iterator[tuple[str, int]]: + """Return the partitioned keys for a given node.""" + name = get_key_name(node) + yield from ((name, i) for i in range(self.count)) + + +def get_key_name(node: Node) -> str: + """Generate the key name for a Node.""" + return f"{type(node).__name__.lower()}-{hash(node)}" + + +def _concat(dfs: Sequence[DataFrame]) -> DataFrame: + # Concatenate a sequence of DataFrames vertically + return Union.do_evaluate(None, *dfs) diff --git a/python/cudf_polars/cudf_polars/experimental/dispatch.py b/python/cudf_polars/cudf_polars/experimental/dispatch.py new file mode 100644 index 00000000000..79a52ff3cde --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/dispatch.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-partition dispatch functions.""" + +from __future__ import annotations + +from functools import singledispatch +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import MutableMapping + from typing import TypeAlias + + from cudf_polars.dsl.ir import IR + from cudf_polars.experimental.base import PartitionInfo + from cudf_polars.typing import GenericTransformer + + +LowerIRTransformer: TypeAlias = ( + "GenericTransformer[IR, tuple[IR, MutableMapping[IR, PartitionInfo]]]" +) +"""Protocol for Lowering IR nodes.""" + + +@singledispatch +def lower_ir_node( + ir: IR, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + """ + Rewrite an IR node and extract partitioning information. + + Parameters + ---------- + ir + IR node to rewrite. + rec + Recursive LowerIRTransformer callable. + + Returns + ------- + new_ir, partition_info + The rewritten node, and a mapping from unique nodes in + the full IR graph to associated partitioning information. + + Notes + ----- + This function is used by `lower_ir_graph`. + + See Also + -------- + lower_ir_graph + """ + raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + + +@singledispatch +def generate_ir_tasks( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + """ + Generate a task graph for evaluation of an IR node. + + Parameters + ---------- + ir + IR node to generate tasks for. + partition_info + Partitioning information, obtained from :func:`lower_ir_graph`. + + Returns + ------- + mapping + A (partial) dask task graph for the evaluation of an ir node. + + Notes + ----- + Task generation should only produce the tasks for the current node, + referring to child tasks by name. + + See Also + -------- + task_graph + """ + raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py new file mode 100644 index 00000000000..3a1fec36079 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-partition IO Logic.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +from cudf_polars.dsl.ir import DataFrameScan, Union +from cudf_polars.experimental.base import PartitionInfo +from cudf_polars.experimental.dispatch import lower_ir_node + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from cudf_polars.dsl.ir import IR + from cudf_polars.experimental.dispatch import LowerIRTransformer + + +@lower_ir_node.register(DataFrameScan) +def _( + ir: DataFrameScan, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + rows_per_partition = ir.config_options.get("executor_options", {}).get( + "max_rows_per_partition", 1_000_000 + ) + + nrows = max(ir.df.shape()[0], 1) + count = math.ceil(nrows / rows_per_partition) + + if count > 1: + length = math.ceil(nrows / count) + slices = [ + DataFrameScan( + ir.schema, + ir.df.slice(offset, length), + ir.projection, + ir.predicate, + ir.config_options, + ) + for offset in range(0, nrows, length) + ] + new_node = Union(ir.schema, None, *slices) + return new_node, {slice: PartitionInfo(count=1) for slice in slices} | { + new_node: PartitionInfo(count=count) + } + + return ir, {ir: PartitionInfo(count=1)} diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 6518dd60c7d..e5884f1c574 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -1,93 +1,46 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 -"""Partitioned LogicalPlan nodes.""" +"""Multi-partition Dask execution.""" from __future__ import annotations +import itertools import operator -from functools import reduce, singledispatch +from functools import reduce from typing import TYPE_CHECKING, Any -from cudf_polars.dsl.ir import IR -from cudf_polars.dsl.traversal import traversal +import cudf_polars.experimental.io # noqa: F401 +from cudf_polars.dsl.ir import IR, Cache, Projection, Union +from cudf_polars.dsl.traversal import CachingVisitor, traversal +from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name +from cudf_polars.experimental.dispatch import ( + generate_ir_tasks, + lower_ir_node, +) if TYPE_CHECKING: from collections.abc import MutableMapping - from typing import TypeAlias from cudf_polars.containers import DataFrame - from cudf_polars.dsl.nodebase import Node - from cudf_polars.typing import GenericTransformer - - -class PartitionInfo: - """ - Partitioning information. - - This class only tracks the partition count (for now). - """ - - __slots__ = ("count",) - - def __init__(self, count: int): - self.count = count - - -LowerIRTransformer: TypeAlias = ( - "GenericTransformer[IR, MutableMapping[IR, PartitionInfo]]" -) -"""Protocol for Lowering IR nodes.""" - - -def get_key_name(node: Node) -> str: - """Generate the key name for a Node.""" - return f"{type(node).__name__.lower()}-{hash(node)}" - - -@singledispatch -def lower_ir_node( - ir: IR, rec: LowerIRTransformer -) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: - """ - Rewrite an IR node and extract partitioning information. - - Parameters - ---------- - ir - IR node to rewrite. - rec - Recursive LowerIRTransformer callable. - - Returns - ------- - new_ir, partition_info - The rewritten node, and a mapping from unique nodes in - the full IR graph to associated partitioning information. - - Notes - ----- - This function is used by `lower_ir_graph`. - - See Also - -------- - lower_ir_graph - """ - raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + from cudf_polars.experimental.dispatch import LowerIRTransformer @lower_ir_node.register(IR) def _(ir: IR, rec: LowerIRTransformer) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + # Default logic - Requires single partition + if len(ir.children) == 0: # Default leaf node has single partition - return ir, {ir: PartitionInfo(count=1)} + return ir, { + ir: PartitionInfo(count=1) + } # pragma: no cover; Missed by pylibcudf executor # Lower children - children, _partition_info = zip(*(rec(c) for c in ir.children), strict=False) + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) partition_info = reduce(operator.or_, _partition_info) # Check that child partitioning is supported - count = max(partition_info[c].count for c in children) - if count > 1: + if any(partition_info[c].count > 1 for c in children): raise NotImplementedError( f"Class {type(ir)} does not support multiple partitions." ) # pragma: no cover @@ -123,41 +76,62 @@ def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: -------- lower_ir_node """ - from cudf_polars.dsl.traversal import CachingVisitor - mapper = CachingVisitor(lower_ir_node) return mapper(ir) -@singledispatch -def generate_ir_tasks( +def task_graph( ir: IR, partition_info: MutableMapping[IR, PartitionInfo] -) -> MutableMapping[Any, Any]: +) -> tuple[MutableMapping[Any, Any], str | tuple[str, int]]: """ - Generate a task graph for evaluation of an IR node. + Construct a task graph for evaluation of an IR graph. Parameters ---------- ir - IR node to generate tasks for. + Root of the graph to rewrite. partition_info - Partitioning information, obtained from :func:`lower_ir_graph`. + A mapping from all unique IR nodes to the + associated partitioning information. Returns ------- - mapping - A (partial) dask task graph for the evaluation of an ir node. + graph + A Dask-compatible task graph for the entire + IR graph with root `ir`. Notes ----- - Task generation should only produce the tasks for the current node, - referring to child tasks by name. + This function traverses the unique nodes of the + graph with root `ir`, and extracts the tasks for + each node with :func:`generate_ir_tasks`. See Also -------- - task_graph + generate_ir_tasks """ - raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + graph = reduce( + operator.or_, + (generate_ir_tasks(node, partition_info) for node in traversal(ir)), + ) + + key_name = get_key_name(ir) + partition_count = partition_info[ir].count + if partition_count > 1: + graph[key_name] = (_concat, list(partition_info[ir].keys(ir))) + return graph, key_name + else: + return graph, (key_name, 0) + + +def evaluate_dask(ir: IR) -> DataFrame: + """Evaluate an IR graph with Dask.""" + from dask import get + + ir, partition_info = lower_ir_graph(ir) + + graph, key = task_graph(ir, partition_info) + return get(graph, key) @generate_ir_tasks.register(IR) @@ -189,48 +163,85 @@ def _( } -def task_graph( - ir: IR, partition_info: MutableMapping[IR, PartitionInfo] -) -> tuple[MutableMapping[Any, Any], str | tuple[str, int]]: - """ - Construct a task graph for evaluation of an IR graph. +@lower_ir_node.register(Union) +def _( + ir: Union, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + # Lower children + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) + partition_info = reduce(operator.or_, _partition_info) - Parameters - ---------- - ir - Root of the graph to rewrite. - partition_info - A mapping from all unique IR nodes to the - associated partitioning information. + # Check zlice + if ir.zlice is not None: # pragma: no cover + if any(p[c].count > 1 for p, c in zip(children, _partition_info, strict=False)): + raise NotImplementedError("zlice is not supported for multiple partitions.") + new_node = ir.reconstruct(children) + partition_info[new_node] = PartitionInfo(count=1) + return new_node, partition_info - Returns - ------- - graph - A Dask-compatible task graph for the entire - IR graph with root `ir`. + # Partition count is the sum of all child partitions + count = sum(partition_info[c].count for c in children) - Notes - ----- - This function traverses the unique nodes of the - graph with root `ir`, and extracts the tasks for - each node with :func:`generate_ir_tasks`. + # Return reconstructed node and partition-info dict + new_node = ir.reconstruct(children) + partition_info[new_node] = PartitionInfo(count=count) + return new_node, partition_info - See Also - -------- - generate_ir_tasks - """ - graph = reduce( - operator.or_, - (generate_ir_tasks(node, partition_info) for node in traversal(ir)), - ) - return graph, (get_key_name(ir), 0) +@generate_ir_tasks.register(Union) +def _( + ir: Union, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + key_name = get_key_name(ir) + partition = itertools.count() + return { + (key_name, next(partition)): child_key + for child in ir.children + for child_key in partition_info[child].keys(child) + } -def evaluate_dask(ir: IR) -> DataFrame: - """Evaluate an IR graph with Dask.""" - from dask import get - ir, partition_info = lower_ir_graph(ir) +def _lower_ir_pwise( + ir: IR, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + # Lower a partition-wise (i.e. embarrassingly-parallel) IR node - graph, key = task_graph(ir, partition_info) - return get(graph, key) + # Lower children + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) + partition_info = reduce(operator.or_, _partition_info) + counts = {partition_info[c].count for c in children} + + # Check that child partitioning is supported + if len(counts) > 1: + raise NotImplementedError( + f"Class {type(ir)} does not support unbalanced partitions." + ) # pragma: no cover + + # Return reconstructed node and partition-info dict + partition = PartitionInfo(count=max(counts)) + new_node = ir.reconstruct(children) + partition_info[new_node] = partition + return new_node, partition_info + + +lower_ir_node.register(Projection, _lower_ir_pwise) +lower_ir_node.register(Cache, _lower_ir_pwise) + + +def _generate_ir_tasks_pwise( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + # Generate partition-wise (i.e. embarrassingly-parallel) tasks + child_names = [get_key_name(c) for c in ir.children] + return { + key: ( + ir.do_evaluate, + *ir._non_child_args, + *[(child_name, i) for child_name in child_names], + ) + for i, key in enumerate(partition_info[ir].keys(ir)) + } + + +generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise) +generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise) diff --git a/python/cudf_polars/tests/dsl/test_traversal.py b/python/cudf_polars/tests/dsl/test_traversal.py index 2f4df9289f8..9755994c419 100644 --- a/python/cudf_polars/tests/dsl/test_traversal.py +++ b/python/cudf_polars/tests/dsl/test_traversal.py @@ -116,7 +116,11 @@ def test_rewrite_ir_node(): def replace_df(node, rec): if isinstance(node, ir.DataFrameScan): return ir.DataFrameScan( - node.schema, new_df._df, node.projection, node.predicate + node.schema, + new_df._df, + node.projection, + node.predicate, + node.config_options, ) return reuse_if_unchanged(node, rec) @@ -144,7 +148,11 @@ def test_rewrite_scan_node(tmp_path): def replace_scan(node, rec): if isinstance(node, ir.Scan): return ir.DataFrameScan( - node.schema, right._df, node.with_columns, node.predicate + node.schema, + right._df, + node.with_columns, + node.predicate, + node.config_options, ) return reuse_if_unchanged(node, rec) diff --git a/python/cudf_polars/tests/experimental/test_dataframescan.py b/python/cudf_polars/tests/experimental/test_dataframescan.py new file mode 100644 index 00000000000..77c7bf0c503 --- /dev/null +++ b/python/cudf_polars/tests/experimental/test_dataframescan.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars import Translator +from cudf_polars.experimental.parallel import lower_ir_graph +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture(scope="module") +def df(): + return pl.LazyFrame( + { + "x": range(30_000), + "y": ["cat", "dog", "fish"] * 10_000, + "z": [1.0, 2.0, 3.0, 4.0, 5.0] * 6_000, + } + ) + + +@pytest.mark.parametrize("max_rows_per_partition", [1_000, 1_000_000]) +def test_parallel_dataframescan(df, max_rows_per_partition): + total_row_count = len(df.collect()) + engine = pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"max_rows_per_partition": max_rows_per_partition}, + ) + assert_gpu_result_equal(df, engine=engine) + + # Check partitioning + qir = Translator(df._ldf.visit(), engine).translate_ir() + ir, info = lower_ir_graph(qir) + count = info[ir].count + if max_rows_per_partition < total_row_count: + assert count > 1 + else: + assert count == 1 + + +def test_dataframescan_concat(df): + engine = pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"max_rows_per_partition": 1_000}, + ) + df2 = pl.concat([df, df]) + assert_gpu_result_equal(df2, engine=engine) diff --git a/python/cudf_polars/tests/test_executors.py b/python/cudf_polars/tests/test_executors.py index 3eaea2ec9ea..b8c0bb926ab 100644 --- a/python/cudf_polars/tests/test_executors.py +++ b/python/cudf_polars/tests/test_executors.py @@ -66,3 +66,19 @@ def test_unknown_executor(): match="ValueError: Unknown executor 'unknown-executor'", ): assert_gpu_result_equal(df, executor="unknown-executor") + + +@pytest.mark.parametrize("executor", [None, "pylibcudf", "dask-experimental"]) +def test_unknown_executor_options(executor): + df = pl.LazyFrame({}) + + with pytest.raises( + pl.exceptions.ComputeError, + match="Unsupported executor_options", + ): + df.collect( + engine=pl.GPUEngine( + executor=executor, + executor_options={"foo": None}, + ) + ) From 4696bbf91ca37ab6960b606d1f7763487ee03ef6 Mon Sep 17 00:00:00 2001 From: Matthew Murray <41342305+Matt711@users.noreply.github.com> Date: Tue, 3 Dec 2024 12:58:35 -0500 Subject: [PATCH 2/2] Revert "Temporarily skip tests due to dask/distributed#8953" (#17492) Reverts rapidsai/cudf#17472 The new dask nightly has resolved https://github.com/dask/distributed/issues/8953 --- .../custreamz/tests/test_dataframes.py | 56 +++---------------- 1 file changed, 7 insertions(+), 49 deletions(-) diff --git a/python/custreamz/custreamz/tests/test_dataframes.py b/python/custreamz/custreamz/tests/test_dataframes.py index 6905044039c..8c0130d2818 100644 --- a/python/custreamz/custreamz/tests/test_dataframes.py +++ b/python/custreamz/custreamz/tests/test_dataframes.py @@ -216,13 +216,7 @@ def test_set_index(): assert_eq(b[0], df.set_index(df.y + 1)) -def test_binary_stream_operators(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_binary_stream_operators(stream): df = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) expected = df.x + df.y @@ -248,13 +242,7 @@ def test_index(stream): assert_eq(L[1], df.index + 5) -def test_pair_arithmetic(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_pair_arithmetic(stream): df = cudf.DataFrame({"x": list(range(10)), "y": [1] * 10}) a = DataFrame(example=df.iloc[:0], stream=stream) @@ -267,13 +255,7 @@ def test_pair_arithmetic(request, stream): assert_eq(cudf.concat(L), (df.x + df.y) * 2) -def test_getitem(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_getitem(stream): df = cudf.DataFrame({"x": list(range(10)), "y": [1] * 10}) a = DataFrame(example=df.iloc[:0], stream=stream) @@ -350,13 +332,7 @@ def test_repr_html(stream): assert "1" in html -def test_setitem(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_setitem(stream): df = cudf.DataFrame({"x": list(range(10)), "y": [1] * 10}) sdf = DataFrame(example=df.iloc[:0], stream=stream) @@ -380,13 +356,7 @@ def test_setitem(request, stream): assert_eq(L[-1], df.mean()) -def test_setitem_overwrites(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_setitem_overwrites(stream): df = cudf.DataFrame({"x": list(range(10))}) sdf = DataFrame(example=df.iloc[:0], stream=stream) stream = sdf.stream @@ -443,14 +413,8 @@ def test_setitem_overwrites(request, stream): ], ) def test_rolling_count_aggregations( - request, op, window, m, pre_get, post_get, kwargs, stream + op, window, m, pre_get, post_get, kwargs, stream ): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream) and len(kwargs) == 0, - reason="https://github.com/dask/distributed/issues/8953", - ) - ) index = pd.DatetimeIndex( pd.date_range("2000-01-01", "2000-01-03", freq="1h") ) @@ -844,13 +808,7 @@ def test_reductions_with_start_state(stream): assert output2[0] == 360 -def test_rolling_aggs_with_start_state(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_rolling_aggs_with_start_state(stream): example = cudf.DataFrame({"name": [], "amount": []}, dtype="float64") sdf = DataFrame(stream, example=example) output0 = (