From 873d813771eafa78d3e1e5d376b191a77a1aea2a Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 11 Oct 2024 15:34:04 +0200 Subject: [PATCH] Impl. Dataframe serialization Depend on #17012 --- .../cudf_polars/containers/dataframe.py | 113 ++++++++++++++++++ .../tests/containers/test_dataframe.py | 46 +++++++ 2 files changed, 159 insertions(+) diff --git a/python/cudf_polars/cudf_polars/containers/dataframe.py b/python/cudf_polars/cudf_polars/containers/dataframe.py index 2c195f6637c..b4bc2af2d59 100644 --- a/python/cudf_polars/cudf_polars/containers/dataframe.py +++ b/python/cudf_polars/cudf_polars/containers/dataframe.py @@ -5,6 +5,7 @@ from __future__ import annotations +import pickle from functools import cached_property from typing import TYPE_CHECKING, cast @@ -146,6 +147,69 @@ def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self: Column(c, name=name) for c, name in zip(table.columns(), names, strict=True) ) + @classmethod + def deserialize(cls, header: dict, frames: list[memoryview | plc.gpumemoryview]): + """ + Create an DataFrame from a serialized representation returned by `.serialize()`. + + Parameters + ---------- + header + The (unpickled) metadata required to reconstruct the object. + frames + List of contiguous buffers, which is a mixture of memoryview and gpumemoryviews. + + Returns + ------- + Buffer + The deserialized Buffer. + """ + packed_metadata, packed_gpu_data = frames + table = plc.contiguous_split.unpack_from_memoryviews( + packed_metadata, packed_gpu_data + ) + columns_kwargs = header["columns_kwargs"] + + if table.num_columns() != len(columns_kwargs): + raise ValueError("Mismatching columns_kwargs and table length.") + return cls( + Column(c, **kw) + for c, kw in zip(table.columns(), columns_kwargs, strict=True) + ) + + def serialize(self): + """ + Serialize the table into header and frames. + + Follows the Dask serialization scheme with a picklable header (dict) and + a list of frames (contiguous buffers). + + Returns + ------- + header + A dict containing any picklabe metadata required to reconstruct the object. + frames + List of frames, which is a mixture of memoryview and gpumemoryviews. + """ + packed = plc.contiguous_split.pack(self.table) + + # Keyword arguments for `Column.__init__`. + columns_kwargs = [ + { + "is_sorted": col.is_sorted, + "order": col.order, + "name": col.name, + } + for col in self.columns + ] + header = { + "columns_kwargs": columns_kwargs, + # Dask Distributed uses "type-serialized" to dispatch deserialization + "type-serialized": pickle.dumps(type(self)), + "frame_count": 2, + } + return header, list(packed.release()) + def sorted_like( self, like: DataFrame, /, *, subset: Set[str] | None = None ) -> Self: @@ -252,3 +316,52 @@ def slice(self, zlice: tuple[int, int] | None) -> Self: end = max(min(end, self.num_rows), 0) (table,) = plc.copying.slice(self.table, [start, end]) return type(self).from_table(table, self.column_names).sorted_like(self) + + +try: + import cupy + from distributed.protocol import ( + dask_deserialize, + dask_serialize, + ) + from distributed.protocol.cuda import ( + cuda_deserialize, + cuda_serialize, + ) + from distributed.utils import log_errors + + @cuda_serialize.register(DataFrame) + def _(x): + with log_errors(): + return x.serialize() + + @cuda_deserialize.register(DataFrame) + def _(header, frames): + with log_errors(): + return DataFrame.deserialize(header, frames) + + @dask_serialize.register(DataFrame) + def _(x): + with log_errors(): + header, frames = x.serialize() + # Copy GPU buffers to host and record it in the header + gpu_frames = [ + i + for i in range(len(frames)) + if isinstance(frames[i], plc.gpumemoryview) + ] + for i in gpu_frames: + frames[i] = memoryview(cupy.asnumpy(frames[i])) + header["gpu_frames"] = gpu_frames + return header, frames + + @dask_deserialize.register(DataFrame) + def _(header, frames): + with log_errors(): + # Copy GPU buffers back to device memory + for i in header.pop("gpu_frames"): + frames[i] = plc.gpumemoryview(cupy.asarray(frames[i])) + return DataFrame.deserialize(header, frames) + +except ImportError: + pass # distributed is probably not installed on the system diff --git a/python/cudf_polars/tests/containers/test_dataframe.py b/python/cudf_polars/tests/containers/test_dataframe.py index 5c68fb8f0aa..3addeaa2381 100644 --- a/python/cudf_polars/tests/containers/test_dataframe.py +++ b/python/cudf_polars/tests/containers/test_dataframe.py @@ -3,6 +3,7 @@ from __future__ import annotations +import pyarrow as pa import pylibcudf as plc import pytest @@ -160,3 +161,48 @@ def test_empty_name_roundtrips_overlap(): def test_empty_name_roundtrips_no_overlap(): df = pl.LazyFrame({"": [1, 2, 3], "b": [4, 5, 6]}) assert_gpu_result_equal(df) + + +@pytest.mark.parametrize( + "arrow_tbl", + [ + pa.table([]), + pa.table({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), + pa.table({"a": [1, 2, 3]}), + pa.table({"a": [1], "b": [2], "c": [3]}), + pa.table({"a": ["a", "bb", "ccc"]}), + pa.table({"a": [1, 2, None], "b": [None, 3, 4]}), + ], +) +def test_serialize(arrow_tbl): + plc_tbl = plc.interop.from_arrow(arrow_tbl) + df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names) + + header, frames = df.serialize() + res = DataFrame.deserialize(header, frames) + + pl.testing.asserts.assert_frame_equal(df.to_polars(), res.to_polars()) + + +@pytest.mark.parametrize( + "arrow_tbl", + [ + # pa.table([]), + pa.table({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), + pa.table({"a": [1, 2, 3]}), + pa.table({"a": [1], "b": [2], "c": [3]}), + pa.table({"a": ["a", "bb", "ccc"]}), + pa.table({"a": [1, 2, None], "b": [None, 3, 4]}), + ], +) +@pytest.mark.parametrize("protocol", ["cuda", "dask"]) +def test_dask_serialize(arrow_tbl, protocol): + from distributed.protocol import deserialize, serialize + + plc_tbl = plc.interop.from_arrow(arrow_tbl) + df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names) + + header, frames = serialize(df, on_error="raises", serializers=[protocol]) + res = deserialize(header, frames, deserializers=[protocol]) + + pl.testing.asserts.assert_frame_equal(df.to_polars(), res.to_polars())