diff --git a/cpp/src/replace/clamp.cu b/cpp/src/replace/clamp.cu index fe5a9cfbd71..31ffc76a4a5 100644 --- a/cpp/src/replace/clamp.cu +++ b/cpp/src/replace/clamp.cu @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include #include @@ -52,26 +52,22 @@ namespace { template struct clamp_strings_fn { + using string_index_pair = cudf::strings::detail::string_index_pair; column_device_view const d_strings; OptionalScalarIterator lo_itr; ReplaceScalarIterator lo_replace_itr; OptionalScalarIterator hi_itr; ReplaceScalarIterator hi_replace_itr; - size_type* d_offsets{}; - char* d_chars{}; - __device__ void operator()(size_type idx) const + __device__ string_index_pair operator()(size_type idx) const { - if (d_strings.is_null(idx)) { - if (!d_chars) { d_offsets[idx] = 0; } - return; - } + if (d_strings.is_null(idx)) { return string_index_pair{nullptr, 0}; } + auto const element = d_strings.element(idx); auto const d_lo = (*lo_itr).value_or(element); auto const d_hi = (*hi_itr).value_or(element); auto const d_lo_replace = *(*lo_replace_itr); auto const d_hi_replace = *(*hi_replace_itr); - auto d_output = d_chars ? d_chars + d_offsets[idx] : nullptr; auto d_str = [d_lo, d_lo_replace, d_hi, d_hi_replace, element] { if (element < d_lo) { return d_lo_replace; } @@ -79,11 +75,9 @@ struct clamp_strings_fn { return element; }(); - if (d_output) { - cudf::strings::detail::copy_string(d_output, d_str); - } else { - d_offsets[idx] = d_str.size_bytes(); - } + // ensures an empty string is not converted to a null row + return !d_str.empty() ? string_index_pair{d_str.data(), d_str.size_bytes()} + : string_index_pair{"", 0}; } }; @@ -101,14 +95,14 @@ std::unique_ptr clamp_string_column(strings_column_view const& inp auto fn = clamp_strings_fn{ d_input, lo_itr, lo_replace_itr, hi_itr, hi_replace_itr}; - auto [offsets_column, chars] = - cudf::strings::detail::make_strings_children(fn, input.size(), stream, mr); - - return make_strings_column(input.size(), - std::move(offsets_column), - chars.release(), - input.null_count(), - std::move(cudf::detail::copy_bitmask(input.parent(), stream, mr))); + rmm::device_uvector indices(input.size(), stream); + thrust::transform(rmm::exec_policy_nosync(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(input.size()), + indices.begin(), + fn); + + return cudf::strings::detail::make_strings_column(indices.begin(), indices.end(), stream, mr); } template