Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arrow cast #962

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1f434cc
feat: add data_type parameter to expr_fn macro for arrow_cast function
kosiew Nov 26, 2024
a576cb7
feat: add arrow_cast function to cast expressions to specified data t…
kosiew Nov 26, 2024
1914a0b
docs: add casting section to user guide with examples for arrow_cast …
kosiew Nov 26, 2024
e623ae3
test: add unit test for arrow_cast function to validate casting to Fl…
kosiew Nov 26, 2024
61115b3
fix: update arrow_cast function to accept Expr type for data_type par…
kosiew Nov 26, 2024
11071e6
fix: update test_arrow_cast to use literal casting for data types
kosiew Nov 26, 2024
20cc781
fix: update arrow_cast function to accept string type for data_type p…
kosiew Nov 26, 2024
1e4e350
fix: update arrow_cast function to accept Expr type for data_type par…
kosiew Nov 26, 2024
8c7e2f8
fix: update test_arrow_cast to use literal for data type parameters
kosiew Nov 26, 2024
b80ae94
fix: update arrow_cast function to use arg_1 for datatype parameter
kosiew Nov 27, 2024
eba0d32
fix: update arrow_cast function to accept string type for data_type p…
kosiew Nov 27, 2024
3a5e210
Revert "fix: update arrow_cast function to accept string type for dat…
kosiew Nov 27, 2024
856ff8c
fix: update test_arrow_cast to cast literals to string type for arrow…
kosiew Nov 27, 2024
dcaf0d6
Revert "fix: update test_arrow_cast to cast literals to string type f…
kosiew Nov 27, 2024
9e1ced7
fix: update arrow_cast function to accept string type for data_type p…
kosiew Nov 27, 2024
8e96e8e
Revert "fix: update arrow_cast function to accept string type for dat…
kosiew Nov 27, 2024
11ed674
fix: add utf8_literal function to create UTF8 literal expressions in …
kosiew Nov 27, 2024
193d21c
Revert "fix: add utf8_literal function to create UTF8 literal express…
kosiew Nov 27, 2024
ba53bd1
feat: add utf8_literal function to create UTF8 literal expressions
kosiew Nov 27, 2024
3b83a96
fix: update test_arrow_cast to use column 'b'
kosiew Nov 27, 2024
cdf32cd
fix: enhance utf8_literal function to handle non-string values
kosiew Dec 2, 2024
187e077
Add description for utf8_literal vs literal
kosiew Dec 2, 2024
d801567
Merge branch 'main' into add-arrow-cast-merge-main
kosiew Dec 3, 2024
0106cb7
docs: clarify utf8_literal function documentation to explain use case
kosiew Dec 3, 2024
74cbd3b
docs: add clarification comments for utf8_literal usage in arrow_cast…
kosiew Dec 3, 2024
1c5b91e
docs: implement ruff recommendation
kosiew Dec 3, 2024
4aa6c7e
fix ruff errors
kosiew Dec 3, 2024
f9814dd
docs: update examples to use utf8_literal in arrow_cast function
kosiew Dec 3, 2024
8eb0ed1
docs: correct typo in comment for utf8_literal usage in test_arrow_cast
kosiew Dec 10, 2024
9216389
docs: remove redundant comment in test_arrow_cast for clarity
kosiew Dec 10, 2024
5e03c3a
refactor: rename utf8_literal to string_literal and add alias str_lit
kosiew Dec 12, 2024
7e28012
docs: improve docstring for string_literal function for clarity
kosiew Dec 12, 2024
5eced8b
docs: update import statement to include str_lit alias for string_lit…
kosiew Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/source/user-guide/common-operations/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ DataFusion offers mathematical functions such as :py:func:`~datafusion.functions

.. ipython:: python

from datafusion import col, literal
from datafusion import col, literal, string_literal, str_lit
from datafusion import functions as f

df.select(
Expand Down Expand Up @@ -104,6 +104,17 @@ This also includes the functions for regular expressions like :py:func:`~datafus
f.regexp_replace(col('"Name"'), literal("saur"), literal("fleur")).alias("flowers")
)

Casting
-------

Casting expressions to different data types using :py:func:`~datafusion.functions.arrow_cast`

.. ipython:: python

df.select(
f.arrow_cast(col('"Total"'), string_literal("Float64")).alias("total_as_float"),
f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int")
)

Other
-----
Expand Down
13 changes: 13 additions & 0 deletions python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ def literal(value):
return Expr.literal(value)


def string_literal(value):
"""Create a UTF8 literal expression.

It differs from `literal` which creates a UTF8view literal.
"""
return Expr.string_literal(value)


def str_lit(value):
"""Alias for `string_literal`."""
return string_literal(value)


def lit(value):
"""Create a literal expression."""
return Expr.literal(value)
Expand Down
16 changes: 16 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,22 @@ def literal(value: Any) -> Expr:
value = pa.scalar(value)
return Expr(expr_internal.Expr.literal(value))

@staticmethod
def string_literal(value: str) -> Expr:
"""Creates a new expression representing a UTF8 literal value.

It is different from `literal` because it is pa.string() instead of
pa.string_view()

This is needed for cases where DataFusion is expecting a UTF8 instead of
UTF8View literal, like in:
https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string())
return Expr(expr_internal.Expr.literal(value))
return Expr.literal(value)

@staticmethod
def column(value: str) -> Expr:
"""Creates a new expression representing a column."""
Expand Down
6 changes: 6 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"array_to_string",
"array_union",
"arrow_typeof",
"arrow_cast",
"ascii",
"asin",
"asinh",
Expand Down Expand Up @@ -1108,6 +1109,11 @@ def arrow_typeof(arg: Expr) -> Expr:
return Expr(f.arrow_typeof(arg.expr))


def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
"""Casts an expression to a specified data type."""
return Expr(f.arrow_cast(expr.expr, data_type.expr))


kosiew marked this conversation as resolved.
Show resolved Hide resolved
def random() -> Expr:
"""Returns a random value in the range ``0.0 <= x < 1.0``."""
return Expr(f.random())
Expand Down
18 changes: 17 additions & 1 deletion python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datafusion import SessionContext, column
from datafusion import functions as f
from datafusion import literal
from datafusion import literal, string_literal

np.seterr(invalid="ignore")

Expand Down Expand Up @@ -907,6 +907,22 @@ def test_temporal_functions(df):
assert result.column(10) == pa.array([31, 26, 2], type=pa.float64())


def test_arrow_cast(df):
df = df.select(
# we use `string_literal` to return utf8 instead of `literal` which returns
# utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view
# https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"),
f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
)
result = df.collect()
assert len(result) == 1
result = result[0]

assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())


def test_case(df):
df = df.select(
f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)),
Expand Down
3 changes: 2 additions & 1 deletion src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ macro_rules! expr_fn {
}
};
}

/// Generates a [pyo3] wrapper for [datafusion::functions::expr_fn]
///
/// These functions take a single `Vec<PyExpr>` argument using `pyo3(signature = (*args))`.
Expand Down Expand Up @@ -564,6 +563,7 @@ expr_fn_vec!(r#struct); // Use raw identifier since struct is a keyword
expr_fn_vec!(named_struct);
expr_fn!(from_unixtime, unixtime);
expr_fn!(arrow_typeof, arg_1);
expr_fn!(arrow_cast, arg_1 datatype);
expr_fn!(random);

// Array Functions
Expand Down Expand Up @@ -856,6 +856,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(range))?;
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
m.add_wrapped(wrap_pyfunction!(asinh))?;
Expand Down
Loading