diff --git a/ci/run_cudf_polars_polars_tests.sh b/ci/run_cudf_polars_polars_tests.sh
index b1bfac2a1dd..c851f65d4f6 100755
--- a/ci/run_cudf_polars_polars_tests.sh
+++ b/ci/run_cudf_polars_polars_tests.sh
@@ -28,8 +28,11 @@ if [[ $(arch) == "aarch64" ]]; then
DESELECTED_TESTS+=("tests/unit/operations/test_join.py::test_join_4_columns_with_validity")
else
# Ensure that we don't run dbgen when it uses newer symbols than supported by the glibc version in the CI image.
+ # Allow errors since any of these commands could produce empty results that would cause the script to fail.
+ set +e
glibc_minor_version=$(ldd --version | head -1 | grep -o "[0-9]\.[0-9]\+" | tail -1 | cut -d '.' -f2)
latest_glibc_symbol_found=$(nm py-polars/tests/benchmark/data/pdsh/dbgen/dbgen | grep GLIBC | grep -o "[0-9]\.[0-9]\+" | sort --version-sort | tail -1 | cut -d "." -f 2)
+ set -e
if [[ ${glibc_minor_version} -lt ${latest_glibc_symbol_found} ]]; then
DESELECTED_TESTS+=("tests/benchmark/test_pdsh.py::test_pdsh")
fi
diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml
index ecc490b378b..97c72ec8042 100644
--- a/conda/environments/all_cuda-118_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-118_arch-x86_64.yaml
@@ -71,6 +71,7 @@ dependencies:
- ptxcompiler
- pyarrow>=14.0.0,<19.0.0a0
- pydata-sphinx-theme!=0.14.2
+- pynvml>=11.4.1,<12.0.0a0
- pytest-benchmark
- pytest-cases>=3.8.2
- pytest-cov
diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml
index 4559829ac3a..84b58b6d7a4 100644
--- a/conda/environments/all_cuda-125_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-125_arch-x86_64.yaml
@@ -69,6 +69,7 @@ dependencies:
- pyarrow>=14.0.0,<19.0.0a0
- pydata-sphinx-theme!=0.14.2
- pynvjitlink>=0.0.0a0
+- pynvml>=11.4.1,<12.0.0a0
- pytest-benchmark
- pytest-cases>=3.8.2
- pytest-cov
diff --git a/conda/recipes/dask-cudf/meta.yaml b/conda/recipes/dask-cudf/meta.yaml
index 1e6c0a35a09..74ecded8ead 100644
--- a/conda/recipes/dask-cudf/meta.yaml
+++ b/conda/recipes/dask-cudf/meta.yaml
@@ -43,6 +43,7 @@ requirements:
run:
- python
- cudf ={{ version }}
+ - pynvml >=11.4.1,<12.0.0a0
- rapids-dask-dependency ={{ minor_version }}
- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
diff --git a/cpp/include/cudf/interop.hpp b/cpp/include/cudf/interop.hpp
index f789d950e51..810f0377597 100644
--- a/cpp/include/cudf/interop.hpp
+++ b/cpp/include/cudf/interop.hpp
@@ -57,12 +57,14 @@ namespace CUDF_EXPORT cudf {
* @throw cudf::logic_error if the any of the DLTensor fields are unsupported
*
* @param managed_tensor a 1D or 2D column-major (Fortran order) tensor
+ * @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned table's device memory
*
* @return Table with a copy of the tensor data
*/
std::unique_ptr
from_dlpack(
DLManagedTensor const* managed_tensor,
+ rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());
/**
@@ -79,12 +81,14 @@ std::unique_ptr from_dlpack(
* or if any of columns have non-zero null count
*
* @param input Table to convert to DLPack
+ * @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned DLPack tensor's device memory
*
* @return 1D or 2D DLPack tensor with a copy of the table data, or nullptr
*/
DLManagedTensor* to_dlpack(
table_view const& input,
+ rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());
/** @} */ // end of group
diff --git a/cpp/include/cudf/io/csv.hpp b/cpp/include/cudf/io/csv.hpp
index dae056ef157..9b2de7c72ec 100644
--- a/cpp/include/cudf/io/csv.hpp
+++ b/cpp/include/cudf/io/csv.hpp
@@ -1362,7 +1362,7 @@ table_with_metadata read_csv(
*/
/**
- *@brief Builder to build options for `writer_csv()`.
+ *@brief Builder to build options for `write_csv()`.
*/
class csv_writer_options_builder;
diff --git a/cpp/src/interop/dlpack.cpp b/cpp/src/interop/dlpack.cpp
index 4395b741e53..b5cc4cbba0d 100644
--- a/cpp/src/interop/dlpack.cpp
+++ b/cpp/src/interop/dlpack.cpp
@@ -297,16 +297,19 @@ DLManagedTensor* to_dlpack(table_view const& input,
} // namespace detail
std::unique_ptr from_dlpack(DLManagedTensor const* managed_tensor,
+ rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
- return detail::from_dlpack(managed_tensor, cudf::get_default_stream(), mr);
+ return detail::from_dlpack(managed_tensor, stream, mr);
}
-DLManagedTensor* to_dlpack(table_view const& input, rmm::device_async_resource_ref mr)
+DLManagedTensor* to_dlpack(table_view const& input,
+ rmm::cuda_stream_view stream,
+ rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
- return detail::to_dlpack(input, cudf::get_default_stream(), mr);
+ return detail::to_dlpack(input, stream, mr);
}
} // namespace cudf
diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt
index 666a7d4ba4b..91c00d6af34 100644
--- a/cpp/tests/CMakeLists.txt
+++ b/cpp/tests/CMakeLists.txt
@@ -701,6 +701,7 @@ ConfigureTest(STREAM_DICTIONARY_TEST streams/dictionary_test.cpp STREAM_MODE tes
ConfigureTest(STREAM_FILLING_TEST streams/filling_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_GROUPBY_TEST streams/groupby_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_HASHING_TEST streams/hash_test.cpp STREAM_MODE testing)
+ConfigureTest(STREAM_INTEROP streams/interop_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_JOIN_TEST streams/join_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_JSONIO_TEST streams/io/json_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_LABELING_BINS_TEST streams/labeling_bins_test.cpp STREAM_MODE testing)
diff --git a/cpp/tests/streams/interop_test.cpp b/cpp/tests/streams/interop_test.cpp
new file mode 100644
index 00000000000..7133baf6df1
--- /dev/null
+++ b/cpp/tests/streams/interop_test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+struct dlpack_deleter {
+ void operator()(DLManagedTensor* tensor) { tensor->deleter(tensor); }
+};
+
+struct DLPackTest : public cudf::test::BaseFixture {};
+
+TEST_F(DLPackTest, ToDLPack)
+{
+ cudf::table_view empty(std::vector{});
+ cudf::to_dlpack(empty, cudf::test::get_default_stream());
+}
+
+TEST_F(DLPackTest, FromDLPack)
+{
+ using unique_managed_tensor = std::unique_ptr;
+ cudf::test::fixed_width_column_wrapper col1({});
+ cudf::test::fixed_width_column_wrapper col2({});
+ cudf::table_view input({col1, col2});
+ unique_managed_tensor tensor(cudf::to_dlpack(input, cudf::test::get_default_stream()));
+ auto result = cudf::from_dlpack(tensor.get(), cudf::test::get_default_stream());
+}
diff --git a/dependencies.yaml b/dependencies.yaml
index 631ce12f0b0..3976696a41c 100644
--- a/dependencies.yaml
+++ b/dependencies.yaml
@@ -758,6 +758,7 @@ dependencies:
common:
- output_types: [conda, requirements, pyproject]
packages:
+ - pynvml>=11.4.1,<12.0.0a0
- rapids-dask-dependency==25.2.*,>=0.0.0a0
run_custreamz:
common:
diff --git a/python/cudf/cudf/_lib/csv.pyx b/python/cudf/cudf/_lib/csv.pyx
index c09e06bfc59..59a970263e0 100644
--- a/python/cudf/cudf/_lib/csv.pyx
+++ b/python/cudf/cudf/_lib/csv.pyx
@@ -1,10 +1,6 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
from libcpp cimport bool
-from libcpp.memory cimport unique_ptr
-from libcpp.string cimport string
-from libcpp.utility cimport move
-from libcpp.vector cimport vector
cimport pylibcudf.libcudf.types as libcudf_types
@@ -23,16 +19,7 @@ from cudf.core.buffer import acquire_spill_lock
from libcpp cimport bool
-from pylibcudf.libcudf.io.csv cimport (
- csv_writer_options,
- write_csv as cpp_write_csv,
-)
-from pylibcudf.libcudf.io.data_sink cimport data_sink
-from pylibcudf.libcudf.io.types cimport sink_info
-from pylibcudf.libcudf.table.table_view cimport table_view
-
-from cudf._lib.io.utils cimport make_sink_info
-from cudf._lib.utils cimport data_from_pylibcudf_io, table_view_from_table
+from cudf._lib.utils cimport data_from_pylibcudf_io
import pylibcudf as plc
@@ -318,59 +305,40 @@ def write_csv(
--------
cudf.to_csv
"""
- cdef table_view input_table_view = table_view_from_table(
- table, not index
- )
- cdef bool include_header_c = header
- cdef char delim_c = ord(sep)
- cdef string line_term_c = lineterminator.encode()
- cdef string na_c = na_rep.encode()
- cdef int rows_per_chunk_c = rows_per_chunk
- cdef vector[string] col_names
- cdef string true_value_c = 'True'.encode()
- cdef string false_value_c = 'False'.encode()
- cdef unique_ptr[data_sink] data_sink_c
- cdef sink_info sink_info_c = make_sink_info(path_or_buf, data_sink_c)
-
- if header is True:
- all_names = columns_apply_na_rep(table._column_names, na_rep)
- if index is True:
- all_names = table._index.names + all_names
-
- if len(all_names) > 0:
- col_names.reserve(len(all_names))
- if len(all_names) == 1:
- if all_names[0] in (None, ''):
- col_names.push_back('""'.encode())
- else:
- col_names.push_back(
- str(all_names[0]).encode()
- )
- else:
- for idx, col_name in enumerate(all_names):
- if col_name is None:
- col_names.push_back(''.encode())
- else:
- col_names.push_back(
- str(col_name).encode()
- )
-
- cdef csv_writer_options options = move(
- csv_writer_options.builder(sink_info_c, input_table_view)
- .names(col_names)
- .na_rep(na_c)
- .include_header(include_header_c)
- .rows_per_chunk(rows_per_chunk_c)
- .line_terminator(line_term_c)
- .inter_column_delimiter(delim_c)
- .true_value(true_value_c)
- .false_value(false_value_c)
- .build()
- )
-
+ index_and_not_empty = index is True and table.index is not None
+ columns = [
+ col.to_pylibcudf(mode="read") for col in table.index._columns
+ ] if index_and_not_empty else []
+ columns.extend(col.to_pylibcudf(mode="read") for col in table._columns)
+ col_names = []
+ if header:
+ all_names = list(table.index.names) if index_and_not_empty else []
+ all_names.extend(
+ na_rep if name is None or pd.isnull(name)
+ else name for name in table._column_names
+ )
+ col_names = [
+ '""' if (name in (None, '') and len(all_names) == 1)
+ else (str(name) if name not in (None, '') else '')
+ for name in all_names
+ ]
try:
- with nogil:
- cpp_write_csv(options)
+ plc.io.csv.write_csv(
+ (
+ plc.io.csv.CsvWriterOptions.builder(
+ plc.io.SinkInfo([path_or_buf]), plc.Table(columns)
+ )
+ .names(col_names)
+ .na_rep(na_rep)
+ .include_header(header)
+ .rows_per_chunk(rows_per_chunk)
+ .line_terminator(str(lineterminator))
+ .inter_column_delimiter(str(sep))
+ .true_value("True")
+ .false_value("False")
+ .build()
+ )
+ )
except OverflowError:
raise OverflowError(
f"Writing CSV file with chunksize={rows_per_chunk} failed. "
@@ -419,11 +387,3 @@ cdef DataType _get_plc_data_type_from_dtype(object dtype) except *:
dtype = cudf.dtype(dtype)
return dtype_to_pylibcudf_type(dtype)
-
-
-def columns_apply_na_rep(column_names, na_rep):
- return tuple(
- na_rep if pd.isnull(col_name)
- else col_name
- for col_name in column_names
- )
diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py
index f8079234df4..99512e2ef52 100644
--- a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py
+++ b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py
@@ -69,7 +69,7 @@ def __init__(
*by: Expr,
) -> None:
self.dtype = dtype
- self.options = (options[0], tuple(options[1]), tuple(options[2]))
+ self.options = options
self.children = (column, *by)
def do_evaluate(
diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py
index a83130666b6..6899747f439 100644
--- a/python/cudf_polars/cudf_polars/dsl/ir.py
+++ b/python/cudf_polars/cudf_polars/dsl/ir.py
@@ -1599,13 +1599,15 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR):
# polars requires that all to-explode columns have the
# same sub-shapes
raise NotImplementedError("Explode with more than one column")
+ self.options = (tuple(to_explode),)
elif self.name == "rename":
- old, new, _ = self.options
+ old, new, strict = self.options
# TODO: perhaps polars should validate renaming in the IR?
if len(new) != len(set(new)) or (
set(new) & (set(df.schema.keys()) - set(old))
):
raise NotImplementedError("Duplicate new names in rename.")
+ self.options = (tuple(old), tuple(new), strict)
elif self.name == "unpivot":
indices, pivotees, variable_name, value_name = self.options
value_name = "value" if value_name is None else value_name
@@ -1631,7 +1633,7 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR):
def get_hashable(self) -> Hashable: # pragma: no cover; Needed by experimental
"""Hashable representation of the node."""
schema_hash = tuple(self.schema.items())
- return (type(self), schema_hash, self.name, str(self.options), *self.children)
+ return (type(self), schema_hash, self.name, self.options, *self.children)
@classmethod
def do_evaluate(
diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py
index 12fc2a196cd..9480ce6e535 100644
--- a/python/cudf_polars/cudf_polars/dsl/translate.py
+++ b/python/cudf_polars/cudf_polars/dsl/translate.py
@@ -633,9 +633,10 @@ def _(node: pl_expr.Sort, translator: Translator, dtype: plc.DataType) -> expr.E
@_translate_expr.register
def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr.Expr:
+ options = node.sort_options
return expr.SortBy(
dtype,
- node.sort_options,
+ (options[0], tuple(options[1]), tuple(options[2])),
translator.translate_expr(n=node.expr),
*(translator.translate_expr(n=n) for n in node.by),
)
diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py
index 6574021e592..35829420b95 100644
--- a/python/cudf_polars/cudf_polars/experimental/groupby.py
+++ b/python/cudf_polars/cudf_polars/experimental/groupby.py
@@ -13,18 +13,18 @@
from cudf_polars.experimental.parallel import (
PartitionInfo,
_concat,
- _ir_parts_info,
- _partitionwise_ir_parts_info,
+ _default_lower_ir_node,
+ _lower_children,
_partitionwise_ir_tasks,
generate_ir_tasks,
get_key_name,
- ir_parts_info,
)
if TYPE_CHECKING:
from collections.abc import MutableMapping
from cudf_polars.dsl.ir import IR
+ from cudf_polars.experimental.parallel import LowerIRTransformer
class GroupByPart(GroupBy):
@@ -42,19 +42,22 @@ class GroupByFinalize(Select):
_GB_AGG_SUPPORTED = ("sum", "count", "mean")
-def lower_groupby_node(ir: GroupBy, rec) -> IR:
+def lower_groupby_node(
+ ir: GroupBy, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
"""Rewrite a GroupBy node with proper partitioning."""
# Lower children first
- children = [rec(child) for child in ir.children]
- if ir_parts_info(children[0]).count == 1:
+ children, partition_info = _lower_children(ir, rec)
+
+ if partition_info[children[0]].count == 1:
# Single partition
- return ir.reconstruct(children)
+ return _default_lower_ir_node(ir, rec)
# Check that we are groupbing on element-wise
# keys (is this already guaranteed?)
for ne in ir.keys:
if not isinstance(ne.value, Col):
- return ir.reconstruct(children)
+ return _default_lower_ir_node(ir, rec)
name_map: MutableMapping[str, Any] = {}
agg_tree: Cast | Agg | None = None
@@ -73,10 +76,10 @@ def lower_groupby_node(ir: GroupBy, rec) -> IR:
elif isinstance(agg, Agg):
# Agg
if agg.name not in _GB_AGG_SUPPORTED:
- return ir.reconstruct(children)
+ return _default_lower_ir_node(ir, rec)
if len(agg.children) > 1:
- return ir.reconstruct(children)
+ return _default_lower_ir_node(ir, rec)
if agg.name == "sum":
# Partwise
@@ -106,7 +109,7 @@ def lower_groupby_node(ir: GroupBy, rec) -> IR:
agg_requests_tree.append(NamedExpr(tmp_name, agg_tree))
else:
# Unsupported
- return ir.reconstruct(children)
+ return _default_lower_ir_node(ir, rec)
gb_pwise = GroupByPart(
ir.schema,
@@ -146,27 +149,21 @@ def lower_groupby_node(ir: GroupBy, rec) -> IR:
)
)
should_broadcast: bool = False
- return GroupByFinalize(
+ new_node = GroupByFinalize(
schema,
output_exprs,
should_broadcast,
gb_tree,
)
-
-
-@_ir_parts_info.register(GroupByPart)
-def _(ir: GroupByPart) -> PartitionInfo:
- return _partitionwise_ir_parts_info(ir)
+ partition_info[new_node] = PartitionInfo(count=1)
+ return new_node, partition_info
@generate_ir_tasks.register(GroupByPart)
-def _(ir: GroupByPart) -> MutableMapping[Any, Any]:
- return _partitionwise_ir_tasks(ir)
-
-
-@_ir_parts_info.register(GroupByTree)
-def _(ir: GroupByTree) -> PartitionInfo:
- return PartitionInfo(count=1)
+def _(
+ ir: GroupByPart, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
+ return _partitionwise_ir_tasks(ir, partition_info)
def _tree_node(do_evaluate, batch, *args):
@@ -174,9 +171,11 @@ def _tree_node(do_evaluate, batch, *args):
@generate_ir_tasks.register(GroupByTree)
-def _(ir: GroupByTree) -> MutableMapping[Any, Any]:
+def _(
+ ir: GroupByTree, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
child = ir.children[0]
- child_count = ir_parts_info(child).count
+ child_count = partition_info[child].count
child_name = get_key_name(child)
name = get_key_name(ir)
@@ -207,12 +206,9 @@ def _(ir: GroupByTree) -> MutableMapping[Any, Any]:
return graph
-@_ir_parts_info.register(GroupByFinalize)
-def _(ir: GroupByFinalize) -> PartitionInfo:
- return _partitionwise_ir_parts_info(ir)
-
-
@generate_ir_tasks.register(GroupByFinalize)
-def _(ir: GroupByFinalize) -> MutableMapping[Any, Any]:
+def _(
+ ir: GroupByFinalize, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
# TODO: Fuse with GroupByTree child task?
- return _partitionwise_ir_tasks(ir)
+ return _partitionwise_ir_tasks(ir, partition_info)
diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py
index 0f9f3919e17..c6707af9658 100644
--- a/python/cudf_polars/cudf_polars/experimental/io.py
+++ b/python/cudf_polars/cudf_polars/experimental/io.py
@@ -13,7 +13,7 @@
from cudf_polars.dsl.ir import Scan
from cudf_polars.experimental.parallel import (
PartitionInfo,
- _ir_parts_info,
+ _default_lower_ir_node,
generate_ir_tasks,
get_key_name,
)
@@ -22,6 +22,7 @@
from collections.abc import MutableMapping
from cudf_polars.dsl.ir import IR
+ from cudf_polars.experimental.parallel import LowerIRTransformer
class ParFileScan(Scan):
@@ -83,14 +84,13 @@ def _plan(self) -> tuple[int, int]:
return (split, stride)
-def lower_scan_node(ir: Scan, rec) -> IR:
+def lower_scan_node(
+ ir: Scan, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
"""Rewrite a Scan node with proper partitioning."""
- if (
- # len(ir.paths) > 1 and
- ir.typ in ("csv", "parquet", "ndjson") and ir.n_rows == -1 and ir.skip_rows == 0
- ):
+ if ir.typ in ("csv", "parquet", "ndjson") and ir.n_rows == -1 and ir.skip_rows == 0:
# TODO: mypy complains: ParFileScan(*ir._ctor_arguments([]))
- return ParFileScan(
+ new_node = ParFileScan(
ir.schema,
ir.typ,
ir.reader_options,
@@ -103,17 +103,14 @@ def lower_scan_node(ir: Scan, rec) -> IR:
ir.row_index,
ir.predicate,
)
- return ir
+ split, stride = new_node._plan
+ if split > 1:
+ count = len(new_node.paths) * split
+ else:
+ count = math.ceil(len(new_node.paths) / stride)
+ return new_node, {new_node: PartitionInfo(count=count)}
-
-@_ir_parts_info.register(ParFileScan)
-def _(ir: ParFileScan) -> PartitionInfo:
- split, stride = ir._plan
- if split > 1:
- count = len(ir.paths) * split
- else:
- count = math.ceil(len(ir.paths) / stride)
- return PartitionInfo(count=count)
+ return _default_lower_ir_node(ir, rec)
def _split_read(
@@ -173,7 +170,9 @@ def _split_read(
@generate_ir_tasks.register(ParFileScan)
-def _(ir: ParFileScan) -> MutableMapping[Any, Any]:
+def _(
+ ir: ParFileScan, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
key_name = get_key_name(ir)
split, stride = ir._plan
paths = list(ir.paths)
diff --git a/python/cudf_polars/cudf_polars/experimental/join.py b/python/cudf_polars/cudf_polars/experimental/join.py
index 10e876093cb..63d87effd8c 100644
--- a/python/cudf_polars/cudf_polars/experimental/join.py
+++ b/python/cudf_polars/cudf_polars/experimental/join.py
@@ -9,17 +9,17 @@
from cudf_polars.dsl.ir import Join
from cudf_polars.experimental.parallel import (
_concat,
- _ir_parts_info,
+ _default_lower_ir_node,
+ _lower_children,
generate_ir_tasks,
get_key_name,
- ir_parts_info,
)
if TYPE_CHECKING:
from collections.abc import MutableMapping
from cudf_polars.dsl.ir import IR
- from cudf_polars.experimental.parallel import PartitionInfo
+ from cudf_polars.experimental.parallel import LowerIRTransformer, PartitionInfo
class BroadcastJoin(Join):
@@ -34,61 +34,64 @@ class RightBroadcastJoin(BroadcastJoin):
"""Right Broadcast Join operation."""
-def lower_join_node(ir: Join, rec) -> IR:
+def lower_join_node(
+ ir: Join, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
"""Rewrite a Join node with proper partitioning."""
# TODO: Add shuffle-based join.
# (Currently using broadcast join in all cases)
+ # Lower children first
+ children, partition_info = _lower_children(ir, rec)
+
how = ir.options[0]
if how not in ("inner", "left", "right"):
# Not supported (yet)
- return ir
- children = [rec(child) for child in ir.children]
+ return _default_lower_ir_node(ir, rec)
+
+ assert len(children) == 2
left, right = children
- left_parts = ir_parts_info(left)
- right_parts = ir_parts_info(right)
+ left_parts = partition_info[left]
+ right_parts = partition_info[right]
if left_parts.count == right_parts.count == 1:
# Single-partition case
- return ir
+ return _default_lower_ir_node(ir, rec)
elif left_parts.count >= right_parts.count and how in ("inner", "left"):
# Broadcast right to every partition of left
- return RightBroadcastJoin(
+ new_node = RightBroadcastJoin(
ir.schema,
ir.left_on,
ir.right_on,
ir.options,
*children,
)
+ partition_info[new_node] = partition_info[left]
else:
# Broadcast left to every partition of right
- return LeftBroadcastJoin(
+ new_node = LeftBroadcastJoin(
ir.schema,
ir.left_on,
ir.right_on,
ir.options,
*children,
)
-
-
-@_ir_parts_info.register(LeftBroadcastJoin)
-def _(ir: LeftBroadcastJoin) -> PartitionInfo:
- return ir_parts_info(ir.children[1])
-
-
-@_ir_parts_info.register(RightBroadcastJoin)
-def _(ir: RightBroadcastJoin) -> PartitionInfo:
- return ir_parts_info(ir.children[0])
+ partition_info[new_node] = partition_info[right]
+ return new_node, partition_info
@generate_ir_tasks.register(BroadcastJoin)
-def _(ir: BroadcastJoin) -> MutableMapping[Any, Any]:
+def _(
+ ir: BroadcastJoin, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
left, right = ir.children
bcast_side = "right" if isinstance(ir, RightBroadcastJoin) else "left"
left_name = get_key_name(left)
right_name = get_key_name(right)
key_name = get_key_name(ir)
- parts = ir_parts_info(ir)
- bcast_parts = ir_parts_info(right) if bcast_side == "right" else ir_parts_info(left)
+ parts = partition_info[ir]
+ bcast_parts = (
+ partition_info[right] if bcast_side == "right" else partition_info[left]
+ )
graph: MutableMapping[Any, Any] = {}
for i in range(parts.count):
diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py
index bd8da2a3b58..74309aba251 100644
--- a/python/cudf_polars/cudf_polars/experimental/parallel.py
+++ b/python/cudf_polars/cudf_polars/experimental/parallel.py
@@ -4,11 +4,12 @@
from __future__ import annotations
-from functools import singledispatch
+import operator
+from functools import reduce, singledispatch
from typing import TYPE_CHECKING, Any
-from cudf_polars.dsl.expr import NamedExpr
from cudf_polars.dsl.ir import (
+ IR,
Filter,
GroupBy,
HStack,
@@ -18,14 +19,15 @@
Select,
Union,
)
-from cudf_polars.dsl.traversal import reuse_if_unchanged, traversal
+from cudf_polars.dsl.traversal import traversal
if TYPE_CHECKING:
from collections.abc import MutableMapping, Sequence
+ from typing import TypeAlias
from cudf_polars.containers import DataFrame
- from cudf_polars.dsl.ir import IR
from cudf_polars.dsl.nodebase import Node
+ from cudf_polars.typing import GenericTransformer
class PartitionInfo:
@@ -41,76 +43,99 @@ def __init__(self, count: int):
self.count = count
-# The hash of an IR object must always map to a
-# unique PartitionInfo object, and we can cache
-# this mapping until evaluation is complete.
-_IR_PARTS_CACHE: MutableMapping[int, PartitionInfo] = {}
-
-
-def _clear_parts_info_cache() -> None:
- """Clear cached partitioning information."""
- _IR_PARTS_CACHE.clear()
+LowerIRTransformer: TypeAlias = (
+ "GenericTransformer[IR, MutableMapping[IR, PartitionInfo]]"
+)
+"""Protocol for Lowering IR nodes."""
-def get_key_name(node: Node | NamedExpr) -> str:
+def get_key_name(node: Node) -> str:
"""Generate the key name for a Node."""
- if isinstance(node, NamedExpr):
- return f"named-{get_key_name(node.value)}" # pragma: no cover
return f"{type(node).__name__.lower()}-{hash(node)}"
@singledispatch
-def lower_ir_node(ir: IR, rec) -> IR:
+def lower_ir_node(
+ ir: IR, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
"""Rewrite an IR node with proper partitioning."""
- # Return same node by default
- return reuse_if_unchanged(ir, rec)
+ raise AssertionError(f"Unhandled type {type(ir)}")
-def lower_ir_graph(ir: IR) -> IR:
- """Rewrite an IR graph with proper partitioning."""
- from cudf_polars.dsl.traversal import CachingVisitor
+def _lower_children(
+ ir: IR, rec: LowerIRTransformer
+) -> tuple[tuple[IR], MutableMapping[IR, PartitionInfo]]:
+ children, _partition_info = zip(*(rec(c) for c in ir.children), strict=False)
+ partition_info: MutableMapping[IR, PartitionInfo] = reduce(
+ operator.or_, _partition_info
+ )
+ return children, partition_info
- mapper = CachingVisitor(lower_ir_node)
- return mapper(ir)
+@lower_ir_node.register(IR)
+def _default_lower_ir_node(
+ ir: IR, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
+ if len(ir.children) == 0:
+ # Default leaf node has single partition
+ return ir, {ir: PartitionInfo(count=1)}
-def _default_ir_parts_info(ir: IR) -> PartitionInfo:
- # Single-partition default behavior.
- # This is used by `_ir_parts_info` for all unregistered IR sub-types.
- count = max((ir_parts_info(child).count for child in ir.children), default=1)
+ # Lower children
+ children, partition_info = _lower_children(ir, rec)
+
+ # Check that child partitioning is supported
+ count = max(partition_info[c].count for c in children)
if count > 1:
raise NotImplementedError(
f"Class {type(ir)} does not support multiple partitions."
) # pragma: no cover
- return PartitionInfo(count=count)
+ # Return reconstructed node and
+ partition = PartitionInfo(count=1)
+ new_node = ir.reconstruct(children)
+ partition_info[new_node] = partition
+ return new_node, partition_info
-def _partitionwise_ir_parts_info(ir: IR) -> PartitionInfo:
- # Simple partitionwise behavior.
- count = max((ir_parts_info(child).count for child in ir.children), default=1)
- return PartitionInfo(count=count)
+def _lower_ir_node_partitionwise(
+ ir: IR, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
+ # Simple partitionwise behavior
+ children, partition_info = _lower_children(ir, rec)
+ partition = PartitionInfo(count=max(partition_info[c].count for c in children))
+ new_node = ir.reconstruct(children)
+ partition_info[new_node] = partition
+ return new_node, partition_info
-@singledispatch
-def _ir_parts_info(ir: IR) -> PartitionInfo:
- """IR partitioning-info dispatch."""
- return _default_ir_parts_info(ir)
+def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
+ """Rewrite an IR graph with proper partitioning."""
+ from cudf_polars.dsl.traversal import CachingVisitor
-def ir_parts_info(ir: IR) -> PartitionInfo:
- """Return the partitioning info for an IR node."""
- key = hash(ir)
- try:
- return _IR_PARTS_CACHE[key]
- except KeyError:
- _IR_PARTS_CACHE[key] = _ir_parts_info(ir)
- return _IR_PARTS_CACHE[key]
+ mapper = CachingVisitor(lower_ir_node)
+ return mapper(ir)
-def _default_ir_tasks(ir: IR) -> MutableMapping[Any, Any]:
+@singledispatch
+def generate_ir_tasks(
+ ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
+ """
+ Generate tasks for an IR node.
+
+ An IR node only needs to generate the graph for
+ the current IR logic (not including child IRs).
+ """
+ raise AssertionError(f"Unhandled type {type(ir)}")
+
+
+@generate_ir_tasks.register(IR)
+def _default_ir_tasks(
+ ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
# Single-partition default behavior.
# This is used by `generate_ir_tasks` for all unregistered IR sub-types.
- if ir_parts_info(ir).count > 1:
+ if partition_info[ir].count > 1:
raise NotImplementedError(
f"Failed to generate multiple output tasks for {ir}."
) # pragma: no cover
@@ -118,7 +143,7 @@ def _default_ir_tasks(ir: IR) -> MutableMapping[Any, Any]:
child_names = []
for child in ir.children:
child_names.append(get_key_name(child))
- if ir_parts_info(child).count > 1:
+ if partition_info[child].count > 1:
raise NotImplementedError(
f"Failed to generate tasks for {ir} with child {child}."
) # pragma: no cover
@@ -133,13 +158,16 @@ def _default_ir_tasks(ir: IR) -> MutableMapping[Any, Any]:
}
-def _partitionwise_ir_tasks(ir: IR) -> MutableMapping[Any, Any]:
+def _partitionwise_ir_tasks(
+ ir: IR,
+ partition_info: MutableMapping[IR, PartitionInfo],
+) -> MutableMapping[Any, Any]:
# Simple partitionwise behavior.
child_names = []
counts = []
for child in ir.children:
child_names.append(get_key_name(child))
- counts.append(ir_parts_info(child).count)
+ counts.append(partition_info[child].count)
counts = counts or [1]
if len(set(counts)) > 1:
raise NotImplementedError(
@@ -157,34 +185,22 @@ def _partitionwise_ir_tasks(ir: IR) -> MutableMapping[Any, Any]:
}
-@singledispatch
-def generate_ir_tasks(ir: IR) -> MutableMapping[Any, Any]:
- """
- Generate tasks for an IR node.
-
- An IR node only needs to generate the graph for
- the current IR logic (not including child IRs).
- """
- return _default_ir_tasks(ir)
-
-
-def task_graph(_ir: IR) -> tuple[MutableMapping[str, Any], str]:
+def task_graph(
+ ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
+) -> tuple[MutableMapping[str, Any], str]:
"""Construct a Dask-compatible task graph."""
- ir: IR = lower_ir_graph(_ir)
+ graph = reduce(
+ operator.or_,
+ [generate_ir_tasks(node, partition_info) for node in traversal(ir)],
+ )
- graph = {
- k: v
- for layer in [generate_ir_tasks(n) for n in traversal(ir)]
- for k, v in layer.items()
- }
key_name = get_key_name(ir)
- partition_count = ir_parts_info(ir).count
+ partition_count = partition_info[ir].count
if partition_count:
graph[key_name] = (_concat, [(key_name, i) for i in range(partition_count)])
else:
graph[key_name] = (key_name, 0)
- _clear_parts_info_cache()
return graph, key_name
@@ -192,7 +208,9 @@ def evaluate_dask(ir: IR) -> DataFrame:
"""Evaluate an IR graph with Dask."""
from dask import get
- graph, key = task_graph(ir)
+ ir, partition_info = lower_ir_graph(ir)
+
+ graph, key = task_graph(ir, partition_info)
return get(graph, key)
@@ -207,7 +225,9 @@ def _concat(dfs: Sequence[DataFrame]) -> DataFrame:
@lower_ir_node.register(Scan)
-def _(ir: Scan, rec) -> IR:
+def _(
+ ir: Scan, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
import cudf_polars.experimental.io as _io
return _io.lower_scan_node(ir, rec)
@@ -219,7 +239,9 @@ def _(ir: Scan, rec) -> IR:
@lower_ir_node.register(Select)
-def _(ir: Select, rec) -> IR:
+def _(
+ ir: Select, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
import cudf_polars.experimental.select as _select
return _select.lower_select_node(ir, rec)
@@ -230,14 +252,18 @@ def _(ir: Select, rec) -> IR:
##
-@_ir_parts_info.register(HStack)
-def _(ir: HStack) -> PartitionInfo:
- return _partitionwise_ir_parts_info(ir)
+@lower_ir_node.register(HStack)
+def _(
+ ir: HStack, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
+ return _lower_ir_node_partitionwise(ir, rec)
@generate_ir_tasks.register(HStack)
-def _(ir: HStack) -> MutableMapping[Any, Any]:
- return _partitionwise_ir_tasks(ir)
+def _(
+ ir: HStack, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
+ return _partitionwise_ir_tasks(ir, partition_info)
##
@@ -248,14 +274,18 @@ def _(ir: HStack) -> MutableMapping[Any, Any]:
## TODO: Can filter expressions include aggregations?
-@_ir_parts_info.register(Filter)
-def _(ir: Filter) -> PartitionInfo:
- return _partitionwise_ir_parts_info(ir)
+@lower_ir_node.register(Filter)
+def _(
+ ir: Filter, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
+ return _lower_ir_node_partitionwise(ir, rec)
@generate_ir_tasks.register(Filter)
-def _(ir: Filter) -> MutableMapping[Any, Any]:
- return _partitionwise_ir_tasks(ir)
+def _(
+ ir: Filter, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
+ return _partitionwise_ir_tasks(ir, partition_info)
##
@@ -263,14 +293,18 @@ def _(ir: Filter) -> MutableMapping[Any, Any]:
##
-@_ir_parts_info.register(Projection)
-def _(ir: Projection) -> PartitionInfo:
- return _partitionwise_ir_parts_info(ir)
+@lower_ir_node.register(Projection)
+def _(
+ ir: Projection, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
+ return _lower_ir_node_partitionwise(ir, rec)
@generate_ir_tasks.register(Projection)
-def _(ir: Projection) -> MutableMapping[Any, Any]:
- return _partitionwise_ir_tasks(ir)
+def _(
+ ir: Projection, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
+ return _partitionwise_ir_tasks(ir, partition_info)
##
@@ -279,7 +313,9 @@ def _(ir: Projection) -> MutableMapping[Any, Any]:
@lower_ir_node.register(GroupBy)
-def _(ir: GroupBy, rec) -> IR:
+def _(
+ ir: GroupBy, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
import cudf_polars.experimental.groupby as _groupby
return _groupby.lower_groupby_node(ir, rec)
@@ -291,7 +327,9 @@ def _(ir: GroupBy, rec) -> IR:
@lower_ir_node.register(Join)
-def _(ir: Join, rec) -> IR:
+def _(
+ ir: Join, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
import cudf_polars.experimental.join as _join
return _join.lower_join_node(ir, rec)
diff --git a/python/cudf_polars/cudf_polars/experimental/select.py b/python/cudf_polars/cudf_polars/experimental/select.py
index 57c0aa3e421..1d5cebc8bed 100644
--- a/python/cudf_polars/cudf_polars/experimental/select.py
+++ b/python/cudf_polars/cudf_polars/experimental/select.py
@@ -8,8 +8,9 @@
from cudf_polars.dsl.ir import Select
from cudf_polars.experimental.parallel import (
- _ir_parts_info,
- _partitionwise_ir_parts_info,
+ PartitionInfo,
+ _default_lower_ir_node,
+ _lower_children,
_partitionwise_ir_tasks,
generate_ir_tasks,
)
@@ -18,7 +19,7 @@
from collections.abc import MutableMapping
from cudf_polars.dsl.ir import IR
- from cudf_polars.experimental.parallel import PartitionInfo
+ from cudf_polars.experimental.parallel import LowerIRTransformer
_PARTWISE = (
@@ -41,33 +42,36 @@ class PartwiseSelect(Select):
"""Partitionwise Select operation."""
-def lower_select_node(ir: Select, rec) -> IR:
+def lower_select_node(
+ ir: Select, rec: LowerIRTransformer
+) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
"""Rewrite a GroupBy node with proper partitioning."""
from cudf_polars.dsl.traversal import traversal
# Lower children first
- children = [rec(child) for child in ir.children]
+ children, partition_info = _lower_children(ir, rec)
# Search the expressions for "complex" operations
for ne in ir.exprs:
for expr in traversal(ne.value):
if type(expr).__name__ not in _PARTWISE:
- return ir.reconstruct(children)
+ return _default_lower_ir_node(ir, rec)
- # Remailing Select ops are partition-wise
- return PartwiseSelect(
+ # Remaining Select ops are partition-wise
+ new_node = PartwiseSelect(
ir.schema,
ir.exprs,
ir.should_broadcast,
*children,
)
-
-
-@_ir_parts_info.register(PartwiseSelect)
-def _(ir: PartwiseSelect) -> PartitionInfo:
- return _partitionwise_ir_parts_info(ir)
+ partition_info[new_node] = PartitionInfo(
+ count=max(partition_info[c].count for c in children)
+ )
+ return new_node, partition_info
@generate_ir_tasks.register(PartwiseSelect)
-def _(ir: PartwiseSelect) -> MutableMapping[Any, Any]:
- return _partitionwise_ir_tasks(ir)
+def _(
+ ir: PartwiseSelect, partition_info: MutableMapping[IR, PartitionInfo]
+) -> MutableMapping[Any, Any]:
+ return _partitionwise_ir_tasks(ir, partition_info)
diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py
index bf8fae552c2..bbedd046760 100644
--- a/python/dask_cudf/dask_cudf/io/parquet.py
+++ b/python/dask_cudf/dask_cudf/io/parquet.py
@@ -55,7 +55,7 @@ def _get_device_size():
handle = pynvml.nvmlDeviceGetHandleByIndex(int(index))
return pynvml.nvmlDeviceGetMemoryInfo(handle).total
- except (ImportError, ValueError):
+ except ValueError:
# Fall back to a conservative 8GiB default
return 8 * 1024**3
diff --git a/python/dask_cudf/pyproject.toml b/python/dask_cudf/pyproject.toml
index 9364cc7647f..33ba8fe083f 100644
--- a/python/dask_cudf/pyproject.toml
+++ b/python/dask_cudf/pyproject.toml
@@ -24,6 +24,7 @@ dependencies = [
"fsspec>=0.6.0",
"numpy>=1.23,<3.0a0",
"pandas>=2.0,<2.2.4dev0",
+ "pynvml>=11.4.1,<12.0.0a0",
"rapids-dask-dependency==25.2.*,>=0.0.0a0",
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
classifiers = [
diff --git a/python/pylibcudf/pylibcudf/io/csv.pxd b/python/pylibcudf/pylibcudf/io/csv.pxd
new file mode 100644
index 00000000000..f04edaa316a
--- /dev/null
+++ b/python/pylibcudf/pylibcudf/io/csv.pxd
@@ -0,0 +1,35 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.
+
+from libcpp.vector cimport vector
+from libcpp.string cimport string
+from libcpp cimport bool
+from pylibcudf.libcudf.io.csv cimport (
+ csv_writer_options,
+ csv_writer_options_builder,
+)
+from pylibcudf.libcudf.io.types cimport quote_style
+from pylibcudf.io.types cimport SinkInfo
+from pylibcudf.table cimport Table
+
+cdef class CsvWriterOptions:
+ cdef csv_writer_options c_obj
+ cdef Table table
+ cdef SinkInfo sink
+
+
+cdef class CsvWriterOptionsBuilder:
+ cdef csv_writer_options_builder c_obj
+ cdef Table table
+ cdef SinkInfo sink
+ cpdef CsvWriterOptionsBuilder names(self, list names)
+ cpdef CsvWriterOptionsBuilder na_rep(self, str val)
+ cpdef CsvWriterOptionsBuilder include_header(self, bool val)
+ cpdef CsvWriterOptionsBuilder rows_per_chunk(self, int val)
+ cpdef CsvWriterOptionsBuilder line_terminator(self, str term)
+ cpdef CsvWriterOptionsBuilder inter_column_delimiter(self, str delim)
+ cpdef CsvWriterOptionsBuilder true_value(self, str val)
+ cpdef CsvWriterOptionsBuilder false_value(self, str val)
+ cpdef CsvWriterOptions build(self)
+
+
+cpdef void write_csv(CsvWriterOptions options)
diff --git a/python/pylibcudf/pylibcudf/io/csv.pyi b/python/pylibcudf/pylibcudf/io/csv.pyi
index 356825a927d..583b66bc29c 100644
--- a/python/pylibcudf/pylibcudf/io/csv.pyi
+++ b/python/pylibcudf/pylibcudf/io/csv.pyi
@@ -5,9 +5,11 @@ from collections.abc import Mapping
from pylibcudf.io.types import (
CompressionType,
QuoteStyle,
+ SinkInfo,
SourceInfo,
TableWithMetadata,
)
+from pylibcudf.table import Table
from pylibcudf.types import DataType
def read_csv(
@@ -52,3 +54,23 @@ def read_csv(
# detect_whitespace_around_quotes: bool = False,
# timestamp_type: DataType = DataType(type_id.EMPTY),
) -> TableWithMetadata: ...
+def write_csv(options: CsvWriterOptionsBuilder) -> None: ...
+
+class CsvWriterOptions:
+ def __init__(self): ...
+ @staticmethod
+ def builder(sink: SinkInfo, table: Table) -> CsvWriterOptionsBuilder: ...
+
+class CsvWriterOptionsBuilder:
+ def __init__(self): ...
+ def names(self, names: list) -> CsvWriterOptionsBuilder: ...
+ def na_rep(self, val: str) -> CsvWriterOptionsBuilder: ...
+ def include_header(self, val: bool) -> CsvWriterOptionsBuilder: ...
+ def rows_per_chunk(self, val: int) -> CsvWriterOptionsBuilder: ...
+ def line_terminator(self, term: str) -> CsvWriterOptionsBuilder: ...
+ def inter_column_delimiter(
+ self, delim: str
+ ) -> CsvWriterOptionsBuilder: ...
+ def true_value(self, val: str) -> CsvWriterOptionsBuilder: ...
+ def false_value(self, val: str) -> CsvWriterOptionsBuilder: ...
+ def build(self) -> CsvWriterOptions: ...
diff --git a/python/pylibcudf/pylibcudf/io/csv.pyx b/python/pylibcudf/pylibcudf/io/csv.pyx
index 858e580ab34..8be391de2c2 100644
--- a/python/pylibcudf/pylibcudf/io/csv.pyx
+++ b/python/pylibcudf/pylibcudf/io/csv.pyx
@@ -2,14 +2,18 @@
from libcpp cimport bool
from libcpp.map cimport map
+
from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector
-from pylibcudf.io.types cimport SourceInfo, TableWithMetadata
+from pylibcudf.io.types cimport SourceInfo, SinkInfo, TableWithMetadata
from pylibcudf.libcudf.io.csv cimport (
csv_reader_options,
+ csv_writer_options,
read_csv as cpp_read_csv,
+ write_csv as cpp_write_csv,
)
+
from pylibcudf.libcudf.io.types cimport (
compression_type,
quote_style,
@@ -17,9 +21,14 @@ from pylibcudf.libcudf.io.types cimport (
)
from pylibcudf.libcudf.types cimport data_type, size_type
from pylibcudf.types cimport DataType
+from pylibcudf.table cimport Table
-
-__all__ = ["read_csv"]
+__all__ = [
+ "read_csv",
+ "write_csv",
+ "CsvWriterOptions",
+ "CsvWriterOptionsBuilder",
+]
cdef tuple _process_parse_dates_hex(list cols):
cdef vector[string] str_cols
@@ -82,6 +91,8 @@ def read_csv(
):
"""Reads a CSV file into a :py:class:`~.types.TableWithMetadata`.
+ For details, see :cpp:func:`read_csv`.
+
Parameters
----------
source_info : SourceInfo
@@ -263,3 +274,202 @@ def read_csv(
c_result = move(cpp_read_csv(options))
return TableWithMetadata.from_libcudf(c_result)
+
+
+# TODO: Implement the remaining methods
+cdef class CsvWriterOptions:
+ """The settings to use for ``write_csv``
+
+ For details, see :cpp:class:`cudf::io::csv_writer_options`
+ """
+ @staticmethod
+ def builder(SinkInfo sink, Table table):
+ """Create a CsvWriterOptionsBuilder object
+
+ For details, see :cpp:func:`cudf::io::csv_writer_options::builder`
+
+ Parameters
+ ----------
+ sink : SinkInfo
+ The sink used for writer output
+ table : Table
+ Table to be written to output
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ cdef CsvWriterOptionsBuilder csv_builder = CsvWriterOptionsBuilder.__new__(
+ CsvWriterOptionsBuilder
+ )
+ csv_builder.c_obj = csv_writer_options.builder(sink.c_obj, table.view())
+ csv_builder.table = table
+ csv_builder.sink = sink
+ return csv_builder
+
+
+# TODO: Implement the remaining methods
+cdef class CsvWriterOptionsBuilder:
+ """Builder to build options for ``write_csv``
+
+ For details, see :cpp:class:`cudf::io::csv_writer_options_builder`
+ """
+ cpdef CsvWriterOptionsBuilder names(self, list names):
+ """Sets optional column names.
+
+ Parameters
+ ----------
+ names : list[str]
+ Column names
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.names([name.encode() for name in names])
+ return self
+
+ cpdef CsvWriterOptionsBuilder na_rep(self, str val):
+ """Sets string to used for null entries.
+
+ Parameters
+ ----------
+ val : str
+ String to represent null value
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.na_rep(val.encode())
+ return self
+
+ cpdef CsvWriterOptionsBuilder include_header(self, bool val):
+ """Enables/Disables headers being written to csv.
+
+ Parameters
+ ----------
+ val : bool
+ Boolean value to enable/disable
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.include_header(val)
+ return self
+
+ cpdef CsvWriterOptionsBuilder rows_per_chunk(self, int val):
+ """Sets maximum number of rows to process for each file write.
+
+ Parameters
+ ----------
+ val : int
+ Number of rows per chunk
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.rows_per_chunk(val)
+ return self
+
+ cpdef CsvWriterOptionsBuilder line_terminator(self, str term):
+ """Sets character used for separating lines.
+
+ Parameters
+ ----------
+ term : str
+ Character to represent line termination
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.line_terminator(term.encode())
+ return self
+
+ cpdef CsvWriterOptionsBuilder inter_column_delimiter(self, str delim):
+ """Sets character used for separating column values.
+
+ Parameters
+ ----------
+ delim : str
+ Character to delimit column values
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.inter_column_delimiter(ord(delim))
+ return self
+
+ cpdef CsvWriterOptionsBuilder true_value(self, str val):
+ """Sets string used for values != 0
+
+ Parameters
+ ----------
+ val : str
+ String to represent values != 0
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.true_value(val.encode())
+ return self
+
+ cpdef CsvWriterOptionsBuilder false_value(self, str val):
+ """Sets string used for values == 0
+
+ Parameters
+ ----------
+ val : str
+ String to represent values == 0
+
+ Returns
+ -------
+ CsvWriterOptionsBuilder
+ Builder to build CsvWriterOptions
+ """
+ self.c_obj.false_value(val.encode())
+ return self
+
+ cpdef CsvWriterOptions build(self):
+ """Create a CsvWriterOptions object"""
+ cdef CsvWriterOptions csv_options = CsvWriterOptions.__new__(
+ CsvWriterOptions
+ )
+ csv_options.c_obj = move(self.c_obj.build())
+ csv_options.table = self.table
+ csv_options.sink = self.sink
+ return csv_options
+
+
+cpdef void write_csv(
+ CsvWriterOptions options
+):
+ """
+ Write to CSV format.
+
+ The table to write, output paths, and options are encapsulated
+ by the `options` object.
+
+ For details, see :cpp:func:`write_csv`.
+
+ Parameters
+ ----------
+ options: CsvWriterOptions
+ Settings for controlling writing behavior
+ """
+
+ with nogil:
+ cpp_write_csv(move(options.c_obj))
diff --git a/python/pylibcudf/pylibcudf/io/types.pyx b/python/pylibcudf/pylibcudf/io/types.pyx
index 7a3f16c4c50..51d5bda75c7 100644
--- a/python/pylibcudf/pylibcudf/io/types.pyx
+++ b/python/pylibcudf/pylibcudf/io/types.pyx
@@ -261,18 +261,24 @@ cdef cppclass iobase_data_sink(data_sink):
cdef class SinkInfo:
- """A class containing details on a source to read from.
+ """
+ A class containing details about destinations (sinks) to write data to.
- For details, see :cpp:class:`cudf::io::sink_info`.
+ For more details, see :cpp:class:`cudf::io::sink_info`.
Parameters
----------
- sinks : list of str, PathLike, BytesIO, StringIO
+ sinks : list of str, PathLike, or io.IOBase instances
+ A list of sinks to write data to. Each sink can be:
- A homogeneous list of sinks (this can be a string filename,
- bytes, or one of the Python I/O classes) to read from.
+ - A string representing a filename.
+ - A PathLike object.
+ - An instance of a Python I/O class that is a subclass of io.IOBase
+ (eg., io.BytesIO, io.StringIO).
- Mixing different types of sinks will raise a `ValueError`.
+ The list must be homogeneous in type unless all sinks are instances
+ of subclasses of io.IOBase. Mixing different types of sinks
+ (that are not all io.IOBase instances) will raise a ValueError.
"""
def __init__(self, list sinks):
@@ -280,32 +286,42 @@ cdef class SinkInfo:
cdef vector[string] paths
if not sinks:
- raise ValueError("Need to pass at least one sink")
+ raise ValueError("At least one sink must be provided.")
if isinstance(sinks[0], os.PathLike):
sinks = [os.path.expanduser(s) for s in sinks]
cdef object initial_sink_cls = type(sinks[0])
- if not all(isinstance(s, initial_sink_cls) for s in sinks):
- raise ValueError("All sinks must be of the same type!")
+ if not all(
+ isinstance(s, initial_sink_cls) or (
+ isinstance(sinks[0], io.IOBase) and isinstance(s, io.IOBase)
+ ) for s in sinks
+ ):
+ raise ValueError(
+ "All sinks must be of the same type unless they are all instances "
+ "of subclasses of io.IOBase."
+ )
- if initial_sink_cls in {io.StringIO, io.BytesIO, io.TextIOBase}:
+ if isinstance(sinks[0], io.IOBase):
data_sinks.reserve(len(sinks))
- if isinstance(sinks[0], (io.StringIO, io.BytesIO)):
- for s in sinks:
+ for s in sinks:
+ if isinstance(s, (io.StringIO, io.BytesIO)):
self.sink_storage.push_back(
unique_ptr[data_sink](new iobase_data_sink(s))
)
- elif isinstance(sinks[0], io.TextIOBase):
- for s in sinks:
- if codecs.lookup(s).name not in ('utf-8', 'ascii'):
+ elif isinstance(s, io.TextIOBase):
+ if codecs.lookup(s.encoding).name not in ('utf-8', 'ascii'):
raise NotImplementedError(f"Unsupported encoding {s.encoding}")
self.sink_storage.push_back(
unique_ptr[data_sink](new iobase_data_sink(s.buffer))
)
- data_sinks.push_back(self.sink_storage.back().get())
- elif initial_sink_cls is str:
+ else:
+ self.sink_storage.push_back(
+ unique_ptr[data_sink](new iobase_data_sink(s))
+ )
+ data_sinks.push_back(self.sink_storage.back().get())
+ elif isinstance(sinks[0], str):
paths.reserve(len(sinks))
for s in sinks:
paths.push_back( s.encode())
diff --git a/python/pylibcudf/pylibcudf/tests/common/utils.py b/python/pylibcudf/pylibcudf/tests/common/utils.py
index d95849ef371..58c94713d09 100644
--- a/python/pylibcudf/pylibcudf/tests/common/utils.py
+++ b/python/pylibcudf/pylibcudf/tests/common/utils.py
@@ -385,12 +385,10 @@ def make_source(path_or_buf, pa_table, format, **kwargs):
NESTED_STRUCT_TESTING_TYPE,
]
+NON_NESTED_PA_TYPES = NUMERIC_PA_TYPES + STRING_PA_TYPES + BOOL_PA_TYPES
+
DEFAULT_PA_TYPES = (
- NUMERIC_PA_TYPES
- + STRING_PA_TYPES
- + BOOL_PA_TYPES
- + LIST_PA_TYPES
- + DEFAULT_PA_STRUCT_TESTING_TYPES
+ NON_NESTED_PA_TYPES + LIST_PA_TYPES + DEFAULT_PA_STRUCT_TESTING_TYPES
)
# Map pylibcudf compression types to pandas ones
diff --git a/python/pylibcudf/pylibcudf/tests/conftest.py b/python/pylibcudf/pylibcudf/tests/conftest.py
index 5265e411c7f..36ab6798d8a 100644
--- a/python/pylibcudf/pylibcudf/tests/conftest.py
+++ b/python/pylibcudf/pylibcudf/tests/conftest.py
@@ -15,7 +15,12 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "common"))
-from utils import ALL_PA_TYPES, DEFAULT_PA_TYPES, NUMERIC_PA_TYPES
+from utils import (
+ ALL_PA_TYPES,
+ DEFAULT_PA_TYPES,
+ NON_NESTED_PA_TYPES,
+ NUMERIC_PA_TYPES,
+)
def _type_to_str(typ):
@@ -79,29 +84,13 @@ def _get_vals_of_type(pa_type, length, seed):
)
-# TODO: Consider adding another fixture/adapting this
-# fixture to consider nullability
-@pytest.fixture(scope="session", params=[0, 100])
-def table_data(request):
- """
- Returns (TableWithMetadata, pa_table).
-
- This is the default fixture you should be using for testing
- pylibcudf I/O writers.
-
- Contains one of each category (e.g. int, bool, list, struct)
- of dtypes.
- """
- nrows = request.param
-
+# TODO: Consider adapting this helper function
+# to consider nullability
+def _generate_table_data(types, nrows, seed=42):
table_dict = {}
- # Colnames in the format expected by
- # plc.io.TableWithMetadata
colnames = []
- seed = 42
-
- for typ in ALL_PA_TYPES:
+ for typ in types:
child_colnames = []
def _generate_nested_data(typ):
@@ -151,6 +140,32 @@ def _generate_nested_data(typ):
), pa_table
+@pytest.fixture(scope="session", params=[0, 100])
+def table_data(request):
+ """
+ Returns (TableWithMetadata, pa_table).
+
+ This is the default fixture you should be using for testing
+ pylibcudf I/O writers.
+
+ Contains one of each category (e.g. int, bool, list, struct)
+ of dtypes.
+ """
+ nrows = request.param
+ return _generate_table_data(ALL_PA_TYPES, nrows)
+
+
+@pytest.fixture(scope="session", params=[0, 100])
+def table_data_with_non_nested_pa_types(request):
+ """
+ Returns (TableWithMetadata, pa_table).
+
+ This fixture is for testing with non-nested PyArrow types.
+ """
+ nrows = request.param
+ return _generate_table_data(NON_NESTED_PA_TYPES, nrows)
+
+
@pytest.fixture(params=[(0, 0), ("half", 0), (-1, "half")])
def nrows_skiprows(table_data, request):
"""
diff --git a/python/pylibcudf/pylibcudf/tests/io/test_csv.py b/python/pylibcudf/pylibcudf/tests/io/test_csv.py
index 22c83acc47c..90d2d0896a5 100644
--- a/python/pylibcudf/pylibcudf/tests/io/test_csv.py
+++ b/python/pylibcudf/pylibcudf/tests/io/test_csv.py
@@ -10,6 +10,7 @@
_convert_types,
assert_table_and_meta_eq,
make_source,
+ sink_to_str,
write_source_str,
)
@@ -282,3 +283,87 @@ def test_read_csv_header(csv_table_data, source_or_sink, header):
# list true_values = None,
# list false_values = None,
# bool dayfirst = False,
+
+
+@pytest.mark.parametrize("sep", [",", "*"])
+@pytest.mark.parametrize("lineterminator", ["\n", "\n\n"])
+@pytest.mark.parametrize("header", [True, False])
+@pytest.mark.parametrize("rows_per_chunk", [8, 100])
+def test_write_csv(
+ table_data_with_non_nested_pa_types,
+ source_or_sink,
+ sep,
+ lineterminator,
+ header,
+ rows_per_chunk,
+):
+ plc_tbl_w_meta, pa_table = table_data_with_non_nested_pa_types
+ sink = source_or_sink
+
+ plc.io.csv.write_csv(
+ (
+ plc.io.csv.CsvWriterOptions.builder(
+ plc.io.SinkInfo([sink]), plc_tbl_w_meta.tbl
+ )
+ .names(plc_tbl_w_meta.column_names())
+ .na_rep("")
+ .include_header(header)
+ .rows_per_chunk(rows_per_chunk)
+ .line_terminator(lineterminator)
+ .inter_column_delimiter(sep)
+ .true_value("True")
+ .false_value("False")
+ .build()
+ )
+ )
+
+ # Convert everything to string to make comparisons easier
+ str_result = sink_to_str(sink)
+
+ pd_result = pa_table.to_pandas().to_csv(
+ sep=sep,
+ lineterminator=lineterminator,
+ header=header,
+ index=False,
+ )
+
+ assert str_result == pd_result
+
+
+@pytest.mark.parametrize("na_rep", ["", "NA"])
+def test_write_csv_na_rep(na_rep):
+ names = ["a", "b"]
+ pa_tbl = pa.Table.from_arrays(
+ [pa.array([1.0, 2.0, None]), pa.array([True, None, False])],
+ names=names,
+ )
+ plc_tbl = plc.interop.from_arrow(pa_tbl)
+ plc_tbl_w_meta = plc.io.types.TableWithMetadata(
+ plc_tbl, column_names=[(name, []) for name in names]
+ )
+
+ sink = io.StringIO()
+
+ plc.io.csv.write_csv(
+ (
+ plc.io.csv.CsvWriterOptions.builder(
+ plc.io.SinkInfo([sink]), plc_tbl_w_meta.tbl
+ )
+ .names(plc_tbl_w_meta.column_names())
+ .na_rep(na_rep)
+ .include_header(True)
+ .rows_per_chunk(8)
+ .line_terminator("\n")
+ .inter_column_delimiter(",")
+ .true_value("True")
+ .false_value("False")
+ .build()
+ )
+ )
+
+ # Convert everything to string to make comparisons easier
+ str_result = sink_to_str(sink)
+
+ pd_result = pa_tbl.to_pandas().to_csv(na_rep=na_rep, index=False)
+
+ assert str_result == pd_result