diff --git a/configs/config_all.yaml b/configs/config_all.yaml
index 1003c89af..9811b0e97 100644
--- a/configs/config_all.yaml
+++ b/configs/config_all.yaml
@@ -79,9 +79,9 @@ process:
- clean_copyright_mapper: # remove copyright comments.
- expand_macro_mapper: # expand macro definitions in Latex text.
- extract_entity_attribute_mapper: # Extract attributes for given entities from the text.
+ api_model: 'gpt-4o' # API model name.
query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried.
query_attributes: ["人物性格"] # Attribute list to be queried.
- api_model: 'gpt-4o' # API model name.
entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction.
entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted.
attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description.
@@ -153,6 +153,18 @@ process:
drop_text: false # If drop the text in the output.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - extract_support_text_mapper: # extract support sub text for a summary.
+ api_model: 'gpt-4o' # API model name.
+ summary_key: '__dj__event_description__' # The field name to store the input summary. Support for nested keys such as "__dj__stats__.text_len".
+ support_text_key: '__dj__support_text__' # The field name to store the output support text for the summary.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # System prompt for the task.
+ input_template: null # Template for building the model input.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ drop_text: false # If drop the text in the output.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- fix_unicode_mapper: # fix unicode errors in text.
- generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples.
hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs.
@@ -259,12 +271,27 @@ process:
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call.
- punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations.
- - python_python_mapper: # executing Python lambda function defined in a file.
+ - python_file_mapper: # executing Python lambda function defined in a file.
file_path: '' # The path to the Python file containing the function to be executed.
function_name: 'process_single' # The name of the function defined in the file to be executed.
- python_lambda_mapper: # executing Python lambda function on data samples.
lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used.
batched: False # A boolean indicating whether to process input data in batches.
+ - relation_identity_mapper: # identify relation between two entity in the text.
+ api_model: 'gpt-4o' # API model name.
+ source_entity: '孙悟空' # The source entity of the relation to be dentified.
+ target_entity: '猪八戒' # The target entity of the relation to be identified.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is input_key in default.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt_template: null # System prompt template for the task. Need to specify by entity1 and entity2.
+ input_template: null # Template for building the model input.
+ output_pattern_template: null # Regular expression template for parsing model output.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ drop_text: false # If drop the text in the output.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- remove_bibliography_mapper: # remove bibliography from Latex text.
- remove_comments_mapper: # remove comments from Latex text, code, etc.
doc_type: tex # comment type you want to remove. Only support 'tex' for now.
@@ -693,3 +720,55 @@ process:
top_ratio: # ratio of selected top samples
topk: # number of selected top sample
reverse: True # determine the sorting rule, if reverse=True, then sort in descending order
+
+# Grouper ops.
+ - naive_grouper: # Group all samples to one batched sample.
+ - key_value_grouper: # Group samples to batched samples according values in given keys.
+ group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default.
+
+# Aggregator ops.
+ - entity_attribute_aggregator: # Return conclusion of the given entity's attribute from some docs.
+ api_model: 'gpt-4o' # API model name.
+ entity: '孙悟空' # The given entity.
+ attribute: '人物经历' # The given attribute.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
+ word_limit: 100 # Prompt the output length.
+ max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute.
+ example_prompt: null # The example part in the system prompt.
+ input_template: null # The input template.
+ output_pattern_template: null # The output template.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance.
+ api_model: 'gpt-4o' # API model name.
+ entity: '孙悟空' # The given entity.
+ query_entity_type: '人物' # The type of queried relavant entities.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
+ max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and entity_type.
+ input_template: null # The input template.
+ output_pattern: null # The output pattern.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - nested_aggregator: # Considering the limitation of input length, nested aggregate contents for each given number of samples.
+ api_model: 'gpt-4o' # API model name.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
+ max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # The system prompt.
+ sub_doc_template: null # The template for input text in each sample.
+ input_template: null # The input template.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py
index 71f871f10..6cf30b5a9 100644
--- a/data_juicer/config/config.py
+++ b/data_juicer/config/config.py
@@ -565,8 +565,13 @@ def sort_op_by_types_and_names(op_name_classes):
if 'deduplicator' in name]
selector_ops = [(name, c) for (name, c) in op_name_classes
if 'selector' in name]
+ grouper_ops = [(name, c) for (name, c) in op_name_classes
+ if 'grouper' in name]
+ aggregator_ops = [(name, c) for (name, c) in op_name_classes
+ if 'aggregator' in name]
ops_sorted_by_types = sorted(mapper_ops) + sorted(filter_ops) + sorted(
- deduplicator_ops) + sorted(selector_ops)
+ deduplicator_ops) + sorted(selector_ops) + sorted(grouper_ops) + \
+ sorted(aggregator_ops)
return ops_sorted_by_types
diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py
index c7ab44c25..e02e10efa 100644
--- a/data_juicer/ops/__init__.py
+++ b/data_juicer/ops/__init__.py
@@ -1,6 +1,6 @@
-from . import deduplicator, filter, mapper, selector
-from .base_op import (OPERATORS, UNFORKABLE, Deduplicator, Filter, Mapper,
- Selector)
+from . import aggregator, deduplicator, filter, grouper, mapper, selector
+from .base_op import (OPERATORS, UNFORKABLE, Aggregator, Deduplicator, Filter,
+ Grouper, Mapper, Selector)
from .load import load_ops
__all__ = [
@@ -9,4 +9,6 @@
'Mapper',
'Deduplicator',
'Selector',
+ 'Grouper',
+ 'Aggregator',
]
diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py
new file mode 100644
index 000000000..4afe2974a
--- /dev/null
+++ b/data_juicer/ops/aggregator/__init__.py
@@ -0,0 +1,8 @@
+from .entity_attribute_aggregator import EntityAttributeAggregator
+from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator
+from .nested_aggregator import NestedAggregator
+
+__all__ = [
+ 'NestedAggregator', 'EntityAttributeAggregator',
+ 'MostRelavantEntitiesAggregator'
+]
diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py
new file mode 100644
index 000000000..96fbbb63f
--- /dev/null
+++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py
@@ -0,0 +1,200 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Aggregator
+from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
+ is_string_list, nested_access,
+ nested_set)
+from data_juicer.utils.lazy_loader import LazyLoader
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+from .nested_aggregator import NestedAggregator
+
+torch = LazyLoader('torch', 'torch')
+vllm = LazyLoader('vllm', 'vllm')
+
+OP_NAME = 'entity_attribute_aggregator'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class EntityAttributeAggregator(Aggregator):
+ """
+ Return conclusion of the given entity's attribute from some docs.
+ """
+
+ DEFAULT_SYSTEM_TEMPLATE = (
+ '给定与`{entity}`相关的一些文档,总结`{entity}`的`{attribute}`。\n'
+ '要求:\n'
+ '- 尽量使用原文专有名词\n'
+ '- 联系上下文,自动忽略上下文不一致的细节错误\n'
+ '- 只对文档中与`{entity}`的`{attribute}`有关的内容进行总结\n'
+ '- 字数限制在**{word_limit}字以内**\n'
+ '- 要求输出格式如下:\n'
+ '# {entity}\n'
+ '## {attribute}\n'
+ '...\n'
+ '{example}')
+
+ DEFAULT_EXAMPLE_PROMPT = ('- 例如,根据相关文档总结`孙悟空`的`出身背景`,**100字**以内的样例如下:\n'
+ '`孙悟空`的`出身背景`总结:\n'
+ '# 孙悟空\n'
+ '## 出身背景\n'
+ '号称齐天大圣,花果山水帘洞的美猴王、西行取经队伍中的大师兄。'
+ '师父是唐僧玄奘,曾拜菩提祖师学艺。'
+ '亲生父母未知,自石头中孕育而生。自认斗战胜佛,最怕观世音菩萨和紧箍咒。\n')
+
+ DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n'
+ '{sub_docs}\n\n'
+ '`{entity}`的`{attribute}`总结:\n')
+
+ DEFAULT_OUTPUT_PATTERN_TEMPLATE = r'\#\s*{entity}\s*\#\#\s*{attribute}\s*(.*?)\Z' # noqa: E501
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ entity: str = None,
+ attribute: str = None,
+ input_key: str = None,
+ output_key: str = None,
+ word_limit: PositiveInt = 100,
+ max_token_num: Optional[PositiveInt] = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt_template: Optional[str] = None,
+ example_prompt: Optional[str] = None,
+ input_template: Optional[str] = None,
+ output_pattern_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param entity: The given entity.
+ :param attribute: The given attribute.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is same as the
+ input_key in default.
+ :param word_limit: Prompt the output length.
+ :param max_token_num: The max token num of the total tokens of the
+ sub documents. Without limitation if it is None.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt_template: The system prompt template.
+ :param example_prompt: The example part in the system prompt.
+ :param input_template: The input template.
+ :param output_pattern_template: The output template.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ if entity is None or attribute is None:
+ raise ValueError('The entity and attribute cannot be None!')
+
+ self.entity = entity
+ self.attribute = attribute
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+ self.word_limit = word_limit
+ self.max_token_num = max_token_num
+
+ system_prompt_template = system_prompt_template or \
+ self.DEFAULT_SYSTEM_TEMPLATE
+ self.example_prompt = example_prompt or self.DEFAULT_EXAMPLE_PROMPT
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+ output_pattern_template = output_pattern_template or \
+ self.DEFAULT_OUTPUT_PATTERN_TEMPLATE
+ self.system_prompt = system_prompt_template.format(
+ entity=self.entity,
+ attribute=self.attribute,
+ word_limit=self.word_limit,
+ example=self.example_prompt)
+ self.output_pattern = output_pattern_template.format(
+ entity=entity, attribute=attribute)
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ return_processor=True,
+ **model_params)
+
+ self.try_num = try_num
+ self.nested_sum = NestedAggregator(model=api_model,
+ max_token_num=max_token_num,
+ api_endpoint=api_endpoint,
+ response_path=response_path,
+ try_num=try_num,
+ model_params=model_params,
+ sampling_params=sampling_params)
+
+ def parse_output(self, response):
+ pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
+ matches = pattern.findall(response)
+ if matches:
+ result = matches[0].strip()
+ else:
+ result = ''
+
+ return result
+
+ def attribute_summary(self, sub_docs, rank=None):
+ if not sub_docs:
+ return ''
+
+ model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
+ token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs]
+ group_docs = avg_split_string_list_under_limit(sub_docs, token_nums,
+ self.max_token_num)
+ results = []
+ for docs in group_docs:
+ doc_str = '\n\n'.join(docs)
+ input_prompt = self.input_template.format(entity=self.entity,
+ attribute=self.attribute,
+ sub_docs=doc_str)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ result = ''
+ for i in range(self.try_num):
+ try:
+ response = model(messages, **self.sampling_params)
+ result = self.parse_output(response)
+ if len(result) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+ results.append(result)
+
+ return self.nested_sum.recursive_summary(results)
+
+ def process_single(self, sample=None, rank=None):
+
+ # if not batched sample
+ sub_docs = nested_access(sample, self.input_key)
+ if not is_string_list(sub_docs):
+ return sample
+
+ sample = nested_set(sample, self.output_key,
+ self.attribute_summary(sub_docs, rank=rank))
+
+ return sample
diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
new file mode 100644
index 000000000..69e1a209c
--- /dev/null
+++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
@@ -0,0 +1,183 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Aggregator
+from data_juicer.utils.common_utils import (is_string_list, nested_access,
+ nested_set)
+from data_juicer.utils.lazy_loader import LazyLoader
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+from ..common import split_text_by_punctuation
+
+torch = LazyLoader('torch', 'torch')
+vllm = LazyLoader('vllm', 'vllm')
+
+OP_NAME = 'most_relavant_entities_aggregator'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class MostRelavantEntitiesAggregator(Aggregator):
+ """
+ Extract entities closely related to a given entity from some texts,
+ and sort them in descending order of importance.
+ """
+
+ DEFAULT_SYSTEM_TEMPLATE = (
+ '给定与`{entity}`相关的一些文档,'
+ '总结一些与`{entity}`最为相关的`{entity_type}`。\n'
+ '要求:\n'
+ '- 不用包含与{entity}为同一{entity_type}的{entity_type}。\n'
+ '- 请按照人物的重要性进行排序,**越重要人物在列表越前面**。\n'
+ '- 你的返回格式如下:\n'
+ '## 分析\n'
+ '你对各个{entity_type}与{entity}关联度的分析\n'
+ '## 列表\n'
+ '人物1, 人物2, 人物3, ...')
+
+ DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n'
+ '{sub_docs}\n\n'
+ '与`{entity}`最相关的一些`{entity_type}`:\n')
+
+ DEFAULT_OUTPUT_PATTERN = r'\#\#\s*列表\s*(.*?)\Z'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ entity: str = None,
+ query_entity_type: str = None,
+ input_key: str = None,
+ output_key: str = None,
+ max_token_num: Optional[PositiveInt] = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt_template: Optional[str] = None,
+ input_template: Optional[str] = None,
+ output_pattern: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param entity: The given entity.
+ :param query_entity_type: The type of queried relavant entities.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is same as the
+ input_key in default.
+ :param max_token_num: The max token num of the total tokens of the
+ sub documents. Without limitation if it is None.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt_template: The system prompt template.
+ :param input_template: The input template.
+ :param output_pattern: The output pattern.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ if entity is None or query_entity_type is None:
+ raise ValueError(
+ 'The entity and query_entity_type cannot be None!')
+
+ self.entity = entity
+ self.query_entity_type = query_entity_type
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+ self.max_token_num = max_token_num
+
+ system_prompt_template = system_prompt_template or \
+ self.DEFAULT_SYSTEM_TEMPLATE
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+ self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
+ self.system_prompt = system_prompt_template.format(
+ entity=entity, entity_type=query_entity_type)
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ return_processor=True,
+ **model_params)
+
+ self.try_num = try_num
+
+ def parse_output(self, response):
+ pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
+ matches = pattern.findall(response)
+ if matches:
+ result = matches[0].strip()
+ else:
+ result = ''
+ result = split_text_by_punctuation(result)
+
+ return result
+
+ def query_most_relavant_entities(self, sub_docs, rank=None):
+ if not sub_docs:
+ return ''
+
+ model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
+ token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs]
+ if self.max_token_num is None:
+ final_docs = sub_docs
+ else:
+ final_docs = []
+ total_num = 0
+ for token_num, doc in zip(token_nums, sub_docs):
+ total_num += token_num
+ if total_num > self.max_token_num:
+ break
+ final_docs.append(doc)
+
+ doc_str = '\n\n'.join(final_docs)
+ input_prompt = self.input_template.format(
+ entity=self.entity,
+ entity_type=self.query_entity_type,
+ sub_docs=doc_str)
+
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ result = []
+ for i in range(self.try_num):
+ try:
+ response = model(messages, **self.sampling_params)
+ result = self.parse_output(response)
+ if len(result) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ return result
+
+ def process_single(self, sample=None, rank=None):
+
+ # if not batched sample
+ sub_docs = nested_access(sample, self.input_key)
+ if not is_string_list(sub_docs):
+ return sample
+
+ sample = nested_set(
+ sample, self.output_key,
+ self.query_most_relavant_entities(sub_docs, rank=rank))
+
+ return sample
diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py
new file mode 100644
index 000000000..124eb1470
--- /dev/null
+++ b/data_juicer/ops/aggregator/nested_aggregator.py
@@ -0,0 +1,179 @@
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Aggregator
+from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
+ is_string_list, nested_access)
+from data_juicer.utils.lazy_loader import LazyLoader
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+torch = LazyLoader('torch', 'torch')
+vllm = LazyLoader('vllm', 'vllm')
+
+OP_NAME = 'nested_aggregator'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class NestedAggregator(Aggregator):
+ """
+ Considering the limitation of input length, nested aggregate
+ contents for each given number of samples.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('给定一些文档碎片,将这些文档整合成一个文档总结。\n'
+ '要求:\n'
+ '- 总结的长度与文档碎片的平均长度基本一致\n'
+ '- 不要包含主观看法\n'
+ '- 注意要尽可能保留文本的专有名词\n'
+ '- 只输出文档总结不要输出其他内容\n'
+ '- 参考如下样例:\n'
+ '文档碎片:\n'
+ '唐僧师徒四人行至白虎岭,遇上了变化多端的白骨精。\n\n'
+ '文档碎片:\n'
+ '白骨精首次变身少女送斋,被孙悟空识破打死,唐僧责怪悟空。\n\n'
+ '文档碎片:\n'
+ '妖怪再变老妇寻女,又被悟空击毙,师傅更加不满,念紧箍咒惩罚。\n\n'
+ '文档碎片:\n'
+ '不甘心的白骨精第三次化作老公公来诱骗,依旧逃不过金睛火眼。\n\n'
+ '文档碎片:\n'
+ '最终,在观音菩萨的帮助下,真相大白,唐僧明白了自己的误解。\n\n'
+ '\n'
+ '文档总结:\n'
+ '唐僧师徒在白虎岭三遇白骨精变化诱惑,悟空屡次识破击毙妖怪却遭误解,最终观音相助真相大白。')
+
+ DEFAULT_INPUT_TEMPLATE = ('{sub_docs}\n\n'
+ '文档总结:\n')
+
+ DEFAULT_SUB_DOC_TEMPLATE = '文档碎片:\n{text}\n'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ input_key: str = None,
+ output_key: str = None,
+ max_token_num: Optional[PositiveInt] = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ sub_doc_template: Optional[str] = None,
+ input_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is same as the
+ input_key in default.
+ :param max_token_num: The max token num of the total tokens of the
+ sub documents. Without limitation if it is None.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: The system prompt.
+ :param sub_doc_template: The template for input text in each sample.
+ :param input_template: The input template.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+ self.max_token_num = max_token_num
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.sub_doc_template = sub_doc_template or \
+ self.DEFAULT_SUB_DOC_TEMPLATE
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ return_processor=True,
+ **model_params)
+
+ self.try_num = try_num
+
+ def parse_output(self, response):
+
+ def if_match(text):
+ quotes = [("'", "'"), ('"', '"'), ('“', '”'), ('‘', '’'),
+ ('`', '`')]
+ if len(text) < 2:
+ return False
+ if (text[0], text[-1]) in quotes:
+ return True
+ else:
+ return False
+
+ text = response.strip()
+ while if_match(text):
+ text = text[1:-1].strip()
+ return text
+
+ def recursive_summary(self, sub_docs, rank=None):
+ if not sub_docs:
+ return ''
+ if len(sub_docs) == 1:
+ return sub_docs[0]
+ model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
+ token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs]
+ group_docs = avg_split_string_list_under_limit(sub_docs, token_nums,
+ self.max_token_num)
+ # merge every two if every single sub doc is a group
+ group_num = len(group_docs)
+ if group_num == len(sub_docs):
+ group_docs = [
+ group_docs[i] +
+ group_docs[i + 1] if i + 1 < group_num else group_docs[i]
+ for i in range(0, group_num, 2)
+ ]
+ results = []
+ for docs in group_docs:
+ doc_strs = [self.sub_doc_template.format(text=d) for d in docs]
+ input_prompt = self.input_template.format(
+ sub_docs='\n'.join(doc_strs))
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ result = ''
+ for i in range(self.try_num):
+ try:
+ response = model(messages, **self.sampling_params)
+ result = self.parse_output(response)
+ if len(result) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+ results.append(result)
+ return self.recursive_summary(results)
+
+ def process_single(self, sample=None, rank=None):
+
+ # if not batched sample
+ sub_docs = nested_access(sample, self.input_key)
+ if not is_string_list(sub_docs):
+ return sample
+
+ sample[self.output_key] = self.recursive_summary(sub_docs, rank=rank)
+
+ return sample
diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py
index 72618a6bb..2091a867e 100644
--- a/data_juicer/ops/base_op.py
+++ b/data_juicer/ops/base_op.py
@@ -131,6 +131,11 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
+ :param index_key: index the samples before process if not None
"""
# init data keys
self.text_key = kwargs.get('text_key', 'text')
@@ -142,6 +147,8 @@ def __init__(self, *args, **kwargs):
self.response_key = kwargs.get('response_key', 'response')
self.history_key = kwargs.get('history_key', 'history')
+ self.index_key = kwargs.get('index_key', None)
+
self.batch_size = kwargs.get('batch_size', 1000)
# whether the model can be accelerated using cuda
@@ -216,6 +223,14 @@ def run(self, dataset):
from data_juicer.core.data import NestedDataset
if not isinstance(dataset, NestedDataset):
dataset = NestedDataset(dataset)
+ if self.index_key is not None:
+
+ def add_index(sample, idx):
+ sample[self.index_key] = idx
+ return sample
+
+ dataset = dataset.map(add_index, with_indices=True)
+
return dataset
def empty_history(self):
@@ -236,6 +251,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Mapper, self).__init__(*args, **kwargs)
@@ -316,6 +335,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Filter, self).__init__(*args, **kwargs)
self.stats_export_path = kwargs.get('stats_export_path', None)
@@ -424,6 +447,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Deduplicator, self).__init__(*args, **kwargs)
@@ -483,6 +510,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Selector, self).__init__(*args, **kwargs)
@@ -501,3 +532,90 @@ def run(self, dataset, *, exporter=None, tracer=None):
if tracer:
tracer.trace_filter(self._name, dataset, new_dataset)
return new_dataset
+
+
+class Grouper(OP):
+
+ def __init__(self, *args, **kwargs):
+ """
+ Base class that group samples.
+
+ :param text_key: the key name of field that stores sample texts
+ to be processed
+ :param image_key: the key name of field that stores sample image list
+ to be processed
+ :param audio_key: the key name of field that stores sample audio list
+ to be processed
+ :param video_key: the key name of field that stores sample video list
+ to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
+ """
+ super(Grouper, self).__init__(*args, **kwargs)
+
+ def process(self, dataset):
+ """
+ Dataset --> dataset.
+
+ :param dataset: input dataset
+ :return: dataset of batched samples.
+ """
+ raise NotImplementedError
+
+ def run(self, dataset, *, exporter=None, tracer=None):
+ dataset = super(Grouper, self).run(dataset)
+ batched_samples = self.process(dataset)
+ from data_juicer.core.data import NestedDataset
+ new_dataset = NestedDataset.from_list(batched_samples)
+ if tracer:
+ tracer.trace_filter(self._name, dataset, new_dataset)
+ return new_dataset
+
+
+class Aggregator(OP):
+
+ def __init__(self, *args, **kwargs):
+ """
+ Base class that group samples.
+
+ :param text_key: the key name of field that stores sample texts
+ to be processed
+ :param image_key: the key name of field that stores sample image list
+ to be processed
+ :param audio_key: the key name of field that stores sample audio list
+ to be processed
+ :param video_key: the key name of field that stores sample video list
+ to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
+ """
+ super(Aggregator, self).__init__(*args, **kwargs)
+ self.process = catch_map_single_exception(self.process_single)
+
+ def process_single(self, sample):
+ """
+ For sample level, batched sample --> sample,
+ the input must be the output of some Grouper OP.
+
+ :param sample: batched sample to aggregate
+ :return: aggregated sample
+ """
+ raise NotImplementedError
+
+ def run(self, dataset, *, exporter=None, tracer=None):
+ dataset = super(Aggregator, self).run(dataset)
+ new_dataset = dataset.map(
+ self.process,
+ num_proc=self.runtime_np(),
+ with_rank=self.use_cuda(),
+ batch_size=self.batch_size,
+ desc=self._name + '_process',
+ )
+ if tracer:
+ tracer.trace_mapper(self._name, dataset, new_dataset,
+ self.text_key)
+ return new_dataset
diff --git a/data_juicer/ops/grouper/__init__.py b/data_juicer/ops/grouper/__init__.py
new file mode 100644
index 000000000..048b305e4
--- /dev/null
+++ b/data_juicer/ops/grouper/__init__.py
@@ -0,0 +1,4 @@
+from .key_value_grouper import KeyValueGrouper
+from .naive_grouper import NaiveGrouper
+
+__all__ = ['NaiveGrouper', 'KeyValueGrouper']
diff --git a/data_juicer/ops/grouper/key_value_grouper.py b/data_juicer/ops/grouper/key_value_grouper.py
new file mode 100644
index 000000000..3d786319f
--- /dev/null
+++ b/data_juicer/ops/grouper/key_value_grouper.py
@@ -0,0 +1,51 @@
+from typing import List, Optional
+
+from data_juicer.utils.common_utils import dict_to_hash, nested_access
+
+from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list
+from .naive_grouper import NaiveGrouper
+
+
+@OPERATORS.register_module('key_value_grouper')
+class KeyValueGrouper(Grouper):
+ """Group samples to batched samples according values in given keys. """
+
+ def __init__(self,
+ group_by_keys: Optional[List[str]] = None,
+ *args,
+ **kwargs):
+ """
+ Initialization method.
+
+ :param group_by_keys: group samples according values in the keys.
+ Support for nested keys such as "__dj__stats__.text_len".
+ It is [self.text_key] in default.
+ :param args: extra args
+ :param kwargs: extra args
+ """
+ super().__init__(*args, **kwargs)
+
+ self.group_by_keys = group_by_keys or [self.text_key]
+ self.naive_grouper = NaiveGrouper()
+
+ def process(self, dataset):
+
+ if len(dataset) == 0:
+ return dataset
+
+ sample_map = {}
+ for sample in dataset:
+ cur_dict = {}
+ for key in self.group_by_keys:
+ cur_dict[key] = nested_access(sample, key)
+ sample_key = dict_to_hash(cur_dict)
+ if sample_key in sample_map:
+ sample_map[sample_key].append(sample)
+ else:
+ sample_map[sample_key] = [sample]
+
+ batched_samples = [
+ convert_list_dict_to_dict_list(sample_map[k]) for k in sample_map
+ ]
+
+ return batched_samples
diff --git a/data_juicer/ops/grouper/naive_grouper.py b/data_juicer/ops/grouper/naive_grouper.py
new file mode 100644
index 000000000..4633dc48e
--- /dev/null
+++ b/data_juicer/ops/grouper/naive_grouper.py
@@ -0,0 +1,24 @@
+from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list
+
+
+@OPERATORS.register_module('naive_grouper')
+class NaiveGrouper(Grouper):
+ """Group all samples to one batched sample. """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Initialization method.
+
+ :param args: extra args
+ :param kwargs: extra args
+ """
+ super().__init__(*args, **kwargs)
+
+ def process(self, dataset):
+
+ if len(dataset) == 0:
+ return dataset
+
+ batched_sample = convert_list_dict_to_dict_list(dataset)
+
+ return [batched_sample]
diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py
index 3c95d1fe6..9b86b83dc 100644
--- a/data_juicer/ops/mapper/__init__.py
+++ b/data_juicer/ops/mapper/__init__.py
@@ -14,6 +14,7 @@
from .extract_event_mapper import ExtractEventMapper
from .extract_keyword_mapper import ExtractKeywordMapper
from .extract_nickname_mapper import ExtractNicknameMapper
+from .extract_support_text_mapper import ExtractSupportTextMapper
from .fix_unicode_mapper import FixUnicodeMapper
from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper
from .generate_qa_from_text_mapper import GenerateQAFromTextMapper
@@ -32,6 +33,7 @@
from .punctuation_normalization_mapper import PunctuationNormalizationMapper
from .python_file_mapper import PythonFileMapper
from .python_lambda_mapper import PythonLambdaMapper
+from .relation_identity_mapper import RelationIdentityMapper
from .remove_bibliography_mapper import RemoveBibliographyMapper
from .remove_comments_mapper import RemoveCommentsMapper
from .remove_header_mapper import RemoveHeaderMapper
@@ -71,25 +73,26 @@
'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper',
'ExpandMacroMapper', 'ExtractEntityAttributeMapper',
'ExtractEntityRelationMapper', 'ExtractEventMapper',
- 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'FixUnicodeMapper',
+ 'ExtractKeywordMapper', 'ExtractNicknameMapper',
+ 'ExtractSupportTextMapper', 'FixUnicodeMapper',
'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper',
'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper',
'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper',
'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper',
'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PairPreferenceMapper', 'PunctuationNormalizationMapper',
- 'PythonFileMapper', 'PythonLambdaMapper', 'RemoveBibliographyMapper',
- 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
- 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
- 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper',
- 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper',
- 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper',
- 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper',
- 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper',
- 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
- 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
- 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
- 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
- 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
- 'WhitespaceNormalizationMapper'
+ 'PythonFileMapper', 'PythonLambdaMapper', 'RelationIdentityMapper',
+ 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper',
+ 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper',
+ 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper',
+ 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper',
+ 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper',
+ 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
+ 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
+ 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper',
+ 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
+ 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
+ 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
+ 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
+ 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
]
diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py
index fd93cfe03..0fc76b11f 100644
--- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py
+++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py
@@ -1,5 +1,4 @@
import re
-from itertools import chain
from typing import Dict, List, Optional
from loguru import logger
@@ -19,34 +18,37 @@ class ExtractEntityAttributeMapper(Mapper):
Extract attributes for given entities from the text
"""
- _batched_op = True
-
DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
'给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n'
'要求:\n'
'- 摘录的示例应该简短。\n'
'- 遵循如下的回复格式:\n'
+ '# {entity}\n'
'## {attribute}:\n'
- '{entity}的{attribute}描述...\n'
- '### 代表性示例1:\n'
- '说明{entity}该{attribute}的原文摘录1...\n'
- '### 代表性示例2:\n'
- '说明{entity}该{attribute}的原文摘录2...\n'
+ '...\n'
+ '### 代表性示例摘录1:\n'
+ '```\n'
+ '...\n'
+ '```\n'
+ '### 代表性示例摘录2:\n'
+ '```\n'
+ '...\n'
+ '```\n'
'...\n')
DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n'
DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)'
- DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)'
+ DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例摘录(\d+):\s*```\s*(.*?)```\s*(?=\#\#\#|\Z)' # noqa: E501
def __init__(self,
+ api_model: str = 'gpt-4o',
query_entities: List[str] = [],
query_attributes: List[str] = [],
- api_model: str = 'gpt-4o',
*,
- entity_key: str = Fields.main_entity,
- attribute_key: str = Fields.attribute,
- attribute_desc_key: str = Fields.attribute_description,
- support_text_key: str = Fields.attribute_support_text,
+ entity_key: str = Fields.main_entities,
+ attribute_key: str = Fields.attributes,
+ attribute_desc_key: str = Fields.attribute_descriptions,
+ support_text_key: str = Fields.attribute_support_texts,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt_template: Optional[str] = None,
@@ -60,9 +62,9 @@ def __init__(self,
**kwargs):
"""
Initialization method.
+ :param api_model: API model name.
:param query_entities: Entity list to be queried.
:param query_attributes: Attribute list to be queried.
- :param api_model: API model name.
:param entity_key: The field name to store the given main entity for
attribute extraction. It's "__dj__entity__" in default.
:param entity_attribute_key: The field name to store the given
@@ -135,7 +137,7 @@ def parse_output(self, raw_output, attribute_name):
return attribute, demos
- def _process_single_sample(self, text='', rank=None):
+ def _process_single_text(self, text='', rank=None):
client = get_model(self.model_key, rank=rank)
entities, attributes, descs, demo_lists = [], [], [], []
@@ -168,31 +170,17 @@ def _process_single_sample(self, text='', rank=None):
return entities, attributes, descs, demo_lists
- def process_batched(self, samples, rank=None):
-
- sample_num = len(samples[self.text_key])
+ def process_single(self, sample, rank=None):
- entities, attributes, descs, demo_lists = [], [], [], []
- for text in samples[self.text_key]:
- res = self._process_single_sample(text, rank=rank)
- cur_ents, cur_attrs, cur_descs, cur_demos = res
- entities.append(cur_ents)
- attributes.append(cur_attrs)
- descs.append(cur_descs)
- demo_lists.append(cur_demos)
+ res = self._process_single_text(sample[self.text_key], rank=rank)
+ entities, attributes, descs, demo_lists = res
if self.drop_text:
- samples.pop(self.text_key)
-
- for key in samples:
- samples[key] = [[samples[key][i]] * len(descs[i])
- for i in range(sample_num)]
- samples[self.entity_key] = entities
- samples[self.attribute_key] = attributes
- samples[self.attribute_desc_key] = descs
- samples[self.support_text_key] = demo_lists
+ sample.pop(self.text_key)
- for key in samples:
- samples[key] = list(chain(*samples[key]))
+ sample[self.entity_key] = entities
+ sample[self.attribute_key] = attributes
+ sample[self.attribute_desc_key] = descs
+ sample[self.support_text_key] = demo_lists
- return samples
+ return sample
diff --git a/data_juicer/ops/mapper/extract_support_text_mapper.py b/data_juicer/ops/mapper/extract_support_text_mapper.py
new file mode 100644
index 000000000..34bdbe653
--- /dev/null
+++ b/data_juicer/ops/mapper/extract_support_text_mapper.py
@@ -0,0 +1,132 @@
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_access, nested_set
+from data_juicer.utils.constant import Fields
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'extract_support_text_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class ExtractSupportTextMapper(Mapper):
+ """
+ Extract support sub text for a summary.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('你将扮演一个文本摘录助手的角色。你的主要任务是基于给定'
+ '的文章(称为“原文”)以及对原文某个部分的简短描述或总结'
+ '(称为“总结”),准确地识别并提取出与该总结相对应的原文'
+ '片段。\n'
+ '要求:\n'
+ '- 你需要尽可能精确地匹配到最符合总结内容的那部分内容\n'
+ '- 如果存在多个可能的答案,请选择最贴近总结意思的那个\n'
+ '- 下面是一个例子帮助理解这一过程:\n'
+ '### 原文:\n'
+ '《红楼梦》是中国古典小说四大名著之一,由清代作家曹雪芹创'
+ '作。它讲述了贾宝玉、林黛玉等人的爱情故事及四大家族的兴衰'
+ '历程。书中通过复杂的人物关系展现了封建社会的各种矛盾冲突'
+ '。其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二'
+ '姐之间的争斗,生动描绘了权力争夺下的女性形象。此外,《红'
+ '楼梦》还以其精美的诗词闻名,这些诗词不仅增添了文学色彩,'
+ '也深刻反映了人物的性格特点和命运走向。\n\n'
+ '### 总结:\n'
+ '描述了书中的两个女性角色之间围绕权力展开的竞争。\n\n'
+ '### 原文摘录:\n'
+ '其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二姐'
+ '之间的争斗,生动描绘了权力争夺下的女性形象。')
+ DEFAULT_INPUT_TEMPLATE = ('### 原文:\n{text}\n\n'
+ '### 总结:\n{summary}\n\n'
+ '### 原文摘录:\n')
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ *,
+ summary_key: str = Fields.event_description,
+ support_text_key: str = Fields.support_text,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ input_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ drop_text: bool = False,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param summary_key: The field name to store the input summary.
+ Support for nested keys such as "__dj__stats__.text_len".
+ It's "__dj__event_description__" in default.
+ :param support_text_key: The field name to store the output
+ support text for the summary. It's "__dj__support_text__" in
+ default.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: System prompt for the task.
+ :param input_template: Template for building the model input.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param drop_text: If drop the text in the output.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.summary_key = summary_key
+ self.support_text_key = support_text_key
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+ self.drop_text = drop_text
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ summary = nested_access(sample, self.summary_key)
+ if not isinstance(summary, str):
+ logger.warning('Unvalid input summary!')
+ return sample
+
+ input_prompt = self.input_template.format(text=sample[self.text_key],
+ summary=summary)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+
+ support_text = ''
+ for i in range(self.try_num):
+ try:
+ response = client(messages, **self.sampling_params)
+ support_text = response.strip()
+ if len(support_text) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+ # default to summary if return None
+ if not support_text:
+ support_text = summary
+
+ sample = nested_set(sample, self.support_text_key, support_text)
+ return sample
diff --git a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py
index 6f5ad7dab..0c0d084b3 100644
--- a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py
+++ b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py
@@ -9,7 +9,7 @@
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
-from ..base_op import OPERATORS, UNFORKABLE, Mapper
+from ..base_op import OPERATORS, Mapper
torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')
@@ -19,7 +19,6 @@
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class GenerateQAFromExamplesMapper(Mapper):
"""
diff --git a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py
index 248dba428..0f3a1cfef 100644
--- a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py
+++ b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py
@@ -3,7 +3,7 @@
from loguru import logger
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
+from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
@@ -14,7 +14,6 @@
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class GenerateQAFromTextMapper(Mapper):
"""
diff --git a/data_juicer/ops/mapper/optimize_qa_mapper.py b/data_juicer/ops/mapper/optimize_qa_mapper.py
index 3563a112b..974730ec5 100644
--- a/data_juicer/ops/mapper/optimize_qa_mapper.py
+++ b/data_juicer/ops/mapper/optimize_qa_mapper.py
@@ -3,7 +3,7 @@
from loguru import logger
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
+from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
@@ -14,7 +14,6 @@
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class OptimizeQAMapper(Mapper):
"""
diff --git a/data_juicer/ops/mapper/optimize_query_mapper.py b/data_juicer/ops/mapper/optimize_query_mapper.py
index dd227b4c1..9ccd84bb1 100644
--- a/data_juicer/ops/mapper/optimize_query_mapper.py
+++ b/data_juicer/ops/mapper/optimize_query_mapper.py
@@ -1,11 +1,10 @@
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
+from data_juicer.ops.base_op import OPERATORS
from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper
OP_NAME = 'optimize_query_mapper'
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class OptimizeQueryMapper(OptimizeQAMapper):
"""
diff --git a/data_juicer/ops/mapper/optimize_response_mapper.py b/data_juicer/ops/mapper/optimize_response_mapper.py
index 158159a9d..f6026b8dc 100644
--- a/data_juicer/ops/mapper/optimize_response_mapper.py
+++ b/data_juicer/ops/mapper/optimize_response_mapper.py
@@ -1,11 +1,10 @@
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
+from data_juicer.ops.base_op import OPERATORS
from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper
OP_NAME = 'optimize_response_mapper'
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class OptimizeResponseMapper(OptimizeQAMapper):
"""
diff --git a/data_juicer/ops/mapper/relation_identity_mapper.py b/data_juicer/ops/mapper/relation_identity_mapper.py
new file mode 100644
index 000000000..29994d744
--- /dev/null
+++ b/data_juicer/ops/mapper/relation_identity_mapper.py
@@ -0,0 +1,155 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_access, nested_set
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'relation_identity_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class RelationIdentityMapper(Mapper):
+ """
+ identify relation between two entity in the text.
+ """
+
+ DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
+ '给定关于{entity1}和{entity2}的文本信息。'
+ '判断{entity1}和{entity2}之间的关系。\n'
+ '要求:\n'
+ '- 关系用一个或多个词语表示,必要时可以加一个形容词来描述这段关系\n'
+ '- 输出关系时不要参杂任何标点符号\n'
+ '- 需要你进行合理的推理才能得出结论\n'
+ '- 如果两个人物身份是同一个人,输出关系为:另一个身份\n'
+ '- 输出格式为:\n'
+ '分析推理:...\n'
+ '所以{entity2}是{entity1}的:...\n'
+ '- 注意输出的是{entity2}是{entity1}的什么关系,而不是{entity1}是{entity2}的什么关系')
+ DEFAULT_INPUT_TEMPLATE = '关于{entity1}和{entity2}的文本信息:\n```\n{text}\n```\n'
+ DEFAULT_OUTPUT_PATTERN_TEMPLATE = r"""
+ \s*分析推理:\s*(.*?)\s*
+ \s*所以{entity2}是{entity1}的:\s*(.*?)\Z
+ """
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ source_entity: str = None,
+ target_entity: str = None,
+ input_key: str = None,
+ output_key: str = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt_template: Optional[str] = None,
+ input_template: Optional[str] = None,
+ output_pattern_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ drop_text: bool = False,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param source_entity: The source entity of the relation to be
+ identified.
+ :param target_entity: The target entity of the relation to be
+ identified.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support
+ for nested keys such as "__dj__stats__.text_len". It is
+ input_key in default.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt_template: System prompt template for the task.
+ :param input_template: Template for building the model input.
+ :param output_pattern_template: Regular expression template for
+ parsing model output.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param drop_text: If drop the text in the output.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ if source_entity is None or target_entity is None:
+ logger.warning('source_entity and target_entity cannot be None')
+
+ self.source_entity = source_entity
+ self.target_entity = target_entity
+
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+
+ system_prompt_template = system_prompt_template or \
+ self.DEFAULT_SYSTEM_PROMPT_TEMPLATE
+ self.system_prompt = system_prompt_template.format(
+ entity1=source_entity, entity2=target_entity)
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+ output_pattern_template = output_pattern_template or \
+ self.DEFAULT_OUTPUT_PATTERN_TEMPLATE
+ self.output_pattern = output_pattern_template.format(
+ entity1=source_entity, entity2=target_entity)
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+ self.drop_text = drop_text
+
+ def parse_output(self, raw_output):
+ pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
+ matches = pattern.findall(raw_output)
+
+ relation = ''
+
+ for match in matches:
+ _, relation = match
+ relation = relation.strip()
+
+ return relation
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ text = nested_access(sample, self.input_key)
+ input_prompt = self.input_template.format(entity1=self.source_entity,
+ entity2=self.target_entity,
+ text=text)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ relation = ''
+ for i in range(self.try_num):
+ try:
+ output = client(messages, **self.sampling_params)
+ relation = self.parse_output(output)
+ if len(relation) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ sample = nested_set(sample, self.output_key, relation)
+ if self.drop_text:
+ sample.pop(self.text_key)
+
+ return sample
diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py
index 50116ea0c..5ea9091b0 100644
--- a/data_juicer/utils/auto_install_mapping.py
+++ b/data_juicer/utils/auto_install_mapping.py
@@ -79,5 +79,21 @@
['rouge', 'torch', 'transformers', 'vllm'],
'video_tagging_from_frames_mapper': ['ram', 'torch'],
'text_entity_dependency_filter': ['spacy-pkuseg'],
- 'optimize_response_mapper': ['torch', 'transformers', 'vllm']
+ 'optimize_response_mapper': ['torch', 'transformers', 'vllm'],
+ 'text_chunk_mapper': ['transformers', 'dashscope', 'openai'],
+ 'entity_attribute_aggregator': ['transformers', 'dashscope', 'openai'],
+ 'most_relavant_entities_aggregator':
+ ['transformers', 'dashscope', 'openai'],
+ 'nested_aggregator': ['transformers', 'dashscope', 'openai'],
+ 'calibrate_qa_mapper': ['openai'],
+ 'calibrate_query_mapper': ['openai'],
+ 'calibrate_response_mapper': ['openai'],
+ 'extract_entity_attribute_mapper': ['openai'],
+ 'extract_entity_relation_mapper': ['openai'],
+ 'extract_event_mapper': ['openai'],
+ 'extract_keyword_mapper': ['openai'],
+ 'extract_nickname_mapper': ['openai'],
+ 'extract_support_text_mapper': ['openai'],
+ 'pair_preference_mapper': ['openai'],
+ 'relation_identity_mapper': ['openai'],
}
diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py
index 959831c5d..bd649bb96 100644
--- a/data_juicer/utils/common_utils.py
+++ b/data_juicer/utils/common_utils.py
@@ -1,6 +1,8 @@
+import hashlib
import sys
import numpy as np
+from loguru import logger
def stats_to_number(s, reverse=True):
@@ -21,6 +23,122 @@ def stats_to_number(s, reverse=True):
return sys.maxsize
+def dict_to_hash(input_dict: dict, hash_length=None):
+ """
+ hash a dict to a string with length hash_length
+
+ :param input_dict: the given dict
+ """
+ sorted_items = sorted(input_dict.items())
+ dict_string = str(sorted_items).encode()
+ hasher = hashlib.sha256()
+ hasher.update(dict_string)
+ hash_value = hasher.hexdigest()
+ if hash_length:
+ hash_value = hash_value[:hash_length]
+ return hash_value
+
+
+def nested_access(data, path, digit_allowed=True):
+ """
+ Access nested data using a dot-separated path.
+
+ :param data: A dictionary or a list to access the nested data from.
+ :param path: A dot-separated string representing the path to access.
+ This can include numeric indices when accessing list
+ elements.
+ :param digit_allowed: Allow transfering string to digit.
+ :return: The value located at the specified path, or raises a KeyError
+ or IndexError if the path does not exist.
+ """
+ keys = path.split('.')
+ for key in keys:
+ # Convert string keys to integers if they are numeric
+ key = int(key) if key.isdigit() and digit_allowed else key
+ try:
+ data = data[key]
+ except Exception:
+ logger.warning(f'Unaccessible dot-separated path: {path}!')
+ return None
+ return data
+
+
+def nested_set(data: dict, path: str, val):
+ """
+ Set the val to the nested data in the dot-separated path.
+
+ :param data: A dictionary with nested format.
+ :param path: A dot-separated string representing the path to set.
+ This can include numeric indices when setting list
+ elements.
+ :return: The nested data after the val set.
+ """
+ keys = path.split('.')
+ cur = data
+ for key in keys[:-1]:
+ if key not in cur:
+ cur[key] = {}
+ cur = cur[key]
+ cur[keys[-1]] = val
+ return data
+
+
+def is_string_list(var):
+ """
+ return if the var is list of string.
+
+ :param var: input variance
+ """
+ return isinstance(var, list) and all(isinstance(it, str) for it in var)
+
+
+def avg_split_string_list_under_limit(str_list: list,
+ token_nums: list,
+ max_token_num=None):
+ """
+ Split the string list to several sub str_list, such that the total
+ token num of each sub string list is less than max_token_num, keeping
+ the total token nums of sub string lists are similar.
+
+ :param str_list: input string list.
+ :param token_nums: token num of each string list.
+ :param max_token_num: max token num of each sub string list.
+ """
+ if max_token_num is None:
+ return [str_list]
+
+ if len(str_list) != len(token_nums):
+ logger.warning('The length of str_list and token_nums must be equal!')
+ return [str_list]
+
+ total_num = sum(token_nums)
+ if total_num <= max_token_num:
+ return [str_list]
+
+ group_num = total_num // max_token_num + 1
+ avg_num = total_num / group_num
+ res = []
+ cur_list = []
+ cur_sum = 0
+ for text, token_num in zip(str_list, token_nums):
+ if token_num > max_token_num:
+ logger.warning(
+ 'Token num is greater than max_token_num in one sample!')
+ if cur_sum + token_num > max_token_num and cur_list:
+ res.append(cur_list)
+ cur_list = []
+ cur_sum = 0
+ cur_list.append(text)
+ cur_sum += token_num
+ if cur_sum > avg_num:
+ res.append(cur_list)
+ cur_list = []
+ cur_sum = 0
+ if cur_list:
+ res.append(cur_list)
+ return res
+
+
def is_float(s):
try:
float(s)
diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py
index 8bc78ad5a..922d44c8b 100644
--- a/data_juicer/utils/constant.py
+++ b/data_juicer/utils/constant.py
@@ -33,14 +33,14 @@ class Fields(object):
event_description = DEFAULT_PREFIX + 'event_description__'
# # a list of characters relevant to the event
relevant_characters = DEFAULT_PREFIX + 'relevant_characters__'
- # # the given main entity for attribute extraction
- main_entity = DEFAULT_PREFIX + 'main_entity__'
- # # the given attribute to be extracted
- attribute = DEFAULT_PREFIX + 'attribute__'
- # # the extracted attribute description
- attribute_description = DEFAULT_PREFIX + 'attribute_description__'
- # # extract from raw data for support the attribute
- attribute_support_text = DEFAULT_PREFIX + 'attribute_support_text__'
+ # # the given main entities for attribute extraction
+ main_entities = DEFAULT_PREFIX + 'main_entities__'
+ # # the given attributes to be extracted
+ attributes = DEFAULT_PREFIX + 'attributes__'
+ # # the extracted attribute descriptions
+ attribute_descriptions = DEFAULT_PREFIX + 'attribute_descriptions__'
+ # # extract from raw datas for support the attribute
+ attribute_support_texts = DEFAULT_PREFIX + 'attribute_support_texts__'
# # the nickname relationship
nickname = DEFAULT_PREFIX + 'nickname__'
# # the entity for knowledge graph
@@ -65,6 +65,8 @@ class Fields(object):
relation_strength = DEFAULT_PREFIX + 'relation_strength__'
# # the keyword in a text
keyword = DEFAULT_PREFIX + 'keyword__'
+ # # support text
+ support_text = DEFAULT_PREFIX + 'support_text__'
class StatsKeysMeta(type):
diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py
index e2fc241cd..7a8618660 100644
--- a/data_juicer/utils/file_utils.py
+++ b/data_juicer/utils/file_utils.py
@@ -1,6 +1,5 @@
import asyncio
import copy
-import hashlib
import os
import re
import shutil
@@ -10,6 +9,7 @@
from datasets.utils.extract import ZstdExtractor as Extractor
+from data_juicer.utils.common_utils import dict_to_hash
from data_juicer.utils.constant import DEFAULT_PREFIX, Fields
@@ -127,22 +127,6 @@ def add_suffix_to_filename(filename, suffix):
return new_name
-def dict_to_hash(input_dict, hash_length=None):
- """
- hash a dict to a string with length hash_length
-
- :param input_dict: the given dict
- """
- sorted_items = sorted(input_dict.items())
- dict_string = str(sorted_items).encode()
- hasher = hashlib.sha256()
- hasher.update(dict_string)
- hash_value = hasher.hexdigest()
- if hash_length:
- hash_value = hash_value[:hash_length]
- return hash_value
-
-
def create_directory_if_not_exists(directory_path):
"""
create a directory if not exists, this function is process safe
diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py
index 305145e82..94b4440eb 100644
--- a/data_juicer/utils/model_utils.py
+++ b/data_juicer/utils/model_utils.py
@@ -11,6 +11,7 @@
from loguru import logger
from data_juicer import cuda_device_count
+from data_juicer.utils.common_utils import nested_access
from data_juicer.utils.lazy_loader import AUTOINSTALL, LazyLoader
from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC
@@ -167,30 +168,11 @@ def __call__(self, messages, **kwargs):
stream=stream,
stream_cls=stream_cls)
result = response.json()
- return self._nested_access(result, self.response_path)
+ return nested_access(result, self.response_path)
except Exception as e:
logger.exception(e)
return ''
- @staticmethod
- def _nested_access(data, path):
- """
- Access nested data using a dot-separated path.
-
- :param data: A dictionary or a list to access the nested data from.
- :param path: A dot-separated string representing the path to access.
- This can include numeric indices when accessing list
- elements.
- :return: The value located at the specified path, or raises a KeyError
- or IndexError if the path does not exist.
- """
- keys = path.split('.')
- for key in keys:
- # Convert string keys to integers if they are numeric
- key = int(key) if key.isdigit() else key
- data = data[key]
- return data
-
@staticmethod
def _filter_arguments(func, args_dict):
"""
@@ -221,19 +203,18 @@ def prepare_api_model(model,
return_processor=False,
processor_config=None,
**model_params):
- """
- Creates an instance of the APIModel for interacting with OpenAI-like APIs.
+ """Creates a callable API model for interacting with OpenAI-compatible API.
+ The callable supports custom response parsing and works with proxy servers
+ that may be incompatible.
- :param model: The name of the model to be used for making API calls.
+ :param model: The name of the model to interact with.
:param endpoint: The URL endpoint for the API. If provided as a relative
path, it will be appended to the base URL (defined by the
`OPENAI_BASE_URL` environment variable or through an additional
`base_url` parameter). By default, it is set to
'/chat/completions' for OpenAI compatibility.
- :param response_path: A dot-separated string specifying the path to
- extract desired content from the API response. The default value is
- 'choices.0.message.content', which corresponds to the typical
- structure of an OpenAI API response.
+ :param response_path: The dot-separated path to extract desired content
+ from the API response. Defaults to 'choices.0.message.content'.
:param return_processor: A boolean flag indicating whether to return a
processor along with the model. The processor can be used for tasks
like tokenization or encoding. Defaults to False.
@@ -279,8 +260,8 @@ def get_processor():
"- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501
)
- if processor_config is not None \
- and 'pretrained_model_name_or_path' in processor_config:
+ if processor_config is not None and \
+ 'pretrained_model_name_or_path' in processor_config:
processor = transformers.AutoProcessor.from_pretrained(
**processor_config)
else:
diff --git a/demos/role_playing_system_prompt/README_ZH.md b/demos/role_playing_system_prompt/README_ZH.md
new file mode 100644
index 000000000..956c335bb
--- /dev/null
+++ b/demos/role_playing_system_prompt/README_ZH.md
@@ -0,0 +1,49 @@
+# 为LLM构造角色扮演的system prompt
+
+在该Demo中,我们展示了如何通过Data-Juicer的菜谱,生成让LLM扮演剧本中给定角色的system prompt。我们这里以《莲花楼》为例。
+
+## 数据准备
+将《莲花楼》按章节划分,按顺序每个章节对应Data-Juicer的一个sample,放到“text”关键字下。如下json格式:
+```json
+[
+ {'text': '第一章内容'},
+ {'text': '第二章内容'},
+ {'text': '第三章内容'},
+ ...
+]
+```
+
+## 执行
+```shell
+python tools/process_data.py --config ./demos/role_playing_system_prompt/role_playing_system_prompt_test.yaml
+```
+
+## 生成样例
+
+```text
+扮演李莲花与用户进行对话。
+# 角色身份
+原名李相夷,曾是武林盟主,创立四顾门。十年前因中碧茶之毒,隐姓埋名,成为莲花楼的老板,过着市井生活。
+# 角色经历
+李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。
+# 角色性格
+李莲花是一个机智、幽默、善于观察和推理的人物。他表面上看似随和、悠闲,甚至有些懒散,但实际上心思缜密,洞察力极强。他不仅具备敏锐的观察力和独特的思维方式,还拥有深厚的内功和高超的医术。他对朋友忠诚,愿意为了保护他们不惜一切代价,同时在面对敌人时毫不手软。尽管内心充满正义感和责任感,但他选择远离江湖纷争,追求宁静自在的生活。他对过去的自己(李相夷)有着深刻的反思,对乔婉娩的感情复杂,既有愧疚也有关怀。李莲花能够在复杂的环境中保持冷静,巧妙地利用智慧和技能解决问题,展现出非凡的勇气和决心。
+# 角色能力
+李莲花是一位智慧与武艺兼备的高手,拥有深厚的内力、高超的医术和敏锐的洞察力。他擅长使用轻功、剑术和特殊武器,如婆娑步和少师剑,能够在关键时刻化解危机。尽管身体状况不佳,他仍能通过内功恢复体力,运用智谋和技巧应对各种挑战。他在江湖中身份多变,既能以游医身份逍遥自在,也能以李相夷的身份化解武林危机。
+# 人际关系
+方多病 (称呼:方小宝、方大少爷)李莲花的徒弟。百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。
+笛飞声 (称呼:阿飞、笛大盟主)金鸳盟盟主,曾与李相夷激战并重伤李相夷,后因中毒失去内力,与李莲花有复杂恩怨。笛飞声是金鸳盟盟主,十年前因与李相夷一战成名。他利用单孤刀的弟子朴锄山引诱李相夷,最终重伤李相夷,但自己也被李相夷钉在桅杆上。十年后,笛飞声恢复内力,重新执掌金鸳盟,与角丽谯合作,试图利用罗摩天冰和业火痋控制武林。在与李莲花和方多病的多次交手中,笛飞声多次展现强大实力,但也多次被李莲花等人挫败。最终,笛飞声在与李莲花的对决中被制住,但并未被杀死。笛飞声与李莲花约定在东海再战,但李莲花因中毒未赴约。笛飞声在东海之战中并未出现,留下了许多未解之谜。
+乔婉娩 (称呼:乔姑娘)李莲花的前女友。四顾门前任门主李相夷的爱人,现任门主肖紫衿的妻子,江湖中知名侠女。乔婉娩是四顾门的重要人物,与李相夷有着复杂的情感纠葛。在李相夷失踪后,乔婉娩嫁给了肖紫衿,但内心始终未能忘记李相夷。在李莲花(即李相夷)重新出现后,乔婉娩通过种种线索确认了他的身份,但最终选择支持肖紫衿,维护四顾门的稳定。乔婉娩在四顾门的复兴过程中发挥了重要作用,尤其是在调查金鸳盟和南胤阴谋的过程中,她提供了关键的情报和支持。尽管内心充满矛盾,乔婉娩最终决定与肖紫衿共同面对江湖的挑战,展现了她的坚强和智慧。
+肖紫衿 (称呼:紫衿)李莲花的门主兼旧识。四顾门现任门主,曾与李相夷有深厚恩怨,后与乔婉娩成婚。肖紫衿是四顾门的重要人物,与李相夷和乔婉娩关系密切。他曾在李相夷的衣冠冢前与李莲花对峙,质问他为何归来,并坚持要与李莲花决斗。尽管李莲花展示了武功,但肖紫衿最终选择不与他继续争斗。肖紫衿在乔婉娩与李相夷的误会中扮演了关键角色,一度因嫉妒取消了与乔婉娩的婚事。后来,肖紫衿在乔婉娩的支持下担任四顾门的新门主,致力于复兴四顾门。在与单孤刀的对抗中,肖紫衿展现了坚定的决心和领导能力,最终带领四顾门取得了胜利。
+单孤刀 (称呼:师兄)李莲花的师兄兼敌人。单孤刀,李莲花的师兄,四顾门创始人之一,因不满李相夷与金鸳盟签订协定而独自行动,最终被金鸳盟杀害。单孤刀是李莲花的师兄,与李相夷一同创立四顾门。单孤刀性格争强好胜,难以容人,最终因不满李相夷与金鸳盟签订协定,决定独自行动。单孤刀被金鸳盟杀害,李相夷得知后悲愤交加,誓言与金鸳盟不死不休。单孤刀的死成为李相夷心中的一大阴影,多年后李莲花在调查中发现单孤刀并非真正死亡,而是诈死以实现自己的野心。最终,单孤刀在与李莲花和方多病的对决中失败,被轩辕箫的侍卫杀死。
+# 语言风格
+李莲花的语言风格幽默诙谐,充满智慧和机智,善于用轻松的语气化解紧张的气氛。他常用比喻、反讽和夸张来表达复杂的观点,同时在关键时刻能简洁明了地揭示真相。他的言语中带有调侃和自嘲,但又不失真诚和温情,展现出一种从容不迫的态度。无论是面对朋友还是敌人,李莲花都能以幽默和智慧赢得尊重。
+供参考语言风格的部分李莲花台词:
+李莲花:你问我干吗?该启程了啊。
+李莲花:说起师门,你怎么也算云隐山一份子啊?不如趁今日叩拜了你师祖婆婆,再正儿八经给我这个师父磕头敬了茶,往后我守山中、你也尽心在跟前罢?
+李莲花:恭贺肖大侠和乔姑娘,喜结连理。
+李莲花淡淡一笑:放心吧,该看到的,都看到了。
+李莲花:如果现在去百川院,你家旺福就白死了。
+```
+
+
diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml
new file mode 100644
index 000000000..eadac45da
--- /dev/null
+++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml
@@ -0,0 +1,57 @@
+# global parameters
+project_name: 'role-play-demo-process'
+dataset_path: 'path_to_the_lianhualou_novel_json_file'
+np: 1 # number of subprocess to process your dataset
+
+export_path: 'path_to_output_jsonl_file'
+
+# process schedule
+process:
+# # chunk the novel if necessary
+# - text_chunk_mapper:
+# max_len: 8000
+# split_pattern: '\n\n'
+# overlap_len: 400
+# tokenizer: 'qwen2.5-72b-instruct'
+# trust_remote_code: True
+ # extract language_style, role_charactor and role_skill
+ - extract_entity_attribute_mapper:
+ api_model: 'qwen2.5-72b-instruct'
+ query_entities: ['李莲花']
+ query_attributes: ["角色性格", "角色武艺和能力", "语言风格"]
+ # extract nickname
+ - extract_nickname_mapper:
+ api_model: 'qwen2.5-72b-instruct'
+ # extract events
+ - extract_event_mapper:
+ api_model: 'qwen2.5-72b-instruct'
+ index_key: 'chunk_id' # chunk_id for deduplicating attributes and nicknames
+ # group all events
+ - naive_grouper:
+ # role experiences summary from events
+ - entity_attribute_aggregator:
+ api_model: 'qwen2.5-72b-instruct'
+ entity: '李莲花'
+ attribute: '身份背景'
+ input_key: '__dj__event_description__'
+ output_key: '__dj__role_background__'
+ word_limit: 50
+ - entity_attribute_aggregator:
+ api_model: 'qwen2.5-72b-instruct'
+ entity: '李莲花'
+ attribute: '主要经历'
+ input_key: '__dj__event_description__'
+ output_key: '__dj__role_experience__'
+ word_limit: 150
+ # most relavant roles summary from events
+ - most_relavant_entities_aggregator:
+ api_model: 'qwen2.5-72b-instruct'
+ entity: '李莲花'
+ query_entity_type: '人物'
+ input_key: '__dj__event_description__'
+ output_key: '__dj__important_relavant_roles__'
+ # generate the system prompt
+ - python_file_mapper:
+ file_path: 'path_to_system_prompt_gereration_python_file'
+ function_name: 'get_system_prompt'
+
\ No newline at end of file
diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py
new file mode 100644
index 000000000..dc2738900
--- /dev/null
+++ b/demos/role_playing_system_prompt/system_prompt_generator.py
@@ -0,0 +1,192 @@
+import random
+
+from itertools import chain
+from loguru import logger
+from collections import Counter
+
+from data_juicer.ops.aggregator import NestedAggregator
+from data_juicer.ops.aggregator import EntityAttributeAggregator
+from data_juicer.ops.mapper import RelationIdentityMapper
+from data_juicer.utils.constant import Fields
+
+api_model = 'qwen2.5-72b-instruct'
+
+main_entity = "李莲花"
+query_attributes = ["语言风格", "角色性格", "角色武艺和能力"]
+system_prompt_key = '__dj__system_prompt__'
+example_num_limit = 5
+max_relavant_roles_num = 5
+
+role_info_template = "# {entity}\n## 身份背景\n{identity}\n## 人物经历\n{experience}"
+relation_identity_text_template = """
+{source_entity}的信息:
+{source_entity_info}
+{target_entity}的信息:
+{target_entity_info}
+{source_entity}对{target_entity}的称呼:{nicknames}
+"""
+
+nested_sum = NestedAggregator(
+ api_model=api_model,
+ try_num=3)
+
+def dedup_sort_val_by_chunk_id(sample, id_key, val_key):
+ chunk_ids = sample[id_key]
+ vals = sample[val_key]
+ id_to_val = {}
+ for id, val in zip(chunk_ids, vals):
+ id_to_val[id] = val
+ sorted_ids = list(id_to_val.keys())
+ sorted_ids.sort()
+ sorted_vals = [id_to_val[id] for id in sorted_ids]
+ return list(chain(*sorted_vals))
+
+def get_attributes(sample):
+ main_entities = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.main_entities)
+ attribute_names = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attributes)
+ attribute_descs = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_descriptions)
+ attribute_support_texts = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_support_texts)
+ attributes = {}
+ support_texts = {}
+ for attr in query_attributes:
+ attributes[attr] = []
+ support_texts[attr] = []
+ for entity, attr_name, attr_desc, sub_support_texts in \
+ zip(main_entities, attribute_names, attribute_descs, attribute_support_texts):
+ if entity == main_entity and attr_name in query_attributes:
+ attributes[attr_name].append(attr_desc)
+ support_texts[attr_name].append(sub_support_texts)
+ return attributes, support_texts
+
+def get_nicknames(sample):
+ nicknames = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.nickname)
+ nickname_map = {}
+ for nr in nicknames:
+ if nr[Fields.source_entity] == main_entity:
+ role_name = nr[Fields.target_entity]
+ if role_name not in nickname_map:
+ nickname_map[role_name] = []
+ nickname_map[role_name].append(nr[Fields.relation_description])
+
+ max_nums = 3
+ for role_name, nickname_list in nickname_map.items():
+ th = (len(nickname_list)+1) // 2
+ count = Counter(nickname_list)
+ sorted_items = sorted(count.items(), key=lambda x: x[1], reverse=True)
+ most_common_nicknames = []
+ idx = 0
+ while th > 0 and idx < min(len(sorted_items), max_nums):
+ most_common_nicknames.append(sorted_items[idx][0])
+ th -= sorted_items[idx][1]
+ idx += 1
+ nickname_map[role_name] = most_common_nicknames
+ return nickname_map
+
+
+def get_system_prompt(sample):
+
+ main_role_identity = sample['__dj__role_background__']
+ main_role_experience = sample['__dj__role_experience__']
+ attributes, support_texts = get_attributes(sample)
+ main_role_character = nested_sum.recursive_summary(attributes['角色性格'])
+ main_role_skill = nested_sum.recursive_summary(attributes['角色武艺和能力'])
+ main_role_lang_style = nested_sum.recursive_summary(attributes['语言风格'])
+ lang_style_examples = list(chain(*support_texts['语言风格']))
+ lang_style_example_num = min(example_num_limit, len(lang_style_examples))
+ lang_style_examples = random.sample(lang_style_examples, lang_style_example_num)
+
+ main_role_info = role_info_template.format(
+ entity=main_entity,
+ identity=main_role_identity,
+ experience=main_role_experience
+ )
+
+ nicknames = get_nicknames(sample)
+
+ relation_detail = ""
+ relavant_roles = sample['__dj__important_relavant_roles__']
+ for role_name in relavant_roles[:max_relavant_roles_num]:
+ if role_name == main_entity:
+ continue
+
+ # get sub role identity
+ op = EntityAttributeAggregator(
+ api_model=api_model,
+ entity=role_name,
+ attribute='身份背景',
+ input_key='__dj__event_description__',
+ output_key='__dj__role_background__',
+ word_limit=30
+ )
+ sample = op.process_single(sample)
+ role_identity = sample['__dj__role_background__'].replace('\n', '')
+
+ # get sub role experience
+ op = EntityAttributeAggregator(
+ api_model=api_model,
+ entity=role_name,
+ attribute='主要经历',
+ input_key='__dj__event_description__',
+ output_key='__dj__role_experience__',
+ word_limit=100
+ )
+ sample = op.process_single(sample)
+ role_experience = sample['__dj__role_experience__'].replace('\n', '')
+
+ # get relation identity with main role
+ role_info = role_info_template.format(
+ entity=role_name,
+ identity=role_identity,
+ experience=role_experience
+ )
+ op = RelationIdentityMapper(
+ api_model=api_model,
+ source_entity=main_entity,
+ target_entity=role_name,
+ output_key='__dj__relation_identity__'
+ )
+ if role_name in nicknames:
+ cur_nicknames = '、'.join(nicknames[role_name])
+ else:
+ cur_nicknames = role_name
+ text = relation_identity_text_template.format(
+ source_entity=main_entity,
+ source_entity_info=main_role_info,
+ target_entity=role_name,
+ target_entity_info=role_info,
+ nicknames = cur_nicknames
+ )
+ tmp_sample = {'text': text}
+ tmp_sample = op.process_single(tmp_sample)
+ relation = tmp_sample['__dj__relation_identity__']
+
+ relation_detail += f"\n{role_name} (称呼:{cur_nicknames})"
+ if relation:
+ relation_detail += f"{main_entity}的{relation}。"
+ relation_detail += f"{role_identity}{role_experience}".replace('\n', '')
+
+ full_system_prompt = f"""扮演{main_entity}与用户进行对话。\n"""
+ full_system_prompt += """# 角色身份\n"""
+ full_system_prompt += main_role_identity.replace('\n', '')
+ full_system_prompt += """\n# 角色经历\n"""
+ full_system_prompt += main_role_experience.replace('\n', '')
+ full_system_prompt += """\n# 角色性格\n"""
+ full_system_prompt += main_role_character.replace('\n', '')
+ full_system_prompt += """\n# 角色能力\n"""
+ full_system_prompt += main_role_skill.replace('\n', '')
+
+ full_system_prompt += """\n# 人际关系"""
+ full_system_prompt += relation_detail
+
+ full_system_prompt += """\n# 语言风格\n"""
+ full_system_prompt += main_role_lang_style.replace('\n', '')
+ full_system_prompt += f"""\n供参考语言风格的部分{main_entity}台词:\n"""
+ full_system_prompt += "\n````\n"
+ full_system_prompt += '\n'.join(lang_style_examples)
+ full_system_prompt += "\n````\n"
+
+ logger.info(full_system_prompt)
+
+ sample[system_prompt_key] = full_system_prompt
+
+ return sample
\ No newline at end of file
diff --git a/docs/Operators.md b/docs/Operators.md
index 04a2da380..fe3c6d94d 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -11,10 +11,12 @@ The operators in Data-Juicer are categorized into 5 types.
| Type | Number | Description |
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data |
-| [ Mapper ]( #mapper ) | 61 | Edits and transforms samples |
+| [ Mapper ]( #mapper ) | 63 | Edits and transforms samples |
| [ Filter ]( #filter ) | 44 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |
+| [ Grouper ]( #grouper ) | 2 | Group samples to batched samples |
+| [ Aggregator ]( #aggregator ) | 3 | Aggregate for batched samples, such as summary or conclusion |
All the specific operators are listed below, each featured with several capability tags.
@@ -72,6 +74,7 @@ All the specific operators are listed below, each featured with several capabili
| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relevant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) |
| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Generate keywords for the text. | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) |
| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract nickname relationship in the text. | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) |
+| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract support sub text for a summary. | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) |
| fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) |
| generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs based on examples. | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) |
| generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs from text. | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) |
@@ -90,6 +93,7 @@ All the specific operators are listed below, each featured with several capabili
| punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) |
| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) |
| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) |
+| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) |
| remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) |
| remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) |
| remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) |
@@ -190,6 +194,21 @@ All the specific operators are listed below, each featured with several capabili
| range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples within a specified range by comparing the values of the specified field | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) |
| topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the values of the specified field | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) |
+## Grouper
+
+| Operator | Tags | Description | Source code | Unit tests |
+|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) |
+| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group all samples to one batched sample. | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) |
+
+## Aggregator
+
+| Operator | Tags | Description | Source code | Unit tests |
+|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Return conclusion of the given entity's attribute from some docs. | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) |
+| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) |
+| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Considering the limitation of input length, nested aggregate contents for each given number of samples. | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) |
+
## Contributing
We welcome contributions of adding new operators. Please refer to [How-to Guide for Developers](DeveloperGuide.md).
diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md
index 011ff5a64..61610a873 100644
--- a/docs/Operators_ZH.md
+++ b/docs/Operators_ZH.md
@@ -11,10 +11,12 @@ Data-Juicer 中的算子分为以下 5 种类型。
| 类型 | 数量 | 描述 |
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 |
-| [ Mapper ]( #mapper ) | 61 | 对数据样本进行编辑和转换 |
+| [ Mapper ]( #mapper ) | 63 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 44 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |
+| [ Grouper ]( #grouper ) | 2 | 将样本分组,每一组组成一个批量样本 |
+| [ Aggregator ]( #aggregator ) | 3 | 对批量样本进行汇总,如得出总结或结论 |
下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。
@@ -71,6 +73,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取出事件和事件相关人物 | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) |
| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造文本的关键词 | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) |
| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取昵称称呼关系 | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) |
+| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 为一段总结抽取对应原文 | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) |
| fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) |
| generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 根据种子数据,生成新的对话样本。 | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) |
| generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从文本中生成问答对 | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) |
@@ -89,6 +92,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) |
| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) |
| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) |
+| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) |
| remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) |
| remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) |
| remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档头,例如标题、章节数字/名称等 | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) |
@@ -189,5 +193,20 @@ Data-Juicer 中的算子分为以下 5 种类型。
| range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出指定范围的 k 个样本 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) |
| topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出前 k 个样本 | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) |
+## Grouper
+
+| 算子 | 标签 | 描述 | 源码 | 单测样例 |
+|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据给定键的值将样本分组,每一组组成一个批量样本。 | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) |
+| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个批量样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) |
+
+## Aggregator
+
+| 算子 | 标签 | 描述 | 源码 | 单测样例 |
+|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中总结出给定实体的属性 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) |
+| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中抽取出与给定实体密切相关的实体,按重要性从高到低排序 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) |
+| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 考虑到输入长度的限制,对样本中的内容进行嵌套聚合。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) |
+
## 贡献
我们欢迎社区贡献新的算子,具体请参考[开发者指南](DeveloperGuide_ZH.md)。
diff --git a/environments/science_requires.txt b/environments/science_requires.txt
index 10ea3b86e..af5d6b362 100644
--- a/environments/science_requires.txt
+++ b/environments/science_requires.txt
@@ -26,3 +26,5 @@ ffmpeg-python
opencv-python
vllm>=0.1.3
rouge
+dashscope
+openai
diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py
index 1cb7c4463..2502fe1e1 100644
--- a/tests/config/test_config_funcs.py
+++ b/tests/config/test_config_funcs.py
@@ -54,6 +54,7 @@ def test_yaml_cfg_file(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
}, 'nested dict load fail, for nonparametric op')
self.assertDictEqual(
@@ -75,6 +76,7 @@ def test_yaml_cfg_file(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
}, 'nested dict load fail, un-expected internal value')
@@ -144,6 +146,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -165,6 +168,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -186,6 +190,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -207,6 +212,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -228,6 +234,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
diff --git a/tests/ops/Aggregator/__init__.py b/tests/ops/Aggregator/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/Aggregator/test_entity_attribute_aggregator.py
new file mode 100644
index 000000000..1f80da3a3
--- /dev/null
+++ b/tests/ops/Aggregator/test_entity_attribute_aggregator.py
@@ -0,0 +1,139 @@
+import unittest
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.aggregator import EntityAttributeAggregator
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+
+
+@SKIPPED_TESTS.register_module()
+class EntityAttributeAggregatorTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples):
+
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for data in new_dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ self.assertEqual(len(new_dataset), len(samples))
+
+ def test_default_aggregator(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='主要经历'
+ )
+ self._run_helper(op, samples)
+
+ def test_input_output(self):
+ samples = [
+ {
+ 'sub_docs': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='身份背景',
+ input_key='sub_docs',
+ output_key='text'
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='身份背景',
+ max_token_num=200
+ )
+ self._run_helper(op, samples)
+
+ def test_word_limit_num(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='身份背景',
+ word_limit=20
+ )
+ self._run_helper(op, samples)
+
+
+ def test_example_prompt(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ example_prompt=(
+ '- 例如,根据相关文档总结`孙悟空`的`另外身份`,样例如下:\n'
+ '`孙悟空`的`另外身份`总结:\n'
+ '# 孙悟空\n'
+ '## 另外身份\n'
+ '孙行者、齐天大圣、美猴王\n'
+ )
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='另外身份',
+ example_prompt=example_prompt,
+ word_limit=20
+ )
+ self._run_helper(op, samples)
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py
new file mode 100644
index 000000000..1d8678134
--- /dev/null
+++ b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py
@@ -0,0 +1,93 @@
+import unittest
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+
+
+@SKIPPED_TESTS.register_module()
+class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples):
+
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for data in new_dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ self.assertEqual(len(new_dataset), len(samples))
+
+ def test_default_aggregator(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+
+ op = MostRelavantEntitiesAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ query_entity_type='人物'
+ )
+ self._run_helper(op, samples)
+
+ def test_input_output(self):
+ samples = [
+ {
+ 'dj_result':{
+ 'events': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ }
+ },
+ ]
+
+ op = MostRelavantEntitiesAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ query_entity_type='人物',
+ input_key='dj_result.events',
+ output_key='dj_result.relavant_roles'
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = MostRelavantEntitiesAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ query_entity_type='人物',
+ max_token_num=40
+ )
+ self._run_helper(op, samples)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/Aggregator/test_nested_aggregator.py
new file mode 100644
index 000000000..6347652bc
--- /dev/null
+++ b/tests/ops/Aggregator/test_nested_aggregator.py
@@ -0,0 +1,119 @@
+import unittest
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.aggregator import NestedAggregator
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+
+
+@SKIPPED_TESTS.register_module()
+class NestedAggregatorTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples):
+
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for data in new_dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ self.assertEqual(len(new_dataset), len(samples))
+
+ def test_default_aggregator(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct'
+ )
+ self._run_helper(op, samples)
+
+ def test_input_output(self):
+ samples = [
+ {
+ 'sub_docs': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ input_key='sub_docs',
+ output_key='text'
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num_1(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ max_token_num=2
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num_2(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ max_token_num=90
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num_3(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ max_token_num=200
+ )
+ self._run_helper(op, samples)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/grouper/__init__.py b/tests/ops/grouper/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/ops/grouper/test_key_value_grouper.py b/tests/ops/grouper/test_key_value_grouper.py
new file mode 100644
index 000000000..1ac186423
--- /dev/null
+++ b/tests/ops/grouper/test_key_value_grouper.py
@@ -0,0 +1,54 @@
+import unittest
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.grouper.key_value_grouper import KeyValueGrouper
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+
+
+class KeyValueGrouperTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples, target):
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for batched_sample in new_dataset:
+ lang = batched_sample['meta'][0]['language']
+ self.assertEqual(batched_sample['text'], target[lang])
+
+ def test_key_value_grouper(self):
+
+ source = [
+ {
+ 'text': "Today is Sunday and it's a happy day!",
+ 'meta': {
+ 'language': 'en'
+ }
+ },
+ {
+ 'text': "Welcome to Alibaba.",
+ 'meta': {
+ 'language': 'en'
+ }
+ },
+ {
+ 'text': '欢迎来到阿里巴巴!',
+ 'meta': {
+ 'language': 'zh'
+ }
+ },
+ ]
+ target = {
+ 'en':[
+ "Today is Sunday and it's a happy day!",
+ "Welcome to Alibaba."
+ ],
+ 'zh':[
+ '欢迎来到阿里巴巴!'
+ ]
+ }
+
+ op = KeyValueGrouper(['meta.language'])
+ self._run_helper(op, source, target)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/grouper/test_naive_grouper.py b/tests/ops/grouper/test_naive_grouper.py
new file mode 100644
index 000000000..4e69a8ba2
--- /dev/null
+++ b/tests/ops/grouper/test_naive_grouper.py
@@ -0,0 +1,47 @@
+import unittest
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.grouper.naive_grouper import NaiveGrouper
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+
+
+class NaiveGrouperTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples, target):
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for d, t in zip(new_dataset, target):
+ self.assertEqual(d['text'], t['text'])
+
+ def test_naive_group(self):
+
+ source = [
+ {
+ 'text': "Today is Sunday and it's a happy day!"
+ },
+ {
+ 'text':
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.'
+ },
+ {
+ 'text': '欢迎来到阿里巴巴!'
+ },
+ ]
+ target = [
+ {
+ 'text':[
+ "Today is Sunday and it's a happy day!",
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.',
+ '欢迎来到阿里巴巴!'
+ ]
+ }
+ ]
+
+ op = NaiveGrouper()
+ self._run_helper(op, source, target)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py
index 96f186d29..f15b4ca3f 100644
--- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py
+++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py
@@ -21,9 +21,9 @@ def _run_op(self, api_model, response_path=None):
query_attributes = ["语言风格", "角色性格"]
op = ExtractEntityAttributeMapper(
+ api_model=api_model,
query_entities=query_entities,
- query_attributes=query_attributes,
- api_model=api_model,
+ query_attributes=query_attributes,
response_path=response_path)
raw_text = """△笛飞声独自坐在莲花楼屋顶上。李莲花边走边悠闲地给马喂草。方多病则走在一侧,却总不时带着怀疑地盯向楼顶的笛飞声。
@@ -49,9 +49,14 @@ def _run_op(self, api_model, response_path=None):
dataset = Dataset.from_list(samples)
dataset = dataset.map(op.process, batch_size=1)
for sample in dataset:
- logger.info(f'{sample[Fields.main_entity]} {sample[Fields.attribute]}: {sample[Fields.attribute_description]}')
- self.assertNotEqual(sample[Fields.attribute_description], '')
- self.assertNotEqual(len(sample[Fields.attribute_support_text]), 0)
+ ents = sample[Fields.main_entities]
+ attrs = sample[Fields.attributes]
+ descs = sample[Fields.attribute_descriptions]
+ sups = sample[Fields.attribute_support_texts]
+ for ent, attr, desc, sup in zip(ents, attrs, descs, sups):
+ logger.info(f'{ent} {attr}: {desc}')
+ self.assertNotEqual(desc, '')
+ self.assertNotEqual(len(sup), 0)
def test(self):
# before runing this test, set below environment variables:
diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py
index 1652c8db2..aba40d73e 100644
--- a/tests/ops/mapper/test_extract_event_mapper.py
+++ b/tests/ops/mapper/test_extract_event_mapper.py
@@ -18,7 +18,8 @@ class ExtractEventMapperTest(DataJuicerTestCaseBase):
def _run_op(self, api_model, response_path=None):
op = ExtractEventMapper(api_model=api_model,
- response_path=response_path)
+ response_path=response_path,
+ index_key='chunk_id')
raw_text = """△芩婆走到中间,看着众人。
芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。
@@ -57,9 +58,11 @@ def _run_op(self, api_model, response_path=None):
}]
dataset = Dataset.from_list(samples)
- dataset = dataset.map(op.process, batch_size=2)
+ dataset = op.run(dataset)
self.assertNotEqual(len(dataset), 0)
for sample in dataset:
+ logger.info(f"chunk_id: {sample['chunk_id']}")
+ self.assertEqual(sample['chunk_id'], 0)
logger.info(f"event: {sample[Fields.event_description]}")
self.assertNotEqual(sample[Fields.event_description], '')
logger.info(f"characters: {sample[Fields.relevant_characters]}")
diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py
new file mode 100644
index 000000000..0445d2526
--- /dev/null
+++ b/tests/ops/mapper/test_extract_support_text_mapper.py
@@ -0,0 +1,80 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.extract_support_text_mapper import ExtractSupportTextMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields
+from data_juicer.utils.common_utils import nested_access
+
+# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class ExtractSupportTextMapperTest(DataJuicerTestCaseBase):
+
+
+ def _run_op(self, api_model):
+
+ summary_key = 'data.event'
+ support_text_key = 'data.support_text'
+ op = ExtractSupportTextMapper(api_model=api_model,
+ summary_key=summary_key,
+ support_text_key=support_text_key)
+
+ raw_text = """△芩婆走到中间,看着众人。
+芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。
+封磬震惊:二子?不是只有一个儿子吗?
+芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。
+李莲花似回忆起了什么:李相显......
+芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐......
+闪回/
+李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵......
+△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。
+△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去!
+△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。
+乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。
+乞丐甲野蛮地抢:给我拿来!
+小单孤刀:放开他!
+△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了!
+△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地!
+/闪回结束
+△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄!
+芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少?
+△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。
+芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。
+△单孤刀呆住。
+芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。
+△李莲花得知真相,闭目叹息。
+△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。
+封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了!
+△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上......
+△封磬颓然地跪倒下来。
+△李莲花对眼前的一切有些意外、无措。
+笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。
+△笛飞声不禁冷笑一下。
+"""
+ event = "李相显托付单孤刀。"
+ samples = [{
+ 'text': raw_text,
+ 'data':{
+ 'event': event
+ }
+ }]
+
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ sample = dataset[0]
+ logger.info(f"support_text: \n{nested_access(sample, support_text_key)}")
+
+ def test(self):
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+ self._run_op('qwen2.5-72b-instruct')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py
new file mode 100644
index 000000000..d730cb79f
--- /dev/null
+++ b/tests/ops/mapper/test_relation_identity_mapper.py
@@ -0,0 +1,58 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.relation_identity_mapper import RelationIdentityMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields
+
+# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class RelationIdentityMapperTest(DataJuicerTestCaseBase):
+
+
+ def _run_op(self, api_model, response_path=None):
+
+ op = RelationIdentityMapper(api_model=api_model,
+ source_entity="李莲花",
+ target_entity="方多病",
+ response_path=response_path)
+
+ raw_text = """李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。
+在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。
+李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。
+随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。
+李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。
+在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。
+在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。
+在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。
+最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。
+一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。
+方多病 (称呼:方小宝、方大少爷)百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。
+他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。
+他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。
+最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。
+"""
+ samples = [{
+ 'text': raw_text,
+ }]
+
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ for data in dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ def test(self):
+ # before runing this test, set below environment variables:
+ # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
+ # export OPENAI_API_KEY=your_key
+ self._run_op('qwen2.5-72b-instruct')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py
index 8004d9ede..0c0a70db3 100644
--- a/tests/ops/mapper/test_text_chunk_mapper.py
+++ b/tests/ops/mapper/test_text_chunk_mapper.py
@@ -2,9 +2,10 @@
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.text_chunk_mapper import TextChunkMapper
-from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+@SKIPPED_TESTS.register_module()
class TextChunkMapperTest(DataJuicerTestCaseBase):
def _run_helper(self, op, samples, target):