Skip to content

Commit

Permalink
Add unit tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Oct 9, 2024
1 parent e129ca2 commit 2dbcd92
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ strings
regex_flags
regex_program
repeat
replace_re
replace
side_type
slice
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
==========
replace_re
==========

.. automodule:: pylibcudf.strings.replace_re
:members:
1 change: 1 addition & 0 deletions python/pylibcudf/pylibcudf/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
regex_program,
repeat,
replace,
replace_re,
side_type,
slice,
split,
Expand Down
8 changes: 4 additions & 4 deletions python/pylibcudf/pylibcudf/strings/replace_re.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,27 +60,27 @@ cpdef Column replace_re(
"""
cdef unique_ptr[column] c_result
cdef vector[string] c_lst_patterns
cdef regex_program c_regex_pattern
cdef regex_program* c_regex_pattern
cdef column_view c_replacement_col
cdef string_scalar* c_replacement_scalar

if patterns is RegexProgram and replacement is Scalar:
c_replacement_scalar = <string_scalar*>((<Scalar>replacement).get())
c_regex_pattern = (<RegexProgram>patterns).c_obj.get()[0]
c_regex_pattern = (<RegexProgram>patterns).c_obj.get()

with nogil:
c_result = move(
cpp_replace_re.replace_re(
input.view(),
c_regex_pattern,
dereference(c_regex_pattern),
dereference(c_replacement_scalar),
max_replace_count,
)
)

return Column.from_libcudf(move(c_result))
elif patterns is list and replacement is Column:
c_replacement_col = replacement.view()
c_replacement_col = (<Column>replacement).view()
for pattern in patterns:
c_lst_patterns.push_back(pattern.encode())

Expand Down
71 changes: 71 additions & 0 deletions python/pylibcudf/pylibcudf/tests/test_string_replace_re.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import pyarrow as pa
import pyarrow.compute as pc
import pylibcudf as plc
import pytest
from utils import assert_column_eq


@pytest.mark.parametrize("max_replace_count", [-1, 1])
def test_replace_re_regex_program_scalar(max_replace_count):
arr = pa.array(["foo", "fuz", None])
pat = "f."
repl = "ba"
result = plc.strings.replace_re.replace_re(
plc.interop.from_arrow(arr),
plc.strings.regex_program.RegexProgram.create(
pat, plc.strings.regex_flags.RegexFlags.DEFAULT
),
plc.interop.from_arrow(pa.scalar(repl)),
max_replace_count=max_replace_count,
)
expected = pc.replace_substring_regex(
arr,
pat,
repl,
max_replacements=max_replace_count
if max_replace_count != -1
else None,
)
assert_column_eq(result, expected)


@pytest.mark.parametrize(
"flags",
[
plc.strings.regex_flags.RegexFlags.DEFAULT,
plc.strings.regex_flags.RegexFlags.DOTALL,
],
)
def test_replace_re_list_str_columns(flags):
arr = pa.array(["foo", "fuz", None])
pats = ["oo", "uz"]
repls = ["a", "b"]
result = plc.strings.replace_re.replace_re(
plc.interop.from_arrow(arr),
pats,
plc.interop.from_arrow(pa.array(repls)),
flags=flags,
)
expected = arr
for pat, repl in zip(pats, repls):
expected = pc.replace_substring_regex(
expected,
pat,
repl,
)
assert_column_eq(result, expected)


def test_replace_with_backrefs():
arr = pa.array(["Z756", None])
result = plc.strings.replace_re.replace_with_backrefs(
plc.interop.from_arrow(arr),
plc.strings.regex_program.RegexProgram.create(
"(\\d)(\\d)", plc.strings.regex_flags.RegexFlags.DEFAULT
),
"V\\2\\1",
)
expected = pa.array(["ZV576", None])
assert_column_eq(result, expected)

0 comments on commit 2dbcd92

Please sign in to comment.