Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add string.repeats API to pylibcudf #16834

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ strings
find
regex_flags
regex_program
repeat
replace
slice
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
======
repeat
======

.. automodule:: pylibcudf.strings.repeat
:members:
40 changes: 12 additions & 28 deletions python/cudf/cudf/_lib/strings/repeat.pyx
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
# Copyright (c) 2021-2024, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

from cudf.core.buffer import acquire_spill_lock

from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.column.column_view cimport column_view
from pylibcudf.libcudf.strings cimport repeat as cpp_repeat
from pylibcudf.libcudf.types cimport size_type

from cudf._lib.column cimport Column

import pylibcudf as plc


@acquire_spill_lock()
def repeat_scalar(Column source_strings,
Expand All @@ -21,16 +16,11 @@ def repeat_scalar(Column source_strings,
each string in `source_strings`
`repeats` number of times.
"""
cdef unique_ptr[column] c_result
cdef column_view source_view = source_strings.view()

with nogil:
c_result = move(cpp_repeat.repeat_strings(
source_view,
repeats
))

return Column.from_unique_ptr(move(c_result))
plc_result = plc.strings.repeat.repeat_strings(
source_strings.to_pylibcudf(mode="read"),
repeats
)
return Column.from_pylibcudf(plc_result)


@acquire_spill_lock()
Expand All @@ -41,14 +31,8 @@ def repeat_sequence(Column source_strings,
each string in `source_strings`
`repeats` number of times.
"""
cdef unique_ptr[column] c_result
cdef column_view source_view = source_strings.view()
cdef column_view repeats_view = repeats.view()

with nogil:
c_result = move(cpp_repeat.repeat_strings(
source_view,
repeats_view
))

return Column.from_unique_ptr(move(c_result))
plc_result = plc.strings.repeat.repeat_strings(
source_strings.to_pylibcudf(mode="read"),
repeats.to_pylibcudf(mode="read")
)
return Column.from_pylibcudf(plc_result)
8 changes: 4 additions & 4 deletions python/pylibcudf/pylibcudf/libcudf/strings/repeat.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ cdef extern from "cudf/strings/repeat_strings.hpp" namespace "cudf::strings" \
nogil:

cdef unique_ptr[column] repeat_strings(
column_view strings,
size_type repeat) except +
column_view input,
size_type repeat_times) except +

cdef unique_ptr[column] repeat_strings(
column_view strings,
column_view repeats) except +
column_view input,
column_view repeat_times) except +
2 changes: 1 addition & 1 deletion python/pylibcudf/pylibcudf/strings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# =============================================================================

set(cython_sources capitalize.pyx case.pyx char_types.pyx contains.pyx find.pyx regex_flags.pyx
regex_program.pyx replace.pyx slice.pyx
regex_program.pyx repeat.pyx replace.pyx slice.pyx
)

set(linked_libraries cudf::cudf)
Expand Down
1 change: 1 addition & 0 deletions python/pylibcudf/pylibcudf/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
find,
regex_flags,
regex_program,
repeat,
replace,
slice,
)
10 changes: 10 additions & 0 deletions python/pylibcudf/pylibcudf/strings/repeat.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from pylibcudf.column cimport Column
from pylibcudf.libcudf.types cimport size_type

ctypedef fused ColumnorSizeType:
Column
size_type

cpdef Column repeat_strings(Column input, ColumnorSizeType repeat_times)
51 changes: 51 additions & 0 deletions python/pylibcudf/pylibcudf/strings/repeat.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move
from pylibcudf.column cimport Column
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.strings cimport repeat as cpp_repeat
from pylibcudf.libcudf.types cimport size_type


cpdef Column repeat_strings(Column input, ColumnorSizeType repeat_times):
"""
Repeat each string in the given strings column by the numbers
of times given in another numeric column.

For details, see :cpp:func:`cudf::strings::repeat`.

Parameters
----------
input : Column
The column containing strings to repeat.
repeat_times : Column or int
Number(s) of times that the corresponding input strings
for each row are repeated.

Returns
-------
Column
New column containing the repeated strings.
"""
cdef unique_ptr[column] c_result

if ColumnorSizeType is Column:
with nogil:
c_result = move(
cpp_repeat.repeat_strings(
input.view(),
repeat_times.view()
)
)
elif ColumnorSizeType is size_type:
with nogil:
c_result = move(
cpp_repeat.repeat_strings(
input.view(),
repeat_times
)
)
else:
raise ValueError("repeat_times must be size_type or integer")

return Column.from_libcudf(move(c_result))
20 changes: 20 additions & 0 deletions python/pylibcudf/pylibcudf/tests/test_string_repeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import pyarrow as pa
import pyarrow.compute as pc
import pylibcudf as plc
import pytest


@pytest.mark.parametrize("repeats", [pa.array([2, 2]), 2])
def test_repeat_strings(repeats):
arr = pa.array(["1", None])
plc_result = plc.strings.repeat.repeat_strings(
plc.interop.from_arrow(arr),
plc.interop.from_arrow(repeats)
if not isinstance(repeats, int)
else repeats,
)
result = plc.interop.to_arrow(plc_result)
expected = pa.chunked_array(pc.binary_repeat(arr, repeats))
assert result.equals(expected)
Loading