Skip to content

Commit

Permalink
meta tags aggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Dec 19, 2024
1 parent e4c6ff1 commit 12f8946
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions data_juicer/ops/aggregator/meta_tags_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ class MetaTagsAggregator(Aggregator):
'** 养生归类为健康 **\n'
'** 科学创新归类为科技 **\n')

DEFAULT_INPUT_TEMPLATE = ('| 合并前标签 | 频次 |\n'
DEFAULT_INPUT_TEMPLATE = ('{target_tag_str}'
'| 合并前标签 | 频次 |\n'
'| ------ | ------ |\n'
'{tag_strs}')
DEFAULT_TARGET_TAG_TEMPLATE = '合并后的标签应限定在[{target_tags}]中。\n'
DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |'

DEFAULT_OUTPUT_PATTERN = r'\*\*\s*(\w+)归类为(\w+)\s*\*\*'
Expand All @@ -86,6 +88,7 @@ def __init__(self,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
target_tag_template: Optional[str] = None,
tag_template: Optional[str] = None,
output_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
Expand All @@ -102,6 +105,7 @@ def __init__(self,
Defaults to 'choices.0.message.content'.
:param system_prompt: The system prompt.
:param input_template: The input template.
:param target_tag_template: The tap template for target tags.
:param tag_template: The tap template for each tag and its
frequency.
:param output_pattern: The output pattern.
Expand All @@ -115,13 +119,19 @@ def __init__(self,
super().__init__(**kwargs)

self.meta_tag_key = meta_tag_key
self.target_tags = target_tags

self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
target_tag_template = target_tag_template or \
self.DEFAULT_TARGET_TAG_TEMPLATE
self.tag_template = tag_template or self.DEFAULT_TAG_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN

self.target_tag_str = ''
if target_tags:
self.target_tag_str = target_tag_template(
target_tags=', '.join(target_tags))

self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
Expand All @@ -143,12 +153,13 @@ def parse_output(self, response):
def meta_map(self, meta_cnts, rank=None):

model, _ = get_model(self.model_key, rank, self.use_cuda())

tag_strs = [
self.tag_template.format(tag=k, cnt=meta_cnts[k])
for k in meta_cnts
]

input_prompt = self.input_template.format(tag_strs='\n'.join(tag_strs))
input_prompt = self.input_template.format(
target_tag_str=self.target_tag_str, tag_strs='\n'.join(tag_strs))

messages = [{
'role': 'system',
Expand Down

0 comments on commit 12f8946

Please sign in to comment.