Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sentence_augmentation_mapper #401

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ process:
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove
- sentence_augmentation_mapper: # augment sentences using LLMs.
hf_model: 'Qwen/Qwen2-7B-Instruct' # model name of the LLM on huggingface
system_prompt: None # system prompt
task_sentence: None # the instruction for the current task
max_new_tokens: 256 # the maximum number of new tokens generated by the model
temperature: 0.2 # used to control the randomness of generated text
top_p: None # randomly select the next word from the group of words whose cumulative probability reaches p
num_beams: 1 # the larger the beam search size, the higher the quality of the generated text
- sentence_split_mapper: # split text to multiple sentences and join them with '\n'
lang: 'en' # split text in what language
- video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
remove_repeat_sentences_mapper, remove_specific_chars_mapper,
remove_table_text_mapper,
remove_words_with_incorrect_substrings_mapper,
replace_content_mapper, sentence_split_mapper,
video_captioning_from_audio_mapper,
replace_content_mapper, sentence_augmentation_mapper,
sentence_split_mapper, video_captioning_from_audio_mapper,
video_captioning_from_frames_mapper,
video_captioning_from_summarizer_mapper,
video_captioning_from_video_mapper, video_face_blur_mapper,
Expand Down Expand Up @@ -54,6 +54,7 @@
from .remove_words_with_incorrect_substrings_mapper import \
RemoveWordsWithIncorrectSubstringsMapper
from .replace_content_mapper import ReplaceContentMapper
from .sentence_augmentation_mapper import SentenceAugmentationMapper
from .sentence_split_mapper import SentenceSplitMapper
from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper
from .video_captioning_from_frames_mapper import \
Expand Down Expand Up @@ -118,6 +119,7 @@
'AudioFFmpegWrappedMapper',
'VideoSplitByDurationMapper',
'VideoFaceBlurMapper',
'SentenceAugmentationMapper'
]

# yapf: enable
138 changes: 138 additions & 0 deletions data_juicer/ops/mapper/sentence_augmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial \
intelligence assistant. The assistant gives helpful, detailed, and \
polite answers to the user's questions."

OP_NAME = 'sentence_augmentation_mapper'


@OPERATORS.register_module(OP_NAME)
class SentenceAugmentationMapper(Mapper):
"""Mapper to optimize instruction.
Recommended model list: [
lmsys/vicuna-13b-v1.5
Qwen/Qwen2-7B-Instruct
]
"""
_accelerator = 'cuda'

def __init__(self,
hf_model: str = 'Qwen/Qwen2-7B-Instruct',
system_prompt: str = None,
task_sentence: str = None,
max_new_tokens=256,
temperature=0.2,
top_p=None,
num_beams=1,
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param system_prompt: System prompt.
:param task_sentence: The instruction for the current task.
:param max_new_tokens: the maximum number of new tokens
generated by the model.
:param temperature: used to control the randomness of \
generated text. The higher the temperature, the more \
random and creative the generated text will be.
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved
:param top_p: randomly select the next word from the group \
of words whose cumulative probability reaches p.
:param num_beams: the larger the beam search size, the higher \
the quality of the generated text.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, num_proc=1, **kwargs)

if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
self.system_prompt = system_prompt
self.hf_model = hf_model
self.max_new_tokens = max_new_tokens

self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model)
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
self.task_sentence = task_sentence

def process(self, sample=None, rank=None):

if 'vicuna' in self.hf_model:
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved
if torch.cuda.is_available():
model, processor = get_model(self.model_key,
rank=rank,
use_cuda=True)
else:
model, processor = get_model(self.model_key, rank=rank)

Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved
input_prompt = self.system_prompt + " USER: Here \
is a sentence: \"" + sample[
self.text_key] + "\". " + self.task_sentence
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved

inputs = processor(input_prompt,
return_tensors='pt').to(model.device)
response = model.generate(**inputs,
max_new_tokens=self.max_new_tokens,
eos_token_id=processor.eos_token_id,
top_p=self.top_p,
temperature=self.temperature,
num_beams=self.num_beams)
input_token_len = inputs.input_ids.shape[1]
n_diff_input_output = (inputs.input_ids !=
response[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are \
not the same as the input_ids')
output = processor.batch_decode(response[:, input_token_len:],
skip_special_tokens=True)[0]
output = output.strip().strip("\"")

else:
if torch.cuda.is_available():
model, processor = get_model(self.model_key,
rank=rank,
use_cuda=True)
else:
model, processor = get_model(self.model_key, rank=rank)

messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role':
'user',
'content':
"Here is a sentence: \"" + sample[self.text_key] + "\". " +
self.task_sentence
}]
input_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)

inputs = processor(input_prompt,
return_tensors='pt').to(model.device)
response = model.generate(**inputs,
max_new_tokens=self.max_new_tokens,
eos_token_id=processor.eos_token_id,
top_p=self.top_p,
temperature=self.temperature,
num_beams=self.num_beams)
input_token_len = inputs.input_ids.shape[1]
n_diff_input_output = (inputs.input_ids !=
response[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved
print(f'[Warning] {n_diff_input_output} output_ids are not \
the same as the input_ids')
output = processor.batch_decode(response[:, input_token_len:],
skip_special_tokens=True)[0]
output = output.strip().strip("\"")

sample[self.text_key] = output

return sample
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
| Type | Number | Description |
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 43 | Edits and transforms samples |
| [ Mapper ]( #mapper ) | 44 | Edits and transforms samples |
| [ Filter ]( #filter ) | 41 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |
Expand Down Expand Up @@ -77,6 +77,7 @@ All the specific operators are listed below, each featured with several capabili
| remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile)|
| remove_words_with_incorrect_<br />substrings_mapper | General | en, zh | Removes words containing specified substrings |
| replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string |
| sentence_augmentation_mapper | General | en, zh | Augment sentences using LLMs (Large Language Models) |
| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics |
| video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model |
| video_captioning_from_frames_mapper | Multimodal | - | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| 类型 | 数量 | 描述 |
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 43 | 对数据样本进行编辑和转换 |
| [ Mapper ]( #mapper ) | 44 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 41 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |
Expand Down Expand Up @@ -76,6 +76,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) |
| remove_words_with_incorrect_<br />substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 |
| replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 |
| sentence_augmentation_mapper | General | en, zh | 使用大语言模型来给句子做增强 |
| sentence_split_mapper | General | en | 根据语义拆分和重组句子 |
| video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 |
| video_captioning_from_frames_mapper | Multimodal | - | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 |
Expand Down
37 changes: 37 additions & 0 deletions tests/ops/mapper/test_sentence_augmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest
from data_juicer.ops.mapper.sentence_augmentation_mapper import SentenceAugmentationMapper
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)

# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class SentenceAugmentationMapperTest(DataJuicerTestCaseBase):

text_key = 'text'

def _run_sentence_augmentation_mapper(self):
op = SentenceAugmentationMapper(
hf_model='Qwen2-7B-Instruct',
task_sentence="Please replace one entity in this sentence with another entity, such as an animal, a vehicle, or a piece of furniture. Please only answer with the replaced sentence. ASSISTANT:",
max_new_tokens=512,
temperature=0.9,
top_p=0.95,
num_beams=1,
)

samples = [
{self.text_key: 'a book is near a cat and a dog'}
]

for sample in samples:
result = op.process(sample)
print(f'Output results: {result}')
self.assertIn(self.text_key, result)
Qirui-jiao marked this conversation as resolved.
Show resolved Hide resolved

def test_sentence_augmentation_mapper(self):
self._run_sentence_augmentation_mapper()



if __name__ == '__main__':
unittest.main()
Loading