Skip to content

Commit

Permalink
Add string.findall APIs to pylibcudf (#16825)
Browse files Browse the repository at this point in the history
Contributes to #15162

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)
  - Matthew Murray (https://github.com/Matt711)

URL: #16825
  • Loading branch information
mroeschke authored Sep 25, 2024
1 parent 9316309 commit 03c77c2
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
====
find
====

.. automodule:: pylibcudf.strings.findall
:members:
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ strings
contains
extract
find
findall
regex_flags
regex_program
repeat
Expand Down
35 changes: 10 additions & 25 deletions python/cudf/cudf/_lib/strings/findall.pyx
Original file line number Diff line number Diff line change
@@ -1,40 +1,25 @@
# Copyright (c) 2019-2024, NVIDIA CORPORATION.

from cython.operator cimport dereference
from libc.stdint cimport uint32_t
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.utility cimport move

from cudf.core.buffer import acquire_spill_lock

from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.column.column_view cimport column_view
from pylibcudf.libcudf.strings.findall cimport findall as cpp_findall
from pylibcudf.libcudf.strings.regex_flags cimport regex_flags
from pylibcudf.libcudf.strings.regex_program cimport regex_program

from cudf._lib.column cimport Column

import pylibcudf as plc


@acquire_spill_lock()
def findall(Column source_strings, object pattern, uint32_t flags):
"""
Returns data with all non-overlapping matches of `pattern`
in each string of `source_strings` as a lists column.
"""
cdef unique_ptr[column] c_result
cdef column_view source_view = source_strings.view()

cdef string pattern_string = <string>str(pattern).encode()
cdef regex_flags c_flags = <regex_flags>flags
cdef unique_ptr[regex_program] c_prog

with nogil:
c_prog = move(regex_program.create(pattern_string, c_flags))
c_result = move(cpp_findall(
source_view,
dereference(c_prog)
))

return Column.from_unique_ptr(move(c_result))
prog = plc.strings.regex_program.RegexProgram.create(
str(pattern), flags
)
plc_result = plc.strings.findall.findall(
source_strings.to_pylibcudf(mode="read"),
prog,
)
return Column.from_pylibcudf(plc_result)
4 changes: 2 additions & 2 deletions python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ from pylibcudf.libcudf.strings.regex_program cimport regex_program
cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil:

cdef unique_ptr[column] findall(
column_view source_strings,
regex_program) except +
column_view input,
regex_program prog) except +
4 changes: 2 additions & 2 deletions python/pylibcudf/pylibcudf/strings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# =============================================================================

set(cython_sources
capitalize.pyx case.pyx char_types.pyx contains.pyx extract.pyx find.pyx regex_flags.pyx
regex_program.pyx repeat.pyx replace.pyx side_type.pyx slice.pyx strip.pyx
capitalize.pyx case.pyx char_types.pyx contains.pyx extract.pyx find.pyx findall.pyx
regex_flags.pyx regex_program.pyx repeat.pyx replace.pyx side_type.pyx slice.pyx strip.pyx
)

set(linked_libraries cudf::cudf)
Expand Down
1 change: 1 addition & 0 deletions python/pylibcudf/pylibcudf/strings/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from . cimport (
convert,
extract,
find,
findall,
regex_flags,
regex_program,
replace,
Expand Down
1 change: 1 addition & 0 deletions python/pylibcudf/pylibcudf/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
convert,
extract,
find,
findall,
regex_flags,
regex_program,
repeat,
Expand Down
7 changes: 7 additions & 0 deletions python/pylibcudf/pylibcudf/strings/findall.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from pylibcudf.column cimport Column
from pylibcudf.strings.regex_program cimport RegexProgram


cpdef Column findall(Column input, RegexProgram pattern)
40 changes: 40 additions & 0 deletions python/pylibcudf/pylibcudf/strings/findall.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move
from pylibcudf.column cimport Column
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.strings cimport findall as cpp_findall
from pylibcudf.strings.regex_program cimport RegexProgram


cpdef Column findall(Column input, RegexProgram pattern):
"""
Returns a lists column of strings for each matching occurrence using
the regex_program pattern within each string.
For details, see For details, see :cpp:func:`cudf::strings::findall`.
Parameters
----------
input : Column
Strings instance for this operation
pattern : RegexProgram
Regex pattern
Returns
-------
Column
New lists column of strings
"""
cdef unique_ptr[column] c_result

with nogil:
c_result = move(
cpp_findall.findall(
input.view(),
pattern.c_obj.get()[0]
)
)

return Column.from_libcudf(move(c_result))
23 changes: 23 additions & 0 deletions python/pylibcudf/pylibcudf/tests/test_string_findall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import re

import pyarrow as pa
import pylibcudf as plc
from utils import assert_column_eq


def test_findall():
arr = pa.array(["bunny", "rabbit", "hare", "dog"])
pattern = "[ab]"
result = plc.strings.findall.findall(
plc.interop.from_arrow(arr),
plc.strings.regex_program.RegexProgram.create(
pattern, plc.strings.regex_flags.RegexFlags.DEFAULT
),
)
pa_result = plc.interop.to_arrow(result)
expected = pa.array(
[re.findall(pattern, elem) for elem in arr.to_pylist()],
type=pa_result.type,
)
assert_column_eq(result, expected)

0 comments on commit 03c77c2

Please sign in to comment.