Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to ZSTD compression when writing Parquet #981

Merged
merged 16 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 88 additions & 6 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@

from __future__ import annotations
import warnings
from typing import Any, Iterable, List, TYPE_CHECKING, Literal, overload
from typing import (
Any,
Iterable,
List,
TYPE_CHECKING,
Literal,
overload,
Optional,
Union,
)
from datafusion.record_batch import RecordBatchStream
from typing_extensions import deprecated
from datafusion.plan import LogicalPlan, ExecutionPlan
Expand All @@ -35,6 +44,60 @@

from datafusion._internal import DataFrame as DataFrameInternal
from datafusion.expr import Expr, SortExpr, sort_or_default
from enum import Enum


# excerpt from deltalake
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
class Compression(Enum):
"""Enum representing the available compression types for Parquet files."""

UNCOMPRESSED = "uncompressed"
SNAPPY = "snappy"
GZIP = "gzip"
BROTLI = "brotli"
LZ4 = "lz4"
LZ0 = "lz0"
ZSTD = "zstd"
LZ4_RAW = "lz4_raw"

@classmethod
def from_str(cls, value: str) -> "Compression":
"""Convert a string to a Compression enum value.

Args:
value: The string representation of the compression type.

Returns:
The Compression enum lowercase value.

Raises:
ValueError: If the string does not match any Compression enum value.
"""
try:
return cls(value.lower())
except ValueError:
raise ValueError(
f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}"
)

def get_default_level(self) -> Optional[int]:
"""Get the default compression level for the compression type.

Returns:
The default compression level for the compression type.
"""
# GZIP, BROTLI default values from deltalake repo
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
# ZSTD default value from delta-rs
# https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
if self == Compression.GZIP:
return 6
elif self == Compression.BROTLI:
return 1
elif self == Compression.ZSTD:
return 4
return None


class DataFrame:
Expand Down Expand Up @@ -620,17 +683,36 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
def write_parquet(
self,
path: str | pathlib.Path,
compression: str = "uncompressed",
compression: Union[str, Compression] = Compression.ZSTD,
compression_level: int | None = None,
kosiew marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.

Args:
path: Path of the Parquet file to write.
compression: Compression type to use.
compression_level: Compression level to use.
"""
self.df.write_parquet(str(path), compression, compression_level)
compression: Compression type to use. Default is "ZSTD".
Available compression types are:
- "uncompressed": No compression.
- "snappy": Snappy compression.
- "gzip": Gzip compression.
- "brotli": Brotli compression.
- "lz0": LZ0 compression.
- "lz4": LZ4 compression.
- "lz4_raw": LZ4_RAW compression.
- "zstd": Zstandard compression.
compression_level: Compression level to use. For ZSTD, the
recommended range is 1 to 22, with the default being 4. Higher levels
provide better compression but slower speed.
"""
# Convert string to Compression enum if necessary
if isinstance(compression, str):
compression = Compression.from_str(compression)

if compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}:
if compression_level is None:
compression_level = compression.get_default_level()

self.df.write_parquet(str(path), compression.value, compression_level)

def write_json(self, path: str | pathlib.Path) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.
Expand Down
14 changes: 12 additions & 2 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,14 +1107,24 @@ def test_write_compressed_parquet_wrong_compression_level(
)


@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"])
def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compression):
@pytest.mark.parametrize("compression", ["wrong"])
def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression):
path = tmp_path

with pytest.raises(ValueError):
df.write_parquet(str(path), compression=compression)


@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"])
def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression):
# Test write_parquet with zstd, brotli, gzip default compression level,
# ie don't specify compression level
# should complete without error
path = tmp_path

df.write_parquet(str(path), compression=compression)


def test_dataframe_export(df) -> None:
# Guarantees that we have the canonical implementation
# reading our dataframe export
Expand Down
2 changes: 1 addition & 1 deletion src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ impl PyDataFrame {
/// Write a `DataFrame` to a Parquet file.
#[pyo3(signature = (
path,
compression="uncompressed",
compression="zstd",
compression_level=None
))]
fn write_parquet(
Expand Down
Loading