diff --git a/README.md b/README.md
index 518e54713..586869b0a 100644
--- a/README.md
+++ b/README.md
@@ -333,9 +333,16 @@ python tools/analyze_data.py --config configs/demo/analyzer.yaml
# use command line tool
dj-analyze --config configs/demo/analyzer.yaml
+
+# you can also use auto mode to avoid writing a recipe. It will analyze a small
+# part (e.g. 1000 samples, specified by argument `auto_num`) of your dataset
+# with all Filters that produce stats.
+dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000]
```
-- **Note:** Analyzer only compute stats of Filter ops. So extra Mapper or Deduplicator ops will be ignored in the analysis process.
+- **Note:** Analyzer only compute stats for Filters that produce stats or other OPs that produce tags/categories in meta. So other OPs will be ignored in the analysis process. We use the following registries to decorate OPs:
+ - `NON_STATS_FILTERS`: decorate Filters that **DO NOT** produce any stats.
+ - `TAGGING_OPS`: decorate OPs that **DO** produce tags/categories in meta field.
### Data Visualization
diff --git a/README_ZH.md b/README_ZH.md
index 366fcb004..42612964a 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -310,9 +310,15 @@ python tools/analyze_data.py --config configs/demo/analyzer.yaml
# 使用命令行工具
dj-analyze --config configs/demo/analyzer.yaml
+
+# 你也可以使用"自动"模式来避免写一个新的数据菜谱。它会使用全部可产出统计信息的 Filter 来分析
+# 你的数据集的一小部分(如1000条样本,可通过 `auto_num` 参数指定)
+dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000]
```
-* **注意**:Analyzer 只计算 Filter 算子的状态,其他的算子(例如 Mapper 和 Deduplicator)会在分析过程中被忽略。
+* **注意**:Analyzer 只用于能在 stats 字段里产出统计信息的 Filter 算子和能在 meta 字段里产出 tags 或类别标签的其他算子。除此之外的其他的算子会在分析过程中被忽略。我们使用以下两种注册器来装饰相关的算子:
+ * `NON_STATS_FILTERS`:装饰那些**不能**产出任何统计信息的 Filter 算子。
+ * `TAGGING_OPS`:装饰那些能在 meta 字段中产出 tags 或类别标签的算子。
### 数据可视化
diff --git a/configs/config_all.yaml b/configs/config_all.yaml
index 1003c89af..82cd6824e 100644
--- a/configs/config_all.yaml
+++ b/configs/config_all.yaml
@@ -79,9 +79,9 @@ process:
- clean_copyright_mapper: # remove copyright comments.
- expand_macro_mapper: # expand macro definitions in Latex text.
- extract_entity_attribute_mapper: # Extract attributes for given entities from the text.
+ api_model: 'gpt-4o' # API model name.
query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried.
query_attributes: ["人物性格"] # Attribute list to be queried.
- api_model: 'gpt-4o' # API model name.
entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction.
entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted.
attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description.
@@ -153,6 +153,18 @@ process:
drop_text: false # If drop the text in the output.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - extract_support_text_mapper: # extract support sub text for a summary.
+ api_model: 'gpt-4o' # API model name.
+ summary_key: '__dj__event_description__' # The field name to store the input summary. Support for nested keys such as "__dj__stats__.text_len".
+ support_text_key: '__dj__support_text__' # The field name to store the output support text for the summary.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # System prompt for the task.
+ input_template: null # Template for building the model input.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ drop_text: false # If drop the text in the output.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- fix_unicode_mapper: # fix unicode errors in text.
- generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples.
hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs.
@@ -259,12 +271,27 @@ process:
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call.
- punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations.
- - python_python_mapper: # executing Python lambda function defined in a file.
+ - python_file_mapper: # executing Python lambda function defined in a file.
file_path: '' # The path to the Python file containing the function to be executed.
function_name: 'process_single' # The name of the function defined in the file to be executed.
- python_lambda_mapper: # executing Python lambda function on data samples.
lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used.
batched: False # A boolean indicating whether to process input data in batches.
+ - relation_identity_mapper: # identify relation between two entity in the text.
+ api_model: 'gpt-4o' # API model name.
+ source_entity: '孙悟空' # The source entity of the relation to be dentified.
+ target_entity: '猪八戒' # The target entity of the relation to be identified.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is input_key in default.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt_template: null # System prompt template for the task. Need to specify by entity1 and entity2.
+ input_template: null # Template for building the model input.
+ output_pattern_template: null # Regular expression template for parsing model output.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ drop_text: false # If drop the text in the output.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- remove_bibliography_mapper: # remove bibliography from Latex text.
- remove_comments_mapper: # remove comments from Latex text, code, etc.
doc_type: tex # comment type you want to remove. Only support 'tex' for now.
@@ -567,7 +594,7 @@ process:
vertical_flip: false # flip frame image vertically (top to bottom).
reduce_mode: avg # reduce mode when one text corresponds to multiple videos in a chunk, must be one of ['avg','max', 'min'].
any_or_all: any # keep this sample when any/all videos meet the filter condition
- mem_required: '1GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
+ mem_required: '1500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- video_motion_score_filter: # Keep samples with video motion scores within a specific range.
min_score: 0.25 # the minimum motion score to keep samples
max_score: 10000.0 # the maximum motion score to keep samples
@@ -693,3 +720,55 @@ process:
top_ratio: # ratio of selected top samples
topk: # number of selected top sample
reverse: True # determine the sorting rule, if reverse=True, then sort in descending order
+
+# Grouper ops.
+ - naive_grouper: # Group all samples to one batched sample.
+ - key_value_grouper: # Group samples to batched samples according values in given keys.
+ group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default.
+
+# Aggregator ops.
+ - entity_attribute_aggregator: # Return conclusion of the given entity's attribute from some docs.
+ api_model: 'gpt-4o' # API model name.
+ entity: '孙悟空' # The given entity.
+ attribute: '人物经历' # The given attribute.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
+ word_limit: 100 # Prompt the output length.
+ max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute.
+ example_prompt: null # The example part in the system prompt.
+ input_template: null # The input template.
+ output_pattern_template: null # The output template.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance.
+ api_model: 'gpt-4o' # API model name.
+ entity: '孙悟空' # The given entity.
+ query_entity_type: '人物' # The type of queried relavant entities.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
+ max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and entity_type.
+ input_template: null # The input template.
+ output_pattern: null # The output pattern.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
+ - nested_aggregator: # Considering the limitation of input length, nested aggregate contents for each given number of samples.
+ api_model: 'gpt-4o' # API model name.
+ input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
+ output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
+ max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
+ api_endpoint: null # URL endpoint for the API.
+ response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
+ system_prompt: null # The system prompt.
+ sub_doc_template: null # The template for input text in each sample.
+ input_template: null # The input template.
+ try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
+ model_params: {} # Parameters for initializing the API model.
+ sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
diff --git a/data_juicer/__init__.py b/data_juicer/__init__.py
index 91ce93bae..7b7173c37 100644
--- a/data_juicer/__init__.py
+++ b/data_juicer/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '1.0.1'
+__version__ = '1.0.2'
import os
import subprocess
diff --git a/data_juicer/analysis/column_wise_analysis.py b/data_juicer/analysis/column_wise_analysis.py
index 775b42683..ce5b3617d 100644
--- a/data_juicer/analysis/column_wise_analysis.py
+++ b/data_juicer/analysis/column_wise_analysis.py
@@ -4,8 +4,9 @@
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
+from wordcloud import WordCloud
-from data_juicer.utils.constant import Fields
+from data_juicer.utils.constant import DEFAULT_PREFIX, Fields
from .overall_analysis import OverallAnalysis
@@ -69,6 +70,12 @@ def __init__(self,
stats into one image file
"""
self.stats = pd.DataFrame(dataset[Fields.stats])
+ self.meta = pd.DataFrame(dataset[Fields.meta])
+ # remove non-tag columns
+ meta_columns = self.meta.columns
+ for col_name in meta_columns:
+ if not col_name.startswith(DEFAULT_PREFIX):
+ self.meta = self.meta.drop(col_name, axis=1)
self.output_path = output_path
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
@@ -100,8 +107,9 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False):
width_unit = 4
height_unit = 6
- columns = self.stats.columns
- num = len(columns)
+ stats_and_meta = pd.concat([self.stats, self.meta], axis=1)
+ all_columns = stats_and_meta.columns
+ num = len(all_columns)
# get the recommended "best" number of columns and rows
rec_row, rec_col, grid_indexes = get_row_col(num, num_subcol)
@@ -114,9 +122,9 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False):
fig = plt.figure(figsize=(rec_width, rec_height),
layout='constrained')
subfigs = fig.subfigures(rec_row, rec_col, wspace=0.01)
- for i, column_name in enumerate(tqdm(columns.to_list(),
- desc='Column')):
- data = self.stats[column_name]
+ for i, column_name in enumerate(
+ tqdm(all_columns.to_list(), desc='Column')):
+ data = stats_and_meta[column_name]
# explode data to flatten inner list
data = data.explode().infer_objects()
grid = grid_indexes[i]
@@ -145,33 +153,39 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False):
else:
axes = [None] * num_subcol
- # draw histogram
- self.draw_hist(axes[0],
- data,
- os.path.join(self.output_path,
- f'{column_name}-hist.png'),
- percentiles=percentiles)
-
- # draw box
- self.draw_box(axes[1],
- data,
- os.path.join(self.output_path,
- f'{column_name}-box.png'),
- percentiles=percentiles)
+ if not skip_export:
+ # draw histogram
+ self.draw_hist(axes[0],
+ data,
+ os.path.join(self.output_path,
+ f'{column_name}-hist.png'),
+ percentiles=percentiles)
+
+ # draw box
+ self.draw_box(axes[1],
+ data,
+ os.path.join(self.output_path,
+ f'{column_name}-box.png'),
+ percentiles=percentiles)
else:
# object (string) or string list -- only draw histogram for
# this stat
if self.save_stats_in_one_file:
- axes = subfig.subplots(1, 1)
+ axes = subfig.subplots(1, num_subcol)
else:
- axes = None
+ axes = [None] * num_subcol
if not skip_export:
self.draw_hist(
- axes, data,
+ axes[0], data,
os.path.join(self.output_path,
f'{column_name}-hist.png'))
+ self.draw_wordcloud(
+ axes[1], data,
+ os.path.join(self.output_path,
+ f'{column_name}-wordcloud.png'))
+
# add a title to the figure of this stat
if self.save_stats_in_one_file:
subfig.suptitle(f'{data.name}',
@@ -203,10 +217,7 @@ def draw_hist(self, ax, data, save_path, percentiles=None, show=False):
"""
# recommended number of bins
data_num = len(data)
- if data_num >= 100:
- rec_bins = int(math.sqrt(len(data)))
- else:
- rec_bins = None
+ rec_bins = max(int(math.sqrt(data_num)), 10)
# if ax is None, using plot method in pandas
if ax is None:
@@ -297,3 +308,33 @@ def draw_box(self, ax, data, save_path, percentiles=None, show=False):
# accumulated overlapped figures in different draw_xxx function
# calling
ax.clear()
+
+ def draw_wordcloud(self, ax, data, save_path, show=False):
+ word_list = data.tolist()
+ word_nums = {}
+ for w in word_list:
+ if w in word_nums:
+ word_nums[w] += 1
+ else:
+ word_nums[w] = 1
+
+ wc = WordCloud(width=400, height=320)
+ wc.generate_from_frequencies(word_nums)
+
+ if ax is None:
+ ax = plt.figure(figsize=(20, 16))
+ else:
+ ax.imshow(wc, interpolation='bilinear')
+ ax.axis('off')
+
+ if not self.save_stats_in_one_file:
+ # save into file
+ wc.to_file(save_path)
+
+ if show:
+ plt.show()
+ else:
+ # if no showing, we need to clear this axes to avoid
+ # accumulated overlapped figures in different draw_xxx function
+ # calling
+ ax.clear()
diff --git a/data_juicer/analysis/measure.py b/data_juicer/analysis/measure.py
index fe54cdabd..bd97e811c 100644
--- a/data_juicer/analysis/measure.py
+++ b/data_juicer/analysis/measure.py
@@ -1,9 +1,13 @@
+import numpy as np
+
from data_juicer.utils.lazy_loader import LazyLoader
torch = LazyLoader('torch', 'torch')
td = LazyLoader('td', 'torch.distributions')
F = LazyLoader('F', 'torch.nn.functional')
+stats = LazyLoader('stats', 'scipy.stats')
+
class Measure(object):
"""Base class for Measure distribution.
@@ -48,6 +52,15 @@ def _convert_to_categorical(self, p):
else:
return td.Categorical(torch.tensor(p))
+ def _convert_to_ndarray(self, p):
+ """
+ Convert input data to torch tensor.
+ :param p: input data, now support
+ [`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`].
+ :return: torch tensor
+ """
+ return self._convert_to_tensor(p).numpy()
+
class KLDivMeasure(Measure):
"""
@@ -108,3 +121,101 @@ class EntropyMeasure(Measure):
def measure(self, p):
p = self._convert_to_categorical(p)
return p.entropy()
+
+
+class RelatedTTestMeasure(Measure):
+ """
+ Measure T-Test for two related distributions on their histogram of the same
+ bins.
+
+ Ref:
+ https://en.wikipedia.org/wiki/Student%27s_t-test
+
+ For continuous features or distributions, the input could be dataset stats
+ list.
+ For discrete features or distributions, the input could be the tags or the
+ categories list.
+ """
+ name = 't-test'
+
+ @staticmethod
+ def stats_to_hist(p, q):
+ p = np.array(p)
+ q = np.array(q)
+
+ # get common maximum number of data samples, and max/min values
+ max_data_num = max(len(p), len(q))
+ min_val = min(min(p), min(q))
+ max_val = max(max(p), max(q))
+
+ # get a recommended number of bins
+ rec_bins = max(int(np.sqrt(max_data_num)), 10)
+
+ # get the common bin edges
+ common_p = np.append(p, [min_val, max_val])
+ hist_p, bin_edges = np.histogram(common_p, bins=rec_bins)
+ # restore the hist of the original p
+ hist_p[0] -= 1
+ hist_p[-1] -= 1
+ # get the hist of the original q using the common bin edges
+ hist_q, _ = np.histogram(q, bins=bin_edges)
+ return hist_p, hist_q, bin_edges
+
+ @staticmethod
+ def category_to_hist(p, q):
+
+ def flatten_list(lst):
+ res = []
+ for s in lst:
+ if isinstance(s, list):
+ res.extend(flatten_list(s))
+ else:
+ res.append(s)
+ return res
+
+ # flatten the list
+ p = flatten_list(p)
+ q = flatten_list(q)
+
+ # get the common categories
+ cat_p = set(p)
+ cat_q = set(q)
+ cat_common = cat_p.union(cat_q)
+
+ # get category distributions
+ count_p = {cat: 0 for cat in cat_common}
+ count_q = {cat: 0 for cat in cat_common}
+ for cat in p:
+ count_p[cat] += 1
+ for cat in q:
+ count_q[cat] += 1
+
+ # only keep distribution values sorted by counts
+ sorted_cat = list(count_p.items())
+ sorted_cat.sort(key=lambda it: it[1], reverse=True)
+ sorted_cat = [it[0] for it in sorted_cat]
+ # get the value dist
+ hist_p = [count_p[cat] for cat in sorted_cat]
+ hist_q = [count_q[cat] for cat in sorted_cat]
+
+ return hist_p, hist_q, count_p, count_q, sorted_cat
+
+ def measure(self, p, q):
+ """
+ :param p: the first feature or distribution. (stats/tags/categories)
+ :param q: the second feature or distribution. (stats/tags/categories)
+ :return: the T-Test results object -- ([ref](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats._result_classes.TtestResult.html#scipy.stats._result_classes.TtestResult)) # noqa: E501
+ """
+ ele = p[0]
+ while isinstance(ele, list):
+ ele = ele[0]
+ if isinstance(ele, str):
+ # discrete tags or categories
+ hist_p, hist_q = self.category_to_hist(p, q)[:2]
+ else:
+ # continuous stats
+ hist_p, hist_q = self.stats_to_hist(p, q)[:2]
+
+ # compute the t-test and pval for hist_p and hist_q
+ ttest_res = stats.ttest_rel(hist_p, hist_q)
+ return ttest_res
diff --git a/data_juicer/analysis/overall_analysis.py b/data_juicer/analysis/overall_analysis.py
index 04eefb178..696b25946 100644
--- a/data_juicer/analysis/overall_analysis.py
+++ b/data_juicer/analysis/overall_analysis.py
@@ -5,7 +5,7 @@
from loguru import logger
from tqdm import tqdm
-from data_juicer.utils.constant import Fields
+from data_juicer.utils.constant import DEFAULT_PREFIX, Fields
def _single_column_analysis(col, *args, **kwargs):
@@ -25,6 +25,12 @@ def __init__(self, dataset, output_path):
:param output_path: path to store the analysis results.
"""
self.stats = pd.DataFrame(dataset[Fields.stats])
+ self.meta = pd.DataFrame(dataset[Fields.meta])
+ # remove non-tag columns
+ meta_columns = self.meta.columns
+ for col_name in meta_columns:
+ if not col_name.startswith(DEFAULT_PREFIX):
+ self.meta = self.meta.drop(col_name, axis=1)
self.output_path = output_path
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
@@ -71,10 +77,14 @@ def analyze(self, percentiles=[], num_proc=1, skip_export=False):
# merge default and customized percentiles and get overall information
percentiles = list(set(percentiles + self.default_percentiles))
+ # merge stats and meta
+ stats_and_meta = pd.concat([self.stats, self.meta], axis=1)
+ all_columns = stats_and_meta.columns
+
results = []
pool = Pool(num_proc)
- for col_name in self.stats.columns:
- this_col = self.refine_single_column(self.stats[col_name])
+ for col_name in all_columns:
+ this_col = self.refine_single_column(stats_and_meta[col_name])
res = pool.apply_async(_single_column_analysis,
kwds={
'col': this_col,
diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py
index aba3dd31c..0585ac8c4 100644
--- a/data_juicer/config/config.py
+++ b/data_juicer/config/config.py
@@ -23,7 +23,7 @@
global_parser = None
-def init_configs(args: Optional[List[str]] = None):
+def init_configs(args: Optional[List[str]] = None, which_entry: object = None):
"""
initialize the jsonargparse parser and parse configs from one of:
1. POSIX-style commands line args;
@@ -32,14 +32,29 @@ def init_configs(args: Optional[List[str]] = None):
4. hard-coded defaults
:param args: list of params, e.g., ['--conifg', 'cfg.yaml'], defaut None.
+ :param which_entry: which entry to init configs (executor/analyzer)
:return: a global cfg object used by the Executor or Analyzer
"""
parser = ArgumentParser(default_env=True, default_config_files=None)
- parser.add_argument('--config',
- action=ActionConfigFile,
- help='Path to a dj basic configuration file.',
- required=True)
+ # required but mutually exclusive args group
+ required_group = parser.add_mutually_exclusive_group(required=True)
+ required_group.add_argument('--config',
+ action=ActionConfigFile,
+ help='Path to a dj basic configuration file.')
+ required_group.add_argument('--auto',
+ action='store_true',
+ help='Weather to use an auto analyzing '
+ 'strategy instead of a specific data '
+ 'recipe. If a specific config file is '
+ 'given by --config arg, this arg is '
+ 'disabled. Only available for Analyzer.')
+
+ parser.add_argument('--auto_num',
+ type=PositiveInt,
+ default=1000,
+ help='The number of samples to be analyzed '
+ 'automatically. It\'s 1000 in default.')
parser.add_argument(
'--hpo_config',
@@ -97,7 +112,7 @@ def init_configs(args: Optional[List[str]] = None):
parser.add_argument(
'--export_path',
type=str,
- default='./outputs/hello_world.jsonl',
+ default='./outputs/hello_world/hello_world.jsonl',
help='Path to export and save the output processed dataset. The '
'directory to store the processed dataset will be the work '
'directory of this process.')
@@ -275,6 +290,22 @@ def init_configs(args: Optional[List[str]] = None):
help='Number of samples extracted by tracer to show the dataset '
'difference before and after a op. Only available when '
'open_tracer is true.')
+ parser.add_argument(
+ '--open_insight_mining',
+ type=bool,
+ default=False,
+ help='Whether to open insight mining to trace the OP-wise stats/tags '
+ 'changes during process. It might take more time when opening '
+ 'insight mining.')
+ parser.add_argument(
+ '--op_list_to_mine',
+ type=List[str],
+ default=[],
+ help='Which OPs will be applied on the dataset to mine the insights '
+ 'in their stats changes. Only those OPs that produce stats or '
+ 'meta are valid. If it\'s empty, all OPs that produce stats and '
+ 'meta will be involved. Only available when filter_list_to_mine '
+ 'is true.')
parser.add_argument(
'--op_fusion',
type=bool,
@@ -339,6 +370,14 @@ def init_configs(args: Optional[List[str]] = None):
try:
cfg = parser.parse_args(args=args)
+
+ # check the entry
+ from data_juicer.core.analyzer import Analyzer
+ if not isinstance(which_entry, Analyzer) and cfg.auto:
+ err_msg = '--auto argument can only be used for analyzer!'
+ logger.error(err_msg)
+ raise NotImplementedError(err_msg)
+
cfg = init_setup_from_cfg(cfg)
cfg = update_op_process(cfg, parser)
@@ -493,6 +532,10 @@ def init_setup_from_cfg(cfg: Namespace):
SpecialTokens.image = cfg.image_special_token
SpecialTokens.eoc = cfg.eoc_special_token
+ # add all filters that produce stats
+ if cfg.auto:
+ cfg.process = load_ops_with_stats_meta()
+
# Apply text_key modification during initializing configs
# users can freely specify text_key for different ops using `text_key`
# otherwise, set arg text_key of each op to text_keys
@@ -500,34 +543,48 @@ def init_setup_from_cfg(cfg: Namespace):
text_key = cfg.text_keys[0]
else:
text_key = cfg.text_keys
- for op in cfg.process:
+ op_attrs = {
+ 'text_key': text_key,
+ 'image_key': cfg.image_key,
+ 'audio_key': cfg.audio_key,
+ 'video_key': cfg.video_key,
+ 'num_proc': cfg.np,
+ 'turbo': cfg.turbo,
+ }
+ cfg.process = update_op_attr(cfg.process, op_attrs)
+
+ return cfg
+
+
+def load_ops_with_stats_meta():
+ import pkgutil
+
+ import data_juicer.ops.filter as djfilter
+ from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS
+ stats_filters = [{
+ filter_name: {}
+ } for _, filter_name, _ in pkgutil.iter_modules(djfilter.__path__)
+ if filter_name not in NON_STATS_FILTERS.modules]
+ meta_ops = [{op_name: {}} for op_name in TAGGING_OPS.modules]
+ return stats_filters + meta_ops
+
+
+def update_op_attr(op_list: list, attr_dict: dict = None):
+ if not attr_dict:
+ return op_list
+ updated_op_list = []
+ for op in op_list:
for op_name in op:
args = op[op_name]
if args is None:
- args = {
- 'text_key': text_key,
- 'image_key': cfg.image_key,
- 'audio_key': cfg.audio_key,
- 'video_key': cfg.video_key,
- 'num_proc': cfg.np,
- 'turbo': cfg.turbo,
- }
+ args = attr_dict
else:
- if 'text_key' not in args or args['text_key'] is None:
- args['text_key'] = text_key
- if 'image_key' not in args or args['image_key'] is None:
- args['image_key'] = cfg.image_key
- if 'audio_key' not in args or args['audio_key'] is None:
- args['audio_key'] = cfg.audio_key
- if 'video_key' not in args or args['video_key'] is None:
- args['video_key'] = cfg.video_key
- if 'num_proc' not in args or args['num_proc'] is None:
- args['num_proc'] = cfg.np
- if 'turbo' not in args or args['turbo'] is None:
- args['turbo'] = cfg.turbo
+ for key in attr_dict:
+ if key not in args or args[key] is None:
+ args[key] = attr_dict[key]
op[op_name] = args
-
- return cfg
+ updated_op_list.append(op)
+ return updated_op_list
def _collect_config_info_from_class_docs(configurable_ops, parser):
@@ -570,8 +627,13 @@ def sort_op_by_types_and_names(op_name_classes):
if 'deduplicator' in name]
selector_ops = [(name, c) for (name, c) in op_name_classes
if 'selector' in name]
+ grouper_ops = [(name, c) for (name, c) in op_name_classes
+ if 'grouper' in name]
+ aggregator_ops = [(name, c) for (name, c) in op_name_classes
+ if 'aggregator' in name]
ops_sorted_by_types = sorted(mapper_ops) + sorted(filter_ops) + sorted(
- deduplicator_ops) + sorted(selector_ops)
+ deduplicator_ops) + sorted(selector_ops) + sorted(grouper_ops) + \
+ sorted(aggregator_ops)
return ops_sorted_by_types
@@ -636,7 +698,10 @@ def update_op_process(cfg, parser):
temp_args = namespace_to_arg_list(temp_cfg,
includes=recognized_args,
excludes=['config'])
- temp_args = ['--config', temp_cfg.config[0].absolute] + temp_args
+ if temp_cfg.config:
+ temp_args = ['--config', temp_cfg.config[0].absolute] + temp_args
+ else:
+ temp_args = ['--auto'] + temp_args
temp_parser.parse_args(temp_args)
return cfg
@@ -662,6 +727,8 @@ def namespace_to_arg_list(namespace, prefix='', includes=None, excludes=None):
def config_backup(cfg: Namespace):
+ if not cfg.config:
+ return
cfg_path = cfg.config[0].absolute
work_dir = cfg.work_dir
target_path = os.path.join(work_dir, os.path.basename(cfg_path))
diff --git a/data_juicer/core/adapter.py b/data_juicer/core/adapter.py
index 5ab6e6ec8..64fd622f0 100644
--- a/data_juicer/core/adapter.py
+++ b/data_juicer/core/adapter.py
@@ -1,8 +1,15 @@
-from datasets import concatenate_datasets
+import json
+import os
+from copy import deepcopy
+
+from datasets import Dataset, concatenate_datasets
from datasets.config import DEFAULT_MAX_BATCH_SIZE
+from data_juicer.analysis.measure import RelatedTTestMeasure
from data_juicer.core.monitor import Monitor
from data_juicer.ops import UNFORKABLE
+from data_juicer.utils.cache_utils import dataset_cache_control
+from data_juicer.utils.constant import Fields
from data_juicer.utils.process_utils import setup_mp
@@ -12,6 +19,11 @@ class Adapter:
def __init__(self, cfg: dict):
self.cfg = cfg
+
+ # insight mining related
+ self.enable_insight_mining = self.cfg.open_insight_mining
+
+ # resource probe related
self.idle_resources = Monitor.monitor_current_resources()
@staticmethod
@@ -108,25 +120,21 @@ def adapt_workloads(self, dataset, operators):
return bs_per_op
+ @dataset_cache_control(on=True)
def probe_small_batch(self, dataset, operators):
"""
Perform small batch pre-execution to probe available resources,
current load and estimated OP speed, returning load factors and speed
ranks for each OP.
- Notice: the probe should be run with cache enabled.
+ Notice: the probe should be run with cache enabled to avoid removing
+ the cache files of the input dataset.
:param dataset: The dataset to pre-execute small batch on
:param operators: The OP list to be pre-execution and probe
:return: A list of probe results for each OP and the length of data
batch to probe.
"""
- # record the cache state and enable the cache
- from datasets import (disable_caching, enable_caching,
- is_caching_enabled)
- previous_state = is_caching_enabled()
- if not previous_state:
- enable_caching()
# take a small batch
data_batch = self.take_batch(dataset, self.cfg)
@@ -135,10 +143,6 @@ def probe_small_batch(self, dataset, operators):
# analyze resource utilization
analysis_res = Monitor.analyze_resource_util_list(resource_util_list)
- # if the cache is disabled before, disable it again
- if not previous_state:
- disable_caching()
-
return analysis_res, len(data_batch)
def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9):
@@ -177,3 +181,100 @@ def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9):
batch_size_per_op.append(bs_this_op)
return batch_size_per_op
+
+ @dataset_cache_control(on=True)
+ def analyze_small_batch(self, dataset, current_state):
+ """
+ Perform small batch analysis to probe the current OP-wise stats/meta
+ distributions. The analyzed results will be stored in the directory
+ `{work_dir}/insight_mining`.
+
+ Notice: the probe should be run with cache enabled to avoid removing
+ the cache files of the input dataset.
+
+ :param dataset: The dataset to analyze small batch on
+ :param current_state: A string to indicate the current state of the
+ input dataset. It usually consists of a number of the index of the
+ OP processed just now and the OP name, e.g. "1_text_length_filter".
+ """
+ # prepare analyzer config
+ new_cfg = deepcopy(self.cfg)
+ # check ops to mine
+ new_cfg.auto = True
+ new_cfg.config = None
+ if len(new_cfg.op_list_to_mine) > 0:
+ new_cfg.process = [{
+ op_name: {}
+ } for op_name in new_cfg.op_list_to_mine]
+ # update work dir
+ new_cfg.work_dir = os.path.join(new_cfg.work_dir, 'insight_mining',
+ current_state)
+ new_cfg.export_path = os.path.join(new_cfg.work_dir,
+ f'{current_state}.jsonl')
+ # close insight mining and monitor for inner analysis
+ new_cfg.open_insight_mining = False
+ new_cfg.open_monitor = False
+
+ # init the analyzer
+ from data_juicer.core.analyzer import Analyzer
+ analyzer = Analyzer(new_cfg)
+
+ # remove existing stats and meta in dataset
+ target_fields = {Fields.stats, Fields.meta}
+ target_fields = target_fields.intersection(set(dataset.features))
+ if len(target_fields) > 0:
+ dataset = dataset.remove_columns(list(target_fields))
+ analyzer.run(dataset, skip_return=True)
+
+ def insight_mining(self, pval_th=0.05):
+ """
+ Mining the insights from the OP-wise analysis results. For now, we use
+ T-Test to check the significance of stats/meta changes before and after
+ each OP processing. If the p-value is less than a given threshold
+ (usually 0.05), we think the stats/meta changes are significant. The
+ insight mining results will be stored in the file
+ `{work_dir}/insight_mining/insight_mining.json`.
+
+ :param pval_th: the threshold of p-value.
+ """
+ work_dir = os.path.join(self.cfg.work_dir, 'insight_mining')
+ res_order = [
+ d for d in os.listdir(work_dir)
+ if os.path.isdir(os.path.join(work_dir, d))
+ ]
+ res_order.sort()
+
+ # collect analysis results
+ analysis_results = {}
+ for res_dir in res_order:
+ res = Dataset.from_json(
+ os.path.join(work_dir, res_dir,
+ f'{res_dir}_stats.jsonl')).flatten()
+ analysis_results[res_dir] = res
+
+ # distribution change significance analysis
+ ttest_measure = RelatedTTestMeasure()
+
+ sig_res = {}
+ # i = 0 is the original dataset
+ for i in range(1, len(res_order)):
+ prev_res = analysis_results[res_order[i - 1]]
+ curr_res = analysis_results[res_order[i]]
+
+ # only consider common stats and meta
+ common_features = list(
+ set(prev_res.features).intersection(set(curr_res.features)))
+ curr_sig_res = {}
+ for feat in common_features:
+ ttest_res = ttest_measure(prev_res[feat], curr_res[feat])
+ curr_sig_res[feat] = {
+ 't-statistic (standardized mean difference)':
+ ttest_res.statistic,
+ 'p-value': ttest_res.pvalue,
+ 'significant':
+ True if ttest_res.pvalue < pval_th else False,
+ }
+ sig_res[res_order[i]] = curr_sig_res
+
+ with open(os.path.join(work_dir, 'insight_mining.json'), 'w') as out:
+ json.dump(sig_res, out)
diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py
index 2ae4d3511..d9ac586e9 100644
--- a/data_juicer/core/analyzer.py
+++ b/data_juicer/core/analyzer.py
@@ -1,6 +1,7 @@
import os
-from typing import Optional
+from typing import Optional, Union
+from datasets import Dataset
from jsonargparse import Namespace
from loguru import logger
from pydantic import PositiveInt
@@ -8,11 +9,12 @@
from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis
from data_juicer.config import init_configs
from data_juicer.format import load_formatter
-from data_juicer.ops import Filter, load_ops
+from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS, Filter, load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils import cache_utils
from .adapter import Adapter
+from .data import NestedDataset
from .exporter import Exporter
@@ -33,7 +35,7 @@ def __init__(self, cfg: Optional[Namespace] = None):
:param cfg: optional jsonargparse Namespace dict.
"""
- self.cfg = init_configs() if cfg is None else cfg
+ self.cfg = init_configs(which_entry=self) if cfg is None else cfg
self.work_dir = self.cfg.work_dir
@@ -71,22 +73,31 @@ def __init__(self, cfg: Optional[Namespace] = None):
self.analysis_path = os.path.join(self.cfg.work_dir, 'analysis')
def run(self,
+ dataset: Union[Dataset, NestedDataset] = None,
load_data_np: Optional[PositiveInt] = None,
skip_export: bool = False,
skip_return: bool = False):
"""
Running the dataset analysis pipeline.
+ :param dataset: a Dataset object to be analyzed.
:param load_data_np: number of workers when loading the dataset.
:param skip_export: whether export the results into disk
:param skip_return: skip return for API called.
:return: analyzed dataset.
"""
# 1. format data
- logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
- dataset = self.formatter.load_dataset(load_data_np, self.cfg)
+ if dataset is None:
+ logger.info('Loading dataset from data formatter...')
+ dataset = self.formatter.load_dataset(load_data_np, self.cfg)
+ else:
+ logger.info(f'Using existing dataset {dataset}')
+ if self.cfg.auto:
+ # if it's auto analysis, only analyze for a minor part of the input
+ # dataset to save time and computing resource
+ dataset = dataset.take(min(len(dataset), self.cfg.auto_num))
# extract processes
logger.info('Preparing process operators...')
@@ -107,16 +118,26 @@ def run(self,
logger.info('Computing the stats of dataset...')
stats_collected = False
for op in ops:
- if isinstance(op, Filter):
+ if isinstance(op, Filter) \
+ and op._name not in NON_STATS_FILTERS.modules:
original_process = op.process
op.process = None
- dataset = dataset.process(op, work_dir=self.work_dir)
+ dataset = dataset.process(op,
+ work_dir=self.work_dir,
+ open_monitor=self.cfg.open_monitor)
op.process = original_process
stats_collected = True
+ elif op._name in TAGGING_OPS.modules:
+ dataset = dataset.process(op,
+ work_dir=self.work_dir,
+ open_monitor=self.cfg.open_monitor)
+ stats_collected = True
if not stats_collected:
- logger.warning('No stats collected. Please add some Filter ops to '
- 'the process list in configs.')
- return dataset
+ logger.warning(
+ 'No stats/meta collected. Please add some Filter OPs or '
+ 'Tagging OPs to the process list in configs.')
+ if not skip_return:
+ return dataset
# 3. data export
logger.info('Exporting dataset to disk...')
diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py
index 361f6e8a0..d0f8083e1 100644
--- a/data_juicer/core/data.py
+++ b/data_juicer/core/data.py
@@ -172,6 +172,7 @@ def process(
exporter=None,
checkpointer=None,
tracer=None,
+ adapter=None,
open_monitor=True,
):
if operators is None:
@@ -185,9 +186,19 @@ def process(
if open_monitor:
resource_util_list = []
+ # whether to enable insight mining
+ enable_insight_mining = adapter.enable_insight_mining \
+ if adapter else False
+ # record the analysis results of the original dataset
+ if enable_insight_mining:
+ logger.info('Analyze small batch for the original dataset for '
+ 'insight mining...')
+ adapter.analyze_small_batch(self, '0_original')
+
dataset = self
+ op_num = len(operators)
try:
- for op in operators:
+ for idx, op in enumerate(operators, start=1):
mp_context = ['forkserver', 'spawn'] if (
op.use_cuda()
or op._name in unforkable_operators) else None
@@ -211,8 +222,16 @@ def process(
if open_monitor:
resource_util_list.append(resource_util_per_op)
end = time()
- logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. '
- f'Left {len(dataset)} samples.')
+ logger.info(
+ f'[{idx}/{op_num}] OP [{op._name}] Done in '
+ f'{end - start:.3f}s. Left {len(dataset)} samples.')
+
+ # record the analysis results of the current dataset
+ if enable_insight_mining:
+ logger.info(
+ f'Analyze small batch for the current dataset after '
+ f'OP [{op._name}] for insight mining...')
+ adapter.analyze_small_batch(dataset, f'{idx}_{op._name}')
except: # noqa: E722
logger.error(f'An error occurred during Op [{op._name}].')
traceback.print_exc()
@@ -223,6 +242,7 @@ def process(
'last op...')
dataset.cleanup_cache_files()
checkpointer.save_ckpt(dataset)
+ # make summarization on the monitor results
if work_dir and open_monitor:
# get the analyzed version
resource_util_list = Monitor.analyze_resource_util_list(
@@ -234,6 +254,10 @@ def process(
json.dump(resource_util_list, out)
Monitor.draw_resource_util_graph(resource_util_list,
monitor_dir)
+ # make summarization on the insight mining results
+ if work_dir and enable_insight_mining:
+ logger.info('Insight mining for each OP...')
+ adapter.insight_mining()
return dataset
def update_args(self, args, kargs, is_filter=False):
diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py
index f78059247..7f0d93a66 100644
--- a/data_juicer/core/executor.py
+++ b/data_juicer/core/executor.py
@@ -199,6 +199,7 @@ def run(self,
exporter=self.exporter,
checkpointer=self.ckpt_manager,
tracer=self.tracer,
+ adapter=self.adapter,
open_monitor=self.cfg.open_monitor,
)
tend = time()
diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py
index 72b555d34..dbdb4fb9f 100644
--- a/data_juicer/core/exporter.py
+++ b/data_juicer/core/exporter.py
@@ -106,10 +106,15 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True):
:param export_stats: whether to export stats of dataset.
:return:
"""
- if Fields.stats in dataset.features and export_stats:
+ if export_stats:
# export stats of datasets into a single file.
logger.info('Exporting computed stats into a single file...')
- ds_stats = dataset.select_columns(Fields.stats)
+ export_columns = []
+ if Fields.stats in dataset.features:
+ export_columns.append(Fields.stats)
+ if Fields.meta in dataset.features:
+ export_columns.append(Fields.meta)
+ ds_stats = dataset.select_columns(export_columns)
stats_file = export_path.replace('.' + suffix, '_stats.jsonl')
Exporter.to_jsonl(
ds_stats,
@@ -119,7 +124,7 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True):
if self.export_ds:
# fetch the corresponding export method according to the suffix
if not self.keep_stats_in_res_ds:
- extra_fields = {Fields.stats}
+ extra_fields = {Fields.stats, Fields.meta}
feature_fields = set(dataset.features.keys())
removed_fields = extra_fields.intersection(feature_fields)
dataset = dataset.remove_columns(removed_fields)
diff --git a/data_juicer/core/monitor.py b/data_juicer/core/monitor.py
index 0210e3732..d5fdee241 100644
--- a/data_juicer/core/monitor.py
+++ b/data_juicer/core/monitor.py
@@ -15,7 +15,13 @@ def resource_monitor(mdict, interval):
while True:
this_states.append(Monitor.monitor_current_resources())
time.sleep(interval)
- if mdict['stop']:
+ try:
+ stop_sign = mdict['stop']
+ except (BrokenPipeError, FileNotFoundError):
+ # mdict crushes due to the main process is terminated already,
+ # which is not the fault here
+ return
+ if stop_sign:
break
mdict['resource'] = this_states
diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py
index 2064b8a17..568f88e41 100644
--- a/data_juicer/core/ray_data.py
+++ b/data_juicer/core/ray_data.py
@@ -1,7 +1,10 @@
+from __future__ import annotations
+
import os
from functools import partial
+from typing import Any, Dict, List, Literal, Optional, Union
-import pyarrow as pa
+import pyarrow
from loguru import logger
from data_juicer import cuda_device_count
@@ -12,6 +15,7 @@
from data_juicer.utils.process_utils import calculate_np
rd = LazyLoader('rd', 'ray.data')
+ds = LazyLoader('ds', 'ray.data.datasource')
def get_abs_path(path, dataset_dir):
@@ -33,7 +37,7 @@ def convert_to_absolute_paths(samples, dataset_dir, path_keys):
samples[key][idx] = [
get_abs_path(item, dataset_dir) for item in paths
]
- return pa.Table.from_pydict(samples)
+ return pyarrow.Table.from_pydict(samples)
# TODO: check path for nestdataset
@@ -63,7 +67,7 @@ def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
if Fields.stats not in columns:
- def process_batch_arrow(table: pa.Table) -> pa.Table:
+ def process_batch_arrow(table: pyarrow.Table) -> pyarrow.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats, [new_column_data])
return new_talbe
@@ -81,7 +85,7 @@ def get_num_gpus(op, op_proc):
def filter_batch(batch, filter_func):
- mask = pa.array(filter_func(batch.to_pydict()))
+ mask = pyarrow.array(filter_func(batch.to_pydict()))
return batch.filter(mask)
@@ -174,3 +178,89 @@ def _run_single_op(self, op):
import traceback
traceback.print_exc()
exit(1)
+
+ @classmethod
+ def read_json(cls, paths: Union[str, List[str]]) -> RayDataset:
+ # Note: a temp solution for reading json stream
+ # TODO: replace with ray.data.read_json_stream once it is available
+ import pyarrow.json as js
+ try:
+ js.open_json
+ return read_json_stream(paths)
+ except AttributeError:
+ return rd.read_json(paths)
+
+
+class JSONStreamDatasource(ds.JSONDatasource):
+ """
+ A temp Datasource for reading json stream.
+
+ Note:
+
+ Depends on a customized `pyarrow` with `open_json` method.
+ """
+
+ def _read_stream(self, f: 'pyarrow.NativeFile', path: str):
+ from pyarrow.json import open_json
+
+ try:
+ reader = open_json(
+ f,
+ read_options=self.read_options,
+ **self.arrow_json_args,
+ )
+ schema = None
+ while True:
+ try:
+ batch = reader.read_next_batch()
+ table = pyarrow.Table.from_batches([batch], schema=schema)
+ if schema is None:
+ schema = table.schema
+ yield table
+ except StopIteration:
+ return
+ except pyarrow.lib.ArrowInvalid as e:
+ raise ValueError(f'Failed to read JSON file: {path}.') from e
+
+
+def read_json_stream(
+ paths: Union[str, List[str]],
+ *,
+ filesystem: Optional['pyarrow.fs.FileSystem'] = None,
+ parallelism: int = -1,
+ ray_remote_args: Dict[str, Any] = None,
+ arrow_open_stream_args: Optional[Dict[str, Any]] = None,
+ meta_provider=None,
+ partition_filter=None,
+ partitioning=ds.partitioning.Partitioning('hive'),
+ include_paths: bool = False,
+ ignore_missing_paths: bool = False,
+ shuffle: Union[Literal['files'], None] = None,
+ file_extensions: Optional[List[str]] = ['json', 'jsonl'],
+ concurrency: Optional[int] = None,
+ override_num_blocks: Optional[int] = None,
+ **arrow_json_args,
+) -> rd.Dataset:
+ if meta_provider is None:
+ meta_provider = ds.file_meta_provider.DefaultFileMetadataProvider()
+
+ datasource = JSONStreamDatasource(
+ paths,
+ arrow_json_args=arrow_json_args,
+ filesystem=filesystem,
+ open_stream_args=arrow_open_stream_args,
+ meta_provider=meta_provider,
+ partition_filter=partition_filter,
+ partitioning=partitioning,
+ ignore_missing_paths=ignore_missing_paths,
+ shuffle=shuffle,
+ include_paths=include_paths,
+ file_extensions=file_extensions,
+ )
+ return rd.read_datasource(
+ datasource,
+ parallelism=parallelism,
+ ray_remote_args=ray_remote_args,
+ concurrency=concurrency,
+ override_num_blocks=override_num_blocks,
+ )
diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py
index 1d90e31b3..41990b36a 100644
--- a/data_juicer/core/ray_executor.py
+++ b/data_juicer/core/ray_executor.py
@@ -61,7 +61,7 @@ def run(self, load_data_np=None):
from data_juicer.format.formatter import FORMATTERS
dataset = FORMATTERS.modules[obj_name](**args).load_dataset()
else:
- dataset = rd.read_json(self.cfg.dataset_path)
+ dataset = RayDataset.read_json(self.cfg.dataset_path)
# convert all the path in dataset to absolute path
dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg)
diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py
index c7ab44c25..2ab622266 100644
--- a/data_juicer/ops/__init__.py
+++ b/data_juicer/ops/__init__.py
@@ -1,5 +1,6 @@
-from . import deduplicator, filter, mapper, selector
-from .base_op import (OPERATORS, UNFORKABLE, Deduplicator, Filter, Mapper,
+from . import aggregator, deduplicator, filter, grouper, mapper, selector
+from .base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE,
+ Aggregator, Deduplicator, Filter, Grouper, Mapper,
Selector)
from .load import load_ops
@@ -9,4 +10,6 @@
'Mapper',
'Deduplicator',
'Selector',
+ 'Grouper',
+ 'Aggregator',
]
diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py
new file mode 100644
index 000000000..4afe2974a
--- /dev/null
+++ b/data_juicer/ops/aggregator/__init__.py
@@ -0,0 +1,8 @@
+from .entity_attribute_aggregator import EntityAttributeAggregator
+from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator
+from .nested_aggregator import NestedAggregator
+
+__all__ = [
+ 'NestedAggregator', 'EntityAttributeAggregator',
+ 'MostRelavantEntitiesAggregator'
+]
diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py
new file mode 100644
index 000000000..96fbbb63f
--- /dev/null
+++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py
@@ -0,0 +1,200 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Aggregator
+from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
+ is_string_list, nested_access,
+ nested_set)
+from data_juicer.utils.lazy_loader import LazyLoader
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+from .nested_aggregator import NestedAggregator
+
+torch = LazyLoader('torch', 'torch')
+vllm = LazyLoader('vllm', 'vllm')
+
+OP_NAME = 'entity_attribute_aggregator'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class EntityAttributeAggregator(Aggregator):
+ """
+ Return conclusion of the given entity's attribute from some docs.
+ """
+
+ DEFAULT_SYSTEM_TEMPLATE = (
+ '给定与`{entity}`相关的一些文档,总结`{entity}`的`{attribute}`。\n'
+ '要求:\n'
+ '- 尽量使用原文专有名词\n'
+ '- 联系上下文,自动忽略上下文不一致的细节错误\n'
+ '- 只对文档中与`{entity}`的`{attribute}`有关的内容进行总结\n'
+ '- 字数限制在**{word_limit}字以内**\n'
+ '- 要求输出格式如下:\n'
+ '# {entity}\n'
+ '## {attribute}\n'
+ '...\n'
+ '{example}')
+
+ DEFAULT_EXAMPLE_PROMPT = ('- 例如,根据相关文档总结`孙悟空`的`出身背景`,**100字**以内的样例如下:\n'
+ '`孙悟空`的`出身背景`总结:\n'
+ '# 孙悟空\n'
+ '## 出身背景\n'
+ '号称齐天大圣,花果山水帘洞的美猴王、西行取经队伍中的大师兄。'
+ '师父是唐僧玄奘,曾拜菩提祖师学艺。'
+ '亲生父母未知,自石头中孕育而生。自认斗战胜佛,最怕观世音菩萨和紧箍咒。\n')
+
+ DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n'
+ '{sub_docs}\n\n'
+ '`{entity}`的`{attribute}`总结:\n')
+
+ DEFAULT_OUTPUT_PATTERN_TEMPLATE = r'\#\s*{entity}\s*\#\#\s*{attribute}\s*(.*?)\Z' # noqa: E501
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ entity: str = None,
+ attribute: str = None,
+ input_key: str = None,
+ output_key: str = None,
+ word_limit: PositiveInt = 100,
+ max_token_num: Optional[PositiveInt] = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt_template: Optional[str] = None,
+ example_prompt: Optional[str] = None,
+ input_template: Optional[str] = None,
+ output_pattern_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param entity: The given entity.
+ :param attribute: The given attribute.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is same as the
+ input_key in default.
+ :param word_limit: Prompt the output length.
+ :param max_token_num: The max token num of the total tokens of the
+ sub documents. Without limitation if it is None.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt_template: The system prompt template.
+ :param example_prompt: The example part in the system prompt.
+ :param input_template: The input template.
+ :param output_pattern_template: The output template.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ if entity is None or attribute is None:
+ raise ValueError('The entity and attribute cannot be None!')
+
+ self.entity = entity
+ self.attribute = attribute
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+ self.word_limit = word_limit
+ self.max_token_num = max_token_num
+
+ system_prompt_template = system_prompt_template or \
+ self.DEFAULT_SYSTEM_TEMPLATE
+ self.example_prompt = example_prompt or self.DEFAULT_EXAMPLE_PROMPT
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+ output_pattern_template = output_pattern_template or \
+ self.DEFAULT_OUTPUT_PATTERN_TEMPLATE
+ self.system_prompt = system_prompt_template.format(
+ entity=self.entity,
+ attribute=self.attribute,
+ word_limit=self.word_limit,
+ example=self.example_prompt)
+ self.output_pattern = output_pattern_template.format(
+ entity=entity, attribute=attribute)
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ return_processor=True,
+ **model_params)
+
+ self.try_num = try_num
+ self.nested_sum = NestedAggregator(model=api_model,
+ max_token_num=max_token_num,
+ api_endpoint=api_endpoint,
+ response_path=response_path,
+ try_num=try_num,
+ model_params=model_params,
+ sampling_params=sampling_params)
+
+ def parse_output(self, response):
+ pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
+ matches = pattern.findall(response)
+ if matches:
+ result = matches[0].strip()
+ else:
+ result = ''
+
+ return result
+
+ def attribute_summary(self, sub_docs, rank=None):
+ if not sub_docs:
+ return ''
+
+ model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
+ token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs]
+ group_docs = avg_split_string_list_under_limit(sub_docs, token_nums,
+ self.max_token_num)
+ results = []
+ for docs in group_docs:
+ doc_str = '\n\n'.join(docs)
+ input_prompt = self.input_template.format(entity=self.entity,
+ attribute=self.attribute,
+ sub_docs=doc_str)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ result = ''
+ for i in range(self.try_num):
+ try:
+ response = model(messages, **self.sampling_params)
+ result = self.parse_output(response)
+ if len(result) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+ results.append(result)
+
+ return self.nested_sum.recursive_summary(results)
+
+ def process_single(self, sample=None, rank=None):
+
+ # if not batched sample
+ sub_docs = nested_access(sample, self.input_key)
+ if not is_string_list(sub_docs):
+ return sample
+
+ sample = nested_set(sample, self.output_key,
+ self.attribute_summary(sub_docs, rank=rank))
+
+ return sample
diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
new file mode 100644
index 000000000..69e1a209c
--- /dev/null
+++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
@@ -0,0 +1,183 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Aggregator
+from data_juicer.utils.common_utils import (is_string_list, nested_access,
+ nested_set)
+from data_juicer.utils.lazy_loader import LazyLoader
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+from ..common import split_text_by_punctuation
+
+torch = LazyLoader('torch', 'torch')
+vllm = LazyLoader('vllm', 'vllm')
+
+OP_NAME = 'most_relavant_entities_aggregator'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class MostRelavantEntitiesAggregator(Aggregator):
+ """
+ Extract entities closely related to a given entity from some texts,
+ and sort them in descending order of importance.
+ """
+
+ DEFAULT_SYSTEM_TEMPLATE = (
+ '给定与`{entity}`相关的一些文档,'
+ '总结一些与`{entity}`最为相关的`{entity_type}`。\n'
+ '要求:\n'
+ '- 不用包含与{entity}为同一{entity_type}的{entity_type}。\n'
+ '- 请按照人物的重要性进行排序,**越重要人物在列表越前面**。\n'
+ '- 你的返回格式如下:\n'
+ '## 分析\n'
+ '你对各个{entity_type}与{entity}关联度的分析\n'
+ '## 列表\n'
+ '人物1, 人物2, 人物3, ...')
+
+ DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n'
+ '{sub_docs}\n\n'
+ '与`{entity}`最相关的一些`{entity_type}`:\n')
+
+ DEFAULT_OUTPUT_PATTERN = r'\#\#\s*列表\s*(.*?)\Z'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ entity: str = None,
+ query_entity_type: str = None,
+ input_key: str = None,
+ output_key: str = None,
+ max_token_num: Optional[PositiveInt] = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt_template: Optional[str] = None,
+ input_template: Optional[str] = None,
+ output_pattern: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param entity: The given entity.
+ :param query_entity_type: The type of queried relavant entities.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is same as the
+ input_key in default.
+ :param max_token_num: The max token num of the total tokens of the
+ sub documents. Without limitation if it is None.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt_template: The system prompt template.
+ :param input_template: The input template.
+ :param output_pattern: The output pattern.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ if entity is None or query_entity_type is None:
+ raise ValueError(
+ 'The entity and query_entity_type cannot be None!')
+
+ self.entity = entity
+ self.query_entity_type = query_entity_type
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+ self.max_token_num = max_token_num
+
+ system_prompt_template = system_prompt_template or \
+ self.DEFAULT_SYSTEM_TEMPLATE
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+ self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
+ self.system_prompt = system_prompt_template.format(
+ entity=entity, entity_type=query_entity_type)
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ return_processor=True,
+ **model_params)
+
+ self.try_num = try_num
+
+ def parse_output(self, response):
+ pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
+ matches = pattern.findall(response)
+ if matches:
+ result = matches[0].strip()
+ else:
+ result = ''
+ result = split_text_by_punctuation(result)
+
+ return result
+
+ def query_most_relavant_entities(self, sub_docs, rank=None):
+ if not sub_docs:
+ return ''
+
+ model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
+ token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs]
+ if self.max_token_num is None:
+ final_docs = sub_docs
+ else:
+ final_docs = []
+ total_num = 0
+ for token_num, doc in zip(token_nums, sub_docs):
+ total_num += token_num
+ if total_num > self.max_token_num:
+ break
+ final_docs.append(doc)
+
+ doc_str = '\n\n'.join(final_docs)
+ input_prompt = self.input_template.format(
+ entity=self.entity,
+ entity_type=self.query_entity_type,
+ sub_docs=doc_str)
+
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ result = []
+ for i in range(self.try_num):
+ try:
+ response = model(messages, **self.sampling_params)
+ result = self.parse_output(response)
+ if len(result) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ return result
+
+ def process_single(self, sample=None, rank=None):
+
+ # if not batched sample
+ sub_docs = nested_access(sample, self.input_key)
+ if not is_string_list(sub_docs):
+ return sample
+
+ sample = nested_set(
+ sample, self.output_key,
+ self.query_most_relavant_entities(sub_docs, rank=rank))
+
+ return sample
diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py
new file mode 100644
index 000000000..124eb1470
--- /dev/null
+++ b/data_juicer/ops/aggregator/nested_aggregator.py
@@ -0,0 +1,179 @@
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Aggregator
+from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
+ is_string_list, nested_access)
+from data_juicer.utils.lazy_loader import LazyLoader
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+torch = LazyLoader('torch', 'torch')
+vllm = LazyLoader('vllm', 'vllm')
+
+OP_NAME = 'nested_aggregator'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class NestedAggregator(Aggregator):
+ """
+ Considering the limitation of input length, nested aggregate
+ contents for each given number of samples.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('给定一些文档碎片,将这些文档整合成一个文档总结。\n'
+ '要求:\n'
+ '- 总结的长度与文档碎片的平均长度基本一致\n'
+ '- 不要包含主观看法\n'
+ '- 注意要尽可能保留文本的专有名词\n'
+ '- 只输出文档总结不要输出其他内容\n'
+ '- 参考如下样例:\n'
+ '文档碎片:\n'
+ '唐僧师徒四人行至白虎岭,遇上了变化多端的白骨精。\n\n'
+ '文档碎片:\n'
+ '白骨精首次变身少女送斋,被孙悟空识破打死,唐僧责怪悟空。\n\n'
+ '文档碎片:\n'
+ '妖怪再变老妇寻女,又被悟空击毙,师傅更加不满,念紧箍咒惩罚。\n\n'
+ '文档碎片:\n'
+ '不甘心的白骨精第三次化作老公公来诱骗,依旧逃不过金睛火眼。\n\n'
+ '文档碎片:\n'
+ '最终,在观音菩萨的帮助下,真相大白,唐僧明白了自己的误解。\n\n'
+ '\n'
+ '文档总结:\n'
+ '唐僧师徒在白虎岭三遇白骨精变化诱惑,悟空屡次识破击毙妖怪却遭误解,最终观音相助真相大白。')
+
+ DEFAULT_INPUT_TEMPLATE = ('{sub_docs}\n\n'
+ '文档总结:\n')
+
+ DEFAULT_SUB_DOC_TEMPLATE = '文档碎片:\n{text}\n'
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ input_key: str = None,
+ output_key: str = None,
+ max_token_num: Optional[PositiveInt] = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ sub_doc_template: Optional[str] = None,
+ input_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is same as the
+ input_key in default.
+ :param max_token_num: The max token num of the total tokens of the
+ sub documents. Without limitation if it is None.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: The system prompt.
+ :param sub_doc_template: The template for input text in each sample.
+ :param input_template: The input template.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+ self.max_token_num = max_token_num
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.sub_doc_template = sub_doc_template or \
+ self.DEFAULT_SUB_DOC_TEMPLATE
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ return_processor=True,
+ **model_params)
+
+ self.try_num = try_num
+
+ def parse_output(self, response):
+
+ def if_match(text):
+ quotes = [("'", "'"), ('"', '"'), ('“', '”'), ('‘', '’'),
+ ('`', '`')]
+ if len(text) < 2:
+ return False
+ if (text[0], text[-1]) in quotes:
+ return True
+ else:
+ return False
+
+ text = response.strip()
+ while if_match(text):
+ text = text[1:-1].strip()
+ return text
+
+ def recursive_summary(self, sub_docs, rank=None):
+ if not sub_docs:
+ return ''
+ if len(sub_docs) == 1:
+ return sub_docs[0]
+ model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
+ token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs]
+ group_docs = avg_split_string_list_under_limit(sub_docs, token_nums,
+ self.max_token_num)
+ # merge every two if every single sub doc is a group
+ group_num = len(group_docs)
+ if group_num == len(sub_docs):
+ group_docs = [
+ group_docs[i] +
+ group_docs[i + 1] if i + 1 < group_num else group_docs[i]
+ for i in range(0, group_num, 2)
+ ]
+ results = []
+ for docs in group_docs:
+ doc_strs = [self.sub_doc_template.format(text=d) for d in docs]
+ input_prompt = self.input_template.format(
+ sub_docs='\n'.join(doc_strs))
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ result = ''
+ for i in range(self.try_num):
+ try:
+ response = model(messages, **self.sampling_params)
+ result = self.parse_output(response)
+ if len(result) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+ results.append(result)
+ return self.recursive_summary(results)
+
+ def process_single(self, sample=None, rank=None):
+
+ # if not batched sample
+ sub_docs = nested_access(sample, self.input_key)
+ if not is_string_list(sub_docs):
+ return sample
+
+ sample[self.output_key] = self.recursive_summary(sub_docs, rank=rank)
+
+ return sample
diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py
index 1a2fb0171..9e39c50ab 100644
--- a/data_juicer/ops/base_op.py
+++ b/data_juicer/ops/base_op.py
@@ -14,6 +14,8 @@
OPERATORS = Registry('Operators')
UNFORKABLE = Registry('Unforkable')
+NON_STATS_FILTERS = Registry('Non-stats Filters')
+TAGGING_OPS = Registry('Tagging Operators')
def convert_list_dict_to_dict_list(samples):
@@ -131,6 +133,11 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
+ :param index_key: index the samples before process if not None
"""
# init data keys
self.text_key = kwargs.get('text_key', 'text')
@@ -142,6 +149,8 @@ def __init__(self, *args, **kwargs):
self.response_key = kwargs.get('response_key', 'response')
self.history_key = kwargs.get('history_key', 'history')
+ self.index_key = kwargs.get('index_key', None)
+
self.batch_size = kwargs.get('batch_size', 1000)
# whether the model can be accelerated using cuda
@@ -216,6 +225,26 @@ def run(self, dataset):
from data_juicer.core.data import NestedDataset
if not isinstance(dataset, NestedDataset):
dataset = NestedDataset(dataset)
+ # add meta field for OPs that produce tags
+ if self._name in TAGGING_OPS.modules \
+ and Fields.meta not in dataset.features:
+ from data_juicer.core.data import add_same_content_to_new_column
+ dataset = dataset.map(add_same_content_to_new_column,
+ fn_kwargs={
+ 'new_column_name': Fields.meta,
+ 'initial_value': {}
+ },
+ num_proc=self.runtime_np(),
+ batch_size=self.batch_size,
+ desc='Adding new column for meta')
+ if self.index_key is not None:
+
+ def add_index(sample, idx):
+ sample[self.index_key] = idx
+ return sample
+
+ dataset = dataset.map(add_index, with_indices=True)
+
return dataset
def empty_history(self):
@@ -236,6 +265,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Mapper, self).__init__(*args, **kwargs)
@@ -319,6 +352,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Filter, self).__init__(*args, **kwargs)
self.stats_export_path = kwargs.get('stats_export_path', None)
@@ -387,7 +424,9 @@ def process_single(self, sample):
def run(self, dataset, *, exporter=None, tracer=None, reduce=True):
dataset = super(Filter, self).run(dataset)
- if Fields.stats not in dataset.features:
+ # add stats field for Filters that produce stats
+ if self._name not in NON_STATS_FILTERS.modules \
+ and Fields.stats not in dataset.features:
from data_juicer.core.data import add_same_content_to_new_column
dataset = dataset.map(add_same_content_to_new_column,
fn_kwargs={
@@ -430,6 +469,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Deduplicator, self).__init__(*args, **kwargs)
@@ -489,6 +532,10 @@ def __init__(self, *args, **kwargs):
to be processed
:param video_key: the key name of field that stores sample video list
to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
"""
super(Selector, self).__init__(*args, **kwargs)
@@ -507,3 +554,90 @@ def run(self, dataset, *, exporter=None, tracer=None):
if tracer:
tracer.trace_filter(self._name, dataset, new_dataset)
return new_dataset
+
+
+class Grouper(OP):
+
+ def __init__(self, *args, **kwargs):
+ """
+ Base class that group samples.
+
+ :param text_key: the key name of field that stores sample texts
+ to be processed
+ :param image_key: the key name of field that stores sample image list
+ to be processed
+ :param audio_key: the key name of field that stores sample audio list
+ to be processed
+ :param video_key: the key name of field that stores sample video list
+ to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
+ """
+ super(Grouper, self).__init__(*args, **kwargs)
+
+ def process(self, dataset):
+ """
+ Dataset --> dataset.
+
+ :param dataset: input dataset
+ :return: dataset of batched samples.
+ """
+ raise NotImplementedError
+
+ def run(self, dataset, *, exporter=None, tracer=None):
+ dataset = super(Grouper, self).run(dataset)
+ batched_samples = self.process(dataset)
+ from data_juicer.core.data import NestedDataset
+ new_dataset = NestedDataset.from_list(batched_samples)
+ if tracer:
+ tracer.trace_filter(self._name, dataset, new_dataset)
+ return new_dataset
+
+
+class Aggregator(OP):
+
+ def __init__(self, *args, **kwargs):
+ """
+ Base class that group samples.
+
+ :param text_key: the key name of field that stores sample texts
+ to be processed
+ :param image_key: the key name of field that stores sample image list
+ to be processed
+ :param audio_key: the key name of field that stores sample audio list
+ to be processed
+ :param video_key: the key name of field that stores sample video list
+ to be processed
+ :param query_key: the key name of field that stores sample queris
+ :param response_key: the key name of field that stores responses
+ :param history_key: the key name of field that stores history of
+ queries and responses
+ """
+ super(Aggregator, self).__init__(*args, **kwargs)
+ self.process = catch_map_single_exception(self.process_single)
+
+ def process_single(self, sample):
+ """
+ For sample level, batched sample --> sample,
+ the input must be the output of some Grouper OP.
+
+ :param sample: batched sample to aggregate
+ :return: aggregated sample
+ """
+ raise NotImplementedError
+
+ def run(self, dataset, *, exporter=None, tracer=None):
+ dataset = super(Aggregator, self).run(dataset)
+ new_dataset = dataset.map(
+ self.process,
+ num_proc=self.runtime_np(),
+ with_rank=self.use_cuda(),
+ batch_size=self.batch_size,
+ desc=self._name + '_process',
+ )
+ if tracer:
+ tracer.trace_mapper(self._name, dataset, new_dataset,
+ self.text_key)
+ return new_dataset
diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py
index dad6818e1..8cb986b2b 100644
--- a/data_juicer/ops/filter/__init__.py
+++ b/data_juicer/ops/filter/__init__.py
@@ -63,3 +63,10 @@
'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter',
'WordRepetitionFilter', 'WordsNumFilter'
]
+
+NON_STATS_FILTERS = [
+ 'specified_field_filter',
+ 'specified_numeric_field_filter',
+ 'suffix_filter',
+ 'video_tagging_from_frames_filter',
+]
diff --git a/data_juicer/ops/filter/image_aesthetics_filter.py b/data_juicer/ops/filter/image_aesthetics_filter.py
index bbaba15eb..723845a5d 100644
--- a/data_juicer/ops/filter/image_aesthetics_filter.py
+++ b/data_juicer/ops/filter/image_aesthetics_filter.py
@@ -46,7 +46,7 @@ def __init__(self,
:param args: Extra positional arguments.
:param kwargs: Extra keyword arguments.
"""
-
+ kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
if hf_scorer_model == '':
hf_scorer_model = \
diff --git a/data_juicer/ops/filter/image_nsfw_filter.py b/data_juicer/ops/filter/image_nsfw_filter.py
index 603a48518..aea409ec4 100644
--- a/data_juicer/ops/filter/image_nsfw_filter.py
+++ b/data_juicer/ops/filter/image_nsfw_filter.py
@@ -41,6 +41,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '1GB')
super().__init__(*args, **kwargs)
self.score_threshold = score_threshold
if any_or_all not in ['any', 'all']:
diff --git a/data_juicer/ops/filter/image_text_matching_filter.py b/data_juicer/ops/filter/image_text_matching_filter.py
index dc36cd68a..6881eccf5 100644
--- a/data_juicer/ops/filter/image_text_matching_filter.py
+++ b/data_juicer/ops/filter/image_text_matching_filter.py
@@ -52,6 +52,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py
index d43c9bc3f..9a3f9361b 100644
--- a/data_juicer/ops/filter/image_text_similarity_filter.py
+++ b/data_juicer/ops/filter/image_text_similarity_filter.py
@@ -53,6 +53,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
diff --git a/data_juicer/ops/filter/image_watermark_filter.py b/data_juicer/ops/filter/image_watermark_filter.py
index 0d9eead6a..b752736a4 100644
--- a/data_juicer/ops/filter/image_watermark_filter.py
+++ b/data_juicer/ops/filter/image_watermark_filter.py
@@ -45,6 +45,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '500MB')
super().__init__(*args, **kwargs)
self.prob_threshold = prob_threshold
if any_or_all not in ['any', 'all']:
diff --git a/data_juicer/ops/filter/phrase_grounding_recall_filter.py b/data_juicer/ops/filter/phrase_grounding_recall_filter.py
index 98a2dfb1f..9dec0dc3c 100644
--- a/data_juicer/ops/filter/phrase_grounding_recall_filter.py
+++ b/data_juicer/ops/filter/phrase_grounding_recall_filter.py
@@ -114,6 +114,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '1GB')
super().__init__(*args, **kwargs)
self.min_recall = min_recall
self.max_recall = max_recall
diff --git a/data_juicer/ops/filter/specified_field_filter.py b/data_juicer/ops/filter/specified_field_filter.py
index 86aff2426..41addf8da 100644
--- a/data_juicer/ops/filter/specified_field_filter.py
+++ b/data_juicer/ops/filter/specified_field_filter.py
@@ -1,9 +1,12 @@
from typing import List
-from ..base_op import OPERATORS, Filter
+from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter
+OP_NAME = 'specified_field_filter'
-@OPERATORS.register_module('specified_field_filter')
+
+@NON_STATS_FILTERS.register_module(OP_NAME)
+@OPERATORS.register_module(OP_NAME)
class SpecifiedFieldFilter(Filter):
"""
Filter based on specified field information.
diff --git a/data_juicer/ops/filter/specified_numeric_field_filter.py b/data_juicer/ops/filter/specified_numeric_field_filter.py
index 693be3392..c7a1d301a 100644
--- a/data_juicer/ops/filter/specified_numeric_field_filter.py
+++ b/data_juicer/ops/filter/specified_numeric_field_filter.py
@@ -1,6 +1,6 @@
import sys
-from ..base_op import OPERATORS, Filter
+from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter
def is_number(s):
@@ -13,7 +13,11 @@ def is_number(s):
return False
-@OPERATORS.register_module('specified_numeric_field_filter')
+OP_NAME = 'specified_numeric_field_filter'
+
+
+@NON_STATS_FILTERS.register_module(OP_NAME)
+@OPERATORS.register_module(OP_NAME)
class SpecifiedNumericFieldFilter(Filter):
"""
Filter based on specified numeric field information.
diff --git a/data_juicer/ops/filter/suffix_filter.py b/data_juicer/ops/filter/suffix_filter.py
index ea7868399..7aaca53a7 100644
--- a/data_juicer/ops/filter/suffix_filter.py
+++ b/data_juicer/ops/filter/suffix_filter.py
@@ -2,10 +2,13 @@
from data_juicer.utils.constant import Fields
-from ..base_op import OPERATORS, Filter
+from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter
+OP_NAME = 'suffix_filter'
-@OPERATORS.register_module('suffix_filter')
+
+@NON_STATS_FILTERS.register_module(OP_NAME)
+@OPERATORS.register_module(OP_NAME)
class SuffixFilter(Filter):
"""Filter to keep samples with specified suffix."""
diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py
index 5e674162d..f65334f56 100644
--- a/data_juicer/ops/filter/video_aesthetics_filter.py
+++ b/data_juicer/ops/filter/video_aesthetics_filter.py
@@ -73,7 +73,7 @@ def __init__(self,
:param args: Extra positional arguments.
:param kwargs: Extra keyword arguments.
"""
-
+ kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
if hf_scorer_model == '':
hf_scorer_model = \
diff --git a/data_juicer/ops/filter/video_frames_text_similarity_filter.py b/data_juicer/ops/filter/video_frames_text_similarity_filter.py
index 6b3e92641..da793ccf4 100644
--- a/data_juicer/ops/filter/video_frames_text_similarity_filter.py
+++ b/data_juicer/ops/filter/video_frames_text_similarity_filter.py
@@ -74,6 +74,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
diff --git a/data_juicer/ops/filter/video_nsfw_filter.py b/data_juicer/ops/filter/video_nsfw_filter.py
index 27bafe1d0..a1dd9d214 100644
--- a/data_juicer/ops/filter/video_nsfw_filter.py
+++ b/data_juicer/ops/filter/video_nsfw_filter.py
@@ -65,6 +65,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '1GB')
super().__init__(*args, **kwargs)
self.score_threshold = score_threshold
if frame_sampling_method not in ['all_keyframes', 'uniform']:
diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py
index 7c41b5521..2436d886c 100644
--- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py
+++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py
@@ -5,7 +5,8 @@
from data_juicer.utils.constant import Fields
-from ..base_op import OPERATORS, UNFORKABLE, Filter
+from ..base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE,
+ Filter)
from ..mapper.video_tagging_from_frames_mapper import \
VideoTaggingFromFramesMapper
from ..op_fusion import LOADED_VIDEOS
@@ -13,6 +14,8 @@
OP_NAME = 'video_tagging_from_frames_filter'
+@NON_STATS_FILTERS.register_module(OP_NAME)
+@TAGGING_OPS.register_module(OP_NAME)
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_VIDEOS.register_module(OP_NAME)
@@ -61,6 +64,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '9GB')
super().__init__(*args, **kwargs)
if contain not in ['any', 'all']:
raise ValueError(f'the containing type [{contain}] is not '
@@ -90,7 +94,7 @@ def compute_stats_single(self, sample, rank=None, context=False):
return sample
def process_single(self, sample, rank=None):
- video_tags = sample[self.tag_field_name]
+ video_tags = sample[Fields.meta][self.tag_field_name]
if len(video_tags) <= 0:
return True
diff --git a/data_juicer/ops/filter/video_watermark_filter.py b/data_juicer/ops/filter/video_watermark_filter.py
index 2b7e30f8f..959c91e23 100644
--- a/data_juicer/ops/filter/video_watermark_filter.py
+++ b/data_juicer/ops/filter/video_watermark_filter.py
@@ -69,6 +69,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '500MB')
super().__init__(*args, **kwargs)
self.prob_threshold = prob_threshold
if frame_sampling_method not in ['all_keyframes', 'uniform']:
diff --git a/data_juicer/ops/grouper/__init__.py b/data_juicer/ops/grouper/__init__.py
new file mode 100644
index 000000000..048b305e4
--- /dev/null
+++ b/data_juicer/ops/grouper/__init__.py
@@ -0,0 +1,4 @@
+from .key_value_grouper import KeyValueGrouper
+from .naive_grouper import NaiveGrouper
+
+__all__ = ['NaiveGrouper', 'KeyValueGrouper']
diff --git a/data_juicer/ops/grouper/key_value_grouper.py b/data_juicer/ops/grouper/key_value_grouper.py
new file mode 100644
index 000000000..3d786319f
--- /dev/null
+++ b/data_juicer/ops/grouper/key_value_grouper.py
@@ -0,0 +1,51 @@
+from typing import List, Optional
+
+from data_juicer.utils.common_utils import dict_to_hash, nested_access
+
+from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list
+from .naive_grouper import NaiveGrouper
+
+
+@OPERATORS.register_module('key_value_grouper')
+class KeyValueGrouper(Grouper):
+ """Group samples to batched samples according values in given keys. """
+
+ def __init__(self,
+ group_by_keys: Optional[List[str]] = None,
+ *args,
+ **kwargs):
+ """
+ Initialization method.
+
+ :param group_by_keys: group samples according values in the keys.
+ Support for nested keys such as "__dj__stats__.text_len".
+ It is [self.text_key] in default.
+ :param args: extra args
+ :param kwargs: extra args
+ """
+ super().__init__(*args, **kwargs)
+
+ self.group_by_keys = group_by_keys or [self.text_key]
+ self.naive_grouper = NaiveGrouper()
+
+ def process(self, dataset):
+
+ if len(dataset) == 0:
+ return dataset
+
+ sample_map = {}
+ for sample in dataset:
+ cur_dict = {}
+ for key in self.group_by_keys:
+ cur_dict[key] = nested_access(sample, key)
+ sample_key = dict_to_hash(cur_dict)
+ if sample_key in sample_map:
+ sample_map[sample_key].append(sample)
+ else:
+ sample_map[sample_key] = [sample]
+
+ batched_samples = [
+ convert_list_dict_to_dict_list(sample_map[k]) for k in sample_map
+ ]
+
+ return batched_samples
diff --git a/data_juicer/ops/grouper/naive_grouper.py b/data_juicer/ops/grouper/naive_grouper.py
new file mode 100644
index 000000000..4633dc48e
--- /dev/null
+++ b/data_juicer/ops/grouper/naive_grouper.py
@@ -0,0 +1,24 @@
+from ..base_op import OPERATORS, Grouper, convert_list_dict_to_dict_list
+
+
+@OPERATORS.register_module('naive_grouper')
+class NaiveGrouper(Grouper):
+ """Group all samples to one batched sample. """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Initialization method.
+
+ :param args: extra args
+ :param kwargs: extra args
+ """
+ super().__init__(*args, **kwargs)
+
+ def process(self, dataset):
+
+ if len(dataset) == 0:
+ return dataset
+
+ batched_sample = convert_list_dict_to_dict_list(dataset)
+
+ return [batched_sample]
diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py
index 3c95d1fe6..9b86b83dc 100644
--- a/data_juicer/ops/mapper/__init__.py
+++ b/data_juicer/ops/mapper/__init__.py
@@ -14,6 +14,7 @@
from .extract_event_mapper import ExtractEventMapper
from .extract_keyword_mapper import ExtractKeywordMapper
from .extract_nickname_mapper import ExtractNicknameMapper
+from .extract_support_text_mapper import ExtractSupportTextMapper
from .fix_unicode_mapper import FixUnicodeMapper
from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper
from .generate_qa_from_text_mapper import GenerateQAFromTextMapper
@@ -32,6 +33,7 @@
from .punctuation_normalization_mapper import PunctuationNormalizationMapper
from .python_file_mapper import PythonFileMapper
from .python_lambda_mapper import PythonLambdaMapper
+from .relation_identity_mapper import RelationIdentityMapper
from .remove_bibliography_mapper import RemoveBibliographyMapper
from .remove_comments_mapper import RemoveCommentsMapper
from .remove_header_mapper import RemoveHeaderMapper
@@ -71,25 +73,26 @@
'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper',
'ExpandMacroMapper', 'ExtractEntityAttributeMapper',
'ExtractEntityRelationMapper', 'ExtractEventMapper',
- 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'FixUnicodeMapper',
+ 'ExtractKeywordMapper', 'ExtractNicknameMapper',
+ 'ExtractSupportTextMapper', 'FixUnicodeMapper',
'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper',
'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper',
'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper',
'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper',
'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PairPreferenceMapper', 'PunctuationNormalizationMapper',
- 'PythonFileMapper', 'PythonLambdaMapper', 'RemoveBibliographyMapper',
- 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
- 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
- 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper',
- 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper',
- 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper',
- 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper',
- 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper',
- 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
- 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
- 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
- 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
- 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
- 'WhitespaceNormalizationMapper'
+ 'PythonFileMapper', 'PythonLambdaMapper', 'RelationIdentityMapper',
+ 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper',
+ 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper',
+ 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper',
+ 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper',
+ 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper',
+ 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
+ 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
+ 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper',
+ 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
+ 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
+ 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
+ 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
+ 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
]
diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py
index fd93cfe03..0fc76b11f 100644
--- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py
+++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py
@@ -1,5 +1,4 @@
import re
-from itertools import chain
from typing import Dict, List, Optional
from loguru import logger
@@ -19,34 +18,37 @@ class ExtractEntityAttributeMapper(Mapper):
Extract attributes for given entities from the text
"""
- _batched_op = True
-
DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
'给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n'
'要求:\n'
'- 摘录的示例应该简短。\n'
'- 遵循如下的回复格式:\n'
+ '# {entity}\n'
'## {attribute}:\n'
- '{entity}的{attribute}描述...\n'
- '### 代表性示例1:\n'
- '说明{entity}该{attribute}的原文摘录1...\n'
- '### 代表性示例2:\n'
- '说明{entity}该{attribute}的原文摘录2...\n'
+ '...\n'
+ '### 代表性示例摘录1:\n'
+ '```\n'
+ '...\n'
+ '```\n'
+ '### 代表性示例摘录2:\n'
+ '```\n'
+ '...\n'
+ '```\n'
'...\n')
DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n'
DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)'
- DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)'
+ DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例摘录(\d+):\s*```\s*(.*?)```\s*(?=\#\#\#|\Z)' # noqa: E501
def __init__(self,
+ api_model: str = 'gpt-4o',
query_entities: List[str] = [],
query_attributes: List[str] = [],
- api_model: str = 'gpt-4o',
*,
- entity_key: str = Fields.main_entity,
- attribute_key: str = Fields.attribute,
- attribute_desc_key: str = Fields.attribute_description,
- support_text_key: str = Fields.attribute_support_text,
+ entity_key: str = Fields.main_entities,
+ attribute_key: str = Fields.attributes,
+ attribute_desc_key: str = Fields.attribute_descriptions,
+ support_text_key: str = Fields.attribute_support_texts,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt_template: Optional[str] = None,
@@ -60,9 +62,9 @@ def __init__(self,
**kwargs):
"""
Initialization method.
+ :param api_model: API model name.
:param query_entities: Entity list to be queried.
:param query_attributes: Attribute list to be queried.
- :param api_model: API model name.
:param entity_key: The field name to store the given main entity for
attribute extraction. It's "__dj__entity__" in default.
:param entity_attribute_key: The field name to store the given
@@ -135,7 +137,7 @@ def parse_output(self, raw_output, attribute_name):
return attribute, demos
- def _process_single_sample(self, text='', rank=None):
+ def _process_single_text(self, text='', rank=None):
client = get_model(self.model_key, rank=rank)
entities, attributes, descs, demo_lists = [], [], [], []
@@ -168,31 +170,17 @@ def _process_single_sample(self, text='', rank=None):
return entities, attributes, descs, demo_lists
- def process_batched(self, samples, rank=None):
-
- sample_num = len(samples[self.text_key])
+ def process_single(self, sample, rank=None):
- entities, attributes, descs, demo_lists = [], [], [], []
- for text in samples[self.text_key]:
- res = self._process_single_sample(text, rank=rank)
- cur_ents, cur_attrs, cur_descs, cur_demos = res
- entities.append(cur_ents)
- attributes.append(cur_attrs)
- descs.append(cur_descs)
- demo_lists.append(cur_demos)
+ res = self._process_single_text(sample[self.text_key], rank=rank)
+ entities, attributes, descs, demo_lists = res
if self.drop_text:
- samples.pop(self.text_key)
-
- for key in samples:
- samples[key] = [[samples[key][i]] * len(descs[i])
- for i in range(sample_num)]
- samples[self.entity_key] = entities
- samples[self.attribute_key] = attributes
- samples[self.attribute_desc_key] = descs
- samples[self.support_text_key] = demo_lists
+ sample.pop(self.text_key)
- for key in samples:
- samples[key] = list(chain(*samples[key]))
+ sample[self.entity_key] = entities
+ sample[self.attribute_key] = attributes
+ sample[self.attribute_desc_key] = descs
+ sample[self.support_text_key] = demo_lists
- return samples
+ return sample
diff --git a/data_juicer/ops/mapper/extract_support_text_mapper.py b/data_juicer/ops/mapper/extract_support_text_mapper.py
new file mode 100644
index 000000000..34bdbe653
--- /dev/null
+++ b/data_juicer/ops/mapper/extract_support_text_mapper.py
@@ -0,0 +1,132 @@
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_access, nested_set
+from data_juicer.utils.constant import Fields
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'extract_support_text_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class ExtractSupportTextMapper(Mapper):
+ """
+ Extract support sub text for a summary.
+ """
+
+ DEFAULT_SYSTEM_PROMPT = ('你将扮演一个文本摘录助手的角色。你的主要任务是基于给定'
+ '的文章(称为“原文”)以及对原文某个部分的简短描述或总结'
+ '(称为“总结”),准确地识别并提取出与该总结相对应的原文'
+ '片段。\n'
+ '要求:\n'
+ '- 你需要尽可能精确地匹配到最符合总结内容的那部分内容\n'
+ '- 如果存在多个可能的答案,请选择最贴近总结意思的那个\n'
+ '- 下面是一个例子帮助理解这一过程:\n'
+ '### 原文:\n'
+ '《红楼梦》是中国古典小说四大名著之一,由清代作家曹雪芹创'
+ '作。它讲述了贾宝玉、林黛玉等人的爱情故事及四大家族的兴衰'
+ '历程。书中通过复杂的人物关系展现了封建社会的各种矛盾冲突'
+ '。其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二'
+ '姐之间的争斗,生动描绘了权力争夺下的女性形象。此外,《红'
+ '楼梦》还以其精美的诗词闻名,这些诗词不仅增添了文学色彩,'
+ '也深刻反映了人物的性格特点和命运走向。\n\n'
+ '### 总结:\n'
+ '描述了书中的两个女性角色之间围绕权力展开的竞争。\n\n'
+ '### 原文摘录:\n'
+ '其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二姐'
+ '之间的争斗,生动描绘了权力争夺下的女性形象。')
+ DEFAULT_INPUT_TEMPLATE = ('### 原文:\n{text}\n\n'
+ '### 总结:\n{summary}\n\n'
+ '### 原文摘录:\n')
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ *,
+ summary_key: str = Fields.event_description,
+ support_text_key: str = Fields.support_text,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt: Optional[str] = None,
+ input_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ drop_text: bool = False,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param summary_key: The field name to store the input summary.
+ Support for nested keys such as "__dj__stats__.text_len".
+ It's "__dj__event_description__" in default.
+ :param support_text_key: The field name to store the output
+ support text for the summary. It's "__dj__support_text__" in
+ default.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt: System prompt for the task.
+ :param input_template: Template for building the model input.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param drop_text: If drop the text in the output.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ self.summary_key = summary_key
+ self.support_text_key = support_text_key
+
+ self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+ self.drop_text = drop_text
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ summary = nested_access(sample, self.summary_key)
+ if not isinstance(summary, str):
+ logger.warning('Unvalid input summary!')
+ return sample
+
+ input_prompt = self.input_template.format(text=sample[self.text_key],
+ summary=summary)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+
+ support_text = ''
+ for i in range(self.try_num):
+ try:
+ response = client(messages, **self.sampling_params)
+ support_text = response.strip()
+ if len(support_text) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+ # default to summary if return None
+ if not support_text:
+ support_text = summary
+
+ sample = nested_set(sample, self.support_text_key, support_text)
+ return sample
diff --git a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py
index 6f5ad7dab..0c0d084b3 100644
--- a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py
+++ b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py
@@ -9,7 +9,7 @@
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
-from ..base_op import OPERATORS, UNFORKABLE, Mapper
+from ..base_op import OPERATORS, Mapper
torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')
@@ -19,7 +19,6 @@
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class GenerateQAFromExamplesMapper(Mapper):
"""
diff --git a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py
index 248dba428..0f3a1cfef 100644
--- a/data_juicer/ops/mapper/generate_qa_from_text_mapper.py
+++ b/data_juicer/ops/mapper/generate_qa_from_text_mapper.py
@@ -3,7 +3,7 @@
from loguru import logger
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
+from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
@@ -14,7 +14,6 @@
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class GenerateQAFromTextMapper(Mapper):
"""
diff --git a/data_juicer/ops/mapper/image_captioning_mapper.py b/data_juicer/ops/mapper/image_captioning_mapper.py
index 0bc486193..98bb3ad7c 100644
--- a/data_juicer/ops/mapper/image_captioning_mapper.py
+++ b/data_juicer/ops/mapper/image_captioning_mapper.py
@@ -81,6 +81,8 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '16GB')
+
super().__init__(*args, **kwargs)
if keep_candidate_mode not in [
diff --git a/data_juicer/ops/mapper/image_diffusion_mapper.py b/data_juicer/ops/mapper/image_diffusion_mapper.py
index c53d6f56d..53e315844 100644
--- a/data_juicer/ops/mapper/image_diffusion_mapper.py
+++ b/data_juicer/ops/mapper/image_diffusion_mapper.py
@@ -91,6 +91,7 @@ def __init__(self,
:param hf_img2seq: model name on huggingface to generate caption if
caption_key is None.
"""
+ kwargs.setdefault('mem_required', '8GB')
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())
self.strength = strength
diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py
index d47fbf0ef..dc2099b78 100644
--- a/data_juicer/ops/mapper/image_tagging_mapper.py
+++ b/data_juicer/ops/mapper/image_tagging_mapper.py
@@ -7,7 +7,7 @@
from data_juicer.utils.mm_utils import load_data_with_context, load_image
from data_juicer.utils.model_utils import get_model, prepare_model
-from ..base_op import OPERATORS, UNFORKABLE, Mapper
+from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_IMAGES
torch = LazyLoader('torch', 'torch')
@@ -16,6 +16,7 @@
OP_NAME = 'image_tagging_mapper'
+@TAGGING_OPS.register_module(OP_NAME)
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
@@ -36,6 +37,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '9GB')
super().__init__(*args, **kwargs)
self.model_key = prepare_model(
model_type='recognizeAnything',
@@ -46,12 +48,13 @@ def __init__(self,
def process_single(self, sample, rank=None, context=False):
# check if it's generated already
- if self.tag_field_name in sample:
+ if self.tag_field_name in sample[Fields.meta]:
return sample
# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
- sample[self.tag_field_name] = np.array([[]], dtype=np.str_)
+ sample[Fields.meta][self.tag_field_name] = np.array([[]],
+ dtype=np.str_)
return sample
# load images
@@ -74,5 +77,5 @@ def process_single(self, sample, rank=None, context=False):
sorted_word_list = [item for item, _ in word_count.most_common()]
image_tags.append(np.array(sorted_word_list, dtype=np.str_))
- sample[self.tag_field_name] = image_tags
+ sample[Fields.meta][self.tag_field_name] = image_tags
return sample
diff --git a/data_juicer/ops/mapper/optimize_qa_mapper.py b/data_juicer/ops/mapper/optimize_qa_mapper.py
index 3563a112b..974730ec5 100644
--- a/data_juicer/ops/mapper/optimize_qa_mapper.py
+++ b/data_juicer/ops/mapper/optimize_qa_mapper.py
@@ -3,7 +3,7 @@
from loguru import logger
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
+from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
@@ -14,7 +14,6 @@
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class OptimizeQAMapper(Mapper):
"""
diff --git a/data_juicer/ops/mapper/optimize_query_mapper.py b/data_juicer/ops/mapper/optimize_query_mapper.py
index dd227b4c1..9ccd84bb1 100644
--- a/data_juicer/ops/mapper/optimize_query_mapper.py
+++ b/data_juicer/ops/mapper/optimize_query_mapper.py
@@ -1,11 +1,10 @@
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
+from data_juicer.ops.base_op import OPERATORS
from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper
OP_NAME = 'optimize_query_mapper'
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class OptimizeQueryMapper(OptimizeQAMapper):
"""
diff --git a/data_juicer/ops/mapper/optimize_response_mapper.py b/data_juicer/ops/mapper/optimize_response_mapper.py
index 158159a9d..f6026b8dc 100644
--- a/data_juicer/ops/mapper/optimize_response_mapper.py
+++ b/data_juicer/ops/mapper/optimize_response_mapper.py
@@ -1,11 +1,10 @@
-from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
+from data_juicer.ops.base_op import OPERATORS
from data_juicer.ops.mapper.optimize_qa_mapper import OptimizeQAMapper
OP_NAME = 'optimize_response_mapper'
# TODO: Extend LLM-based OPs into API-based implementation.
-@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class OptimizeResponseMapper(OptimizeQAMapper):
"""
diff --git a/data_juicer/ops/mapper/relation_identity_mapper.py b/data_juicer/ops/mapper/relation_identity_mapper.py
new file mode 100644
index 000000000..29994d744
--- /dev/null
+++ b/data_juicer/ops/mapper/relation_identity_mapper.py
@@ -0,0 +1,155 @@
+import re
+from typing import Dict, Optional
+
+from loguru import logger
+from pydantic import PositiveInt
+
+from data_juicer.ops.base_op import OPERATORS, Mapper
+from data_juicer.utils.common_utils import nested_access, nested_set
+from data_juicer.utils.model_utils import get_model, prepare_model
+
+OP_NAME = 'relation_identity_mapper'
+
+
+# TODO: LLM-based inference.
+@OPERATORS.register_module(OP_NAME)
+class RelationIdentityMapper(Mapper):
+ """
+ identify relation between two entity in the text.
+ """
+
+ DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
+ '给定关于{entity1}和{entity2}的文本信息。'
+ '判断{entity1}和{entity2}之间的关系。\n'
+ '要求:\n'
+ '- 关系用一个或多个词语表示,必要时可以加一个形容词来描述这段关系\n'
+ '- 输出关系时不要参杂任何标点符号\n'
+ '- 需要你进行合理的推理才能得出结论\n'
+ '- 如果两个人物身份是同一个人,输出关系为:另一个身份\n'
+ '- 输出格式为:\n'
+ '分析推理:...\n'
+ '所以{entity2}是{entity1}的:...\n'
+ '- 注意输出的是{entity2}是{entity1}的什么关系,而不是{entity1}是{entity2}的什么关系')
+ DEFAULT_INPUT_TEMPLATE = '关于{entity1}和{entity2}的文本信息:\n```\n{text}\n```\n'
+ DEFAULT_OUTPUT_PATTERN_TEMPLATE = r"""
+ \s*分析推理:\s*(.*?)\s*
+ \s*所以{entity2}是{entity1}的:\s*(.*?)\Z
+ """
+
+ def __init__(self,
+ api_model: str = 'gpt-4o',
+ source_entity: str = None,
+ target_entity: str = None,
+ input_key: str = None,
+ output_key: str = None,
+ *,
+ api_endpoint: Optional[str] = None,
+ response_path: Optional[str] = None,
+ system_prompt_template: Optional[str] = None,
+ input_template: Optional[str] = None,
+ output_pattern_template: Optional[str] = None,
+ try_num: PositiveInt = 3,
+ drop_text: bool = False,
+ model_params: Dict = {},
+ sampling_params: Dict = {},
+ **kwargs):
+ """
+ Initialization method.
+ :param api_model: API model name.
+ :param source_entity: The source entity of the relation to be
+ identified.
+ :param target_entity: The target entity of the relation to be
+ identified.
+ :param input_key: The input field key in the samples. Support for
+ nested keys such as "__dj__stats__.text_len". It is text_key
+ in default.
+ :param output_key: The output field key in the samples. Support
+ for nested keys such as "__dj__stats__.text_len". It is
+ input_key in default.
+ :param api_endpoint: URL endpoint for the API.
+ :param response_path: Path to extract content from the API response.
+ Defaults to 'choices.0.message.content'.
+ :param system_prompt_template: System prompt template for the task.
+ :param input_template: Template for building the model input.
+ :param output_pattern_template: Regular expression template for
+ parsing model output.
+ :param try_num: The number of retry attempts when there is an API
+ call error or output parsing error.
+ :param drop_text: If drop the text in the output.
+ :param model_params: Parameters for initializing the API model.
+ :param sampling_params: Extra parameters passed to the API call.
+ e.g {'temperature': 0.9, 'top_p': 0.95}
+ :param kwargs: Extra keyword arguments.
+ """
+ super().__init__(**kwargs)
+
+ if source_entity is None or target_entity is None:
+ logger.warning('source_entity and target_entity cannot be None')
+
+ self.source_entity = source_entity
+ self.target_entity = target_entity
+
+ self.input_key = input_key or self.text_key
+ self.output_key = output_key or self.input_key
+
+ system_prompt_template = system_prompt_template or \
+ self.DEFAULT_SYSTEM_PROMPT_TEMPLATE
+ self.system_prompt = system_prompt_template.format(
+ entity1=source_entity, entity2=target_entity)
+ self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
+ output_pattern_template = output_pattern_template or \
+ self.DEFAULT_OUTPUT_PATTERN_TEMPLATE
+ self.output_pattern = output_pattern_template.format(
+ entity1=source_entity, entity2=target_entity)
+
+ self.sampling_params = sampling_params
+ self.model_key = prepare_model(model_type='api',
+ model=api_model,
+ endpoint=api_endpoint,
+ response_path=response_path,
+ **model_params)
+
+ self.try_num = try_num
+ self.drop_text = drop_text
+
+ def parse_output(self, raw_output):
+ pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
+ matches = pattern.findall(raw_output)
+
+ relation = ''
+
+ for match in matches:
+ _, relation = match
+ relation = relation.strip()
+
+ return relation
+
+ def process_single(self, sample, rank=None):
+ client = get_model(self.model_key, rank=rank)
+
+ text = nested_access(sample, self.input_key)
+ input_prompt = self.input_template.format(entity1=self.source_entity,
+ entity2=self.target_entity,
+ text=text)
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': input_prompt
+ }]
+ relation = ''
+ for i in range(self.try_num):
+ try:
+ output = client(messages, **self.sampling_params)
+ relation = self.parse_output(output)
+ if len(relation) > 0:
+ break
+ except Exception as e:
+ logger.warning(f'Exception: {e}')
+
+ sample = nested_set(sample, self.output_key, relation)
+ if self.drop_text:
+ sample.pop(self.text_key)
+
+ return sample
diff --git a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py
index 4833409a4..75ffb9b3a 100644
--- a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py
+++ b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py
@@ -32,6 +32,7 @@ def __init__(self, keep_original_sample: bool = True, *args, **kwargs):
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '30GB')
super().__init__(*args, **kwargs)
AUTOINSTALL.check([
'transformers', 'transformers_stream_generator', 'einops',
diff --git a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py
index dbf614510..d4c664c5f 100644
--- a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py
+++ b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py
@@ -108,6 +108,7 @@ def __init__(
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '20GB')
super().__init__(*args, **kwargs)
if keep_candidate_mode not in [
diff --git a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py
index b2f4c8139..67eb7e234 100644
--- a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py
+++ b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py
@@ -81,6 +81,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '40GB')
super().__init__(*args, **kwargs)
AUTOINSTALL.check([
'torch',
diff --git a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py
index 04cd641ab..737626260 100644
--- a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py
+++ b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py
@@ -108,6 +108,7 @@ def __init__(
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '20GB')
super().__init__(*args, **kwargs)
if keep_candidate_mode not in [
diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py
index 763a3381c..7302953f2 100644
--- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py
+++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py
@@ -6,13 +6,14 @@
from data_juicer.utils.mm_utils import extract_audio_from_video
from data_juicer.utils.model_utils import get_model, prepare_model
-from ..base_op import OPERATORS, Mapper
+from ..base_op import OPERATORS, TAGGING_OPS, Mapper
torch = LazyLoader('torch', 'torch')
OP_NAME = 'video_tagging_from_audio_mapper'
+@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class VideoTaggingFromAudioMapper(Mapper):
"""Mapper to generate video tags from audio streams extracted by video
@@ -37,6 +38,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '500MB')
super().__init__(*args, **kwargs)
AUTOINSTALL.check(['torchaudio'])
self.model_key = prepare_model(model_type='huggingface',
@@ -49,12 +51,13 @@ def __init__(self,
def process_single(self, sample, rank=None):
# check if it's generated already
- if self.tag_field_name in sample:
+ if self.tag_field_name in sample[Fields.meta]:
return sample
# there is no video in this sample
if self.video_key not in sample or not sample[self.video_key]:
- sample[self.tag_field_name] = np.array([], dtype=np.str_)
+ sample[Fields.meta][self.tag_field_name] = np.array([],
+ dtype=np.str_)
return sample
# load video paths
@@ -89,5 +92,6 @@ def process_single(self, sample, rank=None):
predicted_tag_id = torch.argmax(logits, dim=-1).item()
predicted_tag = model.config.id2label[predicted_tag_id]
video_audio_tags.append(predicted_tag)
- sample[self.tag_field_name] = np.array(video_audio_tags, dtype=np.str_)
+ sample[Fields.meta][self.tag_field_name] = np.array(video_audio_tags,
+ dtype=np.str_)
return sample
diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py
index 26227738b..31927e1b2 100644
--- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py
+++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py
@@ -10,7 +10,7 @@
load_data_with_context, load_video)
from data_juicer.utils.model_utils import get_model, prepare_model
-from ..base_op import OPERATORS, UNFORKABLE, Mapper
+from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_VIDEOS
ram = LazyLoader('ram', 'ram')
@@ -19,6 +19,7 @@
OP_NAME = 'video_tagging_from_frames_mapper'
+@TAGGING_OPS.register_module(OP_NAME)
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_VIDEOS.register_module(OP_NAME)
@@ -55,6 +56,7 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
+ kwargs.setdefault('mem_required', '9GB')
super().__init__(*args, **kwargs)
if frame_sampling_method not in ['all_keyframes', 'uniform']:
raise ValueError(
@@ -72,12 +74,13 @@ def __init__(self,
def process_single(self, sample, rank=None, context=False):
# check if it's generated already
- if self.tag_field_name in sample:
+ if self.tag_field_name in sample[Fields.meta]:
return sample
# there is no video in this sample
if self.video_key not in sample or not sample[self.video_key]:
- sample[self.tag_field_name] = np.array([[]], dtype=np.str_)
+ sample[Fields.meta][self.tag_field_name] = np.array([[]],
+ dtype=np.str_)
return sample
# load videos
@@ -114,5 +117,5 @@ def process_single(self, sample, rank=None, context=False):
for vid_key in videos:
close_video(videos[vid_key])
- sample[self.tag_field_name] = video_tags
+ sample[Fields.meta][self.tag_field_name] = video_tags
return sample
diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py
index 50116ea0c..5ea9091b0 100644
--- a/data_juicer/utils/auto_install_mapping.py
+++ b/data_juicer/utils/auto_install_mapping.py
@@ -79,5 +79,21 @@
['rouge', 'torch', 'transformers', 'vllm'],
'video_tagging_from_frames_mapper': ['ram', 'torch'],
'text_entity_dependency_filter': ['spacy-pkuseg'],
- 'optimize_response_mapper': ['torch', 'transformers', 'vllm']
+ 'optimize_response_mapper': ['torch', 'transformers', 'vllm'],
+ 'text_chunk_mapper': ['transformers', 'dashscope', 'openai'],
+ 'entity_attribute_aggregator': ['transformers', 'dashscope', 'openai'],
+ 'most_relavant_entities_aggregator':
+ ['transformers', 'dashscope', 'openai'],
+ 'nested_aggregator': ['transformers', 'dashscope', 'openai'],
+ 'calibrate_qa_mapper': ['openai'],
+ 'calibrate_query_mapper': ['openai'],
+ 'calibrate_response_mapper': ['openai'],
+ 'extract_entity_attribute_mapper': ['openai'],
+ 'extract_entity_relation_mapper': ['openai'],
+ 'extract_event_mapper': ['openai'],
+ 'extract_keyword_mapper': ['openai'],
+ 'extract_nickname_mapper': ['openai'],
+ 'extract_support_text_mapper': ['openai'],
+ 'pair_preference_mapper': ['openai'],
+ 'relation_identity_mapper': ['openai'],
}
diff --git a/data_juicer/utils/cache_utils.py b/data_juicer/utils/cache_utils.py
index 7d815db2c..51138d7ed 100644
--- a/data_juicer/utils/cache_utils.py
+++ b/data_juicer/utils/cache_utils.py
@@ -1,4 +1,7 @@
import os
+from functools import wraps
+
+from datasets import disable_caching, enable_caching, is_caching_enabled
# Default cache location
DEFAULT_CACHE_HOME = '~/.cache'
@@ -21,3 +24,47 @@
DEFAULT_DATA_JUICER_MODELS_CACHE)
CACHE_COMPRESS = None
+
+
+class DatasetCacheControl:
+ """Define a range that change the cache state temporarily."""
+
+ def __init__(self, on: bool = False):
+ self.on = on
+
+ def __enter__(self):
+ """
+ Record the original cache state and turn it to the target state.
+ """
+ self.previous_state = is_caching_enabled()
+ if self.on:
+ enable_caching()
+ else:
+ disable_caching()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Restore the original cache state.
+ """
+ if self.previous_state:
+ enable_caching()
+ else:
+ disable_caching()
+
+
+def dataset_cache_control(on):
+ """
+ A more easy-to-use decorator for functions that need to control the cache
+ state temporarily.
+ """
+
+ def dataset_cache_decorator(func):
+
+ @wraps(func)
+ def wrapped_function(*args, **kwargs):
+ with DatasetCacheControl(on=on):
+ return func(*args, **kwargs)
+
+ return wrapped_function
+
+ return dataset_cache_decorator
diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py
index 959831c5d..bd649bb96 100644
--- a/data_juicer/utils/common_utils.py
+++ b/data_juicer/utils/common_utils.py
@@ -1,6 +1,8 @@
+import hashlib
import sys
import numpy as np
+from loguru import logger
def stats_to_number(s, reverse=True):
@@ -21,6 +23,122 @@ def stats_to_number(s, reverse=True):
return sys.maxsize
+def dict_to_hash(input_dict: dict, hash_length=None):
+ """
+ hash a dict to a string with length hash_length
+
+ :param input_dict: the given dict
+ """
+ sorted_items = sorted(input_dict.items())
+ dict_string = str(sorted_items).encode()
+ hasher = hashlib.sha256()
+ hasher.update(dict_string)
+ hash_value = hasher.hexdigest()
+ if hash_length:
+ hash_value = hash_value[:hash_length]
+ return hash_value
+
+
+def nested_access(data, path, digit_allowed=True):
+ """
+ Access nested data using a dot-separated path.
+
+ :param data: A dictionary or a list to access the nested data from.
+ :param path: A dot-separated string representing the path to access.
+ This can include numeric indices when accessing list
+ elements.
+ :param digit_allowed: Allow transfering string to digit.
+ :return: The value located at the specified path, or raises a KeyError
+ or IndexError if the path does not exist.
+ """
+ keys = path.split('.')
+ for key in keys:
+ # Convert string keys to integers if they are numeric
+ key = int(key) if key.isdigit() and digit_allowed else key
+ try:
+ data = data[key]
+ except Exception:
+ logger.warning(f'Unaccessible dot-separated path: {path}!')
+ return None
+ return data
+
+
+def nested_set(data: dict, path: str, val):
+ """
+ Set the val to the nested data in the dot-separated path.
+
+ :param data: A dictionary with nested format.
+ :param path: A dot-separated string representing the path to set.
+ This can include numeric indices when setting list
+ elements.
+ :return: The nested data after the val set.
+ """
+ keys = path.split('.')
+ cur = data
+ for key in keys[:-1]:
+ if key not in cur:
+ cur[key] = {}
+ cur = cur[key]
+ cur[keys[-1]] = val
+ return data
+
+
+def is_string_list(var):
+ """
+ return if the var is list of string.
+
+ :param var: input variance
+ """
+ return isinstance(var, list) and all(isinstance(it, str) for it in var)
+
+
+def avg_split_string_list_under_limit(str_list: list,
+ token_nums: list,
+ max_token_num=None):
+ """
+ Split the string list to several sub str_list, such that the total
+ token num of each sub string list is less than max_token_num, keeping
+ the total token nums of sub string lists are similar.
+
+ :param str_list: input string list.
+ :param token_nums: token num of each string list.
+ :param max_token_num: max token num of each sub string list.
+ """
+ if max_token_num is None:
+ return [str_list]
+
+ if len(str_list) != len(token_nums):
+ logger.warning('The length of str_list and token_nums must be equal!')
+ return [str_list]
+
+ total_num = sum(token_nums)
+ if total_num <= max_token_num:
+ return [str_list]
+
+ group_num = total_num // max_token_num + 1
+ avg_num = total_num / group_num
+ res = []
+ cur_list = []
+ cur_sum = 0
+ for text, token_num in zip(str_list, token_nums):
+ if token_num > max_token_num:
+ logger.warning(
+ 'Token num is greater than max_token_num in one sample!')
+ if cur_sum + token_num > max_token_num and cur_list:
+ res.append(cur_list)
+ cur_list = []
+ cur_sum = 0
+ cur_list.append(text)
+ cur_sum += token_num
+ if cur_sum > avg_num:
+ res.append(cur_list)
+ cur_list = []
+ cur_sum = 0
+ if cur_list:
+ res.append(cur_list)
+ return res
+
+
def is_float(s):
try:
float(s)
diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py
index 8bc78ad5a..30686693e 100644
--- a/data_juicer/utils/constant.py
+++ b/data_juicer/utils/constant.py
@@ -16,13 +16,17 @@ class Fields(object):
context = DEFAULT_PREFIX + 'context__'
suffix = DEFAULT_PREFIX + 'suffix__'
- video_frames = DEFAULT_PREFIX + 'video_frames__'
+ # tags in meta
# video_frame_tags
video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__'
+ # video_audio_tags
video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__'
# image_tags
image_tags = DEFAULT_PREFIX + 'image_tags__'
+ # video_frames
+ video_frames = DEFAULT_PREFIX + 'video_frames__'
+
# the name of the original file from which this sample was derived.
source_file = DEFAULT_PREFIX + 'source_file__'
@@ -33,14 +37,14 @@ class Fields(object):
event_description = DEFAULT_PREFIX + 'event_description__'
# # a list of characters relevant to the event
relevant_characters = DEFAULT_PREFIX + 'relevant_characters__'
- # # the given main entity for attribute extraction
- main_entity = DEFAULT_PREFIX + 'main_entity__'
- # # the given attribute to be extracted
- attribute = DEFAULT_PREFIX + 'attribute__'
- # # the extracted attribute description
- attribute_description = DEFAULT_PREFIX + 'attribute_description__'
- # # extract from raw data for support the attribute
- attribute_support_text = DEFAULT_PREFIX + 'attribute_support_text__'
+ # # the given main entities for attribute extraction
+ main_entities = DEFAULT_PREFIX + 'main_entities__'
+ # # the given attributes to be extracted
+ attributes = DEFAULT_PREFIX + 'attributes__'
+ # # the extracted attribute descriptions
+ attribute_descriptions = DEFAULT_PREFIX + 'attribute_descriptions__'
+ # # extract from raw datas for support the attribute
+ attribute_support_texts = DEFAULT_PREFIX + 'attribute_support_texts__'
# # the nickname relationship
nickname = DEFAULT_PREFIX + 'nickname__'
# # the entity for knowledge graph
@@ -65,6 +69,8 @@ class Fields(object):
relation_strength = DEFAULT_PREFIX + 'relation_strength__'
# # the keyword in a text
keyword = DEFAULT_PREFIX + 'keyword__'
+ # # support text
+ support_text = DEFAULT_PREFIX + 'support_text__'
class StatsKeysMeta(type):
diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py
index e2fc241cd..7a8618660 100644
--- a/data_juicer/utils/file_utils.py
+++ b/data_juicer/utils/file_utils.py
@@ -1,6 +1,5 @@
import asyncio
import copy
-import hashlib
import os
import re
import shutil
@@ -10,6 +9,7 @@
from datasets.utils.extract import ZstdExtractor as Extractor
+from data_juicer.utils.common_utils import dict_to_hash
from data_juicer.utils.constant import DEFAULT_PREFIX, Fields
@@ -127,22 +127,6 @@ def add_suffix_to_filename(filename, suffix):
return new_name
-def dict_to_hash(input_dict, hash_length=None):
- """
- hash a dict to a string with length hash_length
-
- :param input_dict: the given dict
- """
- sorted_items = sorted(input_dict.items())
- dict_string = str(sorted_items).encode()
- hasher = hashlib.sha256()
- hasher.update(dict_string)
- hash_value = hasher.hexdigest()
- if hash_length:
- hash_value = hash_value[:hash_length]
- return hash_value
-
-
def create_directory_if_not_exists(directory_path):
"""
create a directory if not exists, this function is process safe
diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py
index 305145e82..94b4440eb 100644
--- a/data_juicer/utils/model_utils.py
+++ b/data_juicer/utils/model_utils.py
@@ -11,6 +11,7 @@
from loguru import logger
from data_juicer import cuda_device_count
+from data_juicer.utils.common_utils import nested_access
from data_juicer.utils.lazy_loader import AUTOINSTALL, LazyLoader
from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC
@@ -167,30 +168,11 @@ def __call__(self, messages, **kwargs):
stream=stream,
stream_cls=stream_cls)
result = response.json()
- return self._nested_access(result, self.response_path)
+ return nested_access(result, self.response_path)
except Exception as e:
logger.exception(e)
return ''
- @staticmethod
- def _nested_access(data, path):
- """
- Access nested data using a dot-separated path.
-
- :param data: A dictionary or a list to access the nested data from.
- :param path: A dot-separated string representing the path to access.
- This can include numeric indices when accessing list
- elements.
- :return: The value located at the specified path, or raises a KeyError
- or IndexError if the path does not exist.
- """
- keys = path.split('.')
- for key in keys:
- # Convert string keys to integers if they are numeric
- key = int(key) if key.isdigit() else key
- data = data[key]
- return data
-
@staticmethod
def _filter_arguments(func, args_dict):
"""
@@ -221,19 +203,18 @@ def prepare_api_model(model,
return_processor=False,
processor_config=None,
**model_params):
- """
- Creates an instance of the APIModel for interacting with OpenAI-like APIs.
+ """Creates a callable API model for interacting with OpenAI-compatible API.
+ The callable supports custom response parsing and works with proxy servers
+ that may be incompatible.
- :param model: The name of the model to be used for making API calls.
+ :param model: The name of the model to interact with.
:param endpoint: The URL endpoint for the API. If provided as a relative
path, it will be appended to the base URL (defined by the
`OPENAI_BASE_URL` environment variable or through an additional
`base_url` parameter). By default, it is set to
'/chat/completions' for OpenAI compatibility.
- :param response_path: A dot-separated string specifying the path to
- extract desired content from the API response. The default value is
- 'choices.0.message.content', which corresponds to the typical
- structure of an OpenAI API response.
+ :param response_path: The dot-separated path to extract desired content
+ from the API response. Defaults to 'choices.0.message.content'.
:param return_processor: A boolean flag indicating whether to return a
processor along with the model. The processor can be used for tasks
like tokenization or encoding. Defaults to False.
@@ -279,8 +260,8 @@ def get_processor():
"- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501
)
- if processor_config is not None \
- and 'pretrained_model_name_or_path' in processor_config:
+ if processor_config is not None and \
+ 'pretrained_model_name_or_path' in processor_config:
processor = transformers.AutoProcessor.from_pretrained(
**processor_config)
else:
diff --git a/demos/role_playing_system_prompt/README_ZH.md b/demos/role_playing_system_prompt/README_ZH.md
new file mode 100644
index 000000000..956c335bb
--- /dev/null
+++ b/demos/role_playing_system_prompt/README_ZH.md
@@ -0,0 +1,49 @@
+# 为LLM构造角色扮演的system prompt
+
+在该Demo中,我们展示了如何通过Data-Juicer的菜谱,生成让LLM扮演剧本中给定角色的system prompt。我们这里以《莲花楼》为例。
+
+## 数据准备
+将《莲花楼》按章节划分,按顺序每个章节对应Data-Juicer的一个sample,放到“text”关键字下。如下json格式:
+```json
+[
+ {'text': '第一章内容'},
+ {'text': '第二章内容'},
+ {'text': '第三章内容'},
+ ...
+]
+```
+
+## 执行
+```shell
+python tools/process_data.py --config ./demos/role_playing_system_prompt/role_playing_system_prompt_test.yaml
+```
+
+## 生成样例
+
+```text
+扮演李莲花与用户进行对话。
+# 角色身份
+原名李相夷,曾是武林盟主,创立四顾门。十年前因中碧茶之毒,隐姓埋名,成为莲花楼的老板,过着市井生活。
+# 角色经历
+李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。
+# 角色性格
+李莲花是一个机智、幽默、善于观察和推理的人物。他表面上看似随和、悠闲,甚至有些懒散,但实际上心思缜密,洞察力极强。他不仅具备敏锐的观察力和独特的思维方式,还拥有深厚的内功和高超的医术。他对朋友忠诚,愿意为了保护他们不惜一切代价,同时在面对敌人时毫不手软。尽管内心充满正义感和责任感,但他选择远离江湖纷争,追求宁静自在的生活。他对过去的自己(李相夷)有着深刻的反思,对乔婉娩的感情复杂,既有愧疚也有关怀。李莲花能够在复杂的环境中保持冷静,巧妙地利用智慧和技能解决问题,展现出非凡的勇气和决心。
+# 角色能力
+李莲花是一位智慧与武艺兼备的高手,拥有深厚的内力、高超的医术和敏锐的洞察力。他擅长使用轻功、剑术和特殊武器,如婆娑步和少师剑,能够在关键时刻化解危机。尽管身体状况不佳,他仍能通过内功恢复体力,运用智谋和技巧应对各种挑战。他在江湖中身份多变,既能以游医身份逍遥自在,也能以李相夷的身份化解武林危机。
+# 人际关系
+方多病 (称呼:方小宝、方大少爷)李莲花的徒弟。百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。
+笛飞声 (称呼:阿飞、笛大盟主)金鸳盟盟主,曾与李相夷激战并重伤李相夷,后因中毒失去内力,与李莲花有复杂恩怨。笛飞声是金鸳盟盟主,十年前因与李相夷一战成名。他利用单孤刀的弟子朴锄山引诱李相夷,最终重伤李相夷,但自己也被李相夷钉在桅杆上。十年后,笛飞声恢复内力,重新执掌金鸳盟,与角丽谯合作,试图利用罗摩天冰和业火痋控制武林。在与李莲花和方多病的多次交手中,笛飞声多次展现强大实力,但也多次被李莲花等人挫败。最终,笛飞声在与李莲花的对决中被制住,但并未被杀死。笛飞声与李莲花约定在东海再战,但李莲花因中毒未赴约。笛飞声在东海之战中并未出现,留下了许多未解之谜。
+乔婉娩 (称呼:乔姑娘)李莲花的前女友。四顾门前任门主李相夷的爱人,现任门主肖紫衿的妻子,江湖中知名侠女。乔婉娩是四顾门的重要人物,与李相夷有着复杂的情感纠葛。在李相夷失踪后,乔婉娩嫁给了肖紫衿,但内心始终未能忘记李相夷。在李莲花(即李相夷)重新出现后,乔婉娩通过种种线索确认了他的身份,但最终选择支持肖紫衿,维护四顾门的稳定。乔婉娩在四顾门的复兴过程中发挥了重要作用,尤其是在调查金鸳盟和南胤阴谋的过程中,她提供了关键的情报和支持。尽管内心充满矛盾,乔婉娩最终决定与肖紫衿共同面对江湖的挑战,展现了她的坚强和智慧。
+肖紫衿 (称呼:紫衿)李莲花的门主兼旧识。四顾门现任门主,曾与李相夷有深厚恩怨,后与乔婉娩成婚。肖紫衿是四顾门的重要人物,与李相夷和乔婉娩关系密切。他曾在李相夷的衣冠冢前与李莲花对峙,质问他为何归来,并坚持要与李莲花决斗。尽管李莲花展示了武功,但肖紫衿最终选择不与他继续争斗。肖紫衿在乔婉娩与李相夷的误会中扮演了关键角色,一度因嫉妒取消了与乔婉娩的婚事。后来,肖紫衿在乔婉娩的支持下担任四顾门的新门主,致力于复兴四顾门。在与单孤刀的对抗中,肖紫衿展现了坚定的决心和领导能力,最终带领四顾门取得了胜利。
+单孤刀 (称呼:师兄)李莲花的师兄兼敌人。单孤刀,李莲花的师兄,四顾门创始人之一,因不满李相夷与金鸳盟签订协定而独自行动,最终被金鸳盟杀害。单孤刀是李莲花的师兄,与李相夷一同创立四顾门。单孤刀性格争强好胜,难以容人,最终因不满李相夷与金鸳盟签订协定,决定独自行动。单孤刀被金鸳盟杀害,李相夷得知后悲愤交加,誓言与金鸳盟不死不休。单孤刀的死成为李相夷心中的一大阴影,多年后李莲花在调查中发现单孤刀并非真正死亡,而是诈死以实现自己的野心。最终,单孤刀在与李莲花和方多病的对决中失败,被轩辕箫的侍卫杀死。
+# 语言风格
+李莲花的语言风格幽默诙谐,充满智慧和机智,善于用轻松的语气化解紧张的气氛。他常用比喻、反讽和夸张来表达复杂的观点,同时在关键时刻能简洁明了地揭示真相。他的言语中带有调侃和自嘲,但又不失真诚和温情,展现出一种从容不迫的态度。无论是面对朋友还是敌人,李莲花都能以幽默和智慧赢得尊重。
+供参考语言风格的部分李莲花台词:
+李莲花:你问我干吗?该启程了啊。
+李莲花:说起师门,你怎么也算云隐山一份子啊?不如趁今日叩拜了你师祖婆婆,再正儿八经给我这个师父磕头敬了茶,往后我守山中、你也尽心在跟前罢?
+李莲花:恭贺肖大侠和乔姑娘,喜结连理。
+李莲花淡淡一笑:放心吧,该看到的,都看到了。
+李莲花:如果现在去百川院,你家旺福就白死了。
+```
+
+
diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml
new file mode 100644
index 000000000..eadac45da
--- /dev/null
+++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml
@@ -0,0 +1,57 @@
+# global parameters
+project_name: 'role-play-demo-process'
+dataset_path: 'path_to_the_lianhualou_novel_json_file'
+np: 1 # number of subprocess to process your dataset
+
+export_path: 'path_to_output_jsonl_file'
+
+# process schedule
+process:
+# # chunk the novel if necessary
+# - text_chunk_mapper:
+# max_len: 8000
+# split_pattern: '\n\n'
+# overlap_len: 400
+# tokenizer: 'qwen2.5-72b-instruct'
+# trust_remote_code: True
+ # extract language_style, role_charactor and role_skill
+ - extract_entity_attribute_mapper:
+ api_model: 'qwen2.5-72b-instruct'
+ query_entities: ['李莲花']
+ query_attributes: ["角色性格", "角色武艺和能力", "语言风格"]
+ # extract nickname
+ - extract_nickname_mapper:
+ api_model: 'qwen2.5-72b-instruct'
+ # extract events
+ - extract_event_mapper:
+ api_model: 'qwen2.5-72b-instruct'
+ index_key: 'chunk_id' # chunk_id for deduplicating attributes and nicknames
+ # group all events
+ - naive_grouper:
+ # role experiences summary from events
+ - entity_attribute_aggregator:
+ api_model: 'qwen2.5-72b-instruct'
+ entity: '李莲花'
+ attribute: '身份背景'
+ input_key: '__dj__event_description__'
+ output_key: '__dj__role_background__'
+ word_limit: 50
+ - entity_attribute_aggregator:
+ api_model: 'qwen2.5-72b-instruct'
+ entity: '李莲花'
+ attribute: '主要经历'
+ input_key: '__dj__event_description__'
+ output_key: '__dj__role_experience__'
+ word_limit: 150
+ # most relavant roles summary from events
+ - most_relavant_entities_aggregator:
+ api_model: 'qwen2.5-72b-instruct'
+ entity: '李莲花'
+ query_entity_type: '人物'
+ input_key: '__dj__event_description__'
+ output_key: '__dj__important_relavant_roles__'
+ # generate the system prompt
+ - python_file_mapper:
+ file_path: 'path_to_system_prompt_gereration_python_file'
+ function_name: 'get_system_prompt'
+
\ No newline at end of file
diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py
new file mode 100644
index 000000000..dc2738900
--- /dev/null
+++ b/demos/role_playing_system_prompt/system_prompt_generator.py
@@ -0,0 +1,192 @@
+import random
+
+from itertools import chain
+from loguru import logger
+from collections import Counter
+
+from data_juicer.ops.aggregator import NestedAggregator
+from data_juicer.ops.aggregator import EntityAttributeAggregator
+from data_juicer.ops.mapper import RelationIdentityMapper
+from data_juicer.utils.constant import Fields
+
+api_model = 'qwen2.5-72b-instruct'
+
+main_entity = "李莲花"
+query_attributes = ["语言风格", "角色性格", "角色武艺和能力"]
+system_prompt_key = '__dj__system_prompt__'
+example_num_limit = 5
+max_relavant_roles_num = 5
+
+role_info_template = "# {entity}\n## 身份背景\n{identity}\n## 人物经历\n{experience}"
+relation_identity_text_template = """
+{source_entity}的信息:
+{source_entity_info}
+{target_entity}的信息:
+{target_entity_info}
+{source_entity}对{target_entity}的称呼:{nicknames}
+"""
+
+nested_sum = NestedAggregator(
+ api_model=api_model,
+ try_num=3)
+
+def dedup_sort_val_by_chunk_id(sample, id_key, val_key):
+ chunk_ids = sample[id_key]
+ vals = sample[val_key]
+ id_to_val = {}
+ for id, val in zip(chunk_ids, vals):
+ id_to_val[id] = val
+ sorted_ids = list(id_to_val.keys())
+ sorted_ids.sort()
+ sorted_vals = [id_to_val[id] for id in sorted_ids]
+ return list(chain(*sorted_vals))
+
+def get_attributes(sample):
+ main_entities = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.main_entities)
+ attribute_names = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attributes)
+ attribute_descs = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_descriptions)
+ attribute_support_texts = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_support_texts)
+ attributes = {}
+ support_texts = {}
+ for attr in query_attributes:
+ attributes[attr] = []
+ support_texts[attr] = []
+ for entity, attr_name, attr_desc, sub_support_texts in \
+ zip(main_entities, attribute_names, attribute_descs, attribute_support_texts):
+ if entity == main_entity and attr_name in query_attributes:
+ attributes[attr_name].append(attr_desc)
+ support_texts[attr_name].append(sub_support_texts)
+ return attributes, support_texts
+
+def get_nicknames(sample):
+ nicknames = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.nickname)
+ nickname_map = {}
+ for nr in nicknames:
+ if nr[Fields.source_entity] == main_entity:
+ role_name = nr[Fields.target_entity]
+ if role_name not in nickname_map:
+ nickname_map[role_name] = []
+ nickname_map[role_name].append(nr[Fields.relation_description])
+
+ max_nums = 3
+ for role_name, nickname_list in nickname_map.items():
+ th = (len(nickname_list)+1) // 2
+ count = Counter(nickname_list)
+ sorted_items = sorted(count.items(), key=lambda x: x[1], reverse=True)
+ most_common_nicknames = []
+ idx = 0
+ while th > 0 and idx < min(len(sorted_items), max_nums):
+ most_common_nicknames.append(sorted_items[idx][0])
+ th -= sorted_items[idx][1]
+ idx += 1
+ nickname_map[role_name] = most_common_nicknames
+ return nickname_map
+
+
+def get_system_prompt(sample):
+
+ main_role_identity = sample['__dj__role_background__']
+ main_role_experience = sample['__dj__role_experience__']
+ attributes, support_texts = get_attributes(sample)
+ main_role_character = nested_sum.recursive_summary(attributes['角色性格'])
+ main_role_skill = nested_sum.recursive_summary(attributes['角色武艺和能力'])
+ main_role_lang_style = nested_sum.recursive_summary(attributes['语言风格'])
+ lang_style_examples = list(chain(*support_texts['语言风格']))
+ lang_style_example_num = min(example_num_limit, len(lang_style_examples))
+ lang_style_examples = random.sample(lang_style_examples, lang_style_example_num)
+
+ main_role_info = role_info_template.format(
+ entity=main_entity,
+ identity=main_role_identity,
+ experience=main_role_experience
+ )
+
+ nicknames = get_nicknames(sample)
+
+ relation_detail = ""
+ relavant_roles = sample['__dj__important_relavant_roles__']
+ for role_name in relavant_roles[:max_relavant_roles_num]:
+ if role_name == main_entity:
+ continue
+
+ # get sub role identity
+ op = EntityAttributeAggregator(
+ api_model=api_model,
+ entity=role_name,
+ attribute='身份背景',
+ input_key='__dj__event_description__',
+ output_key='__dj__role_background__',
+ word_limit=30
+ )
+ sample = op.process_single(sample)
+ role_identity = sample['__dj__role_background__'].replace('\n', '')
+
+ # get sub role experience
+ op = EntityAttributeAggregator(
+ api_model=api_model,
+ entity=role_name,
+ attribute='主要经历',
+ input_key='__dj__event_description__',
+ output_key='__dj__role_experience__',
+ word_limit=100
+ )
+ sample = op.process_single(sample)
+ role_experience = sample['__dj__role_experience__'].replace('\n', '')
+
+ # get relation identity with main role
+ role_info = role_info_template.format(
+ entity=role_name,
+ identity=role_identity,
+ experience=role_experience
+ )
+ op = RelationIdentityMapper(
+ api_model=api_model,
+ source_entity=main_entity,
+ target_entity=role_name,
+ output_key='__dj__relation_identity__'
+ )
+ if role_name in nicknames:
+ cur_nicknames = '、'.join(nicknames[role_name])
+ else:
+ cur_nicknames = role_name
+ text = relation_identity_text_template.format(
+ source_entity=main_entity,
+ source_entity_info=main_role_info,
+ target_entity=role_name,
+ target_entity_info=role_info,
+ nicknames = cur_nicknames
+ )
+ tmp_sample = {'text': text}
+ tmp_sample = op.process_single(tmp_sample)
+ relation = tmp_sample['__dj__relation_identity__']
+
+ relation_detail += f"\n{role_name} (称呼:{cur_nicknames})"
+ if relation:
+ relation_detail += f"{main_entity}的{relation}。"
+ relation_detail += f"{role_identity}{role_experience}".replace('\n', '')
+
+ full_system_prompt = f"""扮演{main_entity}与用户进行对话。\n"""
+ full_system_prompt += """# 角色身份\n"""
+ full_system_prompt += main_role_identity.replace('\n', '')
+ full_system_prompt += """\n# 角色经历\n"""
+ full_system_prompt += main_role_experience.replace('\n', '')
+ full_system_prompt += """\n# 角色性格\n"""
+ full_system_prompt += main_role_character.replace('\n', '')
+ full_system_prompt += """\n# 角色能力\n"""
+ full_system_prompt += main_role_skill.replace('\n', '')
+
+ full_system_prompt += """\n# 人际关系"""
+ full_system_prompt += relation_detail
+
+ full_system_prompt += """\n# 语言风格\n"""
+ full_system_prompt += main_role_lang_style.replace('\n', '')
+ full_system_prompt += f"""\n供参考语言风格的部分{main_entity}台词:\n"""
+ full_system_prompt += "\n````\n"
+ full_system_prompt += '\n'.join(lang_style_examples)
+ full_system_prompt += "\n````\n"
+
+ logger.info(full_system_prompt)
+
+ sample[system_prompt_key] = full_system_prompt
+
+ return sample
\ No newline at end of file
diff --git a/docs/Operators.md b/docs/Operators.md
index 04a2da380..fe3c6d94d 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -11,10 +11,12 @@ The operators in Data-Juicer are categorized into 5 types.
| Type | Number | Description |
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data |
-| [ Mapper ]( #mapper ) | 61 | Edits and transforms samples |
+| [ Mapper ]( #mapper ) | 63 | Edits and transforms samples |
| [ Filter ]( #filter ) | 44 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |
+| [ Grouper ]( #grouper ) | 2 | Group samples to batched samples |
+| [ Aggregator ]( #aggregator ) | 3 | Aggregate for batched samples, such as summary or conclusion |
All the specific operators are listed below, each featured with several capability tags.
@@ -72,6 +74,7 @@ All the specific operators are listed below, each featured with several capabili
| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relevant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) |
| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Generate keywords for the text. | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) |
| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract nickname relationship in the text. | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) |
+| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract support sub text for a summary. | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) |
| fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) |
| generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs based on examples. | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) |
| generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs from text. | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) |
@@ -90,6 +93,7 @@ All the specific operators are listed below, each featured with several capabili
| punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) |
| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) |
| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) |
+| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) |
| remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) |
| remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) |
| remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) |
@@ -190,6 +194,21 @@ All the specific operators are listed below, each featured with several capabili
| range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples within a specified range by comparing the values of the specified field | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) |
| topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the values of the specified field | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) |
+## Grouper
+
+| Operator | Tags | Description | Source code | Unit tests |
+|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) |
+| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group all samples to one batched sample. | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) |
+
+## Aggregator
+
+| Operator | Tags | Description | Source code | Unit tests |
+|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Return conclusion of the given entity's attribute from some docs. | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) |
+| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) |
+| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Considering the limitation of input length, nested aggregate contents for each given number of samples. | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) |
+
## Contributing
We welcome contributions of adding new operators. Please refer to [How-to Guide for Developers](DeveloperGuide.md).
diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md
index 011ff5a64..61610a873 100644
--- a/docs/Operators_ZH.md
+++ b/docs/Operators_ZH.md
@@ -11,10 +11,12 @@ Data-Juicer 中的算子分为以下 5 种类型。
| 类型 | 数量 | 描述 |
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 |
-| [ Mapper ]( #mapper ) | 61 | 对数据样本进行编辑和转换 |
+| [ Mapper ]( #mapper ) | 63 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 44 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |
+| [ Grouper ]( #grouper ) | 2 | 将样本分组,每一组组成一个批量样本 |
+| [ Aggregator ]( #aggregator ) | 3 | 对批量样本进行汇总,如得出总结或结论 |
下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。
@@ -71,6 +73,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取出事件和事件相关人物 | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) |
| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造文本的关键词 | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) |
| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取昵称称呼关系 | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) |
+| extract_support_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 为一段总结抽取对应原文 | [code](../data_juicer/ops/mapper/extract_support_text_mapper.py) | [tests](../tests/ops/mapper/test_extract_support_text_mapper.py) |
| fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) |
| generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 根据种子数据,生成新的对话样本。 | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) |
| generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从文本中生成问答对 | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) |
@@ -89,6 +92,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) |
| python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) |
| python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) |
+| relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) |
| remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) |
| remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) |
| remove_header_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档头,例如标题、章节数字/名称等 | [code](../data_juicer/ops/mapper/remove_header_mapper.py) | [tests](../tests/ops/mapper/test_remove_header_mapper.py) |
@@ -189,5 +193,20 @@ Data-Juicer 中的算子分为以下 5 种类型。
| range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出指定范围的 k 个样本 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) |
| topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出前 k 个样本 | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) |
+## Grouper
+
+| 算子 | 标签 | 描述 | 源码 | 单测样例 |
+|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据给定键的值将样本分组,每一组组成一个批量样本。 | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) |
+| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个批量样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) |
+
+## Aggregator
+
+| 算子 | 标签 | 描述 | 源码 | 单测样例 |
+|-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------|
+| entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中总结出给定实体的属性 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) |
+| most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中抽取出与给定实体密切相关的实体,按重要性从高到低排序 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) |
+| nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 考虑到输入长度的限制,对样本中的内容进行嵌套聚合。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) |
+
## 贡献
我们欢迎社区贡献新的算子,具体请参考[开发者指南](DeveloperGuide_ZH.md)。
diff --git a/environments/dev_requires.txt b/environments/dev_requires.txt
index 0ecd058c4..44dd79158 100644
--- a/environments/dev_requires.txt
+++ b/environments/dev_requires.txt
@@ -4,4 +4,4 @@ sphinx
sphinx-autobuild
sphinx_rtd_theme
recommonmark
-wandb
+wandb<=0.19.0
diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt
index 414458edc..71aa0ba38 100644
--- a/environments/minimal_requires.txt
+++ b/environments/minimal_requires.txt
@@ -33,3 +33,4 @@ pydantic>=2.0
Pillow
fastapi[standard]>=0.100
httpx
+wordcloud
diff --git a/environments/sandbox_requires.txt b/environments/sandbox_requires.txt
index 7f1d27a25..6a1791cf8 100644
--- a/environments/sandbox_requires.txt
+++ b/environments/sandbox_requires.txt
@@ -1,5 +1,4 @@
torch>=1.11.0
-wandb
fire
pyspark
# vbench-related
diff --git a/environments/science_requires.txt b/environments/science_requires.txt
index 10ea3b86e..af5d6b362 100644
--- a/environments/science_requires.txt
+++ b/environments/science_requires.txt
@@ -26,3 +26,5 @@ ffmpeg-python
opencv-python
vllm>=0.1.3
rouge
+dashscope
+openai
diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py
index 1cb7c4463..2502fe1e1 100644
--- a/tests/config/test_config_funcs.py
+++ b/tests/config/test_config_funcs.py
@@ -54,6 +54,7 @@ def test_yaml_cfg_file(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
}, 'nested dict load fail, for nonparametric op')
self.assertDictEqual(
@@ -75,6 +76,7 @@ def test_yaml_cfg_file(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
}, 'nested dict load fail, un-expected internal value')
@@ -144,6 +146,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -165,6 +168,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -186,6 +190,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -207,6 +212,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
self.assertDictEqual(
@@ -228,6 +234,7 @@ def test_mixture_cfg(self):
'mem_required': 0,
'turbo': False,
'batch_size': 1000,
+ 'index_key': None,
}
})
diff --git a/tests/ops/Aggregator/__init__.py b/tests/ops/Aggregator/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/Aggregator/test_entity_attribute_aggregator.py
new file mode 100644
index 000000000..1f80da3a3
--- /dev/null
+++ b/tests/ops/Aggregator/test_entity_attribute_aggregator.py
@@ -0,0 +1,139 @@
+import unittest
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.aggregator import EntityAttributeAggregator
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+
+
+@SKIPPED_TESTS.register_module()
+class EntityAttributeAggregatorTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples):
+
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for data in new_dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ self.assertEqual(len(new_dataset), len(samples))
+
+ def test_default_aggregator(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='主要经历'
+ )
+ self._run_helper(op, samples)
+
+ def test_input_output(self):
+ samples = [
+ {
+ 'sub_docs': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='身份背景',
+ input_key='sub_docs',
+ output_key='text'
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='身份背景',
+ max_token_num=200
+ )
+ self._run_helper(op, samples)
+
+ def test_word_limit_num(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='身份背景',
+ word_limit=20
+ )
+ self._run_helper(op, samples)
+
+
+ def test_example_prompt(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ example_prompt=(
+ '- 例如,根据相关文档总结`孙悟空`的`另外身份`,样例如下:\n'
+ '`孙悟空`的`另外身份`总结:\n'
+ '# 孙悟空\n'
+ '## 另外身份\n'
+ '孙行者、齐天大圣、美猴王\n'
+ )
+ op = EntityAttributeAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ attribute='另外身份',
+ example_prompt=example_prompt,
+ word_limit=20
+ )
+ self._run_helper(op, samples)
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py
new file mode 100644
index 000000000..1d8678134
--- /dev/null
+++ b/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py
@@ -0,0 +1,93 @@
+import unittest
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+
+
+@SKIPPED_TESTS.register_module()
+class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples):
+
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for data in new_dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ self.assertEqual(len(new_dataset), len(samples))
+
+ def test_default_aggregator(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+
+ op = MostRelavantEntitiesAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ query_entity_type='人物'
+ )
+ self._run_helper(op, samples)
+
+ def test_input_output(self):
+ samples = [
+ {
+ 'dj_result':{
+ 'events': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ }
+ },
+ ]
+
+ op = MostRelavantEntitiesAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ query_entity_type='人物',
+ input_key='dj_result.events',
+ output_key='dj_result.relavant_roles'
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = MostRelavantEntitiesAggregator(
+ api_model='qwen2.5-72b-instruct',
+ entity='李莲花',
+ query_entity_type='人物',
+ max_token_num=40
+ )
+ self._run_helper(op, samples)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/Aggregator/test_nested_aggregator.py
new file mode 100644
index 000000000..6347652bc
--- /dev/null
+++ b/tests/ops/Aggregator/test_nested_aggregator.py
@@ -0,0 +1,119 @@
+import unittest
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.aggregator import NestedAggregator
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+
+
+@SKIPPED_TESTS.register_module()
+class NestedAggregatorTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples):
+
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for data in new_dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ self.assertEqual(len(new_dataset), len(samples))
+
+ def test_default_aggregator(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct'
+ )
+ self._run_helper(op, samples)
+
+ def test_input_output(self):
+ samples = [
+ {
+ 'sub_docs': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ input_key='sub_docs',
+ output_key='text'
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num_1(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ max_token_num=2
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num_2(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ max_token_num=90
+ )
+ self._run_helper(op, samples)
+
+ def test_max_token_num_3(self):
+ samples = [
+ {
+ 'text': [
+ "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。",
+ "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。",
+ '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。',
+ '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。',
+ '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'
+ ]
+ },
+ ]
+ op = NestedAggregator(
+ api_model='qwen2.5-72b-instruct',
+ max_token_num=200
+ )
+ self._run_helper(op, samples)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/filter/test_video_tagging_from_frames_filter.py b/tests/ops/filter/test_video_tagging_from_frames_filter.py
index bc4f67fb4..4018136ec 100644
--- a/tests/ops/filter/test_video_tagging_from_frames_filter.py
+++ b/tests/ops/filter/test_video_tagging_from_frames_filter.py
@@ -6,6 +6,7 @@
from data_juicer.ops.filter.video_tagging_from_frames_filter import \
VideoTaggingFromFramesFilter
from data_juicer.utils.mm_utils import SpecialTokens
+from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
class VideoTaggingFromFramesFilterTest(DataJuicerTestCaseBase):
@@ -21,8 +22,11 @@ def _run_video_tagging_from_frames_filter(self,
target_list,
num_proc=1):
dataset = Dataset.from_list(source_list)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
+ if Fields.meta not in dataset.features:
+ dataset = dataset.add_column(name=Fields.meta,
+ column=[{}] * dataset.num_rows)
+ dataset = dataset.map(op.compute_stats, num_proc=num_proc)
+ dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.select_columns(column_names=['text', 'videos'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
diff --git a/tests/ops/grouper/__init__.py b/tests/ops/grouper/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/ops/grouper/test_key_value_grouper.py b/tests/ops/grouper/test_key_value_grouper.py
new file mode 100644
index 000000000..1ac186423
--- /dev/null
+++ b/tests/ops/grouper/test_key_value_grouper.py
@@ -0,0 +1,54 @@
+import unittest
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.grouper.key_value_grouper import KeyValueGrouper
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+
+
+class KeyValueGrouperTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples, target):
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for batched_sample in new_dataset:
+ lang = batched_sample['meta'][0]['language']
+ self.assertEqual(batched_sample['text'], target[lang])
+
+ def test_key_value_grouper(self):
+
+ source = [
+ {
+ 'text': "Today is Sunday and it's a happy day!",
+ 'meta': {
+ 'language': 'en'
+ }
+ },
+ {
+ 'text': "Welcome to Alibaba.",
+ 'meta': {
+ 'language': 'en'
+ }
+ },
+ {
+ 'text': '欢迎来到阿里巴巴!',
+ 'meta': {
+ 'language': 'zh'
+ }
+ },
+ ]
+ target = {
+ 'en':[
+ "Today is Sunday and it's a happy day!",
+ "Welcome to Alibaba."
+ ],
+ 'zh':[
+ '欢迎来到阿里巴巴!'
+ ]
+ }
+
+ op = KeyValueGrouper(['meta.language'])
+ self._run_helper(op, source, target)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/grouper/test_naive_grouper.py b/tests/ops/grouper/test_naive_grouper.py
new file mode 100644
index 000000000..4e69a8ba2
--- /dev/null
+++ b/tests/ops/grouper/test_naive_grouper.py
@@ -0,0 +1,47 @@
+import unittest
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.grouper.naive_grouper import NaiveGrouper
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+
+
+class NaiveGrouperTest(DataJuicerTestCaseBase):
+
+ def _run_helper(self, op, samples, target):
+ dataset = Dataset.from_list(samples)
+ new_dataset = op.run(dataset)
+
+ for d, t in zip(new_dataset, target):
+ self.assertEqual(d['text'], t['text'])
+
+ def test_naive_group(self):
+
+ source = [
+ {
+ 'text': "Today is Sunday and it's a happy day!"
+ },
+ {
+ 'text':
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.'
+ },
+ {
+ 'text': '欢迎来到阿里巴巴!'
+ },
+ ]
+ target = [
+ {
+ 'text':[
+ "Today is Sunday and it's a happy day!",
+ "Sur la plateforme MT4, plusieurs manières d'accéder à \n"
+ 'ces fonctionnalités sont conçues simultanément.',
+ '欢迎来到阿里巴巴!'
+ ]
+ }
+ ]
+
+ op = NaiveGrouper()
+ self._run_helper(op, source, target)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py
index 96f186d29..f15b4ca3f 100644
--- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py
+++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py
@@ -21,9 +21,9 @@ def _run_op(self, api_model, response_path=None):
query_attributes = ["语言风格", "角色性格"]
op = ExtractEntityAttributeMapper(
+ api_model=api_model,
query_entities=query_entities,
- query_attributes=query_attributes,
- api_model=api_model,
+ query_attributes=query_attributes,
response_path=response_path)
raw_text = """△笛飞声独自坐在莲花楼屋顶上。李莲花边走边悠闲地给马喂草。方多病则走在一侧,却总不时带着怀疑地盯向楼顶的笛飞声。
@@ -49,9 +49,14 @@ def _run_op(self, api_model, response_path=None):
dataset = Dataset.from_list(samples)
dataset = dataset.map(op.process, batch_size=1)
for sample in dataset:
- logger.info(f'{sample[Fields.main_entity]} {sample[Fields.attribute]}: {sample[Fields.attribute_description]}')
- self.assertNotEqual(sample[Fields.attribute_description], '')
- self.assertNotEqual(len(sample[Fields.attribute_support_text]), 0)
+ ents = sample[Fields.main_entities]
+ attrs = sample[Fields.attributes]
+ descs = sample[Fields.attribute_descriptions]
+ sups = sample[Fields.attribute_support_texts]
+ for ent, attr, desc, sup in zip(ents, attrs, descs, sups):
+ logger.info(f'{ent} {attr}: {desc}')
+ self.assertNotEqual(desc, '')
+ self.assertNotEqual(len(sup), 0)
def test(self):
# before runing this test, set below environment variables:
diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py
index 1652c8db2..aba40d73e 100644
--- a/tests/ops/mapper/test_extract_event_mapper.py
+++ b/tests/ops/mapper/test_extract_event_mapper.py
@@ -18,7 +18,8 @@ class ExtractEventMapperTest(DataJuicerTestCaseBase):
def _run_op(self, api_model, response_path=None):
op = ExtractEventMapper(api_model=api_model,
- response_path=response_path)
+ response_path=response_path,
+ index_key='chunk_id')
raw_text = """△芩婆走到中间,看着众人。
芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。
@@ -57,9 +58,11 @@ def _run_op(self, api_model, response_path=None):
}]
dataset = Dataset.from_list(samples)
- dataset = dataset.map(op.process, batch_size=2)
+ dataset = op.run(dataset)
self.assertNotEqual(len(dataset), 0)
for sample in dataset:
+ logger.info(f"chunk_id: {sample['chunk_id']}")
+ self.assertEqual(sample['chunk_id'], 0)
logger.info(f"event: {sample[Fields.event_description]}")
self.assertNotEqual(sample[Fields.event_description], '')
logger.info(f"characters: {sample[Fields.relevant_characters]}")
diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py
new file mode 100644
index 000000000..0445d2526
--- /dev/null
+++ b/tests/ops/mapper/test_extract_support_text_mapper.py
@@ -0,0 +1,80 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.extract_support_text_mapper import ExtractSupportTextMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields
+from data_juicer.utils.common_utils import nested_access
+
+# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class ExtractSupportTextMapperTest(DataJuicerTestCaseBase):
+
+
+ def _run_op(self, api_model):
+
+ summary_key = 'data.event'
+ support_text_key = 'data.support_text'
+ op = ExtractSupportTextMapper(api_model=api_model,
+ summary_key=summary_key,
+ support_text_key=support_text_key)
+
+ raw_text = """△芩婆走到中间,看着众人。
+芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。
+封磬震惊:二子?不是只有一个儿子吗?
+芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。
+李莲花似回忆起了什么:李相显......
+芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐......
+闪回/
+李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵......
+△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。
+△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去!
+△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。
+乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。
+乞丐甲野蛮地抢:给我拿来!
+小单孤刀:放开他!
+△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了!
+△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地!
+/闪回结束
+△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄!
+芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少?
+△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。
+芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。
+△单孤刀呆住。
+芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。
+△李莲花得知真相,闭目叹息。
+△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。
+封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了!
+△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上......
+△封磬颓然地跪倒下来。
+△李莲花对眼前的一切有些意外、无措。
+笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。
+△笛飞声不禁冷笑一下。
+"""
+ event = "李相显托付单孤刀。"
+ samples = [{
+ 'text': raw_text,
+ 'data':{
+ 'event': event
+ }
+ }]
+
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ sample = dataset[0]
+ logger.info(f"support_text: \n{nested_access(sample, support_text_key)}")
+
+ def test(self):
+ # before runing this test, set below environment variables:
+ # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/
+ # export OPENAI_API_KEY=your_dashscope_key
+ self._run_op('qwen2.5-72b-instruct')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_image_tagging_mapper.py b/tests/ops/mapper/test_image_tagging_mapper.py
index 9ec3e4d22..d2bbddec2 100644
--- a/tests/ops/mapper/test_image_tagging_mapper.py
+++ b/tests/ops/mapper/test_image_tagging_mapper.py
@@ -24,6 +24,9 @@ def _run_image_tagging_mapper(self,
target_list,
num_proc=1):
dataset = Dataset.from_list(source_list)
+ if Fields.meta not in dataset.features:
+ dataset = dataset.add_column(name=Fields.meta,
+ column=[{}] * dataset.num_rows)
dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
@@ -38,23 +41,26 @@ def test(self):
}]
tgt_list = [{
'images': [self.img1_path],
- Fields.image_tags: [[
- 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling',
- 'chair', 'pillar', 'comfort', 'side table', 'floor',
- 'hardwood floor', 'headboard', 'linen', 'mattress',
- 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp',
- 'stool', 'white', 'window', 'wood floor']],
+ Fields.meta: {
+ Fields.image_tags: [[
+ 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling',
+ 'chair', 'pillar', 'comfort', 'side table', 'floor',
+ 'hardwood floor', 'headboard', 'linen', 'mattress',
+ 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp',
+ 'stool', 'white', 'window', 'wood floor']]},
}, {
'images': [self.img2_path],
- Fields.image_tags: [[
- 'advertisement', 'back', 'bus', 'car', 'city bus',
- 'city street', 'curb', 'decker bus', 'drive', 'license plate',
- 'road', 'street scene', 'tour bus', 'travel', 'white']],
+ Fields.meta: {
+ Fields.image_tags: [[
+ 'advertisement', 'back', 'bus', 'car', 'city bus',
+ 'city street', 'curb', 'decker bus', 'drive', 'license plate',
+ 'road', 'street scene', 'tour bus', 'travel', 'white']]},
}, {
'images': [self.img3_path],
- Fields.image_tags: [[
- 'alley', 'black', 'building', 'catch', 'person', 'pavement',
- 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']],
+ Fields.meta: {
+ Fields.image_tags: [[
+ 'alley', 'black', 'building', 'catch', 'person', 'pavement',
+ 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]},
}]
op = ImageTaggingMapper()
self._run_image_tagging_mapper(op, ds_list, tgt_list)
@@ -67,13 +73,15 @@ def test_no_images(self):
}]
tgt_list = [{
'images': [],
- Fields.image_tags: [[]],
+ Fields.meta: {
+ Fields.image_tags: [[]]},
}, {
'images': [self.img2_path],
- Fields.image_tags: [[
- 'advertisement', 'back', 'bus', 'car', 'city bus',
- 'city street', 'curb', 'decker bus', 'drive', 'license plate',
- 'road', 'street scene', 'tour bus', 'travel', 'white']],
+ Fields.meta: {
+ Fields.image_tags: [[
+ 'advertisement', 'back', 'bus', 'car', 'city bus',
+ 'city street', 'curb', 'decker bus', 'drive', 'license plate',
+ 'road', 'street scene', 'tour bus', 'travel', 'white']]},
}]
op = ImageTaggingMapper()
self._run_image_tagging_mapper(op, ds_list, tgt_list)
@@ -90,23 +98,26 @@ def test_specified_tag_field_name(self):
}]
tgt_list = [{
'images': [self.img1_path],
- tag_field_name: [[
- 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling',
- 'chair', 'pillar', 'comfort', 'side table', 'floor',
- 'hardwood floor', 'headboard', 'linen', 'mattress',
- 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp',
- 'stool', 'white', 'window', 'wood floor']],
+ Fields.meta: {
+ tag_field_name: [[
+ 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling',
+ 'chair', 'pillar', 'comfort', 'side table', 'floor',
+ 'hardwood floor', 'headboard', 'linen', 'mattress',
+ 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp',
+ 'stool', 'white', 'window', 'wood floor']]},
}, {
'images': [self.img2_path],
- tag_field_name: [[
- 'advertisement', 'back', 'bus', 'car', 'city bus',
- 'city street', 'curb', 'decker bus', 'drive', 'license plate',
- 'road', 'street scene', 'tour bus', 'travel', 'white']],
+ Fields.meta: {
+ tag_field_name: [[
+ 'advertisement', 'back', 'bus', 'car', 'city bus',
+ 'city street', 'curb', 'decker bus', 'drive', 'license plate',
+ 'road', 'street scene', 'tour bus', 'travel', 'white']]},
}, {
'images': [self.img3_path],
- tag_field_name: [[
- 'alley', 'black', 'building', 'catch', 'person', 'pavement',
- 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']],
+ Fields.meta: {
+ tag_field_name: [[
+ 'alley', 'black', 'building', 'catch', 'person', 'pavement',
+ 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]},
}]
op = ImageTaggingMapper(tag_field_name=tag_field_name)
self._run_image_tagging_mapper(op, ds_list, tgt_list)
@@ -126,23 +137,26 @@ def test_multi_process(self):
}]
tgt_list = [{
'images': [self.img1_path],
- Fields.image_tags: [[
- 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling',
- 'chair', 'pillar', 'comfort', 'side table', 'floor',
- 'hardwood floor', 'headboard', 'linen', 'mattress',
- 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp',
- 'stool', 'white', 'window', 'wood floor']],
+ Fields.meta: {
+ Fields.image_tags: [[
+ 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling',
+ 'chair', 'pillar', 'comfort', 'side table', 'floor',
+ 'hardwood floor', 'headboard', 'linen', 'mattress',
+ 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp',
+ 'stool', 'white', 'window', 'wood floor']]},
}, {
'images': [self.img2_path],
- Fields.image_tags: [[
- 'advertisement', 'back', 'bus', 'car', 'city bus',
- 'city street', 'curb', 'decker bus', 'drive', 'license plate',
- 'road', 'street scene', 'tour bus', 'travel', 'white']],
+ Fields.meta: {
+ Fields.image_tags: [[
+ 'advertisement', 'back', 'bus', 'car', 'city bus',
+ 'city street', 'curb', 'decker bus', 'drive', 'license plate',
+ 'road', 'street scene', 'tour bus', 'travel', 'white']]},
}, {
'images': [self.img3_path],
- Fields.image_tags: [[
- 'alley', 'black', 'building', 'catch', 'person', 'pavement',
- 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']],
+ Fields.meta: {
+ Fields.image_tags: [[
+ 'alley', 'black', 'building', 'catch', 'person', 'pavement',
+ 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]},
}]
op = ImageTaggingMapper()
self._run_image_tagging_mapper(op,
diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py
new file mode 100644
index 000000000..d730cb79f
--- /dev/null
+++ b/tests/ops/mapper/test_relation_identity_mapper.py
@@ -0,0 +1,58 @@
+import unittest
+import json
+
+from loguru import logger
+
+from data_juicer.core.data import NestedDataset as Dataset
+from data_juicer.ops.mapper.relation_identity_mapper import RelationIdentityMapper
+from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
+ DataJuicerTestCaseBase)
+from data_juicer.utils.constant import Fields
+
+# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
+# These tests have been tested locally.
+@SKIPPED_TESTS.register_module()
+class RelationIdentityMapperTest(DataJuicerTestCaseBase):
+
+
+ def _run_op(self, api_model, response_path=None):
+
+ op = RelationIdentityMapper(api_model=api_model,
+ source_entity="李莲花",
+ target_entity="方多病",
+ response_path=response_path)
+
+ raw_text = """李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。
+在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。
+李莲花与方多病合作,解决了灵山派掌门王青山的假死案,揭露了朴管家的罪行。
+随后,他与方多病和笛飞声一起调查了玉秋霜的死亡案,最终揭露了玉红烛的阴谋。在朴锄山,李莲花和方多病调查了七具无头尸事件,发现男童的真实身份是笛飞声。
+李莲花利用飞猿爪偷走男童手中的观音垂泪,导致笛飞声恢复内力,但李莲花巧妙逃脱。李莲花与方多病继续合作,调查了少师剑被盗案,揭露了静仁和尚的阴谋。
+在采莲庄,他解决了新娘溺水案,找到了狮魂的线索,并在南门园圃挖出单孤刀的药棺。在玉楼春的案件中,李莲花和方多病揭露了玉楼春的阴谋,救出了被拐的清儿。
+在石寿村,他们发现了柔肠玉酿的秘密,并救出了被控制的武林高手。李莲花与方多病在白水园设下机关,救出方多病的母亲何晓惠,并最终在云隐山找到了治疗碧茶之毒的方法。
+在天机山庄,他揭露了单孤刀的野心,救出了被控制的大臣。在皇宫,李莲花与方多病揭露了魔僧和单孤刀的阴谋,成功解救了皇帝。
+最终,李莲花在东海之滨与笛飞声的决斗中未出现,留下一封信,表示自己已无法赴约。
+一年后,方多病在东海畔的柯厝村找到了李莲花,此时的李莲花双目失明,右手残废,但心态平和,过着简单的生活。
+方多病 (称呼:方小宝、方大少爷)百川院刑探,单孤刀之子,李相夷的徒弟。方多病通过百川院的考核,成为刑探,并在百川院内展示了自己是李相夷的弟子,获得暂时的录用。
+他接到任务前往嘉州调查金鸳盟的余孽,期间与李莲花相识并合作破案。方多病在调查过程中逐渐了解到自己的身世,发现自己的生父是单孤刀。
+他与李莲花、笛飞声等人多次合作,共同对抗金鸳盟和单孤刀的阴谋。方多病在一系列案件中展现了出色的推理能力和武艺,逐渐成长为一名优秀的刑探。
+最终,方多病在天机山庄和皇宫的斗争中发挥了关键作用,帮助李莲花等人挫败了单孤刀的野心。在李莲花中毒后,方多病决心为他寻找解毒之法,展现了深厚的友情。
+"""
+ samples = [{
+ 'text': raw_text,
+ }]
+
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+ for data in dataset:
+ for k in data:
+ logger.info(f"{k}: {data[k]}")
+
+ def test(self):
+ # before runing this test, set below environment variables:
+ # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
+ # export OPENAI_API_KEY=your_key
+ self._run_op('qwen2.5-72b-instruct')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py
index 8004d9ede..0c0a70db3 100644
--- a/tests/ops/mapper/test_text_chunk_mapper.py
+++ b/tests/ops/mapper/test_text_chunk_mapper.py
@@ -2,9 +2,10 @@
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.text_chunk_mapper import TextChunkMapper
-from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
+from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS
+@SKIPPED_TESTS.register_module()
class TextChunkMapperTest(DataJuicerTestCaseBase):
def _run_helper(self, op, samples, target):
diff --git a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py
index 8bbf05933..00a376170 100644
--- a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py
+++ b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py
@@ -31,8 +31,11 @@ def _run_video_tagging_from_audio_mapper(self,
tag_field_name=Fields.video_audio_tags,
num_proc=1):
dataset = Dataset.from_list(source_list)
+ if Fields.meta not in dataset.features:
+ dataset = dataset.add_column(name=Fields.meta,
+ column=[{}] * dataset.num_rows)
dataset = dataset.map(op.process, num_proc=num_proc)
- res_list = dataset.select_columns([tag_field_name])[tag_field_name]
+ res_list = dataset.flatten().select_columns([f'{Fields.meta}.{tag_field_name}'])[f'{Fields.meta}.{tag_field_name}']
self.assertEqual(res_list, target_list)
def test(self):
diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py
index 4484df754..31fc04c3b 100644
--- a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py
+++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py
@@ -25,6 +25,9 @@ def _run_video_tagging_from_frames_mapper(self,
target_list,
num_proc=1):
dataset = Dataset.from_list(source_list)
+ if Fields.meta not in dataset.features:
+ dataset = dataset.add_column(name=Fields.meta,
+ column=[{}] * dataset.num_rows)
dataset = dataset.map(op.process, num_proc=num_proc)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
@@ -46,30 +49,33 @@ def test(self):
'text':
f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
'videos': [self.vid1_path],
- Fields.video_frame_tags: [[
- 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
- 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
- 'sky'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
+ 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
+ 'sky'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
'videos': [self.vid2_path],
- Fields.video_frame_tags: [[
- 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
- 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
- 'ball', 'person'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
+ 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
+ 'ball', 'person'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid3_path],
- Fields.video_frame_tags: [[
- 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
- 'conversation', 'round table', 'closet', 'computer', 'girl',
- 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
- 'selfie', 'stand'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
+ 'conversation', 'round table', 'closet', 'computer', 'girl',
+ 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
+ 'selfie', 'stand'
+ ]]}
}]
op = VideoTaggingFromFramesMapper()
self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list)
@@ -87,16 +93,18 @@ def test_no_video(self):
'text':
f'白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
'videos': [],
- Fields.video_frame_tags: [[]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[]]}
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
'videos': [self.vid2_path],
- Fields.video_frame_tags: [[
- 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
- 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
- 'ball', 'person'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
+ 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
+ 'ball', 'person'
+ ]]}
}]
op = VideoTaggingFromFramesMapper()
self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list)
@@ -120,30 +128,33 @@ def test_specified_tag_field_name(self):
'text':
f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
'videos': [self.vid1_path],
- tag_field_name: [[
- 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
- 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
- 'sky'
- ]]
+ Fields.meta: {
+ tag_field_name: [[
+ 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
+ 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
+ 'sky'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
'videos': [self.vid2_path],
- tag_field_name: [[
- 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
- 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
- 'ball', 'person'
- ]]
+ Fields.meta: {
+ tag_field_name: [[
+ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
+ 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
+ 'ball', 'person'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid3_path],
- tag_field_name: [[
- 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
- 'conversation', 'round table', 'closet', 'computer', 'girl',
- 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
- 'selfie', 'stand'
- ]]
+ Fields.meta: {
+ tag_field_name: [[
+ 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
+ 'conversation', 'round table', 'closet', 'computer', 'girl',
+ 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
+ 'selfie', 'stand'
+ ]]}
}]
op = VideoTaggingFromFramesMapper(tag_field_name=tag_field_name)
self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list)
@@ -165,30 +176,33 @@ def test_uniform(self):
'text':
f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
'videos': [self.vid1_path],
- Fields.video_frame_tags: [[
- 'cartoon', 'animal', 'anime', 'game', 'screenshot',
- 'video game', 'cartoon character', 'robe', 'ray', 'text',
- 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'cartoon', 'animal', 'anime', 'game', 'screenshot',
+ 'video game', 'cartoon character', 'robe', 'ray', 'text',
+ 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']]}
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
'videos': [self.vid2_path],
- Fields.video_frame_tags: [[
- 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
- 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy',
- 'tennis racket', 'blind', 'game controller', 'remote', 'stand',
- 'video game', 'Wii controller', 'play', 'baseball uniform',
- 'toy', 'green']]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
+ 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy',
+ 'tennis racket', 'blind', 'game controller', 'remote', 'stand',
+ 'video game', 'Wii controller', 'play', 'baseball uniform',
+ 'toy', 'green']]}
}, {
'text':
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid3_path],
- Fields.video_frame_tags: [[
- 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person',
- 'round table', 'computer', 'girl', 'man', 'closet', 'laptop',
- 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand',
- 'point'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person',
+ 'round table', 'computer', 'girl', 'man', 'closet', 'laptop',
+ 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand',
+ 'point'
+ ]]}
}]
op = VideoTaggingFromFramesMapper(frame_sampling_method='uniform',
frame_num=10)
@@ -216,30 +230,33 @@ def test_multi_process(self):
'text':
f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
'videos': [self.vid1_path],
- Fields.video_frame_tags: [[
- 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
- 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
- 'sky'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
+ 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
+ 'sky'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
'videos': [self.vid2_path],
- Fields.video_frame_tags: [[
- 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
- 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
- 'ball', 'person'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
+ 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
+ 'ball', 'person'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid3_path],
- Fields.video_frame_tags: [[
- 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
- 'conversation', 'round table', 'closet', 'computer', 'girl',
- 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
- 'selfie', 'stand'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
+ 'conversation', 'round table', 'closet', 'computer', 'girl',
+ 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
+ 'selfie', 'stand'
+ ]]}
}]
op = VideoTaggingFromFramesMapper()
self._run_video_tagging_from_frames_mapper(op,
@@ -268,44 +285,47 @@ def test_multi_chunk(self):
'text':
f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。',
'videos': [self.vid1_path, self.vid2_path],
- Fields.video_frame_tags:
- [[
- 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
- 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
- 'sky'
- ], [
- 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
- 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
- 'ball', 'person'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags:
+ [[
+ 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
+ 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
+ 'sky'
+ ], [
+ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
+ 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
+ 'ball', 'person'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid2_path, self.vid3_path],
- Fields.video_frame_tags: [[
- 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
- 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
- 'ball', 'person'
- ], [
- 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
- 'conversation', 'round table', 'closet', 'computer', 'girl',
- 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
- 'selfie', 'stand'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy',
+ 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket',
+ 'ball', 'person'
+ ], [
+ 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
+ 'conversation', 'round table', 'closet', 'computer', 'girl',
+ 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
+ 'selfie', 'stand'
+ ]]}
}, {
'text':
f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid1_path, self.vid3_path],
- Fields.video_frame_tags: [[
- 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
- 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
- 'sky'
- ], [
- 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
- 'conversation', 'round table', 'closet', 'computer', 'girl',
- 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
- 'selfie', 'stand'
- ]]
+ Fields.meta: {
+ Fields.video_frame_tags: [[
+ 'animal', 'ray', 'text', 'writing', 'yellow', 'game',
+ 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe',
+ 'sky'
+ ], [
+ 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf',
+ 'conversation', 'round table', 'closet', 'computer', 'girl',
+ 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand',
+ 'selfie', 'stand'
+ ]]}
}]
op = VideoTaggingFromFramesMapper()
self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list)