From 3808fc481c528ff5377257235fd9a8db7a9c5b1a Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Thu, 22 Aug 2024 20:50:53 +0800 Subject: [PATCH 1/4] upload sentence_augmentation_mapper --- configs/config_all.yaml | 5 + data_juicer/ops/mapper/__init__.py | 6 +- .../mapper/sentence_augmentation_mapper.py | 119 ++++++++++++++++++ docs/Operators.md | 3 +- docs/Operators_ZH.md | 3 +- .../test_sentence_augmentation_mapper.py | 34 +++++ 6 files changed, 166 insertions(+), 4 deletions(-) create mode 100644 data_juicer/ops/mapper/sentence_augmentation_mapper.py create mode 100644 tests/ops/mapper/test_sentence_augmentation_mapper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 855d45731..29b92f3bf 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -151,6 +151,11 @@ process: lang: en # sample in which language tokenization: false # whether to use model to tokenize documents substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove + - sentence_augmentation_mapper: # augment sentences using LLMs. + system_prompt: None # system prompt + task_sentence: None # the instruction for the current task + max_new_tokens: 256 # the maximum number of new tokens generated by the model + sampling_params: {} # sampling parameters for text generation - sentence_split_mapper: # split text to multiple sentences and join them with '\n' lang: 'en' # split text in what language - video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5213498e9..e8b72eee9 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -12,8 +12,8 @@ remove_repeat_sentences_mapper, remove_specific_chars_mapper, remove_table_text_mapper, remove_words_with_incorrect_substrings_mapper, - replace_content_mapper, sentence_split_mapper, - video_captioning_from_audio_mapper, + replace_content_mapper, sentence_augmentation_mapper, + sentence_split_mapper, video_captioning_from_audio_mapper, video_captioning_from_frames_mapper, video_captioning_from_summarizer_mapper, video_captioning_from_video_mapper, video_face_blur_mapper, @@ -54,6 +54,7 @@ from .remove_words_with_incorrect_substrings_mapper import \ RemoveWordsWithIncorrectSubstringsMapper from .replace_content_mapper import ReplaceContentMapper +from .sentence_augmentation_mapper import SentenceAugmentationMapper from .sentence_split_mapper import SentenceSplitMapper from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper from .video_captioning_from_frames_mapper import \ @@ -118,6 +119,7 @@ 'AudioFFmpegWrappedMapper', 'VideoSplitByDurationMapper', 'VideoFaceBlurMapper', + 'SentenceAugmentationMapper' ] # yapf: enable diff --git a/data_juicer/ops/mapper/sentence_augmentation_mapper.py b/data_juicer/ops/mapper/sentence_augmentation_mapper.py new file mode 100644 index 000000000..e8246b4a1 --- /dev/null +++ b/data_juicer/ops/mapper/sentence_augmentation_mapper.py @@ -0,0 +1,119 @@ +from typing import Dict + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial \ + intelligence assistant. The assistant gives helpful, detailed, and \ + polite answers to the user's questions." + +OP_NAME = 'sentence_augmentation_mapper' + + +@OPERATORS.register_module(OP_NAME) +class SentenceAugmentationMapper(Mapper): + """Mapper to optimize instruction. + Recommended model list: [ + lmsys/vicuna-13b-v1.5 + Qwen/Qwen2-7B-Instruct + ] + """ + _accelerator = 'cuda' + + def __init__(self, + hf_model: str = 'Qwen/Qwen2-7B-Instruct', + system_prompt: str = None, + task_sentence: str = None, + max_new_tokens=256, + sampling_params: Dict = {}, + *args, + **kwargs): + """ + Initialization method. + :param hf_model: Hugginface model id. + :param system_prompt: System prompt. + :param task_sentence: The instruction for the current task. + :param max_new_tokens: the maximum number of new tokens + generated by the model. + :param sampling_params: Sampling parameters for text generation. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, num_proc=1, **kwargs) + + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + self.system_prompt = system_prompt + self.hf_model = hf_model + self.max_new_tokens = max_new_tokens + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model) + self.sampling_params = sampling_params + self.task_sentence = task_sentence + + def process(self, sample=None, rank=None): + + if 'vicuna' in self.hf_model: + model, processor = get_model(self.model_key, + rank=rank, + use_cuda=True) + + input_prompt = self.system_prompt + " USER: Here \ + is a sentence: \"" + sample[ + self.text_key] + "\". " + self.task_sentence + + inputs = processor(input_prompt, + return_tensors='pt').to(model.device) + response = model.generate(**inputs, + max_new_tokens=self.max_new_tokens, + eos_token_id=processor.eos_token_id, + **self.sampling_params) + input_token_len = inputs.input_ids.shape[1] + n_diff_input_output = (inputs.input_ids != + response[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are \ + not the same as the input_ids') + output = processor.batch_decode(response[:, input_token_len:], + skip_special_tokens=True)[0] + output = output.strip().strip("\"") + + else: + model, processor = get_model(self.model_key, + rank=rank, + use_cuda=True) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': + 'user', + 'content': + "Here is a sentence: \"" + sample[self.text_key] + "\". " + + self.task_sentence + }] + input_prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) + + inputs = processor(input_prompt, + return_tensors='pt').to(model.device) + response = model.generate(**inputs, + max_new_tokens=self.max_new_tokens, + eos_token_id=processor.eos_token_id, + **self.sampling_params) + input_token_len = inputs.input_ids.shape[1] + n_diff_input_output = (inputs.input_ids != + response[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are not \ + the same as the input_ids') + output = processor.batch_decode(response[:, input_token_len:], + skip_special_tokens=True)[0] + output = output.strip().strip("\"") + + sample[self.text_key] = output + + return sample diff --git a/docs/Operators.md b/docs/Operators.md index a35210161..49e1801ca 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 43 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 44 | Edits and transforms samples | | [ Filter ]( #filter ) | 41 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -77,6 +77,7 @@ All the specific operators are listed below, each featured with several capabili | remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile)| | remove_words_with_incorrect_
substrings_mapper | General | en, zh | Removes words containing specified substrings | | replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string | +| sentence_augmentation_mapper | General | en, zh | Augment sentences using LLMs (Large Language Models) | | sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics | | video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model | | video_captioning_from_frames_mapper | Multimodal | - | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 855d109a7..4c404cf4b 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 43 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 44 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 41 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -76,6 +76,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) | | remove_words_with_incorrect_
substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 | | replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | +| sentence_augmentation_mapper | General | en, zh | 使用大语言模型来给句子做增强 | | sentence_split_mapper | General | en | 根据语义拆分和重组句子 | | video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | | video_captioning_from_frames_mapper | Multimodal | - | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | diff --git a/tests/ops/mapper/test_sentence_augmentation_mapper.py b/tests/ops/mapper/test_sentence_augmentation_mapper.py new file mode 100644 index 000000000..be26d01c3 --- /dev/null +++ b/tests/ops/mapper/test_sentence_augmentation_mapper.py @@ -0,0 +1,34 @@ +import unittest +from data_juicer.ops.mapper.sentence_augmentation_mapper import SentenceAugmentationMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class SentenceAugmentationMapperTest(DataJuicerTestCaseBase): + + text_key = 'text' + + def _run_sentence_augmentation_mapper(self): + op = SentenceAugmentationMapper( + hf_model='Qwen2-7B-Instruct', + task_sentence="Please replace one entity in this sentence with another entity, such as an animal, a vehicle, or a piece of furniture. Please only answer with the replaced sentence. ASSISTANT:", + max_new_tokens=512, + sampling_params={'temperature': 0.9, 'top_p': 0.95} + ) + + samples = [ + {self.text_key: 'a book is near a cat and a dog'} + ] + + for sample in samples: + result = op.process(sample) + print(f'Output results: {result}') + self.assertIn(self.text_key, result) + + def test_sentence_augmentation_mapper(self): + self._run_sentence_augmentation_mapper() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From cc88377a861364410f4bc045b7ab6ef408049a29 Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Sat, 24 Aug 2024 11:12:31 +0800 Subject: [PATCH 2/4] update --- configs/config_all.yaml | 5 ++- .../mapper/sentence_augmentation_mapper.py | 45 +++++++++++++------ .../test_sentence_augmentation_mapper.py | 5 ++- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 29b92f3bf..62bad6437 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -152,10 +152,13 @@ process: tokenization: false # whether to use model to tokenize documents substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove - sentence_augmentation_mapper: # augment sentences using LLMs. + hf_model: 'Qwen/Qwen2-7B-Instruct' # model name of the LLM on huggingface system_prompt: None # system prompt task_sentence: None # the instruction for the current task max_new_tokens: 256 # the maximum number of new tokens generated by the model - sampling_params: {} # sampling parameters for text generation + temperature: 0.2 # used to control the randomness of generated text + top_p: None # randomly select the next word from the group of words whose cumulative probability reaches p + num_beams: 1 # the larger the beam search size, the higher the quality of the generated text - sentence_split_mapper: # split text to multiple sentences and join them with '\n' lang: 'en' # split text in what language - video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model diff --git a/data_juicer/ops/mapper/sentence_augmentation_mapper.py b/data_juicer/ops/mapper/sentence_augmentation_mapper.py index e8246b4a1..355b3b5d5 100644 --- a/data_juicer/ops/mapper/sentence_augmentation_mapper.py +++ b/data_juicer/ops/mapper/sentence_augmentation_mapper.py @@ -1,4 +1,4 @@ -from typing import Dict +import torch from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.model_utils import get_model, prepare_model @@ -25,7 +25,9 @@ def __init__(self, system_prompt: str = None, task_sentence: str = None, max_new_tokens=256, - sampling_params: Dict = {}, + temperature=0.2, + top_p=None, + num_beams=1, *args, **kwargs): """ @@ -35,8 +37,13 @@ def __init__(self, :param task_sentence: The instruction for the current task. :param max_new_tokens: the maximum number of new tokens generated by the model. - :param sampling_params: Sampling parameters for text generation. - e.g {'temperature': 0.9, 'top_p': 0.95} + :param temperature: used to control the randomness of \ + generated text. The higher the temperature, the more \ + random and creative the generated text will be. + :param top_p: randomly select the next word from the group \ + of words whose cumulative probability reaches p. + :param num_beams: the larger the beam search size, the higher \ + the quality of the generated text. :param args: extra args :param kwargs: extra args """ @@ -50,15 +57,20 @@ def __init__(self, self.model_key = prepare_model(model_type='huggingface', pretrained_model_name_or_path=hf_model) - self.sampling_params = sampling_params + self.temperature = temperature + self.top_p = top_p + self.num_beams = num_beams self.task_sentence = task_sentence def process(self, sample=None, rank=None): if 'vicuna' in self.hf_model: - model, processor = get_model(self.model_key, - rank=rank, - use_cuda=True) + if torch.cuda.is_available(): + model, processor = get_model(self.model_key, + rank=rank, + use_cuda=True) + else: + model, processor = get_model(self.model_key, rank=rank) input_prompt = self.system_prompt + " USER: Here \ is a sentence: \"" + sample[ @@ -69,7 +81,9 @@ def process(self, sample=None, rank=None): response = model.generate(**inputs, max_new_tokens=self.max_new_tokens, eos_token_id=processor.eos_token_id, - **self.sampling_params) + top_p=self.top_p, + temperature=self.temperature, + num_beams=self.num_beams) input_token_len = inputs.input_ids.shape[1] n_diff_input_output = (inputs.input_ids != response[:, :input_token_len]).sum().item() @@ -81,9 +95,12 @@ def process(self, sample=None, rank=None): output = output.strip().strip("\"") else: - model, processor = get_model(self.model_key, - rank=rank, - use_cuda=True) + if torch.cuda.is_available(): + model, processor = get_model(self.model_key, + rank=rank, + use_cuda=True) + else: + model, processor = get_model(self.model_key, rank=rank) messages = [{ 'role': 'system', @@ -103,7 +120,9 @@ def process(self, sample=None, rank=None): response = model.generate(**inputs, max_new_tokens=self.max_new_tokens, eos_token_id=processor.eos_token_id, - **self.sampling_params) + top_p=self.top_p, + temperature=self.temperature, + num_beams=self.num_beams) input_token_len = inputs.input_ids.shape[1] n_diff_input_output = (inputs.input_ids != response[:, :input_token_len]).sum().item() diff --git a/tests/ops/mapper/test_sentence_augmentation_mapper.py b/tests/ops/mapper/test_sentence_augmentation_mapper.py index be26d01c3..ac1e6a9c1 100644 --- a/tests/ops/mapper/test_sentence_augmentation_mapper.py +++ b/tests/ops/mapper/test_sentence_augmentation_mapper.py @@ -14,7 +14,9 @@ def _run_sentence_augmentation_mapper(self): hf_model='Qwen2-7B-Instruct', task_sentence="Please replace one entity in this sentence with another entity, such as an animal, a vehicle, or a piece of furniture. Please only answer with the replaced sentence. ASSISTANT:", max_new_tokens=512, - sampling_params={'temperature': 0.9, 'top_p': 0.95} + temperature=0.9, + top_p=0.95, + num_beams=1, ) samples = [ @@ -30,5 +32,6 @@ def test_sentence_augmentation_mapper(self): self._run_sentence_augmentation_mapper() + if __name__ == '__main__': unittest.main() \ No newline at end of file From 255ceb7cc463f953faa47bb5cc66b24c001648fb Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Tue, 27 Aug 2024 11:29:47 +0800 Subject: [PATCH 3/4] update --- .../mapper/sentence_augmentation_mapper.py | 95 ++++++++----------- .../test_sentence_augmentation_mapper.py | 5 +- 2 files changed, 39 insertions(+), 61 deletions(-) diff --git a/data_juicer/ops/mapper/sentence_augmentation_mapper.py b/data_juicer/ops/mapper/sentence_augmentation_mapper.py index 355b3b5d5..e75e1ac6e 100644 --- a/data_juicer/ops/mapper/sentence_augmentation_mapper.py +++ b/data_juicer/ops/mapper/sentence_augmentation_mapper.py @@ -1,5 +1,3 @@ -import torch - from data_juicer.ops.base_op import OPERATORS, Mapper from data_juicer.utils.model_utils import get_model, prepare_model @@ -12,7 +10,12 @@ @OPERATORS.register_module(OP_NAME) class SentenceAugmentationMapper(Mapper): - """Mapper to optimize instruction. + """Mapper to augment sentences. + The purpose of this operation is to enhance sentences. + If the input text is at the document level, the enhancement + effect may not be optimal. Therefore, please consider the + length of the input text carefully. + Recommended model list: [ lmsys/vicuna-13b-v1.5 Qwen/Qwen2-7B-Instruct @@ -37,12 +40,12 @@ def __init__(self, :param task_sentence: The instruction for the current task. :param max_new_tokens: the maximum number of new tokens generated by the model. - :param temperature: used to control the randomness of \ - generated text. The higher the temperature, the more \ - random and creative the generated text will be. - :param top_p: randomly select the next word from the group \ + :param temperature: used to control the randomness of + generated text. The higher the temperature, the more + random and creative the generated text will be. + :param top_p: randomly select the next word from the group of words whose cumulative probability reaches p. - :param num_beams: the larger the beam search size, the higher \ + :param num_beams: the larger the beam search size, the higher the quality of the generated text. :param args: extra args :param kwargs: extra args @@ -64,44 +67,21 @@ def __init__(self, def process(self, sample=None, rank=None): - if 'vicuna' in self.hf_model: - if torch.cuda.is_available(): - model, processor = get_model(self.model_key, - rank=rank, - use_cuda=True) - else: - model, processor = get_model(self.model_key, rank=rank) + if self.task_sentence is None: + print('[Warning] task_sentence is None!') + sample[self.text_key] = '' + return sample + model, processor = get_model(model_key=self.model_key, + rank=rank, + use_cuda=self.use_cuda()) + + if 'vicuna' in self.hf_model: input_prompt = self.system_prompt + " USER: Here \ is a sentence: \"" + sample[ - self.text_key] + "\". " + self.task_sentence - - inputs = processor(input_prompt, - return_tensors='pt').to(model.device) - response = model.generate(**inputs, - max_new_tokens=self.max_new_tokens, - eos_token_id=processor.eos_token_id, - top_p=self.top_p, - temperature=self.temperature, - num_beams=self.num_beams) - input_token_len = inputs.input_ids.shape[1] - n_diff_input_output = (inputs.input_ids != - response[:, :input_token_len]).sum().item() - if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are \ - not the same as the input_ids') - output = processor.batch_decode(response[:, input_token_len:], - skip_special_tokens=True)[0] - output = output.strip().strip("\"") + self.text_key] + "\". " + self.task_sentence + ' ASSISTANT:' else: - if torch.cuda.is_available(): - model, processor = get_model(self.model_key, - rank=rank, - use_cuda=True) - else: - model, processor = get_model(self.model_key, rank=rank) - messages = [{ 'role': 'system', 'content': self.system_prompt @@ -115,23 +95,22 @@ def process(self, sample=None, rank=None): input_prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) - inputs = processor(input_prompt, - return_tensors='pt').to(model.device) - response = model.generate(**inputs, - max_new_tokens=self.max_new_tokens, - eos_token_id=processor.eos_token_id, - top_p=self.top_p, - temperature=self.temperature, - num_beams=self.num_beams) - input_token_len = inputs.input_ids.shape[1] - n_diff_input_output = (inputs.input_ids != - response[:, :input_token_len]).sum().item() - if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are not \ - the same as the input_ids') - output = processor.batch_decode(response[:, input_token_len:], - skip_special_tokens=True)[0] - output = output.strip().strip("\"") + inputs = processor(input_prompt, return_tensors='pt').to(model.device) + response = model.generate(**inputs, + max_new_tokens=self.max_new_tokens, + eos_token_id=processor.eos_token_id, + top_p=self.top_p, + temperature=self.temperature, + num_beams=self.num_beams) + input_token_len = inputs.input_ids.shape[1] + n_diff_input_output = (inputs.input_ids != + response[:, :input_token_len]).sum().item() + if n_diff_input_output > 0: + print(f'[Warning] {n_diff_input_output} output_ids are \ + not the same as the input_ids') + output = processor.batch_decode(response[:, input_token_len:], + skip_special_tokens=True)[0] + output = output.strip().strip("\"") sample[self.text_key] = output diff --git a/tests/ops/mapper/test_sentence_augmentation_mapper.py b/tests/ops/mapper/test_sentence_augmentation_mapper.py index ac1e6a9c1..f795e3317 100644 --- a/tests/ops/mapper/test_sentence_augmentation_mapper.py +++ b/tests/ops/mapper/test_sentence_augmentation_mapper.py @@ -11,8 +11,8 @@ class SentenceAugmentationMapperTest(DataJuicerTestCaseBase): def _run_sentence_augmentation_mapper(self): op = SentenceAugmentationMapper( - hf_model='Qwen2-7B-Instruct', - task_sentence="Please replace one entity in this sentence with another entity, such as an animal, a vehicle, or a piece of furniture. Please only answer with the replaced sentence. ASSISTANT:", + hf_model='Qwen/Qwen2-7B-Instruct', + task_sentence="Please replace one entity in this sentence with another entity, such as an animal, a vehicle, or a piece of furniture. Please only answer with the replaced sentence.", max_new_tokens=512, temperature=0.9, top_p=0.95, @@ -26,7 +26,6 @@ def _run_sentence_augmentation_mapper(self): for sample in samples: result = op.process(sample) print(f'Output results: {result}') - self.assertIn(self.text_key, result) def test_sentence_augmentation_mapper(self): self._run_sentence_augmentation_mapper() From 5496ad84540dd9a5c32c423116329f9ac323010c Mon Sep 17 00:00:00 2001 From: Qirui-jiao <156628817+Qirui-jiao@users.noreply.github.com> Date: Fri, 30 Aug 2024 22:53:37 +0800 Subject: [PATCH 4/4] Update test_sentence_augmentation_mapper.py --- tests/ops/mapper/test_sentence_augmentation_mapper.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/ops/mapper/test_sentence_augmentation_mapper.py b/tests/ops/mapper/test_sentence_augmentation_mapper.py index f795e3317..5cba22b93 100644 --- a/tests/ops/mapper/test_sentence_augmentation_mapper.py +++ b/tests/ops/mapper/test_sentence_augmentation_mapper.py @@ -1,10 +1,8 @@ import unittest from data_juicer.ops.mapper.sentence_augmentation_mapper import SentenceAugmentationMapper -from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, - DataJuicerTestCaseBase) +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + -# These tests have been tested locally. -@SKIPPED_TESTS.register_module() class SentenceAugmentationMapperTest(DataJuicerTestCaseBase): text_key = 'text' @@ -33,4 +31,4 @@ def test_sentence_augmentation_mapper(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()