diff --git a/cpp/include/nvtext/subword_tokenize.hpp b/cpp/include/nvtext/subword_tokenize.hpp index c4210699975..4d06aa5d4bc 100644 --- a/cpp/include/nvtext/subword_tokenize.hpp +++ b/cpp/include/nvtext/subword_tokenize.hpp @@ -62,11 +62,13 @@ struct hashed_vocabulary { * @param filename_hashed_vocabulary A path to the preprocessed vocab.txt file. * Note that this is the file AFTER python/perfect_hash.py has been used * for preprocessing. + * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Memory resource to allocate any returned objects. * @return vocabulary hash-table elements */ std::unique_ptr load_vocabulary_file( std::string const& filename_hashed_vocabulary, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); /** @@ -147,6 +149,7 @@ struct tokenizer_result { * @param do_truncate If true, the tokenizer will discard all the token-ids after * `max_sequence_length` for each input string. If false, it will use a new row * in the output token-ids to continue generating the output. + * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Memory resource to allocate any returned objects. * @return token-ids, attention-mask, and metadata */ @@ -157,6 +160,7 @@ tokenizer_result subword_tokenize( uint32_t stride, bool do_lower_case, bool do_truncate, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); /** @} */ // end of group diff --git a/cpp/src/text/subword/load_hash_file.cu b/cpp/src/text/subword/load_hash_file.cu index eca703e2604..b13ad0a7de8 100644 --- a/cpp/src/text/subword/load_hash_file.cu +++ b/cpp/src/text/subword/load_hash_file.cu @@ -289,10 +289,12 @@ std::unique_ptr load_vocabulary_file( } // namespace detail std::unique_ptr load_vocabulary_file( - std::string const& filename_hashed_vocabulary, rmm::device_async_resource_ref mr) + std::string const& filename_hashed_vocabulary, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { CUDF_FUNC_RANGE(); - return detail::load_vocabulary_file(filename_hashed_vocabulary, cudf::get_default_stream(), mr); + return detail::load_vocabulary_file(filename_hashed_vocabulary, stream, mr); } } // namespace nvtext diff --git a/cpp/src/text/subword/subword_tokenize.cu b/cpp/src/text/subword/subword_tokenize.cu index d7e04a0c208..dee589d6daf 100644 --- a/cpp/src/text/subword/subword_tokenize.cu +++ b/cpp/src/text/subword/subword_tokenize.cu @@ -293,17 +293,12 @@ tokenizer_result subword_tokenize(cudf::strings_column_view const& strings, uint32_t stride, bool do_lower_case, bool do_truncate, + rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { CUDF_FUNC_RANGE(); - return detail::subword_tokenize(strings, - vocabulary_table, - max_sequence_length, - stride, - do_lower_case, - do_truncate, - cudf::get_default_stream(), - mr); + return detail::subword_tokenize( + strings, vocabulary_table, max_sequence_length, stride, do_lower_case, do_truncate, stream, mr); } } // namespace nvtext diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 6d3d1454462..23632f6fbba 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -743,6 +743,7 @@ ConfigureTest( streams/text/ngrams_test.cpp streams/text/replace_test.cpp streams/text/stemmer_test.cpp + streams/text/subword_tokenize_test.cpp streams/text/tokenize_test.cpp STREAM_MODE testing diff --git a/cpp/tests/streams/text/subword_tokenize_test.cpp b/cpp/tests/streams/text/subword_tokenize_test.cpp new file mode 100644 index 00000000000..9474e6b269c --- /dev/null +++ b/cpp/tests/streams/text/subword_tokenize_test.cpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include + +// Global environment for temporary files +auto const temp_env = static_cast( + ::testing::AddGlobalTestEnvironment(new cudf::test::TempDirTestEnvironment)); + +struct TextSubwordTest : public cudf::test::BaseFixture {}; + +// Create a fake hashed vocab text file for the tests in this source file. +// The vocab only includes the following words: +// 'this', 'is', 'a', 'test', 'tést' +// The period '.' character also has a token id. +void create_hashed_vocab(std::string const& hash_file) +{ + constexpr size_t coefsize = 23; + std::vector> coefficients(coefsize, {65559, 0}); + std::ofstream outfile(hash_file, std::ofstream::out); + outfile << "1\n0\n" << coefficients.size() << "\n"; + for (auto c : coefficients) { + outfile << c.first << " " << c.second << "\n"; + } + std::vector hash_table(coefsize, 0); + outfile << hash_table.size() << "\n"; + hash_table[0] = 3015668L; // based on values + hash_table[1] = 6205475701751155871L; // from the + hash_table[5] = 6358029; // bert_hash_table.txt + hash_table[16] = 451412625363L; // file for the test + hash_table[20] = 6206321707968235495L; // words above + for (auto h : hash_table) { + outfile << h << "\n"; + } + outfile << "100\n101\n102\n\n"; +} + +TEST(TextSubwordTest, Tokenize) +{ + uint32_t const nrows = 100; + std::vector h_strings(nrows, "This is a test. A test this is."); + cudf::test::strings_column_wrapper strings(h_strings.cbegin(), h_strings.cend()); + std::string const hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); + create_hashed_vocab(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file, cudf::test::get_default_stream()); + + uint32_t const max_sequence_length = 16; + uint32_t const stride = 16; + + auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings}, + *vocab, + max_sequence_length, + stride, + true, // do_lower_case + false, // do_truncate + cudf::test::get_default_stream()); +}