Skip to content

Commit

Permalink
fix in string scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
shrshi committed Nov 27, 2024
1 parent 953a5d5 commit ba54bb6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion cpp/src/text/bpe/load_merge_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ std::unique_ptr<bpe_merge_pairs::bpe_merge_pairs_impl> create_bpe_merge_pairs_im
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto pairs = cudf::strings::split_record(input, cudf::string_scalar(" "), 1, stream, mr);
auto pairs =
cudf::strings::split_record(input, cudf::string_scalar(" ", true, stream, mr), 1, stream, mr);
auto content = pairs->release();
return create_bpe_merge_pairs_impl(std::move(content.children.back()), stream);
}
Expand Down
8 changes: 4 additions & 4 deletions cpp/tests/streams/text/bpe_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct TextBytePairEncoding : public cudf::test::BaseFixture {};

TEST_F(TextBytePairEncoding, BytePairEncoding)
{
auto stream = cudf::test::get_default_stream();
// partial table based on values from https://huggingface.co/gpt2/raw/main/merges.txt
auto mpt = cudf::test::strings_column_wrapper({
"e n", // 14
Expand All @@ -45,15 +46,14 @@ TEST_F(TextBytePairEncoding, BytePairEncoding)
"s ent" // 33832
});

auto merge_pairs =
nvtext::load_merge_pairs(cudf::strings_column_view(mpt), cudf::test::get_default_stream());
auto merge_pairs = nvtext::load_merge_pairs(cudf::strings_column_view(mpt), stream);

auto validity = cudf::test::iterators::null_at(4);
cudf::test::strings_column_wrapper input(
{"thisisit", "thisis test-sentence-1", "thisistestsentence-2", "this-istestsentence 3", "", ""},
validity);
auto sv = cudf::strings_column_view(input);

auto results = nvtext::byte_pair_encoding(
sv, *merge_pairs, cudf::string_scalar(" "), cudf::test::get_default_stream());
auto results =
nvtext::byte_pair_encoding(sv, *merge_pairs, cudf::string_scalar(" ", true, stream), stream);
}

0 comments on commit ba54bb6

Please sign in to comment.