From 12f8946b38aff7f606284dd1c554e3cf35dc8408 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 19 Dec 2024 16:47:45 +0800 Subject: [PATCH] meta tags aggregator --- .../ops/aggregator/meta_tags_aggregator.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py index 12518e255..ec9f5201d 100644 --- a/data_juicer/ops/aggregator/meta_tags_aggregator.py +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -70,9 +70,11 @@ class MetaTagsAggregator(Aggregator): '** 养生归类为健康 **\n' '** 科学创新归类为科技 **\n') - DEFAULT_INPUT_TEMPLATE = ('| 合并前标签 | 频次 |\n' + DEFAULT_INPUT_TEMPLATE = ('{target_tag_str}' + '| 合并前标签 | 频次 |\n' '| ------ | ------ |\n' '{tag_strs}') + DEFAULT_TARGET_TAG_TEMPLATE = '合并后的标签应限定在[{target_tags}]中。\n' DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |' DEFAULT_OUTPUT_PATTERN = r'\*\*\s*(\w+)归类为(\w+)\s*\*\*' @@ -86,6 +88,7 @@ def __init__(self, response_path: Optional[str] = None, system_prompt: Optional[str] = None, input_template: Optional[str] = None, + target_tag_template: Optional[str] = None, tag_template: Optional[str] = None, output_pattern: Optional[str] = None, try_num: PositiveInt = 3, @@ -102,6 +105,7 @@ def __init__(self, Defaults to 'choices.0.message.content'. :param system_prompt: The system prompt. :param input_template: The input template. + :param target_tag_template: The tap template for target tags. :param tag_template: The tap template for each tag and its frequency. :param output_pattern: The output pattern. @@ -115,13 +119,19 @@ def __init__(self, super().__init__(**kwargs) self.meta_tag_key = meta_tag_key - self.target_tags = target_tags self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + target_tag_template = target_tag_template or \ + self.DEFAULT_TARGET_TAG_TEMPLATE self.tag_template = tag_template or self.DEFAULT_TAG_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + self.target_tag_str = '' + if target_tags: + self.target_tag_str = target_tag_template( + target_tags=', '.join(target_tags)) + self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', model=api_model, @@ -143,12 +153,13 @@ def parse_output(self, response): def meta_map(self, meta_cnts, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) + tag_strs = [ self.tag_template.format(tag=k, cnt=meta_cnts[k]) for k in meta_cnts ] - - input_prompt = self.input_template.format(tag_strs='\n'.join(tag_strs)) + input_prompt = self.input_template.format( + target_tag_str=self.target_tag_str, tag_strs='\n'.join(tag_strs)) messages = [{ 'role': 'system',