Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sentence_augmentation_mapper #401

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ process:
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove
- sentence_augmentation_mapper: # augment sentences using LLMs.
hf_model: 'Qwen/Qwen2-7B-Instruct' # model name of the LLM on huggingface
system_prompt: None # system prompt
task_sentence: None # the instruction for the current task
max_new_tokens: 256 # the maximum number of new tokens generated by the model
temperature: 0.2 # used to control the randomness of generated text
top_p: None # randomly select the next word from the group of words whose cumulative probability reaches p
num_beams: 1 # the larger the beam search size, the higher the quality of the generated text
- sentence_split_mapper: # split text to multiple sentences and join them with '\n'
lang: 'en' # split text in what language
- video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
remove_repeat_sentences_mapper, remove_specific_chars_mapper,
remove_table_text_mapper,
remove_words_with_incorrect_substrings_mapper,
replace_content_mapper, sentence_split_mapper,
video_captioning_from_audio_mapper,
replace_content_mapper, sentence_augmentation_mapper,
sentence_split_mapper, video_captioning_from_audio_mapper,
video_captioning_from_frames_mapper,
video_captioning_from_summarizer_mapper,
video_captioning_from_video_mapper, video_face_blur_mapper,
Expand Down Expand Up @@ -57,6 +57,7 @@
from .remove_words_with_incorrect_substrings_mapper import \
RemoveWordsWithIncorrectSubstringsMapper
from .replace_content_mapper import ReplaceContentMapper
from .sentence_augmentation_mapper import SentenceAugmentationMapper
from .sentence_split_mapper import SentenceSplitMapper
from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper
from .video_captioning_from_frames_mapper import \
Expand Down Expand Up @@ -123,6 +124,7 @@
'AudioFFmpegWrappedMapper',
'VideoSplitByDurationMapper',
'VideoFaceBlurMapper',
'SentenceAugmentationMapper'
]

# yapf: enable
117 changes: 117 additions & 0 deletions data_juicer/ops/mapper/sentence_augmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial \
intelligence assistant. The assistant gives helpful, detailed, and \
polite answers to the user's questions."

OP_NAME = 'sentence_augmentation_mapper'


@OPERATORS.register_module(OP_NAME)
class SentenceAugmentationMapper(Mapper):
"""Mapper to augment sentences.
The purpose of this operation is to enhance sentences.
If the input text is at the document level, the enhancement
effect may not be optimal. Therefore, please consider the
length of the input text carefully.

Recommended model list: [
lmsys/vicuna-13b-v1.5
Qwen/Qwen2-7B-Instruct
]
"""
_accelerator = 'cuda'

def __init__(self,
hf_model: str = 'Qwen/Qwen2-7B-Instruct',
system_prompt: str = None,
task_sentence: str = None,
max_new_tokens=256,
temperature=0.2,
top_p=None,
num_beams=1,
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param system_prompt: System prompt.
:param task_sentence: The instruction for the current task.
:param max_new_tokens: the maximum number of new tokens
generated by the model.
:param temperature: used to control the randomness of
generated text. The higher the temperature, the more
random and creative the generated text will be.
:param top_p: randomly select the next word from the group
of words whose cumulative probability reaches p.
:param num_beams: the larger the beam search size, the higher
the quality of the generated text.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, num_proc=1, **kwargs)

if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
self.system_prompt = system_prompt
self.hf_model = hf_model
self.max_new_tokens = max_new_tokens

self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model)
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
self.task_sentence = task_sentence

def process(self, sample=None, rank=None):

if self.task_sentence is None:
print('[Warning] task_sentence is None!')
sample[self.text_key] = ''
return sample

Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved
model, processor = get_model(model_key=self.model_key,
rank=rank,
use_cuda=self.use_cuda())

if 'vicuna' in self.hf_model:
input_prompt = self.system_prompt + " USER: Here \
is a sentence: \"" + sample[
self.text_key] + "\". " + self.task_sentence + ' ASSISTANT:'

else:
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role':
'user',
'content':
"Here is a sentence: \"" + sample[self.text_key] + "\". " +
self.task_sentence
}]
input_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)

inputs = processor(input_prompt, return_tensors='pt').to(model.device)
response = model.generate(**inputs,
max_new_tokens=self.max_new_tokens,
eos_token_id=processor.eos_token_id,
top_p=self.top_p,
temperature=self.temperature,
num_beams=self.num_beams)
input_token_len = inputs.input_ids.shape[1]
n_diff_input_output = (inputs.input_ids !=
response[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are \
not the same as the input_ids')
output = processor.batch_decode(response[:, input_token_len:],
skip_special_tokens=True)[0]
output = output.strip().strip("\"")

sample[self.text_key] = output

return sample
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
| Type | Number | Description |
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 46 | Edits and transforms samples |
| [ Mapper ]( #mapper ) | 47 | Edits and transforms samples |
| [ Filter ]( #filter ) | 41 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |
Expand Down Expand Up @@ -80,6 +80,7 @@ All the specific operators are listed below, each featured with several capabili
| remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile)|
| remove_words_with_incorrect_<br />substrings_mapper | General | en, zh | Removes words containing specified substrings |
| replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string |
| sentence_augmentation_mapper | General | en, zh | Augment sentences using LLMs (Large Language Models) |
| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics |
| video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model |
| video_captioning_from_frames_mapper | Multimodal | - | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| 类型 | 数量 | 描述 |
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 |
| [ Mapper ]( #mapper ) | 47 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 41 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |
Expand Down Expand Up @@ -79,6 +79,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) |
| remove_words_with_incorrect_<br />substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 |
| replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 |
| sentence_augmentation_mapper | General | en, zh | 使用大语言模型来给句子做增强 |
| sentence_split_mapper | General | en | 根据语义拆分和重组句子 |
| video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 |
| video_captioning_from_frames_mapper | Multimodal | - | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 |
Expand Down
34 changes: 34 additions & 0 deletions tests/ops/mapper/test_sentence_augmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
from data_juicer.ops.mapper.sentence_augmentation_mapper import SentenceAugmentationMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class SentenceAugmentationMapperTest(DataJuicerTestCaseBase):

text_key = 'text'

def _run_sentence_augmentation_mapper(self):
op = SentenceAugmentationMapper(
hf_model='Qwen/Qwen2-7B-Instruct',
task_sentence="Please replace one entity in this sentence with another entity, such as an animal, a vehicle, or a piece of furniture. Please only answer with the replaced sentence.",
max_new_tokens=512,
temperature=0.9,
top_p=0.95,
num_beams=1,
)

samples = [
{self.text_key: 'a book is near a cat and a dog'}
]

for sample in samples:
result = op.process(sample)
print(f'Output results: {result}')

def test_sentence_augmentation_mapper(self):
self._run_sentence_augmentation_mapper()



if __name__ == '__main__':
unittest.main()
Loading