From 4933f5d0f6e9cce9f8bbb7cf82879bee05cb6fd8 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Fri, 15 Nov 2024 15:30:08 +0800 Subject: [PATCH 01/22] add ray minhash deduplicator --- configs/config_all.yaml | 11 + data_juicer/core/ray_data.py | 4 +- data_juicer/ops/deduplicator/__init__.py | 3 +- .../ray_redis_minhash_deduplicator.py | 365 ++++++++++++++++++ data_juicer/utils/constant.py | 1 + docs/Operators.md | 1 + docs/Operators_ZH.md | 1 + environments/dist_requires.txt | 2 +- 8 files changed, 385 insertions(+), 3 deletions(-) create mode 100644 data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 90fc18875..d251d24a2 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -643,6 +643,17 @@ process: redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port lowercase: false # whether to convert text to lower case ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations + - ray_redis_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm + redis_address: 'redis://localhost:6379' # the address of the redis instance + tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] + window_size: 5 # window size of shingling + num_permutations: 256 # number of permutations in minhash computing + jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication + num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives + num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm + lowercase: true # whether to convert text to lower case + ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. + tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization. # Selector ops - frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 0c131561e..30848bcf3 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -5,7 +5,7 @@ from data_juicer import cuda_device_count from data_juicer.core.data import DJDataset -from data_juicer.ops import Filter, Mapper +from data_juicer.ops import Deduplicator, Filter, Mapper from data_juicer.utils.constant import Fields from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.process_utils import calculate_np @@ -123,6 +123,8 @@ def _run_single_op(self, op): self.data.write_json(op.stats_export_path, force_ascii=False) self.data = self.data.filter(op.process) + elif isinstance(op, Deduplicator): + self.data = op.run(self.data) else: logger.error( 'Ray executor only support Filter and Mapper OPs for now') diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 56aec0e10..c368e196d 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -5,6 +5,7 @@ from .ray_basic_deduplicator import RayBasicDeduplicator from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator +from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator @@ -12,5 +13,5 @@ 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', - 'VideoDeduplicator' + 'RayRedisMinhashDeduplicator', 'VideoDeduplicator' ] diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py new file mode 100644 index 000000000..8883feb4b --- /dev/null +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -0,0 +1,365 @@ +import random +import time +from collections import defaultdict +from typing import Optional + +import numpy as np +import pandas as pd +import pyarrow as pa +import regex +from loguru import logger +from pydantic import Field, PositiveInt +from typing_extensions import Annotated + +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import prepare_sentencepiece_model + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace +from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, + optimal_param, sha1_hash32) + +redis = LazyLoader('redis', 'redis') + + +def retry_on_busy(func): + + def wrapper(*args, **kwargs): + max_retries = 10 + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + if 'BUSY' in str(e) and attempt < max_retries - 1: + time.sleep(random.uniform(0.1, 0.3) * (2**attempt)) + else: + raise + + return wrapper + + +class RedisUnionFind: + + def __init__(self, + prefix: str, + redis_address: str = 'redis://localhost:6379'): + self.prefix = prefix + self.redis_address = redis_address + self.redis = redis.from_url(url=redis_address) + self.set_key = f'{prefix}_UF_SET' + self.rank_key = f'{prefix}_UF_RANK' + self.incur_id_key = f'{prefix}_UF_INCURID' + + # Lua scripts + self.union_script = self.redis.register_script(""" + local function find(x) + local path = {} + while true do + local parent = redis.call('HGET', KEYS[1], x) + if not parent then + return nil + end + if parent == x then + break + end + table.insert(path, x) + x = parent + end + for _, node in ipairs(path) do + redis.call('HSET', KEYS[1], node, x) + end + return x + end + + local root_x = find(ARGV[1]) + local root_y = find(ARGV[2]) + if not root_x then + redis.call('HSET', KEYS[1], ARGV[1], ARGV[1]) + redis.call('HSET', KEYS[2], ARGV[1], 0) + root_x = ARGV[1] + end + if not root_y then + redis.call('HSET', KEYS[1], ARGV[2], ARGV[2]) + redis.call('HSET', KEYS[2], ARGV[2], 0) + root_y = ARGV[2] + end + if root_x == root_y then + return root_x + end + local rank_x = tonumber(redis.call('HGET', KEYS[2], root_x)) + local rank_y = tonumber(redis.call('HGET', KEYS[2], root_y)) + if rank_x < rank_y then + redis.call('HSET', KEYS[1], root_x, root_y) + return root_y + elseif rank_x > rank_y then + redis.call('HSET', KEYS[1], root_y, root_x) + return root_x + else + redis.call('HSET', KEYS[1], root_y, root_x) + redis.call('HINCRBY', KEYS[2], root_x, 1) + return root_x + end + """) + + def get_uid(self): + return int(self.redis.incr(self.incur_id_key)) + + @retry_on_busy + def union(self, x, y): + return self.union_script(keys=[self.set_key, self.rank_key], + args=[x, y]) + + def is_ancestor(self, x): + ancestor = self.redis.hget(self.set_key, x) + return ancestor is None or int(ancestor) == x + + def __reduce__(self): + return (RedisUnionFind, (self.prefix, self.redis_address)) + + def clean(self): + self.redis.delete(self.set_key, self.rank_key, self.incur_id_key) + + +OP_NAME = 'ray_redis_minhash_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +class RayRedisMinhashDeduplicator(Deduplicator): + """ + A basic exact matching deduplicator for RAY. + Although its functionality is deduplication, + it is implemented as Filter sub-class. + """ + + def __init__( + self, + tokenization: str = 'space', + window_size: PositiveInt = 5, + lowercase: bool = True, + ignore_pattern: Optional[str] = None, + num_permutations: PositiveInt = 256, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, + redis_address: str = 'redis://localhost:6379', + *args, + **kwargs, + ): + """ + Initialization method. + + :param tokenization: tokenization method for sample texts. It + should be one of [space, punctuation, character, + sentencepiece]. For English-like languages, we recommend + to use 'space', for Chinese-like languages, we recommend + to use 'character', and for multiple languages, we recommend + to use 'sentencepiece'. If using 'sentencepiece', please + provided the model path in the 'tokenizer_model' field. + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing minhash + :param num_permutations: number of permutations in minhash + computing + :param jaccard_threshold: the min jaccard similarity threshold + in near-duplicate detection. When the jaccard similarity of + two sample texts is >= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + :param tokenizer_model: path for the sentencepiece model, used for + sentencepiece tokenization. + :param redis_address: address of your redis instance, e.g. + 'redis://localhost:6379' + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + self.punctuation_pattern = regex.compile(r'\p{P}') + + if self.tokenization == 'sentencepiece': + if tokenizer_model is None: + raise ValueError("To use 'sentencepiece' tokenization, " + "'tokenizer_model' is required.") + self.tokenizer = prepare_sentencepiece_model(tokenizer_model) + else: + self.tokenizer = None + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + + def run(self, dataset): + from ray.data.aggregate import AggregateFn + + union_find = RedisUnionFind(self.redis_address) + + def add_uid_column(table: pa.Table) -> pa.Table: + new_column_data = [union_find.get_uid() for _ in range(len(table))] + new_table = table.append_column(HashKeys.uid, [new_column_data]) + return new_table + + def calculate_minhash(table: pa.Table) -> pa.Table: + ids = table.column(HashKeys.uid).to_pandas() + texts = table.column(self.text_key).to_pandas() + hashes = texts.apply(lambda x: self.compute_minhash(x)) + hashes = pa.Array.from_pandas(hashes).flatten() + + repeated_ids = pa.Array.from_pandas(ids.repeat(self.num_bands)) + + return pa.Table.from_arrays([repeated_ids, hashes], + names=[HashKeys.uid, HashKeys.minhash]) + + def _is_null(r): + return pd.isnull(r) + + class UnionFn(AggregateFn): + + def __init__(self, union_find): + union_find = union_find + + def accumulate(cur, row): + if _is_null(row): + return cur + elif _is_null(cur): + return row[HashKeys.uid] + else: + root = union_find.union(row[HashKeys.uid], cur) + return int(root) + + def merge(a, b): + if _is_null(a): + return b + if _is_null(b): + return a + root = union_find.union(a, b) + return int(root) + + super().__init__( + init=lambda k: None, + accumulate_row=accumulate, + merge=merge, + name='union', + ) + + def filter_with_union_find(table: pa.Table) -> pa.Table: + uids = table.column(HashKeys.uid).to_pandas() + mask = pa.Array.from_pandas( + uids.apply(lambda x: union_find.is_ancestor(x))) + return table.filter(mask) + + dataset_with_id = dataset.map_batches( + add_uid_column, batch_format='pyarrow').materialize() + dataset_with_id.map_batches(calculate_minhash, + batch_format='pyarrow').groupby( + HashKeys.minhash).aggregate( + UnionFn(union_find)).materialize() + result = dataset_with_id.map_batches(filter_with_union_find, + batch_format='pyarrow') + logger.info(f'Keep {result.count()} samples after MinHash dedup.') + union_find.clean() + return result + + def compute_minhash(self, text): + """ + Compute minhash values for the sample. + + :param sample: input sample + :return: sample with minhash value. + """ + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = set() + if self.tokenization == 'character': + tokens = { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + tokens = self.tokenizer.encode(text, out_type=str) + tokens = { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + # compute minhash value + hv = np.array([sha1_hash32(token) for token in tokens], + dtype=np.uint64) + phv = np.bitwise_and( + ((hv * np.tile(self.perm_a, + (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, + MAX_HASH) + hash_values = np.vstack([ + phv, + np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH + ]).min(axis=0) + return [ + bytes(hash_values[start:end].byteswap().data) + + start.to_bytes(8, byteorder='little') + for start, end in self.hash_ranges + # groupby minhash||brand_id + ] diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index ab88035b9..1fe8d7002 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -216,6 +216,7 @@ class StatsKeys(object, metaclass=StatsKeysMeta): class HashKeys(object): + uid = DEFAULT_PREFIX + 'uid' hash = DEFAULT_PREFIX + 'hash' minhash = DEFAULT_PREFIX + 'minhash' simhash = DEFAULT_PREFIX + 'simhash' diff --git a/docs/Operators.md b/docs/Operators.md index 7717ba434..f1a20c9ef 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -173,6 +173,7 @@ All the specific operators are listed below, each featured with several capabili | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using SimHash | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | +| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray and Redis | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level by comparing MD5 hash on ray | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents on ray | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents on ray | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 81aee2149..b1194f250 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -172,6 +172,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 SimHash 在文档级别对样本去重 | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | +| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式 | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较 MD5 哈希值在文档级别对样本去重,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | diff --git a/environments/dist_requires.txt b/environments/dist_requires.txt index 4060a654f..b6ab28d06 100644 --- a/environments/dist_requires.txt +++ b/environments/dist_requires.txt @@ -1,2 +1,2 @@ -ray==2.31.0 +ray<=2.38.0 redis>=5.0.0 From 6b79f9004b82fe43c231e97d1d9dc4f2c00b6182 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Fri, 15 Nov 2024 16:24:03 +0800 Subject: [PATCH 02/22] fix redis prefix --- .../ops/deduplicator/ray_redis_minhash_deduplicator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py index 8883feb4b..4a414fa00 100644 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -1,5 +1,6 @@ import random import time +import uuid from collections import defaultdict from typing import Optional @@ -233,11 +234,12 @@ def __init__( ) for _ in range(self.num_permutation)], dtype=np.uint64, ).T + self.redis_address = redis_address def run(self, dataset): from ray.data.aggregate import AggregateFn - union_find = RedisUnionFind(self.redis_address) + union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], redis_address=self.redis_address) def add_uid_column(table: pa.Table) -> pa.Table: new_column_data = [union_find.get_uid() for _ in range(len(table))] From 991e2906615c8fdd67179f41dd5ae9b808ba23e6 Mon Sep 17 00:00:00 2001 From: panxuchen Date: Fri, 15 Nov 2024 16:28:18 +0800 Subject: [PATCH 03/22] fix redis prefix --- data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py index 4a414fa00..72c250af1 100644 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -239,7 +239,8 @@ def __init__( def run(self, dataset): from ray.data.aggregate import AggregateFn - union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], redis_address=self.redis_address) + union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], + redis_address=self.redis_address) def add_uid_column(table: pa.Table) -> pa.Table: new_column_data = [union_find.get_uid() for _ in range(len(table))] From 31338d147dd477b0fb1980ff899c04034d8d0e46 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 19 Nov 2024 11:20:34 +0800 Subject: [PATCH 04/22] add ray_minhash_deduplicator --- data_juicer/ops/deduplicator/__init__.py | 3 +- .../deduplicator/ray_minhash_deduplicator.py | 371 ++++++++++++++++++ .../ray_redis_minhash_deduplicator.py | 40 +- 3 files changed, 399 insertions(+), 15 deletions(-) create mode 100644 data_juicer/ops/deduplicator/ray_minhash_deduplicator.py diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index c368e196d..0e95ae956 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -5,6 +5,7 @@ from .ray_basic_deduplicator import RayBasicDeduplicator from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator +from .ray_minhash_deduplicator import RayMinhashDeduplicator from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator @@ -13,5 +14,5 @@ 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', - 'RayRedisMinhashDeduplicator', 'VideoDeduplicator' + 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', 'VideoDeduplicator', ] diff --git a/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py new file mode 100644 index 000000000..16755b9f6 --- /dev/null +++ b/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py @@ -0,0 +1,371 @@ +import random +import time +import uuid +from collections import defaultdict +from typing import Optional + +import ray +import numpy as np +import pandas as pd +import pyarrow as pa +import regex +from loguru import logger +from pydantic import Field, PositiveInt +from typing_extensions import Annotated +from typing import Dict + +from data_juicer.utils.constant import HashKeys, Fields +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import prepare_sentencepiece_model + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace +from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, + optimal_param, sha1_hash32) + + +@ray.remote +class UnionFindWithMerge: + def __init__(self): + """Initialization method.""" + self.parent: Dict[str | bytes, str] = {} + + @staticmethod + def find_with_parent(parent, x): + x_list = [] + while x in parent: + x_list.append(x) + x = parent[x] + for xx in x_list: + parent[xx] = x + return x + + def find(self, x): + return self.find_with_parent(self.parent, x) + + def union(self, x, y): + px = self.find(x) + py = self.find(y) + if px == py: + return + if px > py: + px, py = py, px + self.parent[py] = px + + def union_list(self, x_list): + px_list = [self.find(x) for x in x_list] + p = min(px_list) + for px in px_list: + if p != px: + self.parent[px] = p + + def is_ancestor(self, x): + assert (x not in self.parent) == (x == self.parent.get(x, x)), f'{x}, {self.parent.get(x, x)}' + return x not in self.parent + + def get_num(self): + return len(self.parent) + + def get_nodes(self): + return set(self.parent.keys()) + + def get_parent(self): + return self.parent + + def merge(self, union_find_set): + union_find_set_parent = ray.get(union_find_set.get_parent.remote()) + union_find_set_nodes = set(union_find_set_parent.keys()) + parrent_nodes = set(self.parent.keys()) + for x in union_find_set_nodes: + px = self.find_with_parent(union_find_set_parent, x) + if x in parrent_nodes: + py = self.find(x) + self.union(px, py) + else: + self.parent[x] = px + + def merge_list(self, union_find_set_list): + for union_find_set in union_find_set_list: + self.merge(union_find_set) + + +OP_NAME = 'ray_minhash_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +class RayMinhashDeduplicator(Deduplicator): + """ + A basic exact matching deduplicator for RAY. + Although its functionality is deduplication, + it is implemented as Filter sub-class. + """ + + # TODO: Set a more reasonable value + EMPTY_HASH_VALUE = 'EMPTY' + _batched_op = True + + def __init__( + self, + tokenization: str = 'space', + window_size: PositiveInt = 5, + lowercase: bool = True, + ignore_pattern: Optional[str] = None, + num_permutations: PositiveInt = 256, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, + union_find_parallel_num: Optional[int] = 16, + union_find_merge_num: Optional[int] = 2, + *args, + **kwargs, + ): + """ + Initialization method. + + :param tokenization: tokenization method for sample texts. It + should be one of [space, punctuation, character, + sentencepiece]. For English-like languages, we recommend + to use 'space', for Chinese-like languages, we recommend + to use 'character', and for multiple languages, we recommend + to use 'sentencepiece'. If using 'sentencepiece', please + provided the model path in the 'tokenizer_model' field. + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing minhash + :param num_permutations: number of permutations in minhash + computing + :param jaccard_threshold: the min jaccard similarity threshold + in near-duplicate detection. When the jaccard similarity of + two sample texts is >= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + :param tokenizer_model: path for the sentencepiece model, used for + sentencepiece tokenization. + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + self.punctuation_pattern = regex.compile(r'\p{P}') + + if self.tokenization == 'sentencepiece': + if tokenizer_model is None: + raise ValueError("To use 'sentencepiece' tokenization, " + "'tokenizer_model' is required.") + self.tokenizer = prepare_sentencepiece_model(tokenizer_model) + else: + self.tokenizer = None + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + + self.init_union_find(union_find_parallel_num, union_find_merge_num) + + + def init_union_find(self, union_find_parallel_num, union_find_merge_num): + self.union_find_parallel_num = union_find_parallel_num # 2 # 16 + self.union_find_merge_num = union_find_merge_num + self.union_find_list = [ + UnionFindWithMerge.remote() + for _ in range(self.union_find_parallel_num) + ] + + def compute_stats(self, samples: pa.Table) -> pa.Table: + samples_list = samples[self.text_key] + uuid_list = [uuid.uuid4().hex for _ in range(samples.num_rows)] + all_hash_values = [[] for _ in range(self.num_bands)] + + for text in samples_list: + text = text.as_py() + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = set() + if self.tokenization == 'character': + tokens = { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + tokens = self.tokenizer.encode(text, out_type=str) + tokens = { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + if len(tokens) > 0: + hv = np.array( + [sha1_hash32(token) for token in tokens], + dtype=np.uint64 + ) + phv = ( + (hv[:, None] * self.perm_a[None, :] + + self.perm_b) % MERSENNE_PRIME + ).astype(np.uint32) + hash_values = phv.min(axis=0) + else: + hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) + for i, (start, end) in enumerate(self.hash_ranges): + all_hash_values[i].append( + i.to_bytes(4, 'big') + + bytes(hash_values[start:end].byteswap().data) + ) + + samples = samples.append_column(HashKeys.uid, pa.array(uuid_list)) + for i, hash_values in enumerate(all_hash_values): + samples = samples.append_column(HashKeys.minhash + f"_{i}", pa.array(hash_values)) + return samples + + def map_batched(self, samples: pa.Table) -> pa.Table: + table = pa.Table.from_arrays( + [ + pa.concat_arrays( + [samples[HashKeys.uid].combine_chunks()] * len(self.hash_ranges) + ), + pa.concat_arrays( + [ + samples[HashKeys.minhash + f'_{i}'].combine_chunks() + for i in range(len(self.hash_ranges)) + ] + ), + ], + names=[HashKeys.uid, HashKeys.minhash] + ) + return table + + def agg_func(self, group: pa.Table) -> pa.Table: + if group.num_rows != 1: + uuid_list = [uid.as_py() for uid in group[HashKeys.uid]] + union_find_id = np.random.randint(0, self.union_find_parallel_num) + union_find = self.union_find_list[union_find_id] + ray.get(union_find.union_list.remote(uuid_list)) + return group + + def merge(self): + union_find_list = self.union_find_list + while len(union_find_list) > 1: + new_union_find_list = [] + task_list = [] + buffer = [] + for union_find in union_find_list: + buffer.append(union_find) + if len(buffer) == self.union_find_merge_num: + new_union_find_list.append(buffer[0]) + task_list.append(buffer[0].merge_list.remote(buffer[1:])) + buffer = [] + if len(buffer) > 0: + new_union_find_list.append(buffer[0]) + if len(buffer) > 1: + task_list.append(buffer[0].merge_list.remote(buffer[1:])) + ray.get(task_list) + union_find_list = new_union_find_list + self.parent = ray.get(union_find_list[0].get_nodes.remote()) + + def filter_with_union_find(self, samples: pa.Table) -> pa.Table: + mask = [ + uid.as_py() not in self.parent + for uid in samples[HashKeys.uid] + ] + return samples.filter(mask) + + def run(self, dataset): + # import time + # start_time = time.time() + dataset = dataset.map_batches( + self.compute_stats, + batch_format='pyarrow', + ).materialize() + drop_columns = [] + for i in range(len(self.hash_ranges)): + drop_column = HashKeys.minhash + f'_{i}' + drop_columns.append(drop_column) + # end_time = time.time() + # print(f'minhash time = {end_time - start_time}') + + # start_time = time.time() + dataset.map_batches( + self.map_batched, + batch_format='pyarrow', + ).groupby( + HashKeys.minhash + ).map_groups( + self.agg_func, batch_format='pyarrow' + ).materialize() + # end_time = time.time() + # print(f'group time = {end_time - start_time}') + # start_time = time.time() + self.merge() + # end_time = time.time() + # print(f'merge time = {end_time - start_time}') + result = dataset.drop_columns( + drop_columns + ).map_batches( + self.filter_with_union_find, + batch_format='pyarrow' + ).materialize() + logger.info(f'Keep {result.count()} samples after MinHash dedup.') + return result diff --git a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py index 72c250af1..ee5478e3b 100644 --- a/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py @@ -144,7 +144,7 @@ def __init__( num_bands: Optional[PositiveInt] = None, num_rows_per_band: Optional[PositiveInt] = None, tokenizer_model: Optional[str] = None, - redis_address: str = 'redis://localhost:6379', + redis_address: str = 'redis://localhost:6380', *args, **kwargs, ): @@ -303,7 +303,7 @@ def filter_with_union_find(table: pa.Table) -> pa.Table: HashKeys.minhash).aggregate( UnionFn(union_find)).materialize() result = dataset_with_id.map_batches(filter_with_union_find, - batch_format='pyarrow') + batch_format='pyarrow').materialize() logger.info(f'Keep {result.count()} samples after MinHash dedup.') union_find.clean() return result @@ -349,20 +349,32 @@ def compute_minhash(self, text): raise NotImplementedError( f'Unimplemented tokenization method [{self.tokenization}]') - # compute minhash value - hv = np.array([sha1_hash32(token) for token in tokens], - dtype=np.uint64) - phv = np.bitwise_and( - ((hv * np.tile(self.perm_a, - (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, - MAX_HASH) - hash_values = np.vstack([ - phv, - np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH - ]).min(axis=0) + # # compute minhash value + # hv = np.array([sha1_hash32(token) for token in tokens], + # dtype=np.uint64) + # phv = np.bitwise_and( + # ((hv * np.tile(self.perm_a, + # (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, + # MAX_HASH) + # hash_values = np.vstack([ + # phv, + # np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH + # ]).min(axis=0) + if len(tokens) > 0: + hv = np.array( + [sha1_hash32(token) for token in tokens], + dtype=np.uint64 + ) + phv = ( + (hv[:, None] * self.perm_a[None, :] + + self.perm_b) % MERSENNE_PRIME + ).astype(np.uint32) + hash_values = phv.min(axis=0) + else: + hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) return [ bytes(hash_values[start:end].byteswap().data) + - start.to_bytes(8, byteorder='little') + start.to_bytes(4, byteorder='little') for start, end in self.hash_ranges # groupby minhash||brand_id ] From 139507204bb94a35e23424c73ede64595f098b54 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 21 Nov 2024 10:27:46 +0800 Subject: [PATCH 05/22] add bts and multi redis --- data_juicer/ops/deduplicator/__init__.py | 6 +- .../ray_bts_minhash_deduplicator.py | 489 ++++++++++++++++++ .../deduplicator/ray_minhash_deduplicator.py | 3 + .../ray_multi_redis_minhash_deduplicator.py | 473 +++++++++++++++++ 4 files changed, 970 insertions(+), 1 deletion(-) create mode 100644 data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py create mode 100644 data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 0e95ae956..f8be0101d 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -6,7 +6,9 @@ from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator from .ray_minhash_deduplicator import RayMinhashDeduplicator +from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator +from .ray_multi_redis_minhash_deduplicator import RayMultiRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator @@ -14,5 +16,7 @@ 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', - 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', 'VideoDeduplicator', + 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', + 'RayMinhashDeduplicator', 'RayBTSMinhashDeduplicator', + 'RayMultiRedisMinhashDeduplicator', 'VideoDeduplicator', ] diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py new file mode 100644 index 000000000..a92b83cce --- /dev/null +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -0,0 +1,489 @@ +import random +import time +import uuid +from collections import defaultdict +from typing import Optional + +import ray +import numpy as np +import pandas as pd +import pyarrow as pa +import regex +from loguru import logger +from pydantic import Field, PositiveInt +from typing_extensions import Annotated +from typing import Dict + +from data_juicer.utils.constant import HashKeys, Fields +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import prepare_sentencepiece_model + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace +from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, + optimal_param, sha1_hash32) + + +def merge_edge_list_dict_list(edge_list_dict_list): + final_edge_list_dict = {} + for edge_list_dict in edge_list_dict_list: + for hash_v, edge_list in edge_list_dict.items(): + if hash_v not in final_edge_list_dict: + final_edge_list_dict[hash_v] = [] + final_edge_list_dict[hash_v].extend(edge_list) + return final_edge_list_dict + + +def BTS_hash(x, parallel_num): + return int(x[-8:], 16) % parallel_num + + +@ray.remote +class BTSUnionFind: + def __init__(self, parallel_num, parallel_id): + self.parallel_num = parallel_num + self.parallel_id = parallel_id + self.parent = {} + self.edge_buffer = [] + self.edge_list_dict = {} + + def init_union_find_list(self, union_find_list): + self.union_find_list = union_find_list + + def receive_edges(self, edge_list): + self.edge_buffer.extend(edge_list) + + def balanced_union_find(self): + parent = self.parent.copy() + for x, y in self.edge_buffer: + self.union(x, y) + self.edge_buffer = [] + self.rebalancing() + for u in parent: + if parent[u] != self.parent.get(u, None): + return True + return False + + def hash(self, u): + return BTS_hash(u, self.parallel_num) + + def distribute_edge(self, u, v): + hash_u = self.hash(u) + hash_v = self.hash(v) + # if hash_u != self.parallel_id: + if True: + if hash_u not in self.edge_list_dict: + self.edge_list_dict[hash_u] = [] + self.edge_list_dict[hash_u].append((u, v)) + # if hash_v != self.parallel_id and hash_u != hash_v: + if hash_u != hash_v: + if hash_v not in self.edge_list_dict: + self.edge_list_dict[hash_v] = [] + self.edge_list_dict[hash_v].append((v, u)) + + def simplify_edge_list(self): + if self.parallel_id in self.edge_list_dict: + self.edge_buffer.extend(self.edge_list_dict[self.parallel_id]) + del self.edge_list_dict[self.parallel_id] + + def get_edge_list_dict(self): + return self.edge_list_dict + + def edge_redistribution(self): + self.edge_list_dict = {} + for u in self.parent: + v = self.parent[u] + self.distribute_edge(u, v) + self.simplify_edge_list() + # print(f'{self.parallel_id} {self.edge_list_dict}') + self.parent = {} + + def communication(self): + self.edge_list_dict = {} + del_list = [] + for u in self.parent: + hash_u = self.hash(u) + v = self.parent[u] + if self.parent[u] != self.old_parent[u] or (hash_u != self.parallel_id and v not in self.parent): + self.distribute_edge(u, v) + if hash_u != self.parallel_id: + del_list.append(u) + for u in del_list: + del self.parent[u] + self.simplify_edge_list() + # return len(self.edge_list_dict) > 0 + + def find(self, x): + if x not in self.parent: + return x + else: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + px = self.find(x) + py = self.find(y) + if px == py: + return + if px > py: + px, py = py, px + self.parent[py] = px + + def union_list(self, x_list): + px_list = [self.find(x) for x in x_list] + p = min(px_list) + for px in px_list: + if p != px: + self.parent[px] = p + + def rebalancing(self): + self.old_parent = self.parent.copy() + new_px_dict = {} + for x in self.parent: + hash_x = self.hash(x) + px = self.find(x) + key = (px, hash_x) + if key not in new_px_dict: + new_px_dict[key] = x + else: + new_px_dict[key] = min(new_px_dict[key], x) + px_set = set(px for px, _ in new_px_dict) + for px in px_set: + hash_px = self.hash(px) + key = (px, hash_px) + if key not in new_px_dict: + new_px_dict[key] = px + else: + new_px_dict[key] = min(new_px_dict[key], px) + + for x in self.parent: + hash_x = self.hash(x) + px = self.find(x) + key = (px, hash_x) + if x == new_px_dict[key]: + continue + self.parent[x] = new_px_dict[key] + + def get_parent(self): + return self.parent + + def get_nodes(self): + return set(self.parent.keys()) + + +OP_NAME = 'ray_bts_minhash_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +class RayBTSMinhashDeduplicator(Deduplicator): + """ + A basic exact matching deduplicator for RAY. + Although its functionality is deduplication, + it is implemented as Filter sub-class. + """ + + # TODO: Set a more reasonable value + EMPTY_HASH_VALUE = 'EMPTY' + _batched_op = True + + def __init__( + self, + tokenization: str = 'space', + window_size: PositiveInt = 5, + lowercase: bool = True, + ignore_pattern: Optional[str] = None, + num_permutations: PositiveInt = 256, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, + union_find_parallel_num: Optional[int] = 16, + union_find_merge_num: Optional[int] = 2, + *args, + **kwargs, + ): + """ + Initialization method. + + :param tokenization: tokenization method for sample texts. It + should be one of [space, punctuation, character, + sentencepiece]. For English-like languages, we recommend + to use 'space', for Chinese-like languages, we recommend + to use 'character', and for multiple languages, we recommend + to use 'sentencepiece'. If using 'sentencepiece', please + provided the model path in the 'tokenizer_model' field. + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing minhash + :param num_permutations: number of permutations in minhash + computing + :param jaccard_threshold: the min jaccard similarity threshold + in near-duplicate detection. When the jaccard similarity of + two sample texts is >= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + :param tokenizer_model: path for the sentencepiece model, used for + sentencepiece tokenization. + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + self.punctuation_pattern = regex.compile(r'\p{P}') + + if self.tokenization == 'sentencepiece': + if tokenizer_model is None: + raise ValueError("To use 'sentencepiece' tokenization, " + "'tokenizer_model' is required.") + self.tokenizer = prepare_sentencepiece_model(tokenizer_model) + else: + self.tokenizer = None + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + + self.union_find_parallel_num = union_find_parallel_num + self.union_find_merge_num = union_find_merge_num + self.union_find_list = [ + BTSUnionFind.remote(union_find_parallel_num, i) + for i in range(self.union_find_parallel_num) + ] + + def compute_stats(self, samples: pa.Table) -> pa.Table: + samples_list = samples[self.text_key] + uuid_list = [uuid.uuid4().hex for _ in range(samples.num_rows)] + all_hash_values = [[] for _ in range(self.num_bands)] + + for text in samples_list: + text = text.as_py() + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = set() + if self.tokenization == 'character': + tokens = { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + tokens = self.tokenizer.encode(text, out_type=str) + tokens = { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + if len(tokens) > 0: + hv = np.array( + [sha1_hash32(token) for token in tokens], + dtype=np.uint64 + ) + phv = ( + (hv[:, None] * self.perm_a[None, :] + + self.perm_b) % MERSENNE_PRIME + ).astype(np.uint32) + hash_values = phv.min(axis=0) + else: + hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) + for i, (start, end) in enumerate(self.hash_ranges): + all_hash_values[i].append( + i.to_bytes(4, 'big') + + bytes(hash_values[start:end].byteswap().data) + ) + + samples = samples.append_column(HashKeys.uid, pa.array(uuid_list)) + for i, hash_values in enumerate(all_hash_values): + samples = samples.append_column(HashKeys.minhash + f"_{i}", pa.array(hash_values)) + return samples + + def map_batched(self, samples: pa.Table) -> pa.Table: + table = pa.Table.from_arrays( + [ + pa.concat_arrays( + [samples[HashKeys.uid].combine_chunks()] * len(self.hash_ranges) + ), + # pa.array( + # [uid.as_py() for uid in samples[HashKeys.uid]] * len(self.hash_ranges) + # ), + pa.concat_arrays( + [ + samples[HashKeys.minhash + f'_{i}'].combine_chunks() + for i in range(len(self.hash_ranges)) + ] + ), + ], + names=[HashKeys.uid, HashKeys.minhash] + ) + return table + + def agg_func(self, group: pa.Table) -> pa.Table: + if group.num_rows != 1: + uuid_list = [uid.as_py() for uid in group[HashKeys.uid]] + union_find_id = np.random.randint(0, self.union_find_parallel_num) + union_find = self.union_find_list[union_find_id] + ray.get(union_find.union_list.remote(uuid_list)) + return group + + def merge(self): + ray.get([ + union_find.rebalancing.remote() + for union_find in self.union_find_list + ]) + ray.get([ + union_find.edge_redistribution.remote() + for union_find in self.union_find_list + ]) + edge_list_dict_list = ray.get([ + union_find.get_edge_list_dict.remote() + for union_find in self.union_find_list + ]) + edge_list_dict = merge_edge_list_dict_list(edge_list_dict_list) + ray.get([ + self.union_find_list[i].receive_edges.remote(edge_list) + for i, edge_list in edge_list_dict.items() + ]) + ray.get([ + union_find.balanced_union_find.remote() + for union_find in self.union_find_list + ]) + while True: + ray.get([ + union_find.communication.remote() + for union_find in self.union_find_list + ]) + edge_list_dict_list = ray.get([ + union_find.get_edge_list_dict.remote() + for union_find in self.union_find_list + ]) + edge_list_dict = merge_edge_list_dict_list(edge_list_dict_list) + ray.get([ + self.union_find_list[i].receive_edges.remote(edge_list) + for i, edge_list in edge_list_dict.items() + ]) + update_list = ray.get([ + union_find.balanced_union_find.remote() + for union_find in self.union_find_list + ]) + + break_flag = True + for update in update_list: + if update: + break_flag = False + break + if break_flag: + break + self.parents = ray.get([ + union_find.get_nodes.remote() + for union_find in self.union_find_list + ]) + + def is_dup(self, uid): + part = BTS_hash(uid, self.union_find_parallel_num) + return uid in self.parents[part] + + def filter_with_union_find(self, samples: pa.Table) -> pa.Table: + mask = [ + not self.is_dup(uid.as_py()) + for uid in samples[HashKeys.uid] + ] + return samples.filter(mask) + + def run(self, dataset): + import time + start_time = time.time() + dataset = dataset.map_batches( + self.compute_stats, + batch_format='pyarrow', + ).materialize() + drop_columns = [] + for i in range(len(self.hash_ranges)): + drop_column = HashKeys.minhash + f'_{i}' + drop_columns.append(drop_column) + end_time = time.time() + print(f'minhash time = {end_time - start_time}') + + start_time = time.time() + dataset.map_batches( + self.map_batched, + batch_format='pyarrow', + ).groupby( + HashKeys.minhash + ).map_groups( + self.agg_func, batch_format='pyarrow' + ).materialize() + end_time = time.time() + print(f'group time = {end_time - start_time}') + start_time = time.time() + self.merge() + end_time = time.time() + print(f'merge time = {end_time - start_time}') + result = dataset.drop_columns( + drop_columns + ).map_batches( + self.filter_with_union_find, + batch_format='pyarrow' + ).materialize() + logger.info(f'Keep {result.count()} samples after MinHash dedup.') + return result diff --git a/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py index 16755b9f6..e204422ec 100644 --- a/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py @@ -286,6 +286,9 @@ def map_batched(self, samples: pa.Table) -> pa.Table: pa.concat_arrays( [samples[HashKeys.uid].combine_chunks()] * len(self.hash_ranges) ), + # pa.array( + # [uid.as_py() for uid in samples[HashKeys.uid]] * len(self.hash_ranges) + # ), pa.concat_arrays( [ samples[HashKeys.minhash + f'_{i}'].combine_chunks() diff --git a/data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py new file mode 100644 index 000000000..56c5dfcd5 --- /dev/null +++ b/data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py @@ -0,0 +1,473 @@ +import random +import time +import uuid +from collections import defaultdict +from typing import Optional +import ray + +import numpy as np +import pandas as pd +import pyarrow as pa +import regex +from loguru import logger +from pydantic import Field, PositiveInt +from typing_extensions import Annotated +import concurrent + +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import prepare_sentencepiece_model + +from ..base_op import OPERATORS, Deduplicator +from ..common.helper_func import split_on_whitespace +from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, + optimal_param, sha1_hash32) + +redis = LazyLoader('redis', 'redis') + + +def retry_on_busy(func): + + def wrapper(*args, **kwargs): + max_retries = 10 + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + if 'BUSY' in str(e) and attempt < max_retries - 1: + time.sleep(random.uniform(0.1, 0.3) * (2**attempt)) + else: + raise + + return wrapper + + +class RedisUnionFind: + + def __init__(self, + prefix: str, + redis_address: str = 'redis://localhost:6380'): + self.prefix = prefix + self.redis_address = redis_address + self.redis = redis.from_url(url=redis_address) + self.set_key = f'{prefix}_UF_SET' + self.incur_id_key = f'{prefix}_UF_INCURID' + + # Lua scripts + self.union_script = self.redis.register_script(""" + local function find(x) + local path = {} + while true do + local parent = redis.call('HGET', KEYS[1], x) + if not parent then + break + end + table.insert(path, x) + x = parent + end + for _, node in ipairs(path) do + redis.call('HSET', KEYS[1], node, x) + end + return x + end + + local root_x = find(ARGV[1]) + local root_y = find(ARGV[2]) + if root_x == root_y then + return root_x + end + if root_x < root_y then + redis.call('HSET', KEYS[1], root_y, root_x) + return root_x + else + redis.call('HSET', KEYS[1], root_x, root_y) + return root_y + end + """) + + self.merge_script = self.redis.register_script(""" + local function find(key, x) + local path = {} + while true do + local parent = redis.call('HGET', key, x) + if not parent then + break + end + table.insert(path, x) + x = parent + end + for _, node in ipairs(path) do + redis.call('HSET', key, node, x) + end + return x + end + + local function merge(key) + local nodes = redis.call('HKEYS', key) + for _, node in ipairs(nodes) do + local root = find(key, node) + local root_x = find(KEYS[1], node) + local root_y = find(KEYS[1], root) + if root_x < root_y then + redis.call('HSET', KEYS[1], root_y, root_x) + elseif root_x > root_y then + redis.call('HSET', KEYS[1], root_x, root_y) + end + end + end + + for _, key in ipairs(ARGV) do + merge(key) + end + """) + + def get_uid(self): + return int(self.redis.incr(self.incur_id_key)) + + @retry_on_busy + def union(self, x, y) -> int: + return int(self.union_script(keys=[self.set_key], + args=[x, y])) + + @retry_on_busy + def merge(self, set_keys): + # self.merge_script(keys=[self.set_key] + set_keys, args=set_keys) + for set_key in set_keys: + for x, y in self.redis.hgetall(set_key).items(): + self.union(x, y) + # for x in self.redis.hkeys(set_key): + # y = self.redis.hget(set_key, x) + # self.redis. + + def get_nodes(self): + return set(int(x) for x in self.redis.hkeys(self.set_key)) + + def get_data(self): + result = {} + for x in self.get_nodes(): + y = int(self.redis.hget(self.set_key, x)) + result[x] = y + return result + + def is_ancestor(self, x): + ancestor = self.redis.hget(self.set_key, x) + return ancestor is None or int(ancestor) == x + + def __reduce__(self): + return (RedisUnionFind, (self.prefix, self.redis_address)) + + def clean(self): + self.redis.delete(self.set_key, self.incur_id_key) + + +OP_NAME = 'ray_multi_redis_minhash_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +class RayMultiRedisMinhashDeduplicator(Deduplicator): + """ + A basic exact matching deduplicator for RAY. + Although its functionality is deduplication, + it is implemented as Filter sub-class. + """ + + def __init__( + self, + tokenization: str = 'space', + window_size: PositiveInt = 5, + lowercase: bool = True, + ignore_pattern: Optional[str] = None, + num_permutations: PositiveInt = 256, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, + redis_address: str = 'redis://localhost:6380', + union_find_parallel_num: Optional[int] = 16, + union_find_merge_num: Optional[int] = 2, + *args, + **kwargs, + ): + """ + Initialization method. + + :param tokenization: tokenization method for sample texts. It + should be one of [space, punctuation, character, + sentencepiece]. For English-like languages, we recommend + to use 'space', for Chinese-like languages, we recommend + to use 'character', and for multiple languages, we recommend + to use 'sentencepiece'. If using 'sentencepiece', please + provided the model path in the 'tokenizer_model' field. + :param window_size: window size of shingling + :param lowercase: whether to convert text to lower case first + :param ignore_pattern: whether to ignore sub-strings with + specific pattern when computing minhash + :param num_permutations: number of permutations in minhash + computing + :param jaccard_threshold: the min jaccard similarity threshold + in near-duplicate detection. When the jaccard similarity of + two sample texts is >= this threshold, they are regarded as + similar samples and this op will only keep one of them after + deduplication + :param num_bands: number of bands in LSH. Default it's None, and + it will be determined by an optimal params computation + algorithm by minimize the weighted sum of probs of False + Positives and False Negatives + :param num_rows_per_band: number of rows in each band in LSH. + Default it's None, and it will be determined by an optimal + params computation algorithm + :param tokenizer_model: path for the sentencepiece model, used for + sentencepiece tokenization. + :param redis_address: address of your redis instance, e.g. + 'redis://localhost:6379' + """ + super().__init__(*args, **kwargs) + # about minhash computation + self.tokenization = tokenization + self.window_size = window_size + self.lowercase = lowercase + self.ignore_pattern = ignore_pattern + if self.ignore_pattern: + self.ignore_pattern = regex.compile(self.ignore_pattern) + + # check parameters + if self.ignore_pattern and self.tokenization == 'punctuation': + logger.warning('Be careful that tokenization with punctuations ' + 'won\'t work if the ignore pattern includes ' + 'punctuations.') + self.punctuation_pattern = regex.compile(r'\p{P}') + + if self.tokenization == 'sentencepiece': + if tokenizer_model is None: + raise ValueError("To use 'sentencepiece' tokenization, " + "'tokenizer_model' is required.") + self.tokenizer = prepare_sentencepiece_model(tokenizer_model) + else: + self.tokenizer = None + + # about deduplication + self.num_permutation = num_permutations + self.jaccard_threshold = jaccard_threshold + self.num_bands = num_bands + self.num_rows_per_band = num_rows_per_band + + # initialize deduplication parameters + # check number of bands and rows + if self.num_bands is None or self.num_rows_per_band is None: + self.num_bands, self.num_rows_per_band = optimal_param( + self.jaccard_threshold, + self.num_permutation, + ) + + # compute hash ranges and create hash tables + self.hash_ranges = [(i * self.num_rows_per_band, + (i + 1) * self.num_rows_per_band) + for i in range(self.num_bands)] + self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] + + # generate permutations + gen = np.random.RandomState(seed=42) + self.perm_a, self.perm_b = np.array( + [( + gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), + gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), + ) for _ in range(self.num_permutation)], + dtype=np.uint64, + ).T + self.redis_address = redis_address + self.union_find_parallel_num = union_find_parallel_num + self.union_find_merge_num = union_find_merge_num + + def run(self, dataset): + from ray.data.aggregate import AggregateFn + + # union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], + # redis_address=self.redis_address) + union_find_list = [ + RedisUnionFind(prefix=uuid.uuid4().hex[:8] + f'_{i}', redis_address=self.redis_address) + for i in range(self.union_find_parallel_num) + ] + + def add_uid_column(table: pa.Table) -> pa.Table: + new_column_data = [union_find_list[0].get_uid() for _ in range(len(table))] + new_table = table.append_column(HashKeys.uid, [new_column_data]) + return new_table + + def calculate_minhash(table: pa.Table) -> pa.Table: + ids = table.column(HashKeys.uid).to_pandas() + texts = table.column(self.text_key).to_pandas() + hashes = texts.apply(lambda x: self.compute_minhash(x)) + hashes = pa.Array.from_pandas(hashes).flatten() + + repeated_ids = pa.Array.from_pandas(ids.repeat(self.num_bands)) + + return pa.Table.from_arrays([repeated_ids, hashes], + names=[HashKeys.uid, HashKeys.minhash]) + + class UnionFn(AggregateFn): + + def __init__(self, union_find_list): + # union_find = union_find + union_find_num = len(union_find_list) + + def accumulate(cur, row): + if cur is None: + return int.from_bytes(row[HashKeys.minhash][:8], byteorder='big') % union_find_num, row[HashKeys.uid] + else: + assert cur[0] == int.from_bytes(row[HashKeys.minhash][:8], byteorder='big') % union_find_num + union_find = union_find_list[cur[0]] + root = union_find.union(row[HashKeys.uid], cur[1]) + return cur[0], root + + def merge(a, b): + if a is None: + return b + if b is None: + return a + assert a[0] == b[0] + union_find = union_find_list[a[0]] + root = union_find.union(a[1], b[1]) + # root = union_find.union(a, b) + return a[0], root + + super().__init__( + init=lambda k: None, + accumulate_row=accumulate, + merge=merge, + name='union', + ) + + dataset_with_id = dataset.map_batches( + add_uid_column, batch_format='pyarrow').materialize() + dataset_with_id.map_batches( + calculate_minhash, + batch_format='pyarrow' + ).groupby( + HashKeys.minhash + ).aggregate( + UnionFn(union_find_list) + ).materialize() + + # results = [] + # for union_find in union_find_list: + # results.append(union_find.get_data()) + @ray.remote + def merge(x, keys): + x.merge(keys) + + merge_list = union_find_list + while len(merge_list) > 1: + new_merge_list, buffer = [], [] + task_list = [] + for union_find in merge_list: + buffer.append(union_find) + if len(buffer) == self.union_find_merge_num: + new_merge_list.append(buffer[0]) + keys = [u.set_key for u in buffer[1:]] + task_list.append( + merge.remote(buffer[0], keys) + ) + buffer = [] + if len(buffer) > 0: + new_merge_list.append(buffer[0]) + if len(buffer) > 1: + keys = [u.set_key for u in buffer[1:]] + task_list.append( + merge.remote(buffer[0], keys) + ) + ray.get(task_list) + merge_list = new_merge_list + # for m in merge_list: + # results.append(m.get_data()) + + # results.append(merge_list[0].get_data()) + # import json + # with open(f'data_{len(results)}.json', 'w') as f: + # json.dump(results, f) + dup_ids = merge_list[0].get_nodes() + + def filter_with_union_find(table: pa.Table) -> pa.Table: + uids = table.column(HashKeys.uid).to_pandas() + mask = pa.Array.from_pandas( + uids.apply(lambda x: x not in dup_ids)) + return table.filter(mask) + + result = dataset_with_id.map_batches( + filter_with_union_find, + batch_format='pyarrow' + ).materialize() + logger.info(f'Keep {result.count()} samples after MinHash dedup.') + for union_find in union_find_list: + union_find.clean() + return result + + def compute_minhash(self, text): + """ + Compute minhash values for the sample. + + :param sample: input sample + :return: sample with minhash value. + """ + if self.lowercase: + text = text.lower() + if self.ignore_pattern: + text = self.ignore_pattern.sub('', text) + + # get tokens for different tokenization method + tokens = set() + if self.tokenization == 'character': + tokens = { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + tokens = self.punctuation_pattern.split(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + tokens = split_on_whitespace(text) + tokens = { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + tokens = self.tokenizer.encode(text, out_type=str) + tokens = { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + + # # compute minhash value + # hv = np.array([sha1_hash32(token) for token in tokens], + # dtype=np.uint64) + # phv = np.bitwise_and( + # ((hv * np.tile(self.perm_a, + # (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, + # MAX_HASH) + # hash_values = np.vstack([ + # phv, + # np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH + # ]).min(axis=0) + if len(tokens) > 0: + hv = np.array( + [sha1_hash32(token) for token in tokens], + dtype=np.uint64 + ) + phv = ( + (hv[:, None] * self.perm_a[None, :] + + self.perm_b) % MERSENNE_PRIME + ).astype(np.uint32) + hash_values = phv.min(axis=0) + else: + hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) + return [ + bytes(hash_values[start:end].byteswap().data) + + start.to_bytes(4, byteorder='little') + for start, end in self.hash_ranges + # groupby minhash||brand_id + ] From 112662fd11e01c8c29b5b686ae526b059908b7e6 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 26 Nov 2024 10:41:12 +0800 Subject: [PATCH 06/22] optimize v1 in bts minhash --- data_juicer/ops/deduplicator/__init__.py | 2 + .../ray_bts_minhash_deduplicator.py | 230 +++++++++--------- .../deduplicator/ray_minhash_deduplicator.py | 4 +- 3 files changed, 119 insertions(+), 117 deletions(-) diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index f8be0101d..9830f4e52 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -7,6 +7,7 @@ from .ray_image_deduplicator import RayImageDeduplicator from .ray_minhash_deduplicator import RayMinhashDeduplicator from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator +from .ray_bts_v2_minhash_deduplicator import RayBTSV2MinhashDeduplicator from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator from .ray_multi_redis_minhash_deduplicator import RayMultiRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator @@ -18,5 +19,6 @@ 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', 'RayMinhashDeduplicator', 'RayBTSMinhashDeduplicator', + 'RayBTSV2MinhashDeduplicator', 'RayMultiRedisMinhashDeduplicator', 'VideoDeduplicator', ] diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index a92b83cce..aa78a5978 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -35,32 +35,59 @@ def merge_edge_list_dict_list(edge_list_dict_list): def BTS_hash(x, parallel_num): - return int(x[-8:], 16) % parallel_num + # return int(x[-8:], 16) % parallel_num + return int.from_bytes(x, byteorder='big') % parallel_num + + +@ray.remote +class EdgeBuffer: + def __init__(self): + self.edge_dict = {} + + def clear(self): + self.edge_dict = {} + + def set_edges(self, edge_dict): + self.edge_dict = edge_dict + + def get_edges(self, key): + return self.edge_dict.get(key, []) @ray.remote class BTSUnionFind: - def __init__(self, parallel_num, parallel_id): + def __init__(self, parallel_num, parallel_id, remote_edge_buffers): self.parallel_num = parallel_num self.parallel_id = parallel_id self.parent = {} + self.old_parent = {} + self.remote_edge_buffers = remote_edge_buffers self.edge_buffer = [] self.edge_list_dict = {} def init_union_find_list(self, union_find_list): self.union_find_list = union_find_list - def receive_edges(self, edge_list): - self.edge_buffer.extend(edge_list) + def receive_edges(self): + edge_list = ray.get([ + remote_edge_buffer.get_edges.remote(self.parallel_id) + for remote_edge_buffer in self.remote_edge_buffers + ]) + for edges in edge_list: + self.edge_buffer.extend(edges) def balanced_union_find(self): - parent = self.parent.copy() + self.receive_edges() for x, y in self.edge_buffer: self.union(x, y) self.edge_buffer = [] self.rebalancing() - for u in parent: - if parent[u] != self.parent.get(u, None): + old_parent_keys = set(self.old_parent.keys()) + parent_keys = set(self.parent.keys()) + if old_parent_keys ^ parent_keys: + return True + for u in parent_keys: + if self.old_parent.get(u, u) != self.parent.get(u, u): return True return False @@ -70,33 +97,30 @@ def hash(self, u): def distribute_edge(self, u, v): hash_u = self.hash(u) hash_v = self.hash(v) - # if hash_u != self.parallel_id: - if True: - if hash_u not in self.edge_list_dict: - self.edge_list_dict[hash_u] = [] - self.edge_list_dict[hash_u].append((u, v)) - # if hash_v != self.parallel_id and hash_u != hash_v: + if hash_u not in self.edge_list_dict: + self.edge_list_dict[hash_u] = [] + self.edge_list_dict[hash_u].append((u, v)) if hash_u != hash_v: if hash_v not in self.edge_list_dict: self.edge_list_dict[hash_v] = [] - self.edge_list_dict[hash_v].append((v, u)) + self.edge_list_dict[hash_v].append((u, v)) - def simplify_edge_list(self): + def set_edge_buffer(self): if self.parallel_id in self.edge_list_dict: - self.edge_buffer.extend(self.edge_list_dict[self.parallel_id]) + self.edge_buffer = self.edge_list_dict[self.parallel_id] del self.edge_list_dict[self.parallel_id] - - def get_edge_list_dict(self): - return self.edge_list_dict + else: + self.edge_buffer = [] + ray.get(self.remote_edge_buffers[self.parallel_id].set_edges.remote(self.edge_list_dict)) def edge_redistribution(self): + self.rebalancing() self.edge_list_dict = {} for u in self.parent: v = self.parent[u] self.distribute_edge(u, v) - self.simplify_edge_list() - # print(f'{self.parallel_id} {self.edge_list_dict}') self.parent = {} + self.set_edge_buffer() def communication(self): self.edge_list_dict = {} @@ -104,14 +128,14 @@ def communication(self): for u in self.parent: hash_u = self.hash(u) v = self.parent[u] - if self.parent[u] != self.old_parent[u] or (hash_u != self.parallel_id and v not in self.parent): + if self.parent[u] != self.old_parent.get(u, u) or (hash_u != self.parallel_id and v not in self.parent): self.distribute_edge(u, v) if hash_u != self.parallel_id: del_list.append(u) + self.old_parent = self.parent.copy() for u in del_list: del self.parent[u] - self.simplify_edge_list() - # return len(self.edge_list_dict) > 0 + self.set_edge_buffer() def find(self, x): if x not in self.parent: @@ -137,7 +161,6 @@ def union_list(self, x_list): self.parent[px] = p def rebalancing(self): - self.old_parent = self.parent.copy() new_px_dict = {} for x in self.parent: hash_x = self.hash(x) @@ -170,6 +193,12 @@ def get_parent(self): def get_nodes(self): return set(self.parent.keys()) + def is_dup(self, queries): + return [ + query in self.parent + for query in queries + ] + OP_NAME = 'ray_bts_minhash_deduplicator' @@ -198,7 +227,6 @@ def __init__( num_rows_per_band: Optional[PositiveInt] = None, tokenizer_model: Optional[str] = None, union_find_parallel_num: Optional[int] = 16, - union_find_merge_num: Optional[int] = 2, *args, **kwargs, ): @@ -288,18 +316,19 @@ def __init__( ).T self.union_find_parallel_num = union_find_parallel_num - self.union_find_merge_num = union_find_merge_num + self.remote_edge_buffers = [ + EdgeBuffer.remote() + for i in range(self.union_find_parallel_num) + ] self.union_find_list = [ - BTSUnionFind.remote(union_find_parallel_num, i) + BTSUnionFind.remote(union_find_parallel_num, i, self.remote_edge_buffers) for i in range(self.union_find_parallel_num) ] - def compute_stats(self, samples: pa.Table) -> pa.Table: - samples_list = samples[self.text_key] - uuid_list = [uuid.uuid4().hex for _ in range(samples.num_rows)] + def calc_minhash(self, text_list: pa.Array) -> pa.Table: all_hash_values = [[] for _ in range(self.num_bands)] - for text in samples_list: + for text in text_list: text = text.as_py() if self.lowercase: text = text.lower() @@ -352,121 +381,90 @@ def compute_stats(self, samples: pa.Table) -> pa.Table: i.to_bytes(4, 'big') + bytes(hash_values[start:end].byteswap().data) ) - - samples = samples.append_column(HashKeys.uid, pa.array(uuid_list)) - for i, hash_values in enumerate(all_hash_values): - samples = samples.append_column(HashKeys.minhash + f"_{i}", pa.array(hash_values)) - return samples - - def map_batched(self, samples: pa.Table) -> pa.Table: - table = pa.Table.from_arrays( - [ - pa.concat_arrays( - [samples[HashKeys.uid].combine_chunks()] * len(self.hash_ranges) - ), - # pa.array( - # [uid.as_py() for uid in samples[HashKeys.uid]] * len(self.hash_ranges) - # ), - pa.concat_arrays( - [ - samples[HashKeys.minhash + f'_{i}'].combine_chunks() - for i in range(len(self.hash_ranges)) - ] - ), - ], - names=[HashKeys.uid, HashKeys.minhash] - ) - return table + return all_hash_values def agg_func(self, group: pa.Table) -> pa.Table: if group.num_rows != 1: uuid_list = [uid.as_py() for uid in group[HashKeys.uid]] - union_find_id = np.random.randint(0, self.union_find_parallel_num) + # union_find_id = np.random.randint(0, self.union_find_parallel_num) + min_uuid = min(uuid_list) + union_find_id = BTS_hash(min_uuid, self.union_find_parallel_num) union_find = self.union_find_list[union_find_id] ray.get(union_find.union_list.remote(uuid_list)) return group def merge(self): - ray.get([ - union_find.rebalancing.remote() - for union_find in self.union_find_list - ]) ray.get([ union_find.edge_redistribution.remote() for union_find in self.union_find_list ]) - edge_list_dict_list = ray.get([ - union_find.get_edge_list_dict.remote() - for union_find in self.union_find_list - ]) - edge_list_dict = merge_edge_list_dict_list(edge_list_dict_list) - ray.get([ - self.union_find_list[i].receive_edges.remote(edge_list) - for i, edge_list in edge_list_dict.items() - ]) - ray.get([ - union_find.balanced_union_find.remote() - for union_find in self.union_find_list - ]) - while True: + while any( ray.get([ - union_find.communication.remote() - for union_find in self.union_find_list - ]) - edge_list_dict_list = ray.get([ - union_find.get_edge_list_dict.remote() + union_find.balanced_union_find.remote() for union_find in self.union_find_list ]) - edge_list_dict = merge_edge_list_dict_list(edge_list_dict_list) + ): ray.get([ - self.union_find_list[i].receive_edges.remote(edge_list) - for i, edge_list in edge_list_dict.items() - ]) - update_list = ray.get([ - union_find.balanced_union_find.remote() + union_find.communication.remote() for union_find in self.union_find_list ]) - break_flag = True - for update in update_list: - if update: - break_flag = False - break - if break_flag: - break - self.parents = ray.get([ - union_find.get_nodes.remote() - for union_find in self.union_find_list - ]) - - def is_dup(self, uid): - part = BTS_hash(uid, self.union_find_parallel_num) - return uid in self.parents[part] - def filter_with_union_find(self, samples: pa.Table) -> pa.Table: + hash_id_list = [] + query_dict = {} + for uid in samples[HashKeys.uid]: + uid = uid.as_py() + hash_id = BTS_hash(uid, self.union_find_parallel_num) + hash_id_list.append(hash_id) + if hash_id not in query_dict: + query_dict[hash_id] = [] + query_dict[hash_id].append(uid) + results = ray.get([self.union_find_list[hash_id].is_dup.remote(query) for hash_id, query in query_dict.items()]) + result_dict = { + hash_id: result + for hash_id, result in zip(query_dict.keys(), results) + } mask = [ - not self.is_dup(uid.as_py()) - for uid in samples[HashKeys.uid] + not result_dict[hash_id].pop(0) + for hash_id in hash_id_list ] return samples.filter(mask) def run(self, dataset): - import time start_time = time.time() + def add_uid_column(table: pa.Table) -> pa.Table: + uuid_list = [uuid.uuid4().bytes for _ in range(table.num_rows)] + new_table = table.append_column(HashKeys.uid, pa.array(uuid_list)) + return new_table + dataset = dataset.map_batches( - self.compute_stats, + add_uid_column, batch_format='pyarrow', ).materialize() - drop_columns = [] - for i in range(len(self.hash_ranges)): - drop_column = HashKeys.minhash + f'_{i}' - drop_columns.append(drop_column) end_time = time.time() - print(f'minhash time = {end_time - start_time}') + print(f'uid time = {end_time - start_time}') + + def minhash_with_uid(table: pa.Table) -> pa.Table: + minhash_values = self.calc_minhash(table[self.text_key]) + new_table = pa.Table.from_arrays( + [ + pa.concat_arrays( + [table[HashKeys.uid].combine_chunks()] * len(self.hash_ranges) + ), + pa.concat_arrays( + [ + pa.array(minhash_values[i]) + for i in range(len(self.hash_ranges)) + ] + ), + ], + names=[HashKeys.uid, HashKeys.minhash] + ) + return new_table start_time = time.time() dataset.map_batches( - self.map_batched, + minhash_with_uid, batch_format='pyarrow', ).groupby( HashKeys.minhash @@ -479,11 +477,11 @@ def run(self, dataset): self.merge() end_time = time.time() print(f'merge time = {end_time - start_time}') - result = dataset.drop_columns( - drop_columns - ).map_batches( + result = dataset.map_batches( self.filter_with_union_find, batch_format='pyarrow' + ).drop_columns( + HashKeys.uid ).materialize() logger.info(f'Keep {result.count()} samples after MinHash dedup.') return result diff --git a/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py index e204422ec..7534a8f43 100644 --- a/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py @@ -218,7 +218,7 @@ def init_union_find(self, union_find_parallel_num, union_find_merge_num): def compute_stats(self, samples: pa.Table) -> pa.Table: samples_list = samples[self.text_key] - uuid_list = [uuid.uuid4().hex for _ in range(samples.num_rows)] + uuid_list = [uuid.uuid4().bytes for _ in range(samples.num_rows)] all_hash_values = [[] for _ in range(self.num_bands)] for text in samples_list: @@ -369,6 +369,8 @@ def run(self, dataset): ).map_batches( self.filter_with_union_find, batch_format='pyarrow' + ).drop_columns( + HashKeys.uid ).materialize() logger.info(f'Keep {result.count()} samples after MinHash dedup.') return result From 2b9bf6dab9b9365b823434710ea3e2784fa8f38c Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 26 Nov 2024 15:19:16 +0800 Subject: [PATCH 07/22] fix in ray data --- data_juicer/core/ray_data.py | 74 ++++++++++++++---------- data_juicer/ops/deduplicator/__init__.py | 2 - 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 30848bcf3..621e68cd9 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,4 +1,5 @@ import os +from functools import partial import pyarrow as pa from loguru import logger @@ -13,28 +14,26 @@ rd = LazyLoader('rd', 'ray.data') -def is_valid_path(item, dataset_dir): - full_path = os.path.abspath(os.path.join(dataset_dir, item)) - return os.path.exists(full_path) +def get_abs_path(path, dataset_dir): + full_path = os.path.abspath(os.path.join(dataset_dir, path)) + if os.path.exists(full_path): + return full_path + else: + return path -def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys): +def convert_to_absolute_paths(samples, dataset_dir, path_keys): + samples = samples.to_pydict() for key in path_keys: - if key not in dict_with_paths: - continue - if isinstance(dict_with_paths[key], list): - dict_with_paths[key] = [ - os.path.abspath(os.path.join(dataset_dir, item)) - if isinstance(item, str) and is_valid_path(dataset_dir, item) - else item for item in dict_with_paths[key] - ] - elif isinstance(dict_with_paths[key], str): - dict_with_paths[key] = os.path.abspath( - os.path.join(dataset_dir, - dict_with_paths[key])) if is_valid_path( - dict_with_paths[key], - dataset_dir) else dict_with_paths[key] - return dict_with_paths + for idx in range(len(samples[key])): + paths = samples[key][idx] + if isinstance(paths, str): + samples[key][idx] = get_abs_path(paths, dataset_dir) + elif isinstance(paths, list): + samples[key][idx] = [ + get_abs_path(item, dataset_dir) for item in paths + ] + return pa.Table.from_pydict(samples) # TODO: check path for nestdataset @@ -43,22 +42,26 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg): Set all the path in input data to absolute path. Checks dataset_dir and project_dir for valid paths. """ - if not (cfg.video_key in dataset.columns() or cfg.image_key - in dataset.columns() or cfg.audio_key in dataset.columns()): - return dataset - dataset_dir = os.path.dirname(dataset_path) - dataset = dataset.map(lambda item: convert_to_absolute_paths( - item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key])) - logger.info(f"transfer {dataset.count()} sample's paths") + path_keys = [] + columns = dataset.columns() + for key in [cfg.video_key, cfg.image_key, cfg.audio_key]: + if key in columns: + path_keys.append(key) + if len(path_keys) > 0: + dataset_dir = os.path.dirname(dataset_path) + dataset = dataset.map_batches(partial(convert_to_absolute_paths, + dataset_dir=dataset_dir, + path_keys=path_keys), + batch_format='pyarrow', + zero_copy_batch=True) return dataset def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: + columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - columns = dataset.columns() if Fields.stats not in columns: - logger.info(f'columns {columns}') def process_batch_arrow(table: pa.Table) -> pa.Table: new_column_data = [{} for _ in range(len(table))] @@ -77,6 +80,11 @@ def get_num_gpus(op, op_proc): return 1.0 / proc_per_gpu +def filter_batch(batch, filter_func): + mask = pa.array(filter_func(batch.to_pydict())) + return batch.filter(mask) + + class RayDataset(DJDataset): def __init__(self, @@ -122,7 +130,15 @@ def _run_single_op(self, op): if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) - self.data = self.data.filter(op.process) + if op.is_batched_op(): + self.data = self.data.map_batches(partial( + filter_batch, filter_func=op.process), + batch_format='pyarrow', + batch_size=batch_size, + num_gpus=num_gpus, + zero_copy_batch=True) + else: + self.data = self.data.filter(op.process) elif isinstance(op, Deduplicator): self.data = op.run(self.data) else: diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 9830f4e52..f8be0101d 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -7,7 +7,6 @@ from .ray_image_deduplicator import RayImageDeduplicator from .ray_minhash_deduplicator import RayMinhashDeduplicator from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator -from .ray_bts_v2_minhash_deduplicator import RayBTSV2MinhashDeduplicator from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator from .ray_multi_redis_minhash_deduplicator import RayMultiRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator @@ -19,6 +18,5 @@ 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', 'RayMinhashDeduplicator', 'RayBTSMinhashDeduplicator', - 'RayBTSV2MinhashDeduplicator', 'RayMultiRedisMinhashDeduplicator', 'VideoDeduplicator', ] From 420bf15c99306e3672fd011ca8da1c469debc966 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 26 Nov 2024 20:29:05 +0800 Subject: [PATCH 08/22] fix in bts minhash --- .../ray_bts_minhash_deduplicator.py | 54 ++++++++++++++++--- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index aa78a5978..b85eed08f 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -4,6 +4,7 @@ from collections import defaultdict from typing import Optional +import os import ray import numpy as np import pandas as pd @@ -36,7 +37,19 @@ def merge_edge_list_dict_list(edge_list_dict_list): def BTS_hash(x, parallel_num): # return int(x[-8:], 16) % parallel_num - return int.from_bytes(x, byteorder='big') % parallel_num + # return int.from_bytes(x, byteorder='big') % parallel_num + return x % parallel_num + + +@ray.remote +class IdGenerator: + def __init__(self): + self.next_id = 0 + + def get_next_id(self, count): + current_id = self.next_id + self.next_id += count + return range(current_id, self.next_id) @ray.remote @@ -193,6 +206,14 @@ def get_parent(self): def get_nodes(self): return set(self.parent.keys()) + def squeeze(self): + dup_keys = { + x + for x in self.parent + if self.hash(x) == self.parallel_id + } + self.parent = dup_keys + def is_dup(self, queries): return [ query in self.parent @@ -227,6 +248,7 @@ def __init__( num_rows_per_band: Optional[PositiveInt] = None, tokenizer_model: Optional[str] = None, union_find_parallel_num: Optional[int] = 16, + tmp_file_name: Optional[str] = './output/ray-dedup-tmp/', *args, **kwargs, ): @@ -325,6 +347,8 @@ def __init__( for i in range(self.union_find_parallel_num) ] + self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name) + def calc_minhash(self, text_list: pa.Array) -> pa.Table: all_hash_values = [[] for _ in range(self.num_bands)] @@ -408,6 +432,10 @@ def merge(self): union_find.communication.remote() for union_find in self.union_find_list ]) + ray.get([ + union_find.squeeze.remote() + for union_find in self.union_find_list + ]) def filter_with_union_find(self, samples: pa.Table) -> pa.Table: hash_id_list = [] @@ -432,15 +460,28 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: def run(self, dataset): start_time = time.time() + id_generator = IdGenerator.remote() def add_uid_column(table: pa.Table) -> pa.Table: - uuid_list = [uuid.uuid4().bytes for _ in range(table.num_rows)] - new_table = table.append_column(HashKeys.uid, pa.array(uuid_list)) + # uuid_list = [uuid.uuid4().bytes for _ in range(table.num_rows)] + # new_table = table.append_column(HashKeys.uid, pa.array(uuid_list)) + # return new_table + num_rows = len(table) + ids = ray.get(id_generator.get_next_id.remote(num_rows)) + new_table = table.append_column(HashKeys.uid, pa.array(list(ids))) return new_table - dataset = dataset.map_batches( + # dataset = dataset.map_batches( + # add_uid_column, + # batch_format='pyarrow', + # ).materialize() + dataset.map_batches( add_uid_column, batch_format='pyarrow', - ).materialize() + ).write_json( + self.tmp_file_name, + force_ascii=False + ) + dataset = ray.data.read_json(self.tmp_file_name) end_time = time.time() print(f'uid time = {end_time - start_time}') @@ -482,6 +523,5 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: batch_format='pyarrow' ).drop_columns( HashKeys.uid - ).materialize() - logger.info(f'Keep {result.count()} samples after MinHash dedup.') + ) return result From d4506eabbf74160c744017a8e4dfc046bff955ae Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 27 Nov 2024 06:08:23 +0000 Subject: [PATCH 09/22] memory reduce in bts minhash --- .../ray_bts_minhash_deduplicator.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index b85eed08f..6cf009f9f 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -213,6 +213,9 @@ def squeeze(self): if self.hash(x) == self.parallel_id } self.parent = dup_keys + self.old_parent = {} + self.edge_buffer = [] + ray.get(self.remote_edge_buffers[self.parallel_id].clear.remote()) def is_dup(self, queries): return [ @@ -248,7 +251,7 @@ def __init__( num_rows_per_band: Optional[PositiveInt] = None, tokenizer_model: Optional[str] = None, union_find_parallel_num: Optional[int] = 16, - tmp_file_name: Optional[str] = './output/ray-dedup-tmp/', + tmp_file_name: Optional[str] = './outputs/ray-dedup-tmp/', *args, **kwargs, ): @@ -347,7 +350,7 @@ def __init__( for i in range(self.union_find_parallel_num) ] - self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name) + self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name, str(uuid.uuid4())) def calc_minhash(self, text_list: pa.Array) -> pa.Table: all_hash_values = [[] for _ in range(self.num_bands)] @@ -456,24 +459,19 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: not result_dict[hash_id].pop(0) for hash_id in hash_id_list ] - return samples.filter(mask) + columns_to_keep = [name for name in samples.column_names if name != HashKeys.uid] + del hash_id_list, query_dict, result_dict + return samples.select(columns_to_keep).filter(mask) def run(self, dataset): start_time = time.time() id_generator = IdGenerator.remote() def add_uid_column(table: pa.Table) -> pa.Table: - # uuid_list = [uuid.uuid4().bytes for _ in range(table.num_rows)] - # new_table = table.append_column(HashKeys.uid, pa.array(uuid_list)) - # return new_table num_rows = len(table) ids = ray.get(id_generator.get_next_id.remote(num_rows)) new_table = table.append_column(HashKeys.uid, pa.array(list(ids))) return new_table - # dataset = dataset.map_batches( - # add_uid_column, - # batch_format='pyarrow', - # ).materialize() dataset.map_batches( add_uid_column, batch_format='pyarrow', @@ -520,8 +518,7 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: print(f'merge time = {end_time - start_time}') result = dataset.map_batches( self.filter_with_union_find, - batch_format='pyarrow' - ).drop_columns( - HashKeys.uid + batch_format='pyarrow', + zero_copy_batch=True, ) return result From fea44eb1fc9406cb36ff8cfb20baa37023084a95 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 28 Nov 2024 06:34:20 +0000 Subject: [PATCH 10/22] agg opt --- data_juicer/core/ray_data.py | 14 +-- data_juicer/core/ray_executor.py | 2 +- .../ray_bts_minhash_deduplicator.py | 112 ++++++++++++------ 3 files changed, 85 insertions(+), 43 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 621e68cd9..54418441f 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -61,15 +61,15 @@ def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - if Fields.stats not in columns: + # if Fields.stats not in columns: - def process_batch_arrow(table: pa.Table) -> pa.Table: - new_column_data = [{} for _ in range(len(table))] - new_talbe = table.append_column(Fields.stats, [new_column_data]) - return new_talbe + # def process_batch_arrow(table: pa.Table) -> pa.Table: + # new_column_data = [{} for _ in range(len(table))] + # new_talbe = table.append_column(Fields.stats, [new_column_data]) + # return new_talbe - dataset = dataset.map_batches(process_batch_arrow, - batch_format='pyarrow') + # dataset = dataset.map_batches(process_batch_arrow, + # batch_format='pyarrow') return dataset diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 6b93fd3dd..82e85ffc7 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -56,7 +56,7 @@ def run(self, load_data_np=None): from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() else: - dataset = rd.read_json(self.cfg.dataset_path) + dataset = rd.read_json(self.cfg.dataset_path, ray_remote_args=dict(scheduling_strategy="SPREAD")) # convert all the path in dataset to absolute path dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 6cf009f9f..cfd8244c5 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -13,7 +13,9 @@ from loguru import logger from pydantic import Field, PositiveInt from typing_extensions import Annotated -from typing import Dict +from typing import Dict, Union +from ray.data.block import BlockAccessor +from ray.data.aggregate import AggregateFn from data_juicer.utils.constant import HashKeys, Fields from data_juicer.utils.lazy_loader import LazyLoader @@ -25,19 +27,7 @@ optimal_param, sha1_hash32) -def merge_edge_list_dict_list(edge_list_dict_list): - final_edge_list_dict = {} - for edge_list_dict in edge_list_dict_list: - for hash_v, edge_list in edge_list_dict.items(): - if hash_v not in final_edge_list_dict: - final_edge_list_dict[hash_v] = [] - final_edge_list_dict[hash_v].extend(edge_list) - return final_edge_list_dict - - def BTS_hash(x, parallel_num): - # return int(x[-8:], 16) % parallel_num - # return int.from_bytes(x, byteorder='big') % parallel_num return x % parallel_num @@ -49,10 +39,10 @@ def __init__(self): def get_next_id(self, count): current_id = self.next_id self.next_id += count - return range(current_id, self.next_id) + return (current_id, self.next_id) -@ray.remote +@ray.remote(scheduling_strategy="SPREAD") class EdgeBuffer: def __init__(self): self.edge_dict = {} @@ -64,10 +54,10 @@ def set_edges(self, edge_dict): self.edge_dict = edge_dict def get_edges(self, key): - return self.edge_dict.get(key, []) + return self.edge_dict.pop(key, []) -@ray.remote +@ray.remote(scheduling_strategy="SPREAD") class BTSUnionFind: def __init__(self, parallel_num, parallel_id, remote_edge_buffers): self.parallel_num = parallel_num @@ -78,21 +68,17 @@ def __init__(self, parallel_num, parallel_id, remote_edge_buffers): self.edge_buffer = [] self.edge_list_dict = {} - def init_union_find_list(self, union_find_list): - self.union_find_list = union_find_list - - def receive_edges(self): + def balanced_union_find(self): + for x, y in self.edge_buffer: + self.union(x, y) edge_list = ray.get([ remote_edge_buffer.get_edges.remote(self.parallel_id) for remote_edge_buffer in self.remote_edge_buffers ]) for edges in edge_list: - self.edge_buffer.extend(edges) - - def balanced_union_find(self): - self.receive_edges() - for x, y in self.edge_buffer: - self.union(x, y) + for x, y in edges: + self.union(x, y) + del edge_list self.edge_buffer = [] self.rebalancing() old_parent_keys = set(self.old_parent.keys()) @@ -125,6 +111,7 @@ def set_edge_buffer(self): else: self.edge_buffer = [] ray.get(self.remote_edge_buffers[self.parallel_id].set_edges.remote(self.edge_list_dict)) + self.edge_list_dict = {} def edge_redistribution(self): self.rebalancing() @@ -250,7 +237,8 @@ def __init__( num_bands: Optional[PositiveInt] = None, num_rows_per_band: Optional[PositiveInt] = None, tokenizer_model: Optional[str] = None, - union_find_parallel_num: Optional[int] = 16, + union_find_parallel_num: Union[str, int] = 'auto', + union_threshold: Optional[int] = 128, tmp_file_name: Optional[str] = './outputs/ray-dedup-tmp/', *args, **kwargs, @@ -340,7 +328,11 @@ def __init__( dtype=np.uint64, ).T + if union_find_parallel_num == 'auto': + union_find_parallel_num = int(ray.cluster_resources().get('CPU', 32)) // 2 + logger.info(f'union_find_parallel_num = {union_find_parallel_num}') self.union_find_parallel_num = union_find_parallel_num + self.union_threshold = union_threshold self.remote_edge_buffers = [ EdgeBuffer.remote() for i in range(self.union_find_parallel_num) @@ -468,18 +460,18 @@ def run(self, dataset): id_generator = IdGenerator.remote() def add_uid_column(table: pa.Table) -> pa.Table: num_rows = len(table) - ids = ray.get(id_generator.get_next_id.remote(num_rows)) - new_table = table.append_column(HashKeys.uid, pa.array(list(ids))) + min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows)) + new_table = table.append_column(HashKeys.uid, pa.array(list(range(min_id, max_id)))) return new_table dataset.map_batches( add_uid_column, batch_format='pyarrow', - ).write_json( + ).write_parquet( self.tmp_file_name, force_ascii=False - ) - dataset = ray.data.read_json(self.tmp_file_name) + ) # TODO: balance file size + dataset = ray.data.read_parquet(self.tmp_file_name, ray_remote_args=dict(scheduling_strategy="SPREAD")) end_time = time.time() print(f'uid time = {end_time - start_time}') @@ -501,14 +493,64 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: ) return new_table + class UnionFn(AggregateFn): + + def __init__(self, union_find_list, union_threshold=128): + union_find_parallel_num = len(union_find_list) + union_threshold = union_threshold + def union_list(uuid_list): + min_uuid = min(uuid_list) + if len(uuid_list) > 1: + union_find_id = BTS_hash(min_uuid, union_find_parallel_num) + union_find = union_find_list[union_find_id] + ray.get(union_find.union_list.remote(uuid_list)) + return min_uuid + + def accumulate(cur, block): + uuid_list = [] + if cur is not None: + uuid_list.extend(cur) + block_acc = BlockAccessor.for_block(block) + for row in block_acc.iter_rows(public_row_format=False): + uuid_list.append(row[HashKeys.uid]) + if len(uuid_list) > union_threshold: + uuid_list = [union_list(uuid_list)] + return uuid_list + + def merge(a, b): + if a is None: + return b + if b is None: + return a + uuid_list = a + b + if len(uuid_list) > union_threshold: + uuid_list = [union_list(uuid_list)] + return uuid_list + + def finalize(a): + if a is None: + return 0 + return union_list(a) + + super().__init__( + init=lambda k: None, + accumulate_block=accumulate, + merge=merge, + finalize=finalize, + name='union', + ) + start_time = time.time() dataset.map_batches( minhash_with_uid, batch_format='pyarrow', ).groupby( HashKeys.minhash - ).map_groups( - self.agg_func, batch_format='pyarrow' + ).aggregate( + UnionFn( + self.union_find_list, + self.union_threshold, + ) ).materialize() end_time = time.time() print(f'group time = {end_time - start_time}') From 7fb0c59f2a5378c024266736d2b94f4237259108 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 29 Nov 2024 10:07:34 +0000 Subject: [PATCH 11/22] remove groupby --- .../ray_bts_minhash_deduplicator.py | 140 +++++++----------- 1 file changed, 56 insertions(+), 84 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index cfd8244c5..f8285f680 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -160,6 +160,10 @@ def union_list(self, x_list): if p != px: self.parent[px] = p + def union_batch_list(self, batch_x_list): + for x_list in batch_x_list: + self.union_list(x_list) + def rebalancing(self): new_px_dict = {} for x in self.parent: @@ -211,6 +215,28 @@ def is_dup(self, queries): ] +@ray.remote(scheduling_strategy="SPREAD") +class HashTable: + def __init__(self): + self.hash_table = {} + + def add_key_value_pairs(self, pairs): + for key, value in pairs: + if key not in self.hash_table: + self.hash_table[key] = [] + self.hash_table[key].append(value) + + def union(self, union_find): + task = [] + for value in self.hash_table.values(): + if len(value) > 1: + # task.append(union_find.union_list.remote(value)) + task.append(value) + # ray.get(task) + ray.get(union_find.union_batch_list.remote(task)) + del self.hash_table + + OP_NAME = 'ray_bts_minhash_deduplicator' @@ -316,7 +342,6 @@ def __init__( self.hash_ranges = [(i * self.num_rows_per_band, (i + 1) * self.num_rows_per_band) for i in range(self.num_bands)] - self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] # generate permutations gen = np.random.RandomState(seed=42) @@ -341,14 +366,19 @@ def __init__( BTSUnionFind.remote(union_find_parallel_num, i, self.remote_edge_buffers) for i in range(self.union_find_parallel_num) ] + self.hash_tables = [ + HashTable.remote() + for i in range(self.union_find_parallel_num) + ] self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name, str(uuid.uuid4())) - def calc_minhash(self, text_list: pa.Array) -> pa.Table: - all_hash_values = [[] for _ in range(self.num_bands)] + def calc_minhash(self, text_list: pa.Array, uid_list: pa.Array) -> pa.Table: + pairs = {} - for text in text_list: + for text, uid in zip(text_list, uid_list): text = text.as_py() + uid = uid.as_py() if self.lowercase: text = text.lower() if self.ignore_pattern: @@ -396,23 +426,27 @@ def calc_minhash(self, text_list: pa.Array) -> pa.Table: else: hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) for i, (start, end) in enumerate(self.hash_ranges): - all_hash_values[i].append( - i.to_bytes(4, 'big') + - bytes(hash_values[start:end].byteswap().data) - ) - return all_hash_values - - def agg_func(self, group: pa.Table) -> pa.Table: - if group.num_rows != 1: - uuid_list = [uid.as_py() for uid in group[HashKeys.uid]] - # union_find_id = np.random.randint(0, self.union_find_parallel_num) - min_uuid = min(uuid_list) - union_find_id = BTS_hash(min_uuid, self.union_find_parallel_num) - union_find = self.union_find_list[union_find_id] - ray.get(union_find.union_list.remote(uuid_list)) - return group + hash_value = i.to_bytes(4, 'big') + bytes(hash_values[start:end].byteswap().data) + hash_table_id = hash_values[start] % self.union_find_parallel_num + if hash_table_id not in pairs: + pairs[hash_table_id] = [] + pairs[hash_table_id].append((hash_value, uid)) + ray.get([ + self.hash_tables[i].add_key_value_pairs.remote(p) + for i, p in pairs.items() + ]) + # for i, (start, end) in enumerate(self.hash_ranges): + # all_hash_values[i].append( + # i.to_bytes(4, 'big') + + # bytes(hash_values[start:end].byteswap().data) + # ) + # return all_hash_values def merge(self): + ray.get([ + self.hash_tables[i].union.remote(self.union_find_list[i]) + for i in range(self.union_find_parallel_num) + ]) ray.get([ union_find.edge_redistribution.remote() for union_find in self.union_find_list @@ -476,81 +510,19 @@ def add_uid_column(table: pa.Table) -> pa.Table: print(f'uid time = {end_time - start_time}') def minhash_with_uid(table: pa.Table) -> pa.Table: - minhash_values = self.calc_minhash(table[self.text_key]) + self.calc_minhash(table[self.text_key], table[HashKeys.uid]) new_table = pa.Table.from_arrays( [ - pa.concat_arrays( - [table[HashKeys.uid].combine_chunks()] * len(self.hash_ranges) - ), - pa.concat_arrays( - [ - pa.array(minhash_values[i]) - for i in range(len(self.hash_ranges)) - ] - ), ], - names=[HashKeys.uid, HashKeys.minhash] + names=[ + ] ) return new_table - class UnionFn(AggregateFn): - - def __init__(self, union_find_list, union_threshold=128): - union_find_parallel_num = len(union_find_list) - union_threshold = union_threshold - def union_list(uuid_list): - min_uuid = min(uuid_list) - if len(uuid_list) > 1: - union_find_id = BTS_hash(min_uuid, union_find_parallel_num) - union_find = union_find_list[union_find_id] - ray.get(union_find.union_list.remote(uuid_list)) - return min_uuid - - def accumulate(cur, block): - uuid_list = [] - if cur is not None: - uuid_list.extend(cur) - block_acc = BlockAccessor.for_block(block) - for row in block_acc.iter_rows(public_row_format=False): - uuid_list.append(row[HashKeys.uid]) - if len(uuid_list) > union_threshold: - uuid_list = [union_list(uuid_list)] - return uuid_list - - def merge(a, b): - if a is None: - return b - if b is None: - return a - uuid_list = a + b - if len(uuid_list) > union_threshold: - uuid_list = [union_list(uuid_list)] - return uuid_list - - def finalize(a): - if a is None: - return 0 - return union_list(a) - - super().__init__( - init=lambda k: None, - accumulate_block=accumulate, - merge=merge, - finalize=finalize, - name='union', - ) - start_time = time.time() dataset.map_batches( minhash_with_uid, batch_format='pyarrow', - ).groupby( - HashKeys.minhash - ).aggregate( - UnionFn( - self.union_find_list, - self.union_threshold, - ) ).materialize() end_time = time.time() print(f'group time = {end_time - start_time}') From 73d3f8325e8271fdeea971313399f5bffc851f95 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 2 Dec 2024 02:26:51 +0000 Subject: [PATCH 12/22] memory reduce --- .../ray_bts_minhash_deduplicator.py | 84 ++++++++----------- 1 file changed, 36 insertions(+), 48 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index f8285f680..42c037524 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -59,15 +59,34 @@ def get_edges(self, key): @ray.remote(scheduling_strategy="SPREAD") class BTSUnionFind: - def __init__(self, parallel_num, parallel_id, remote_edge_buffers): + def __init__(self, union_threshold, parallel_num, parallel_id, remote_edge_buffers): + self.union_threshold = union_threshold self.parallel_num = parallel_num self.parallel_id = parallel_id + self.hash_table = {} self.parent = {} self.old_parent = {} self.remote_edge_buffers = remote_edge_buffers self.edge_buffer = [] self.edge_list_dict = {} + def add_key_value_pairs(self, pairs): + key_set = set() + for key, value in pairs: + if key not in self.hash_table: + self.hash_table[key] = [] + self.hash_table[key].append(value) + if len(self.hash_table[key]) > self.union_threshold: + key_set.add(key) + for key in key_set: + self.hash_table[key] = [self.union_list(self.hash_table[key])] + + def flush_key_value_pairs(self): + for value in self.hash_table.values(): + if len(value) > 1: + self.union_list(value) + del self.hash_table + def balanced_union_find(self): for x, y in self.edge_buffer: self.union(x, y) @@ -114,6 +133,7 @@ def set_edge_buffer(self): self.edge_list_dict = {} def edge_redistribution(self): + self.flush_key_value_pairs() self.rebalancing() self.edge_list_dict = {} for u in self.parent: @@ -159,6 +179,7 @@ def union_list(self, x_list): for px in px_list: if p != px: self.parent[px] = p + return p def union_batch_list(self, batch_x_list): for x_list in batch_x_list: @@ -215,28 +236,6 @@ def is_dup(self, queries): ] -@ray.remote(scheduling_strategy="SPREAD") -class HashTable: - def __init__(self): - self.hash_table = {} - - def add_key_value_pairs(self, pairs): - for key, value in pairs: - if key not in self.hash_table: - self.hash_table[key] = [] - self.hash_table[key].append(value) - - def union(self, union_find): - task = [] - for value in self.hash_table.values(): - if len(value) > 1: - # task.append(union_find.union_list.remote(value)) - task.append(value) - # ray.get(task) - ray.get(union_find.union_batch_list.remote(task)) - del self.hash_table - - OP_NAME = 'ray_bts_minhash_deduplicator' @@ -264,7 +263,7 @@ def __init__( num_rows_per_band: Optional[PositiveInt] = None, tokenizer_model: Optional[str] = None, union_find_parallel_num: Union[str, int] = 'auto', - union_threshold: Optional[int] = 128, + union_threshold: Optional[int] = 256, tmp_file_name: Optional[str] = './outputs/ray-dedup-tmp/', *args, **kwargs, @@ -363,11 +362,12 @@ def __init__( for i in range(self.union_find_parallel_num) ] self.union_find_list = [ - BTSUnionFind.remote(union_find_parallel_num, i, self.remote_edge_buffers) - for i in range(self.union_find_parallel_num) - ] - self.hash_tables = [ - HashTable.remote() + BTSUnionFind.remote( + union_threshold, + union_find_parallel_num, + i, + self.remote_edge_buffers + ) for i in range(self.union_find_parallel_num) ] @@ -432,21 +432,11 @@ def calc_minhash(self, text_list: pa.Array, uid_list: pa.Array) -> pa.Table: pairs[hash_table_id] = [] pairs[hash_table_id].append((hash_value, uid)) ray.get([ - self.hash_tables[i].add_key_value_pairs.remote(p) + self.union_find_list[i].add_key_value_pairs.remote(p) for i, p in pairs.items() ]) - # for i, (start, end) in enumerate(self.hash_ranges): - # all_hash_values[i].append( - # i.to_bytes(4, 'big') + - # bytes(hash_values[start:end].byteswap().data) - # ) - # return all_hash_values def merge(self): - ray.get([ - self.hash_tables[i].union.remote(self.union_find_list[i]) - for i in range(self.union_find_parallel_num) - ]) ray.get([ union_find.edge_redistribution.remote() for union_find in self.union_find_list @@ -476,7 +466,10 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: if hash_id not in query_dict: query_dict[hash_id] = [] query_dict[hash_id].append(uid) - results = ray.get([self.union_find_list[hash_id].is_dup.remote(query) for hash_id, query in query_dict.items()]) + results = ray.get([ + self.union_find_list[hash_id].is_dup.remote(query) + for hash_id, query in query_dict.items() + ]) result_dict = { hash_id: result for hash_id, result in zip(query_dict.keys(), results) @@ -511,18 +504,13 @@ def add_uid_column(table: pa.Table) -> pa.Table: def minhash_with_uid(table: pa.Table) -> pa.Table: self.calc_minhash(table[self.text_key], table[HashKeys.uid]) - new_table = pa.Table.from_arrays( - [ - ], - names=[ - ] - ) - return new_table + return table start_time = time.time() dataset.map_batches( minhash_with_uid, batch_format='pyarrow', + zero_copy_batch=True, ).materialize() end_time = time.time() print(f'group time = {end_time - start_time}') From d880b0df03d643ff95a986159891b33274c2ac7d Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 2 Dec 2024 11:34:13 +0000 Subject: [PATCH 13/22] merge add uid and minhash calculation --- .../ray_bts_minhash_deduplicator.py | 35 +++++-------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 42c037524..385d53cfe 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -1,24 +1,18 @@ -import random import time import uuid -from collections import defaultdict from typing import Optional import os import ray import numpy as np -import pandas as pd import pyarrow as pa import regex from loguru import logger from pydantic import Field, PositiveInt from typing_extensions import Annotated -from typing import Dict, Union -from ray.data.block import BlockAccessor -from ray.data.aggregate import AggregateFn +from typing import List, Union -from data_juicer.utils.constant import HashKeys, Fields -from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.constant import HashKeys from data_juicer.utils.model_utils import prepare_sentencepiece_model from ..base_op import OPERATORS, Deduplicator @@ -373,12 +367,11 @@ def __init__( self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name, str(uuid.uuid4())) - def calc_minhash(self, text_list: pa.Array, uid_list: pa.Array) -> pa.Table: + def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: pairs = {} for text, uid in zip(text_list, uid_list): text = text.as_py() - uid = uid.as_py() if self.lowercase: text = text.lower() if self.ignore_pattern: @@ -485,14 +478,16 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: def run(self, dataset): start_time = time.time() id_generator = IdGenerator.remote() - def add_uid_column(table: pa.Table) -> pa.Table: + def minhash_with_uid(table: pa.Table) -> pa.Table: num_rows = len(table) min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows)) - new_table = table.append_column(HashKeys.uid, pa.array(list(range(min_id, max_id)))) + uid_list = range(min_id, max_id) + self.calc_minhash(table[self.text_key], uid_list) + new_table = table.append_column(HashKeys.uid, pa.array(list(uid_list))) return new_table dataset.map_batches( - add_uid_column, + minhash_with_uid, batch_format='pyarrow', ).write_parquet( self.tmp_file_name, @@ -500,20 +495,8 @@ def add_uid_column(table: pa.Table) -> pa.Table: ) # TODO: balance file size dataset = ray.data.read_parquet(self.tmp_file_name, ray_remote_args=dict(scheduling_strategy="SPREAD")) end_time = time.time() - print(f'uid time = {end_time - start_time}') + print(f'MinHash time = {end_time - start_time}') - def minhash_with_uid(table: pa.Table) -> pa.Table: - self.calc_minhash(table[self.text_key], table[HashKeys.uid]) - return table - - start_time = time.time() - dataset.map_batches( - minhash_with_uid, - batch_format='pyarrow', - zero_copy_batch=True, - ).materialize() - end_time = time.time() - print(f'group time = {end_time - start_time}') start_time = time.time() self.merge() end_time = time.time() From 6c69aafe4bcd1ed8fd06f952309fea30215796f1 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 3 Dec 2024 05:21:18 +0000 Subject: [PATCH 14/22] format and speed up in bts_minhash --- .../ray_bts_minhash_deduplicator.py | 142 ++++++++---------- 1 file changed, 63 insertions(+), 79 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 385d53cfe..97f4377f0 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -1,5 +1,6 @@ import time import uuid +from collections import defaultdict from typing import Optional import os @@ -21,10 +22,6 @@ optimal_param, sha1_hash32) -def BTS_hash(x, parallel_num): - return x % parallel_num - - @ray.remote class IdGenerator: def __init__(self): @@ -65,15 +62,12 @@ def __init__(self, union_threshold, parallel_num, parallel_id, remote_edge_buffe self.edge_list_dict = {} def add_key_value_pairs(self, pairs): - key_set = set() for key, value in pairs: if key not in self.hash_table: self.hash_table[key] = [] self.hash_table[key].append(value) if len(self.hash_table[key]) > self.union_threshold: - key_set.add(key) - for key in key_set: - self.hash_table[key] = [self.union_list(self.hash_table[key])] + self.hash_table[key] = [self.union_list(self.hash_table[key])] def flush_key_value_pairs(self): for value in self.hash_table.values(): @@ -94,21 +88,11 @@ def balanced_union_find(self): del edge_list self.edge_buffer = [] self.rebalancing() - old_parent_keys = set(self.old_parent.keys()) - parent_keys = set(self.parent.keys()) - if old_parent_keys ^ parent_keys: - return True - for u in parent_keys: - if self.old_parent.get(u, u) != self.parent.get(u, u): - return True - return False - - def hash(self, u): - return BTS_hash(u, self.parallel_num) + return self.old_parent != self.parent def distribute_edge(self, u, v): - hash_u = self.hash(u) - hash_v = self.hash(v) + hash_u = u % self.parallel_num + hash_v = v % self.parallel_num if hash_u not in self.edge_list_dict: self.edge_list_dict[hash_u] = [] self.edge_list_dict[hash_u].append((u, v)) @@ -130,8 +114,7 @@ def edge_redistribution(self): self.flush_key_value_pairs() self.rebalancing() self.edge_list_dict = {} - for u in self.parent: - v = self.parent[u] + for u, v in self.parent.items(): self.distribute_edge(u, v) self.parent = {} self.set_edge_buffer() @@ -139,9 +122,8 @@ def edge_redistribution(self): def communication(self): self.edge_list_dict = {} del_list = [] - for u in self.parent: - hash_u = self.hash(u) - v = self.parent[u] + for u, v in self.parent.items(): + hash_u = u % self.parallel_num if self.parent[u] != self.old_parent.get(u, u) or (hash_u != self.parallel_id and v not in self.parent): self.distribute_edge(u, v) if hash_u != self.parallel_id: @@ -157,7 +139,7 @@ def find(self, x): else: self.parent[x] = self.find(self.parent[x]) return self.parent[x] - + def union(self, x, y): px = self.find(x) py = self.find(y) @@ -166,7 +148,7 @@ def union(self, x, y): if px > py: px, py = py, px self.parent[py] = px - + def union_list(self, x_list): px_list = [self.find(x) for x in x_list] p = min(px_list) @@ -175,14 +157,10 @@ def union_list(self, x_list): self.parent[px] = p return p - def union_batch_list(self, batch_x_list): - for x_list in batch_x_list: - self.union_list(x_list) - def rebalancing(self): new_px_dict = {} for x in self.parent: - hash_x = self.hash(x) + hash_x = x % self.parallel_num px = self.find(x) key = (px, hash_x) if key not in new_px_dict: @@ -191,7 +169,7 @@ def rebalancing(self): new_px_dict[key] = min(new_px_dict[key], x) px_set = set(px for px, _ in new_px_dict) for px in px_set: - hash_px = self.hash(px) + hash_px = px % self.parallel_num key = (px, hash_px) if key not in new_px_dict: new_px_dict[key] = px @@ -199,24 +177,18 @@ def rebalancing(self): new_px_dict[key] = min(new_px_dict[key], px) for x in self.parent: - hash_x = self.hash(x) + hash_x = x % self.parallel_num px = self.find(x) key = (px, hash_x) if x == new_px_dict[key]: continue self.parent[x] = new_px_dict[key] - def get_parent(self): - return self.parent - - def get_nodes(self): - return set(self.parent.keys()) - def squeeze(self): dup_keys = { x for x in self.parent - if self.hash(x) == self.parallel_id + if x % self.parallel_num == self.parallel_id } self.parent = dup_keys self.old_parent = {} @@ -317,6 +289,38 @@ def __init__( else: self.tokenizer = None + if self.tokenization == 'character': + def tokenization_func(text): + return { + str.encode(text[i:i + self.window_size]) + for i in range(len(text) - self.window_size) + } + elif self.tokenization == 'punctuation': + def tokenization_func(text): + tokens = self.punctuation_pattern.split(text) + return { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'space': + def tokenization_func(text): + tokens = split_on_whitespace(text) + return { + str.encode(' '.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + elif self.tokenization == 'sentencepiece': + def tokenization_func(text): + tokens = self.tokenizer.encode(text, out_type=str) + return { + str.encode(''.join(tokens[i:i + self.window_size])) + for i in range(len(tokens) - self.window_size) + } + else: + raise NotImplementedError( + f'Unimplemented tokenization method [{self.tokenization}]') + self.tokenization_func = tokenization_func + # about deduplication self.num_permutation = num_permutations self.jaccard_threshold = jaccard_threshold @@ -367,6 +371,10 @@ def __init__( self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name, str(uuid.uuid4())) + empty_hash_value = np.full((self.num_rows_per_band,), MAX_HASH, dtype=np.uint32) + self.empty_hash_value = b'\x00\x00\x00\x00' + empty_hash_value.tobytes() + self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num) + def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: pairs = {} @@ -377,34 +385,7 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: if self.ignore_pattern: text = self.ignore_pattern.sub('', text) - # get tokens for different tokenization method - tokens = set() - if self.tokenization == 'character': - tokens = { - str.encode(text[i:i + self.window_size]) - for i in range(len(text) - self.window_size) - } - elif self.tokenization == 'punctuation': - tokens = self.punctuation_pattern.split(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'space': - tokens = split_on_whitespace(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'sentencepiece': - tokens = self.tokenizer.encode(text, out_type=str) - tokens = { - str.encode(''.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - else: - raise NotImplementedError( - f'Unimplemented tokenization method [{self.tokenization}]') + tokens = self.tokenization_func(text) if len(tokens) > 0: hv = np.array( @@ -416,19 +397,21 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: + self.perm_b) % MERSENNE_PRIME ).astype(np.uint32) hash_values = phv.min(axis=0) + for i, (start, end) in enumerate(self.hash_ranges): + hash_value = i.to_bytes(4, 'big') + hash_values[start:end].tobytes() + hash_table_id = hash_values[start] % self.union_find_parallel_num + if hash_table_id not in pairs: + pairs[hash_table_id] = [] + pairs[hash_table_id].append((hash_value, uid)) else: - hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) - for i, (start, end) in enumerate(self.hash_ranges): - hash_value = i.to_bytes(4, 'big') + bytes(hash_values[start:end].byteswap().data) - hash_table_id = hash_values[start] % self.union_find_parallel_num - if hash_table_id not in pairs: - pairs[hash_table_id] = [] - pairs[hash_table_id].append((hash_value, uid)) + if self.empty_hash_table_id not in pairs: + pairs[self.empty_hash_table_id] = [] + pairs[self.empty_hash_table_id].append((self.empty_hash_value, uid)) ray.get([ self.union_find_list[i].add_key_value_pairs.remote(p) for i, p in pairs.items() ]) - + def merge(self): ray.get([ union_find.edge_redistribution.remote() @@ -454,7 +437,7 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: query_dict = {} for uid in samples[HashKeys.uid]: uid = uid.as_py() - hash_id = BTS_hash(uid, self.union_find_parallel_num) + hash_id = uid % self.union_find_parallel_num hash_id_list.append(hash_id) if hash_id not in query_dict: query_dict[hash_id] = [] @@ -489,6 +472,7 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: dataset.map_batches( minhash_with_uid, batch_format='pyarrow', + zero_copy_batch=True, ).write_parquet( self.tmp_file_name, force_ascii=False From 788827c42b762db22d56596e5dbd60c0ed107529 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 5 Dec 2024 02:26:24 +0000 Subject: [PATCH 15/22] ray.get fix --- .../ray_bts_minhash_deduplicator.py | 140 +++++++++++++----- 1 file changed, 101 insertions(+), 39 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 97f4377f0..81415f072 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -27,6 +27,7 @@ class IdGenerator: def __init__(self): self.next_id = 0 + @ray.method(num_returns=2) def get_next_id(self, count): current_id = self.next_id self.next_id += count @@ -50,7 +51,15 @@ def get_edges(self, key): @ray.remote(scheduling_strategy="SPREAD") class BTSUnionFind: - def __init__(self, union_threshold, parallel_num, parallel_id, remote_edge_buffers): + def __init__( + self, + union_threshold, + parallel_num, + parallel_id, + remote_edge_buffers, + max_pending_edge_buffer_task, + num_edge_buffer_task_returns, + ): self.union_threshold = union_threshold self.parallel_num = parallel_num self.parallel_id = parallel_id @@ -60,6 +69,8 @@ def __init__(self, union_threshold, parallel_num, parallel_id, remote_edge_buffe self.remote_edge_buffers = remote_edge_buffers self.edge_buffer = [] self.edge_list_dict = {} + self.max_pending_edge_buffer_task = max_pending_edge_buffer_task + self.num_edge_buffer_task_returns = num_edge_buffer_task_returns def add_key_value_pairs(self, pairs): for key, value in pairs: @@ -78,15 +89,27 @@ def flush_key_value_pairs(self): def balanced_union_find(self): for x, y in self.edge_buffer: self.union(x, y) - edge_list = ray.get([ - remote_edge_buffer.get_edges.remote(self.parallel_id) - for remote_edge_buffer in self.remote_edge_buffers - ]) + self.edge_buffer = [] + result_refs = [] + for remote_edge_buffer in self.remote_edge_buffers: + if len(result_refs) > self.max_pending_edge_buffer_task: + ready_refs, result_refs = ray.wait( + result_refs, + num_returns=self.num_edge_buffer_task_returns + ) + edge_list = ray.get(ready_refs) + for edges in edge_list: + for x, y in edges: + self.union(x, y) + del ready_refs + result_refs.append( + remote_edge_buffer.get_edges.remote(self.parallel_id) + ) + edge_list = ray.get(result_refs) for edges in edge_list: for x, y in edges: self.union(x, y) - del edge_list - self.edge_buffer = [] + del edge_list, result_refs self.rebalancing() return self.old_parent != self.parent @@ -195,10 +218,11 @@ def squeeze(self): self.edge_buffer = [] ray.get(self.remote_edge_buffers[self.parallel_id].clear.remote()) - def is_dup(self, queries): + def dup_idx(self, queries): return [ - query in self.parent - for query in queries + idx + for uid, idx in queries + if uid in self.parent ] @@ -230,6 +254,11 @@ def __init__( tokenizer_model: Optional[str] = None, union_find_parallel_num: Union[str, int] = 'auto', union_threshold: Optional[int] = 256, + max_pending_edge_buffer_task: Optional[int] = 20, + num_edge_buffer_task_returns: Optional[int] = 10, + max_pending_filter_tasks: Optional[int] = 20, + num_filter_task_returns: Optional[int] = 10, + merge_batch_size: Optional[int] = 1000, tmp_file_name: Optional[str] = './outputs/ray-dedup-tmp/', *args, **kwargs, @@ -352,19 +381,28 @@ def tokenization_func(text): if union_find_parallel_num == 'auto': union_find_parallel_num = int(ray.cluster_resources().get('CPU', 32)) // 2 + + self.max_pending_edge_buffer_task = max_pending_edge_buffer_task + self.num_edge_buffer_task_returns = num_edge_buffer_task_returns + self.max_pending_filter_tasks = max_pending_filter_tasks + self.num_filter_task_returns = num_filter_task_returns + self.merge_batch_size = min(merge_batch_size, union_find_parallel_num) + logger.info(f'union_find_parallel_num = {union_find_parallel_num}') self.union_find_parallel_num = union_find_parallel_num self.union_threshold = union_threshold self.remote_edge_buffers = [ EdgeBuffer.remote() - for i in range(self.union_find_parallel_num) + for _ in range(self.union_find_parallel_num) ] self.union_find_list = [ BTSUnionFind.remote( - union_threshold, - union_find_parallel_num, + self.union_threshold, + self.union_find_parallel_num, i, - self.remote_edge_buffers + self.remote_edge_buffers, + self.max_pending_edge_buffer_task, + self.num_edge_buffer_task_returns, ) for i in range(self.union_find_parallel_num) ] @@ -407,55 +445,78 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: if self.empty_hash_table_id not in pairs: pairs[self.empty_hash_table_id] = [] pairs[self.empty_hash_table_id].append((self.empty_hash_value, uid)) - ray.get([ - self.union_find_list[i].add_key_value_pairs.remote(p) - for i, p in pairs.items() - ]) + result_refs = [] + for i, p in pairs.items(): + if len(result_refs) > self.max_pending_filter_tasks: + ready_refs, result_refs = ray.wait( + result_refs, + num_returns=self.num_filter_task_returns + ) + ray.get(ready_refs) + result_refs.append( + self.union_find_list[i].add_key_value_pairs.remote(p) + ) + ray.get(result_refs) + + def merge_op_batch(self, object_refs): + results = [] + while object_refs: + ready_refs, object_refs = ray.wait(object_refs, num_returns=self.merge_batch_size) + results.extend(ray.get(ready_refs)) + return results def merge(self): - ray.get([ + self.merge_op_batch([ union_find.edge_redistribution.remote() for union_find in self.union_find_list ]) while any( - ray.get([ + self.merge_op_batch([ union_find.balanced_union_find.remote() for union_find in self.union_find_list ]) ): - ray.get([ + self.merge_op_batch([ union_find.communication.remote() for union_find in self.union_find_list ]) - ray.get([ + self.merge_op_batch([ union_find.squeeze.remote() for union_find in self.union_find_list ]) def filter_with_union_find(self, samples: pa.Table) -> pa.Table: - hash_id_list = [] query_dict = {} - for uid in samples[HashKeys.uid]: + for idx, uid in enumerate(samples[HashKeys.uid]): uid = uid.as_py() hash_id = uid % self.union_find_parallel_num - hash_id_list.append(hash_id) if hash_id not in query_dict: query_dict[hash_id] = [] - query_dict[hash_id].append(uid) - results = ray.get([ - self.union_find_list[hash_id].is_dup.remote(query) - for hash_id, query in query_dict.items() - ]) - result_dict = { - hash_id: result - for hash_id, result in zip(query_dict.keys(), results) - } - mask = [ - not result_dict[hash_id].pop(0) - for hash_id in hash_id_list + query_dict[hash_id].append((uid, idx)) + mask = np.ones(len(samples), dtype=np.bool_) + result_refs = [] + for hash_id, query in query_dict.items(): + if len(result_refs) > self.max_pending_filter_tasks: + ready_refs, result_refs = ray.wait( + result_refs, + num_returns=self.num_filter_task_returns + ) + results = ray.get(ready_refs) + for result in results: + mask[result] = False + del ready_refs + result_refs.append( + self.union_find_list[hash_id].dup_idx.remote(query) + ) + results = ray.get(result_refs) + for result in results: + mask[result] = False + del query_dict, results + columns_to_keep = [ + name + for name in samples.column_names + if name != HashKeys.uid ] - columns_to_keep = [name for name in samples.column_names if name != HashKeys.uid] - del hash_id_list, query_dict, result_dict return samples.select(columns_to_keep).filter(mask) def run(self, dataset): @@ -490,4 +551,5 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: batch_format='pyarrow', zero_copy_batch=True, ) + # logger.info(f'origin count = {dataset.count()}, keep count = {result.count()}') return result From 531955d5fa16e0c839edd33306b743a18f3610d1 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 5 Dec 2024 06:22:38 +0000 Subject: [PATCH 16/22] ray.get fix --- .../ops/deduplicator/ray_bts_minhash_deduplicator.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 81415f072..63040d9d3 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -252,7 +252,7 @@ def __init__( num_bands: Optional[PositiveInt] = None, num_rows_per_band: Optional[PositiveInt] = None, tokenizer_model: Optional[str] = None, - union_find_parallel_num: Union[str, int] = 'auto', + union_find_parallel_num: Union[int, str] = 'auto', union_threshold: Optional[int] = 256, max_pending_edge_buffer_task: Optional[int] = 20, num_edge_buffer_task_returns: Optional[int] = 10, @@ -381,6 +381,8 @@ def tokenization_func(text): if union_find_parallel_num == 'auto': union_find_parallel_num = int(ray.cluster_resources().get('CPU', 32)) // 2 + else: + union_find_parallel_num = int(union_find_parallel_num) self.max_pending_edge_buffer_task = max_pending_edge_buffer_task self.num_edge_buffer_task_returns = num_edge_buffer_task_returns @@ -461,7 +463,10 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: def merge_op_batch(self, object_refs): results = [] while object_refs: - ready_refs, object_refs = ray.wait(object_refs, num_returns=self.merge_batch_size) + ready_refs, object_refs = ray.wait( + object_refs, + num_returns=min(self.merge_batch_size, len(object_refs)) + ) results.extend(ray.get(ready_refs)) return results From 62caefe72dba2da6e47fde7a13637909ea06423f Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 16 Dec 2024 07:02:35 +0000 Subject: [PATCH 17/22] update docs and code format --- configs/config_all.yaml | 18 + data_juicer/core/ray_executor.py | 2 +- data_juicer/ops/deduplicator/__init__.py | 5 +- .../ray_bts_minhash_deduplicator.py | 107 ++-- .../deduplicator/ray_minhash_deduplicator.py | 376 -------------- .../ray_multi_redis_minhash_deduplicator.py | 473 ------------------ docs/Operators.md | 1 + docs/Operators_ZH.md | 3 +- 8 files changed, 98 insertions(+), 887 deletions(-) delete mode 100644 data_juicer/ops/deduplicator/ray_minhash_deduplicator.py delete mode 100644 data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index d251d24a2..df4bf91ad 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -654,6 +654,24 @@ process: lowercase: true # whether to convert text to lower case ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization. + - ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm + tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece] + window_size: 5 # window size of shingling + num_permutations: 256 # number of permutations in minhash computing + jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication + num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives + num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm + lowercase: true # whether to convert text to lower case + ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. + tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization. + union_find_parallel_num: 'auto' # number of parallel workers for union-find algorithm. Default it's 'auto', and it will be determined by half of the number of CPUs. + union_threshold: 256 # threshold for minhash values group to perform union-find algorightm. + max_pending_edge_buffer_task: 20 # max number of pending edge buffer ray tasks. + num_edge_buffer_task_returns: 10 # number of edge buffer tasks for `ray.wait` to return. + max_pending_filter_tasks: 20 # max number of pending filter ray tasks. + num_filter_task_returns: 10 # number of filter tasks for `ray.wait` to return. + merge_batch_size: 1000 # batch size for BTS operations. + tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication. # Selector ops - frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 82e85ffc7..6b93fd3dd 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -56,7 +56,7 @@ def run(self, load_data_np=None): from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() else: - dataset = rd.read_json(self.cfg.dataset_path, ray_remote_args=dict(scheduling_strategy="SPREAD")) + dataset = rd.read_json(self.cfg.dataset_path) # convert all the path in dataset to absolute path dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index f8be0101d..3e9f55f47 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -5,10 +5,8 @@ from .ray_basic_deduplicator import RayBasicDeduplicator from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator -from .ray_minhash_deduplicator import RayMinhashDeduplicator from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator from .ray_redis_minhash_deduplicator import RayRedisMinhashDeduplicator -from .ray_multi_redis_minhash_deduplicator import RayMultiRedisMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator @@ -17,6 +15,5 @@ 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', 'RayImageDeduplicator', 'RayRedisMinhashDeduplicator', - 'RayMinhashDeduplicator', 'RayBTSMinhashDeduplicator', - 'RayMultiRedisMinhashDeduplicator', 'VideoDeduplicator', + 'RayBTSMinhashDeduplicator', 'VideoDeduplicator', ] diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 63040d9d3..20e998e0f 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -22,10 +22,13 @@ optimal_param, sha1_hash32) +BATCH_SIZE = 1000 + + @ray.remote class IdGenerator: - def __init__(self): - self.next_id = 0 + def __init__(self, start_id = 0): + self.next_id = start_id @ray.method(num_returns=2) def get_next_id(self, count): @@ -52,14 +55,14 @@ def get_edges(self, key): @ray.remote(scheduling_strategy="SPREAD") class BTSUnionFind: def __init__( - self, - union_threshold, - parallel_num, - parallel_id, - remote_edge_buffers, - max_pending_edge_buffer_task, - num_edge_buffer_task_returns, - ): + self, + union_threshold, + parallel_num, + parallel_id, + remote_edge_buffers, + max_pending_edge_buffer_task, + num_edge_buffer_task_returns, + ): self.union_threshold = union_threshold self.parallel_num = parallel_num self.parallel_id = parallel_id @@ -114,8 +117,8 @@ def balanced_union_find(self): return self.old_parent != self.parent def distribute_edge(self, u, v): - hash_u = u % self.parallel_num - hash_v = v % self.parallel_num + hash_u = u // BATCH_SIZE % self.parallel_num + hash_v = v // BATCH_SIZE % self.parallel_num if hash_u not in self.edge_list_dict: self.edge_list_dict[hash_u] = [] self.edge_list_dict[hash_u].append((u, v)) @@ -130,7 +133,11 @@ def set_edge_buffer(self): del self.edge_list_dict[self.parallel_id] else: self.edge_buffer = [] - ray.get(self.remote_edge_buffers[self.parallel_id].set_edges.remote(self.edge_list_dict)) + ray.get( + self.remote_edge_buffers[self.parallel_id].set_edges.remote( + self.edge_list_dict + ) + ) self.edge_list_dict = {} def edge_redistribution(self): @@ -146,8 +153,9 @@ def communication(self): self.edge_list_dict = {} del_list = [] for u, v in self.parent.items(): - hash_u = u % self.parallel_num - if self.parent[u] != self.old_parent.get(u, u) or (hash_u != self.parallel_id and v not in self.parent): + hash_u = u // BATCH_SIZE % self.parallel_num + if self.parent[u] != self.old_parent.get(u, u) or \ + (hash_u != self.parallel_id and v not in self.parent): self.distribute_edge(u, v) if hash_u != self.parallel_id: del_list.append(u) @@ -183,7 +191,7 @@ def union_list(self, x_list): def rebalancing(self): new_px_dict = {} for x in self.parent: - hash_x = x % self.parallel_num + hash_x = x // BATCH_SIZE % self.parallel_num px = self.find(x) key = (px, hash_x) if key not in new_px_dict: @@ -192,7 +200,7 @@ def rebalancing(self): new_px_dict[key] = min(new_px_dict[key], x) px_set = set(px for px, _ in new_px_dict) for px in px_set: - hash_px = px % self.parallel_num + hash_px = px // BATCH_SIZE % self.parallel_num key = (px, hash_px) if key not in new_px_dict: new_px_dict[key] = px @@ -200,7 +208,7 @@ def rebalancing(self): new_px_dict[key] = min(new_px_dict[key], px) for x in self.parent: - hash_x = x % self.parallel_num + hash_x = x // BATCH_SIZE % self.parallel_num px = self.find(x) key = (px, hash_x) if x == new_px_dict[key]: @@ -211,7 +219,7 @@ def squeeze(self): dup_keys = { x for x in self.parent - if x % self.parallel_num == self.parallel_id + if x // BATCH_SIZE % self.parallel_num == self.parallel_id } self.parent = dup_keys self.old_parent = {} @@ -293,6 +301,22 @@ def __init__( params computation algorithm :param tokenizer_model: path for the sentencepiece model, used for sentencepiece tokenization. + :param union_find_parallel_num: number of parallel workers for + union-find algorithm. Default it's 'auto', and it will be + determined by half of the number of CPUs. + :param union_threshold: threshold for minhash values group to + perform union-find algorightm. Default it's 256. + :param max_pending_edge_buffer_task: max number of pending edge buffer + ray tasks. Default it's 20. + :param num_edge_buffer_task_returns: number of edge buffer tasks for + `ray.wait` to return. Default it's 10. + :param max_pending_filter_tasks: max number of pending filter ray + tasks. Default it's 20. + :param num_filter_task_returns: number of filter tasks for `ray.wait` + to return. Default it's 10. + :param merge_batch_size: batch size for BTS operations. Default + it's 1000. + :param tmp_file_name: the temporary folder name for deduplication. """ super().__init__(*args, **kwargs) # about minhash computation @@ -380,7 +404,9 @@ def tokenization_func(text): ).T if union_find_parallel_num == 'auto': - union_find_parallel_num = int(ray.cluster_resources().get('CPU', 32)) // 2 + union_find_parallel_num = int( + ray.cluster_resources().get('CPU') / 2 + ) else: union_find_parallel_num = int(union_find_parallel_num) @@ -409,11 +435,20 @@ def tokenization_func(text): for i in range(self.union_find_parallel_num) ] - self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name, str(uuid.uuid4())) + self.tmp_file_name = os.path.join( + os.getcwd(), tmp_file_name, str(uuid.uuid4()) + ) - empty_hash_value = np.full((self.num_rows_per_band,), MAX_HASH, dtype=np.uint32) - self.empty_hash_value = b'\x00\x00\x00\x00' + empty_hash_value.tobytes() - self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num) + empty_hash_value = np.full( + (self.num_rows_per_band,), + MAX_HASH, + dtype=np.uint32 + ) + self.empty_hash_value = b'\x00\x00\x00\x00' \ + + empty_hash_value.tobytes() + self.empty_hash_table_id = int( + MAX_HASH % self.union_find_parallel_num + ) def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: pairs = {} @@ -438,15 +473,19 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: ).astype(np.uint32) hash_values = phv.min(axis=0) for i, (start, end) in enumerate(self.hash_ranges): - hash_value = i.to_bytes(4, 'big') + hash_values[start:end].tobytes() - hash_table_id = hash_values[start] % self.union_find_parallel_num + hash_value = i.to_bytes(4, 'big') \ + + hash_values[start:end].tobytes() + hash_table_id = hash_values[start] \ + % self.union_find_parallel_num if hash_table_id not in pairs: pairs[hash_table_id] = [] pairs[hash_table_id].append((hash_value, uid)) else: if self.empty_hash_table_id not in pairs: pairs[self.empty_hash_table_id] = [] - pairs[self.empty_hash_table_id].append((self.empty_hash_value, uid)) + pairs[self.empty_hash_table_id].append( + (self.empty_hash_value, uid) + ) result_refs = [] for i, p in pairs.items(): if len(result_refs) > self.max_pending_filter_tasks: @@ -494,7 +533,7 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: query_dict = {} for idx, uid in enumerate(samples[HashKeys.uid]): uid = uid.as_py() - hash_id = uid % self.union_find_parallel_num + hash_id = uid // BATCH_SIZE % self.union_find_parallel_num if hash_id not in query_dict: query_dict[hash_id] = [] query_dict[hash_id].append((uid, idx)) @@ -529,10 +568,15 @@ def run(self, dataset): id_generator = IdGenerator.remote() def minhash_with_uid(table: pa.Table) -> pa.Table: num_rows = len(table) - min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows)) + min_id, max_id = ray.get( + id_generator.get_next_id.remote(num_rows) + ) uid_list = range(min_id, max_id) self.calc_minhash(table[self.text_key], uid_list) - new_table = table.append_column(HashKeys.uid, pa.array(list(uid_list))) + new_table = table.append_column( + HashKeys.uid, + pa.array(list(uid_list)) + ) return new_table dataset.map_batches( @@ -543,7 +587,7 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: self.tmp_file_name, force_ascii=False ) # TODO: balance file size - dataset = ray.data.read_parquet(self.tmp_file_name, ray_remote_args=dict(scheduling_strategy="SPREAD")) + dataset = ray.data.read_parquet(self.tmp_file_name) end_time = time.time() print(f'MinHash time = {end_time - start_time}') @@ -556,5 +600,4 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: batch_format='pyarrow', zero_copy_batch=True, ) - # logger.info(f'origin count = {dataset.count()}, keep count = {result.count()}') return result diff --git a/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py deleted file mode 100644 index 7534a8f43..000000000 --- a/data_juicer/ops/deduplicator/ray_minhash_deduplicator.py +++ /dev/null @@ -1,376 +0,0 @@ -import random -import time -import uuid -from collections import defaultdict -from typing import Optional - -import ray -import numpy as np -import pandas as pd -import pyarrow as pa -import regex -from loguru import logger -from pydantic import Field, PositiveInt -from typing_extensions import Annotated -from typing import Dict - -from data_juicer.utils.constant import HashKeys, Fields -from data_juicer.utils.lazy_loader import LazyLoader -from data_juicer.utils.model_utils import prepare_sentencepiece_model - -from ..base_op import OPERATORS, Deduplicator -from ..common.helper_func import split_on_whitespace -from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, - optimal_param, sha1_hash32) - - -@ray.remote -class UnionFindWithMerge: - def __init__(self): - """Initialization method.""" - self.parent: Dict[str | bytes, str] = {} - - @staticmethod - def find_with_parent(parent, x): - x_list = [] - while x in parent: - x_list.append(x) - x = parent[x] - for xx in x_list: - parent[xx] = x - return x - - def find(self, x): - return self.find_with_parent(self.parent, x) - - def union(self, x, y): - px = self.find(x) - py = self.find(y) - if px == py: - return - if px > py: - px, py = py, px - self.parent[py] = px - - def union_list(self, x_list): - px_list = [self.find(x) for x in x_list] - p = min(px_list) - for px in px_list: - if p != px: - self.parent[px] = p - - def is_ancestor(self, x): - assert (x not in self.parent) == (x == self.parent.get(x, x)), f'{x}, {self.parent.get(x, x)}' - return x not in self.parent - - def get_num(self): - return len(self.parent) - - def get_nodes(self): - return set(self.parent.keys()) - - def get_parent(self): - return self.parent - - def merge(self, union_find_set): - union_find_set_parent = ray.get(union_find_set.get_parent.remote()) - union_find_set_nodes = set(union_find_set_parent.keys()) - parrent_nodes = set(self.parent.keys()) - for x in union_find_set_nodes: - px = self.find_with_parent(union_find_set_parent, x) - if x in parrent_nodes: - py = self.find(x) - self.union(px, py) - else: - self.parent[x] = px - - def merge_list(self, union_find_set_list): - for union_find_set in union_find_set_list: - self.merge(union_find_set) - - -OP_NAME = 'ray_minhash_deduplicator' - - -@OPERATORS.register_module(OP_NAME) -class RayMinhashDeduplicator(Deduplicator): - """ - A basic exact matching deduplicator for RAY. - Although its functionality is deduplication, - it is implemented as Filter sub-class. - """ - - # TODO: Set a more reasonable value - EMPTY_HASH_VALUE = 'EMPTY' - _batched_op = True - - def __init__( - self, - tokenization: str = 'space', - window_size: PositiveInt = 5, - lowercase: bool = True, - ignore_pattern: Optional[str] = None, - num_permutations: PositiveInt = 256, - jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, - num_bands: Optional[PositiveInt] = None, - num_rows_per_band: Optional[PositiveInt] = None, - tokenizer_model: Optional[str] = None, - union_find_parallel_num: Optional[int] = 16, - union_find_merge_num: Optional[int] = 2, - *args, - **kwargs, - ): - """ - Initialization method. - - :param tokenization: tokenization method for sample texts. It - should be one of [space, punctuation, character, - sentencepiece]. For English-like languages, we recommend - to use 'space', for Chinese-like languages, we recommend - to use 'character', and for multiple languages, we recommend - to use 'sentencepiece'. If using 'sentencepiece', please - provided the model path in the 'tokenizer_model' field. - :param window_size: window size of shingling - :param lowercase: whether to convert text to lower case first - :param ignore_pattern: whether to ignore sub-strings with - specific pattern when computing minhash - :param num_permutations: number of permutations in minhash - computing - :param jaccard_threshold: the min jaccard similarity threshold - in near-duplicate detection. When the jaccard similarity of - two sample texts is >= this threshold, they are regarded as - similar samples and this op will only keep one of them after - deduplication - :param num_bands: number of bands in LSH. Default it's None, and - it will be determined by an optimal params computation - algorithm by minimize the weighted sum of probs of False - Positives and False Negatives - :param num_rows_per_band: number of rows in each band in LSH. - Default it's None, and it will be determined by an optimal - params computation algorithm - :param tokenizer_model: path for the sentencepiece model, used for - sentencepiece tokenization. - """ - super().__init__(*args, **kwargs) - # about minhash computation - self.tokenization = tokenization - self.window_size = window_size - self.lowercase = lowercase - self.ignore_pattern = ignore_pattern - if self.ignore_pattern: - self.ignore_pattern = regex.compile(self.ignore_pattern) - - # check parameters - if self.ignore_pattern and self.tokenization == 'punctuation': - logger.warning('Be careful that tokenization with punctuations ' - 'won\'t work if the ignore pattern includes ' - 'punctuations.') - self.punctuation_pattern = regex.compile(r'\p{P}') - - if self.tokenization == 'sentencepiece': - if tokenizer_model is None: - raise ValueError("To use 'sentencepiece' tokenization, " - "'tokenizer_model' is required.") - self.tokenizer = prepare_sentencepiece_model(tokenizer_model) - else: - self.tokenizer = None - - # about deduplication - self.num_permutation = num_permutations - self.jaccard_threshold = jaccard_threshold - self.num_bands = num_bands - self.num_rows_per_band = num_rows_per_band - - # initialize deduplication parameters - # check number of bands and rows - if self.num_bands is None or self.num_rows_per_band is None: - self.num_bands, self.num_rows_per_band = optimal_param( - self.jaccard_threshold, - self.num_permutation, - ) - - # compute hash ranges and create hash tables - self.hash_ranges = [(i * self.num_rows_per_band, - (i + 1) * self.num_rows_per_band) - for i in range(self.num_bands)] - self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] - - # generate permutations - gen = np.random.RandomState(seed=42) - self.perm_a, self.perm_b = np.array( - [( - gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), - gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), - ) for _ in range(self.num_permutation)], - dtype=np.uint64, - ).T - - self.init_union_find(union_find_parallel_num, union_find_merge_num) - - - def init_union_find(self, union_find_parallel_num, union_find_merge_num): - self.union_find_parallel_num = union_find_parallel_num # 2 # 16 - self.union_find_merge_num = union_find_merge_num - self.union_find_list = [ - UnionFindWithMerge.remote() - for _ in range(self.union_find_parallel_num) - ] - - def compute_stats(self, samples: pa.Table) -> pa.Table: - samples_list = samples[self.text_key] - uuid_list = [uuid.uuid4().bytes for _ in range(samples.num_rows)] - all_hash_values = [[] for _ in range(self.num_bands)] - - for text in samples_list: - text = text.as_py() - if self.lowercase: - text = text.lower() - if self.ignore_pattern: - text = self.ignore_pattern.sub('', text) - - # get tokens for different tokenization method - tokens = set() - if self.tokenization == 'character': - tokens = { - str.encode(text[i:i + self.window_size]) - for i in range(len(text) - self.window_size) - } - elif self.tokenization == 'punctuation': - tokens = self.punctuation_pattern.split(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'space': - tokens = split_on_whitespace(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'sentencepiece': - tokens = self.tokenizer.encode(text, out_type=str) - tokens = { - str.encode(''.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - else: - raise NotImplementedError( - f'Unimplemented tokenization method [{self.tokenization}]') - - if len(tokens) > 0: - hv = np.array( - [sha1_hash32(token) for token in tokens], - dtype=np.uint64 - ) - phv = ( - (hv[:, None] * self.perm_a[None, :] - + self.perm_b) % MERSENNE_PRIME - ).astype(np.uint32) - hash_values = phv.min(axis=0) - else: - hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) - for i, (start, end) in enumerate(self.hash_ranges): - all_hash_values[i].append( - i.to_bytes(4, 'big') + - bytes(hash_values[start:end].byteswap().data) - ) - - samples = samples.append_column(HashKeys.uid, pa.array(uuid_list)) - for i, hash_values in enumerate(all_hash_values): - samples = samples.append_column(HashKeys.minhash + f"_{i}", pa.array(hash_values)) - return samples - - def map_batched(self, samples: pa.Table) -> pa.Table: - table = pa.Table.from_arrays( - [ - pa.concat_arrays( - [samples[HashKeys.uid].combine_chunks()] * len(self.hash_ranges) - ), - # pa.array( - # [uid.as_py() for uid in samples[HashKeys.uid]] * len(self.hash_ranges) - # ), - pa.concat_arrays( - [ - samples[HashKeys.minhash + f'_{i}'].combine_chunks() - for i in range(len(self.hash_ranges)) - ] - ), - ], - names=[HashKeys.uid, HashKeys.minhash] - ) - return table - - def agg_func(self, group: pa.Table) -> pa.Table: - if group.num_rows != 1: - uuid_list = [uid.as_py() for uid in group[HashKeys.uid]] - union_find_id = np.random.randint(0, self.union_find_parallel_num) - union_find = self.union_find_list[union_find_id] - ray.get(union_find.union_list.remote(uuid_list)) - return group - - def merge(self): - union_find_list = self.union_find_list - while len(union_find_list) > 1: - new_union_find_list = [] - task_list = [] - buffer = [] - for union_find in union_find_list: - buffer.append(union_find) - if len(buffer) == self.union_find_merge_num: - new_union_find_list.append(buffer[0]) - task_list.append(buffer[0].merge_list.remote(buffer[1:])) - buffer = [] - if len(buffer) > 0: - new_union_find_list.append(buffer[0]) - if len(buffer) > 1: - task_list.append(buffer[0].merge_list.remote(buffer[1:])) - ray.get(task_list) - union_find_list = new_union_find_list - self.parent = ray.get(union_find_list[0].get_nodes.remote()) - - def filter_with_union_find(self, samples: pa.Table) -> pa.Table: - mask = [ - uid.as_py() not in self.parent - for uid in samples[HashKeys.uid] - ] - return samples.filter(mask) - - def run(self, dataset): - # import time - # start_time = time.time() - dataset = dataset.map_batches( - self.compute_stats, - batch_format='pyarrow', - ).materialize() - drop_columns = [] - for i in range(len(self.hash_ranges)): - drop_column = HashKeys.minhash + f'_{i}' - drop_columns.append(drop_column) - # end_time = time.time() - # print(f'minhash time = {end_time - start_time}') - - # start_time = time.time() - dataset.map_batches( - self.map_batched, - batch_format='pyarrow', - ).groupby( - HashKeys.minhash - ).map_groups( - self.agg_func, batch_format='pyarrow' - ).materialize() - # end_time = time.time() - # print(f'group time = {end_time - start_time}') - # start_time = time.time() - self.merge() - # end_time = time.time() - # print(f'merge time = {end_time - start_time}') - result = dataset.drop_columns( - drop_columns - ).map_batches( - self.filter_with_union_find, - batch_format='pyarrow' - ).drop_columns( - HashKeys.uid - ).materialize() - logger.info(f'Keep {result.count()} samples after MinHash dedup.') - return result diff --git a/data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py deleted file mode 100644 index 56c5dfcd5..000000000 --- a/data_juicer/ops/deduplicator/ray_multi_redis_minhash_deduplicator.py +++ /dev/null @@ -1,473 +0,0 @@ -import random -import time -import uuid -from collections import defaultdict -from typing import Optional -import ray - -import numpy as np -import pandas as pd -import pyarrow as pa -import regex -from loguru import logger -from pydantic import Field, PositiveInt -from typing_extensions import Annotated -import concurrent - -from data_juicer.utils.constant import HashKeys -from data_juicer.utils.lazy_loader import LazyLoader -from data_juicer.utils.model_utils import prepare_sentencepiece_model - -from ..base_op import OPERATORS, Deduplicator -from ..common.helper_func import split_on_whitespace -from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, - optimal_param, sha1_hash32) - -redis = LazyLoader('redis', 'redis') - - -def retry_on_busy(func): - - def wrapper(*args, **kwargs): - max_retries = 10 - for attempt in range(max_retries): - try: - return func(*args, **kwargs) - except Exception as e: - if 'BUSY' in str(e) and attempt < max_retries - 1: - time.sleep(random.uniform(0.1, 0.3) * (2**attempt)) - else: - raise - - return wrapper - - -class RedisUnionFind: - - def __init__(self, - prefix: str, - redis_address: str = 'redis://localhost:6380'): - self.prefix = prefix - self.redis_address = redis_address - self.redis = redis.from_url(url=redis_address) - self.set_key = f'{prefix}_UF_SET' - self.incur_id_key = f'{prefix}_UF_INCURID' - - # Lua scripts - self.union_script = self.redis.register_script(""" - local function find(x) - local path = {} - while true do - local parent = redis.call('HGET', KEYS[1], x) - if not parent then - break - end - table.insert(path, x) - x = parent - end - for _, node in ipairs(path) do - redis.call('HSET', KEYS[1], node, x) - end - return x - end - - local root_x = find(ARGV[1]) - local root_y = find(ARGV[2]) - if root_x == root_y then - return root_x - end - if root_x < root_y then - redis.call('HSET', KEYS[1], root_y, root_x) - return root_x - else - redis.call('HSET', KEYS[1], root_x, root_y) - return root_y - end - """) - - self.merge_script = self.redis.register_script(""" - local function find(key, x) - local path = {} - while true do - local parent = redis.call('HGET', key, x) - if not parent then - break - end - table.insert(path, x) - x = parent - end - for _, node in ipairs(path) do - redis.call('HSET', key, node, x) - end - return x - end - - local function merge(key) - local nodes = redis.call('HKEYS', key) - for _, node in ipairs(nodes) do - local root = find(key, node) - local root_x = find(KEYS[1], node) - local root_y = find(KEYS[1], root) - if root_x < root_y then - redis.call('HSET', KEYS[1], root_y, root_x) - elseif root_x > root_y then - redis.call('HSET', KEYS[1], root_x, root_y) - end - end - end - - for _, key in ipairs(ARGV) do - merge(key) - end - """) - - def get_uid(self): - return int(self.redis.incr(self.incur_id_key)) - - @retry_on_busy - def union(self, x, y) -> int: - return int(self.union_script(keys=[self.set_key], - args=[x, y])) - - @retry_on_busy - def merge(self, set_keys): - # self.merge_script(keys=[self.set_key] + set_keys, args=set_keys) - for set_key in set_keys: - for x, y in self.redis.hgetall(set_key).items(): - self.union(x, y) - # for x in self.redis.hkeys(set_key): - # y = self.redis.hget(set_key, x) - # self.redis. - - def get_nodes(self): - return set(int(x) for x in self.redis.hkeys(self.set_key)) - - def get_data(self): - result = {} - for x in self.get_nodes(): - y = int(self.redis.hget(self.set_key, x)) - result[x] = y - return result - - def is_ancestor(self, x): - ancestor = self.redis.hget(self.set_key, x) - return ancestor is None or int(ancestor) == x - - def __reduce__(self): - return (RedisUnionFind, (self.prefix, self.redis_address)) - - def clean(self): - self.redis.delete(self.set_key, self.incur_id_key) - - -OP_NAME = 'ray_multi_redis_minhash_deduplicator' - - -@OPERATORS.register_module(OP_NAME) -class RayMultiRedisMinhashDeduplicator(Deduplicator): - """ - A basic exact matching deduplicator for RAY. - Although its functionality is deduplication, - it is implemented as Filter sub-class. - """ - - def __init__( - self, - tokenization: str = 'space', - window_size: PositiveInt = 5, - lowercase: bool = True, - ignore_pattern: Optional[str] = None, - num_permutations: PositiveInt = 256, - jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, - num_bands: Optional[PositiveInt] = None, - num_rows_per_band: Optional[PositiveInt] = None, - tokenizer_model: Optional[str] = None, - redis_address: str = 'redis://localhost:6380', - union_find_parallel_num: Optional[int] = 16, - union_find_merge_num: Optional[int] = 2, - *args, - **kwargs, - ): - """ - Initialization method. - - :param tokenization: tokenization method for sample texts. It - should be one of [space, punctuation, character, - sentencepiece]. For English-like languages, we recommend - to use 'space', for Chinese-like languages, we recommend - to use 'character', and for multiple languages, we recommend - to use 'sentencepiece'. If using 'sentencepiece', please - provided the model path in the 'tokenizer_model' field. - :param window_size: window size of shingling - :param lowercase: whether to convert text to lower case first - :param ignore_pattern: whether to ignore sub-strings with - specific pattern when computing minhash - :param num_permutations: number of permutations in minhash - computing - :param jaccard_threshold: the min jaccard similarity threshold - in near-duplicate detection. When the jaccard similarity of - two sample texts is >= this threshold, they are regarded as - similar samples and this op will only keep one of them after - deduplication - :param num_bands: number of bands in LSH. Default it's None, and - it will be determined by an optimal params computation - algorithm by minimize the weighted sum of probs of False - Positives and False Negatives - :param num_rows_per_band: number of rows in each band in LSH. - Default it's None, and it will be determined by an optimal - params computation algorithm - :param tokenizer_model: path for the sentencepiece model, used for - sentencepiece tokenization. - :param redis_address: address of your redis instance, e.g. - 'redis://localhost:6379' - """ - super().__init__(*args, **kwargs) - # about minhash computation - self.tokenization = tokenization - self.window_size = window_size - self.lowercase = lowercase - self.ignore_pattern = ignore_pattern - if self.ignore_pattern: - self.ignore_pattern = regex.compile(self.ignore_pattern) - - # check parameters - if self.ignore_pattern and self.tokenization == 'punctuation': - logger.warning('Be careful that tokenization with punctuations ' - 'won\'t work if the ignore pattern includes ' - 'punctuations.') - self.punctuation_pattern = regex.compile(r'\p{P}') - - if self.tokenization == 'sentencepiece': - if tokenizer_model is None: - raise ValueError("To use 'sentencepiece' tokenization, " - "'tokenizer_model' is required.") - self.tokenizer = prepare_sentencepiece_model(tokenizer_model) - else: - self.tokenizer = None - - # about deduplication - self.num_permutation = num_permutations - self.jaccard_threshold = jaccard_threshold - self.num_bands = num_bands - self.num_rows_per_band = num_rows_per_band - - # initialize deduplication parameters - # check number of bands and rows - if self.num_bands is None or self.num_rows_per_band is None: - self.num_bands, self.num_rows_per_band = optimal_param( - self.jaccard_threshold, - self.num_permutation, - ) - - # compute hash ranges and create hash tables - self.hash_ranges = [(i * self.num_rows_per_band, - (i + 1) * self.num_rows_per_band) - for i in range(self.num_bands)] - self.hash_tables = [defaultdict(set) for _ in range(self.num_bands)] - - # generate permutations - gen = np.random.RandomState(seed=42) - self.perm_a, self.perm_b = np.array( - [( - gen.randint(1, MERSENNE_PRIME, dtype=np.uint64), - gen.randint(0, MERSENNE_PRIME, dtype=np.uint64), - ) for _ in range(self.num_permutation)], - dtype=np.uint64, - ).T - self.redis_address = redis_address - self.union_find_parallel_num = union_find_parallel_num - self.union_find_merge_num = union_find_merge_num - - def run(self, dataset): - from ray.data.aggregate import AggregateFn - - # union_find = RedisUnionFind(prefix=uuid.uuid4().hex[:8], - # redis_address=self.redis_address) - union_find_list = [ - RedisUnionFind(prefix=uuid.uuid4().hex[:8] + f'_{i}', redis_address=self.redis_address) - for i in range(self.union_find_parallel_num) - ] - - def add_uid_column(table: pa.Table) -> pa.Table: - new_column_data = [union_find_list[0].get_uid() for _ in range(len(table))] - new_table = table.append_column(HashKeys.uid, [new_column_data]) - return new_table - - def calculate_minhash(table: pa.Table) -> pa.Table: - ids = table.column(HashKeys.uid).to_pandas() - texts = table.column(self.text_key).to_pandas() - hashes = texts.apply(lambda x: self.compute_minhash(x)) - hashes = pa.Array.from_pandas(hashes).flatten() - - repeated_ids = pa.Array.from_pandas(ids.repeat(self.num_bands)) - - return pa.Table.from_arrays([repeated_ids, hashes], - names=[HashKeys.uid, HashKeys.minhash]) - - class UnionFn(AggregateFn): - - def __init__(self, union_find_list): - # union_find = union_find - union_find_num = len(union_find_list) - - def accumulate(cur, row): - if cur is None: - return int.from_bytes(row[HashKeys.minhash][:8], byteorder='big') % union_find_num, row[HashKeys.uid] - else: - assert cur[0] == int.from_bytes(row[HashKeys.minhash][:8], byteorder='big') % union_find_num - union_find = union_find_list[cur[0]] - root = union_find.union(row[HashKeys.uid], cur[1]) - return cur[0], root - - def merge(a, b): - if a is None: - return b - if b is None: - return a - assert a[0] == b[0] - union_find = union_find_list[a[0]] - root = union_find.union(a[1], b[1]) - # root = union_find.union(a, b) - return a[0], root - - super().__init__( - init=lambda k: None, - accumulate_row=accumulate, - merge=merge, - name='union', - ) - - dataset_with_id = dataset.map_batches( - add_uid_column, batch_format='pyarrow').materialize() - dataset_with_id.map_batches( - calculate_minhash, - batch_format='pyarrow' - ).groupby( - HashKeys.minhash - ).aggregate( - UnionFn(union_find_list) - ).materialize() - - # results = [] - # for union_find in union_find_list: - # results.append(union_find.get_data()) - @ray.remote - def merge(x, keys): - x.merge(keys) - - merge_list = union_find_list - while len(merge_list) > 1: - new_merge_list, buffer = [], [] - task_list = [] - for union_find in merge_list: - buffer.append(union_find) - if len(buffer) == self.union_find_merge_num: - new_merge_list.append(buffer[0]) - keys = [u.set_key for u in buffer[1:]] - task_list.append( - merge.remote(buffer[0], keys) - ) - buffer = [] - if len(buffer) > 0: - new_merge_list.append(buffer[0]) - if len(buffer) > 1: - keys = [u.set_key for u in buffer[1:]] - task_list.append( - merge.remote(buffer[0], keys) - ) - ray.get(task_list) - merge_list = new_merge_list - # for m in merge_list: - # results.append(m.get_data()) - - # results.append(merge_list[0].get_data()) - # import json - # with open(f'data_{len(results)}.json', 'w') as f: - # json.dump(results, f) - dup_ids = merge_list[0].get_nodes() - - def filter_with_union_find(table: pa.Table) -> pa.Table: - uids = table.column(HashKeys.uid).to_pandas() - mask = pa.Array.from_pandas( - uids.apply(lambda x: x not in dup_ids)) - return table.filter(mask) - - result = dataset_with_id.map_batches( - filter_with_union_find, - batch_format='pyarrow' - ).materialize() - logger.info(f'Keep {result.count()} samples after MinHash dedup.') - for union_find in union_find_list: - union_find.clean() - return result - - def compute_minhash(self, text): - """ - Compute minhash values for the sample. - - :param sample: input sample - :return: sample with minhash value. - """ - if self.lowercase: - text = text.lower() - if self.ignore_pattern: - text = self.ignore_pattern.sub('', text) - - # get tokens for different tokenization method - tokens = set() - if self.tokenization == 'character': - tokens = { - str.encode(text[i:i + self.window_size]) - for i in range(len(text) - self.window_size) - } - elif self.tokenization == 'punctuation': - tokens = self.punctuation_pattern.split(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'space': - tokens = split_on_whitespace(text) - tokens = { - str.encode(' '.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - elif self.tokenization == 'sentencepiece': - tokens = self.tokenizer.encode(text, out_type=str) - tokens = { - str.encode(''.join(tokens[i:i + self.window_size])) - for i in range(len(tokens) - self.window_size) - } - else: - raise NotImplementedError( - f'Unimplemented tokenization method [{self.tokenization}]') - - # # compute minhash value - # hv = np.array([sha1_hash32(token) for token in tokens], - # dtype=np.uint64) - # phv = np.bitwise_and( - # ((hv * np.tile(self.perm_a, - # (len(hv), 1)).T).T + self.perm_b) % MERSENNE_PRIME, - # MAX_HASH) - # hash_values = np.vstack([ - # phv, - # np.ones(self.num_permutation, dtype=np.uint64) * MAX_HASH - # ]).min(axis=0) - if len(tokens) > 0: - hv = np.array( - [sha1_hash32(token) for token in tokens], - dtype=np.uint64 - ) - phv = ( - (hv[:, None] * self.perm_a[None, :] - + self.perm_b) % MERSENNE_PRIME - ).astype(np.uint32) - hash_values = phv.min(axis=0) - else: - hash_values = np.full_like(self.perm_a, MAX_HASH, dtype=np.uint32) - return [ - bytes(hash_values[start:end].byteswap().data) + - start.to_bytes(4, byteorder='little') - for start, end in self.hash_ranges - # groupby minhash||brand_id - ] diff --git a/docs/Operators.md b/docs/Operators.md index f1a20c9ef..0c8d708e6 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -174,6 +174,7 @@ All the specific operators are listed below, each featured with several capabili | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | | ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray and Redis | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | +| ray_bts_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level using MinHashLSH based on Ray | [code](../data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Deduplicates samples at document-level by comparing MD5 hash on ray | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Deduplicates samples at document-level using exact matching of images between documents on ray | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Deduplicates samples at document-level using exact matching of videos between documents on ray | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index b1194f250..57b009238 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -172,7 +172,8 @@ Data-Juicer 中的算子分为以下 5 种类型。 | document_simhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 SimHash 在文档级别对样本去重 | [code](../data_juicer/ops/deduplicator/document_simhash_deduplicator.py) | [tests](../tests/ops/deduplicator/test_document_simhash_deduplicator.py) | | image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/image_deduplicator.py) | [tests](../tests/ops/deduplicator/test_image_deduplicator.py) | | video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本 | [code](../data_juicer/ops/deduplicator/video_deduplicator.py) | [tests](../tests/ops/deduplicator/test_video_deduplicator.py) | -| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式 | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | +| ray_redis_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式(基于Redis) | [code](../data_juicer/ops/deduplicator/ray_redis_minhash_deduplicator.py) | - | +| ray_bts_minhash_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用 MinHashLSH 在文档级别对样本去重,面向 RAY 分布式模式 | [code](../data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py) | - | | ray_document_deduplicator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较 MD5 哈希值在文档级别对样本去重,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_document_deduplicator.py) | - | | ray_image_deduplicator | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 使用文档之间图像的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_image_deduplicator.py) | - | | ray_video_deduplicator | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 使用文档之间视频的精确匹配在文档级别删除重复样本,面向RAY分布式模式 | [code](../data_juicer/ops/deduplicator/ray_video_deduplicator.py) | - | From 8dda1aafcd9477515b008cfc7a94d483df68a66d Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 17 Dec 2024 06:12:47 +0000 Subject: [PATCH 18/22] update docs --- data_juicer/core/ray_data.py | 14 +++++----- .../ray_bts_minhash_deduplicator.py | 28 +++++++++++++++---- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 54418441f..621e68cd9 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -61,15 +61,15 @@ def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - # if Fields.stats not in columns: + if Fields.stats not in columns: - # def process_batch_arrow(table: pa.Table) -> pa.Table: - # new_column_data = [{} for _ in range(len(table))] - # new_talbe = table.append_column(Fields.stats, [new_column_data]) - # return new_talbe + def process_batch_arrow(table: pa.Table) -> pa.Table: + new_column_data = [{} for _ in range(len(table))] + new_talbe = table.append_column(Fields.stats, [new_column_data]) + return new_talbe - # dataset = dataset.map_batches(process_batch_arrow, - # batch_format='pyarrow') + dataset = dataset.map_batches(process_batch_arrow, + batch_format='pyarrow') return dataset diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 20e998e0f..ba87edda9 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -7,13 +7,14 @@ import ray import numpy as np import pyarrow as pa +import pyarrow.parquet as pq import regex from loguru import logger from pydantic import Field, PositiveInt from typing_extensions import Annotated from typing import List, Union -from data_juicer.utils.constant import HashKeys +from data_juicer.utils.constant import HashKeys, Fields from data_juicer.utils.model_utils import prepare_sentencepiece_model from ..base_op import OPERATORS, Deduplicator @@ -54,6 +55,12 @@ def get_edges(self, key): @ray.remote(scheduling_strategy="SPREAD") class BTSUnionFind: + """ + A distributed implementation of Union-Find with load balancing. + + The original paper on BTS Union-Find is available at: + https://ieeexplore.ieee.org/document/10598116 + """ def __init__( self, union_threshold, @@ -438,6 +445,7 @@ def tokenization_func(text): self.tmp_file_name = os.path.join( os.getcwd(), tmp_file_name, str(uuid.uuid4()) ) + os.makedirs(self.tmp_file_name) empty_hash_value = np.full( (self.num_rows_per_band,), @@ -577,16 +585,24 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: HashKeys.uid, pa.array(list(uid_list)) ) - return new_table + if not new_table[Fields.stats][0].as_py(): + columns_to_keep = [ + name + for name in new_table.column_names + if name != Fields.stats + ] + new_table = new_table.select(columns_to_keep) + pq.write_table( + new_table, + os.path.join(self.tmp_file_name, f'{min_id}.parquet') + ) + return pa.Table.from_arrays([]) dataset.map_batches( minhash_with_uid, batch_format='pyarrow', zero_copy_batch=True, - ).write_parquet( - self.tmp_file_name, - force_ascii=False - ) # TODO: balance file size + ).materialize() dataset = ray.data.read_parquet(self.tmp_file_name) end_time = time.time() print(f'MinHash time = {end_time - start_time}') From 9d893be688a7e9ad719c3778620b9ecb204db3b8 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 20 Dec 2024 07:15:21 +0000 Subject: [PATCH 19/22] fix in Fields.stats --- data_juicer/core/ray_data.py | 20 +++++++++---------- data_juicer/core/ray_executor.py | 4 ++-- .../ray_bts_minhash_deduplicator.py | 18 +++++------------ 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 621e68cd9..2966e75e8 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -58,18 +58,8 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg): def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: - columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - if Fields.stats not in columns: - - def process_batch_arrow(table: pa.Table) -> pa.Table: - new_column_data = [{} for _ in range(len(table))] - new_talbe = table.append_column(Fields.stats, [new_column_data]) - return new_talbe - - dataset = dataset.map_batches(process_batch_arrow, - batch_format='pyarrow') return dataset @@ -123,6 +113,16 @@ def _run_single_op(self, op): batch_format='pyarrow', num_gpus=num_gpus) elif isinstance(op, Filter): + columns = self.data.columns() + if Fields.stats not in columns: + + def process_batch_arrow(table: pa.Table) -> pa.Table: + new_column_data = [{} for _ in range(len(table))] + new_talbe = table.append_column(Fields.stats, [new_column_data]) + return new_talbe + + self.data = self.data.map_batches(process_batch_arrow, + batch_format='pyarrow') self.data = self.data.map_batches(op.compute_stats, batch_size=batch_size, batch_format='pyarrow', diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 6b93fd3dd..f146ffc02 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -68,10 +68,10 @@ def run(self, load_data_np=None): logger.info('Processing data...') tstart = time.time() dataset.process(ops) - tend = time.time() - logger.info(f'All Ops are done in {tend - tstart:.3f}s.') # 4. data export logger.info('Exporting dataset to disk...') dataset.data.write_json(self.cfg.export_path, force_ascii=False) + tend = time.time() + logger.info(f'All Ops are done in {tend - tstart:.3f}s.') return dataset diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index ba87edda9..676aa0185 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -585,24 +585,16 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: HashKeys.uid, pa.array(list(uid_list)) ) - if not new_table[Fields.stats][0].as_py(): - columns_to_keep = [ - name - for name in new_table.column_names - if name != Fields.stats - ] - new_table = new_table.select(columns_to_keep) - pq.write_table( - new_table, - os.path.join(self.tmp_file_name, f'{min_id}.parquet') - ) - return pa.Table.from_arrays([]) + return new_table dataset.map_batches( minhash_with_uid, batch_format='pyarrow', zero_copy_batch=True, - ).materialize() + ).write_parquet( + self.tmp_file_name, + force_ascii=False + ) dataset = ray.data.read_parquet(self.tmp_file_name) end_time = time.time() print(f'MinHash time = {end_time - start_time}') From 897032e002545273bd981b0a7189c6f4d17c7cf1 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 20 Dec 2024 10:03:38 +0000 Subject: [PATCH 20/22] code format --- data_juicer/core/ray_data.py | 2 -- data_juicer/ops/deduplicator/__init__.py | 2 +- .../ray_bts_minhash_deduplicator.py | 28 +++++++++---------- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 48e26827d..65babcdda 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,6 +1,4 @@ from __future__ import annotations -import os -from functools import partial import os from functools import partial diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 71c5e4863..29967770d 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -3,9 +3,9 @@ from .document_simhash_deduplicator import DocumentSimhashDeduplicator from .image_deduplicator import ImageDeduplicator from .ray_basic_deduplicator import RayBasicDeduplicator +from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator from .ray_document_deduplicator import RayDocumentDeduplicator from .ray_image_deduplicator import RayImageDeduplicator -from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator from .ray_video_deduplicator import RayVideoDeduplicator from .video_deduplicator import VideoDeduplicator diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 676aa0185..55151906c 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -1,20 +1,17 @@ +import os import time import uuid -from collections import defaultdict -from typing import Optional +from typing import List, Optional, Union -import os -import ray import numpy as np import pyarrow as pa -import pyarrow.parquet as pq +import ray import regex from loguru import logger from pydantic import Field, PositiveInt from typing_extensions import Annotated -from typing import List, Union -from data_juicer.utils.constant import HashKeys, Fields +from data_juicer.utils.constant import HashKeys from data_juicer.utils.model_utils import prepare_sentencepiece_model from ..base_op import OPERATORS, Deduplicator @@ -22,13 +19,12 @@ from .document_minhash_deduplicator import (MAX_HASH, MERSENNE_PRIME, optimal_param, sha1_hash32) - BATCH_SIZE = 1000 @ray.remote class IdGenerator: - def __init__(self, start_id = 0): + def __init__(self, start_id=0): self.next_id = start_id @ray.method(num_returns=2) @@ -38,7 +34,7 @@ def get_next_id(self, count): return (current_id, self.next_id) -@ray.remote(scheduling_strategy="SPREAD") +@ray.remote(scheduling_strategy='SPREAD') class EdgeBuffer: def __init__(self): self.edge_dict = {} @@ -53,7 +49,7 @@ def get_edges(self, key): return self.edge_dict.pop(key, []) -@ray.remote(scheduling_strategy="SPREAD") +@ray.remote(scheduling_strategy='SPREAD') class BTSUnionFind: """ A distributed implementation of Union-Find with load balancing. @@ -161,8 +157,9 @@ def communication(self): del_list = [] for u, v in self.parent.items(): hash_u = u // BATCH_SIZE % self.parallel_num - if self.parent[u] != self.old_parent.get(u, u) or \ - (hash_u != self.parallel_id and v not in self.parent): + if self.parent[u] != self.old_parent.get(u, u) or ( + hash_u != self.parallel_id and v not in self.parent + ): self.distribute_edge(u, v) if hash_u != self.parallel_id: del_list.append(u) @@ -311,7 +308,7 @@ def __init__( :param union_find_parallel_num: number of parallel workers for union-find algorithm. Default it's 'auto', and it will be determined by half of the number of CPUs. - :param union_threshold: threshold for minhash values group to + :param union_threshold: threshold for minhash values group to perform union-find algorightm. Default it's 256. :param max_pending_edge_buffer_task: max number of pending edge buffer ray tasks. Default it's 20. @@ -476,7 +473,7 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: dtype=np.uint64 ) phv = ( - (hv[:, None] * self.perm_a[None, :] + (hv[:, None] * self.perm_a[None, :] + self.perm_b) % MERSENNE_PRIME ).astype(np.uint32) hash_values = phv.min(axis=0) @@ -574,6 +571,7 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: def run(self, dataset): start_time = time.time() id_generator = IdGenerator.remote() + def minhash_with_uid(table: pa.Table) -> pa.Table: num_rows = len(table) min_id, max_id = ray.get( From 87030770d3767c1587ecb6457ec83d897b9ac1c7 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 24 Dec 2024 17:02:18 +0800 Subject: [PATCH 21/22] code format --- data_juicer/core/ray_data.py | 5 +- data_juicer/ops/deduplicator/__init__.py | 15 ++- .../ray_bts_minhash_deduplicator.py | 126 +++++++----------- 3 files changed, 59 insertions(+), 87 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 65babcdda..5375fe2a6 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -119,12 +119,11 @@ def _run_single_op(self, op): elif isinstance(op, Filter): columns = self.data.columns() if Fields.stats not in columns: + def process_batch_arrow(table: pyarrow.Table): new_column_data = [{} for _ in range(len(table))] new_talbe = table.append_column( - Fields.stats, - [new_column_data] - ) + Fields.stats, [new_column_data]) return new_talbe self.data = self.data.map_batches(process_batch_arrow, diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 29967770d..494ac099b 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -10,8 +10,15 @@ from .video_deduplicator import VideoDeduplicator __all__ = [ - 'DocumentDeduplicator', 'DocumentMinhashDeduplicator', - 'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator', - 'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator', - 'RayImageDeduplicator', 'RayBTSMinhashDeduplicator', 'VideoDeduplicator', + 'DocumentDeduplicator', + 'DocumentMinhashDeduplicator', + 'DocumentSimhashDeduplicator', + 'ImageDeduplicator', + 'RayBasicDeduplicator', + 'RayDocumentDeduplicator', + 'RayImageDeduplicator', + 'RayVideoDeduplicator', + 'RayImageDeduplicator', + 'RayBTSMinhashDeduplicator', + 'VideoDeduplicator', ] diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 55151906c..b54fa56ce 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -24,6 +24,7 @@ @ray.remote class IdGenerator: + def __init__(self, start_id=0): self.next_id = start_id @@ -36,6 +37,7 @@ def get_next_id(self, count): @ray.remote(scheduling_strategy='SPREAD') class EdgeBuffer: + def __init__(self): self.edge_dict = {} @@ -57,6 +59,7 @@ class BTSUnionFind: The original paper on BTS Union-Find is available at: https://ieeexplore.ieee.org/document/10598116 """ + def __init__( self, union_threshold, @@ -100,17 +103,14 @@ def balanced_union_find(self): for remote_edge_buffer in self.remote_edge_buffers: if len(result_refs) > self.max_pending_edge_buffer_task: ready_refs, result_refs = ray.wait( - result_refs, - num_returns=self.num_edge_buffer_task_returns - ) + result_refs, num_returns=self.num_edge_buffer_task_returns) edge_list = ray.get(ready_refs) for edges in edge_list: for x, y in edges: self.union(x, y) del ready_refs result_refs.append( - remote_edge_buffer.get_edges.remote(self.parallel_id) - ) + remote_edge_buffer.get_edges.remote(self.parallel_id)) edge_list = ray.get(result_refs) for edges in edge_list: for x, y in edges: @@ -136,11 +136,8 @@ def set_edge_buffer(self): del self.edge_list_dict[self.parallel_id] else: self.edge_buffer = [] - ray.get( - self.remote_edge_buffers[self.parallel_id].set_edges.remote( - self.edge_list_dict - ) - ) + ray.get(self.remote_edge_buffers[self.parallel_id].set_edges.remote( + self.edge_list_dict)) self.edge_list_dict = {} def edge_redistribution(self): @@ -158,8 +155,7 @@ def communication(self): for u, v in self.parent.items(): hash_u = u // BATCH_SIZE % self.parallel_num if self.parent[u] != self.old_parent.get(u, u) or ( - hash_u != self.parallel_id and v not in self.parent - ): + hash_u != self.parallel_id and v not in self.parent): self.distribute_edge(u, v) if hash_u != self.parallel_id: del_list.append(u) @@ -231,11 +227,7 @@ def squeeze(self): ray.get(self.remote_edge_buffers[self.parallel_id].clear.remote()) def dup_idx(self, queries): - return [ - idx - for uid, idx in queries - if uid in self.parent - ] + return [idx for uid, idx in queries if uid in self.parent] OP_NAME = 'ray_bts_minhash_deduplicator' @@ -347,12 +339,14 @@ def __init__( self.tokenizer = None if self.tokenization == 'character': + def tokenization_func(text): return { str.encode(text[i:i + self.window_size]) for i in range(len(text) - self.window_size) } elif self.tokenization == 'punctuation': + def tokenization_func(text): tokens = self.punctuation_pattern.split(text) return { @@ -360,6 +354,7 @@ def tokenization_func(text): for i in range(len(tokens) - self.window_size) } elif self.tokenization == 'space': + def tokenization_func(text): tokens = split_on_whitespace(text) return { @@ -367,6 +362,7 @@ def tokenization_func(text): for i in range(len(tokens) - self.window_size) } elif self.tokenization == 'sentencepiece': + def tokenization_func(text): tokens = self.tokenizer.encode(text, out_type=str) return { @@ -408,9 +404,8 @@ def tokenization_func(text): ).T if union_find_parallel_num == 'auto': - union_find_parallel_num = int( - ray.cluster_resources().get('CPU') / 2 - ) + union_find_parallel_num = int(ray.cluster_resources().get('CPU') / + 2) else: union_find_parallel_num = int(union_find_parallel_num) @@ -424,8 +419,7 @@ def tokenization_func(text): self.union_find_parallel_num = union_find_parallel_num self.union_threshold = union_threshold self.remote_edge_buffers = [ - EdgeBuffer.remote() - for _ in range(self.union_find_parallel_num) + EdgeBuffer.remote() for _ in range(self.union_find_parallel_num) ] self.union_find_list = [ BTSUnionFind.remote( @@ -435,25 +429,19 @@ def tokenization_func(text): self.remote_edge_buffers, self.max_pending_edge_buffer_task, self.num_edge_buffer_task_returns, - ) - for i in range(self.union_find_parallel_num) + ) for i in range(self.union_find_parallel_num) ] - self.tmp_file_name = os.path.join( - os.getcwd(), tmp_file_name, str(uuid.uuid4()) - ) + self.tmp_file_name = os.path.join(os.getcwd(), tmp_file_name, + str(uuid.uuid4())) os.makedirs(self.tmp_file_name) - empty_hash_value = np.full( - (self.num_rows_per_band,), - MAX_HASH, - dtype=np.uint32 - ) + empty_hash_value = np.full((self.num_rows_per_band, ), + MAX_HASH, + dtype=np.uint32) self.empty_hash_value = b'\x00\x00\x00\x00' \ + empty_hash_value.tobytes() - self.empty_hash_table_id = int( - MAX_HASH % self.union_find_parallel_num - ) + self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num) def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: pairs = {} @@ -468,14 +456,10 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: tokens = self.tokenization_func(text) if len(tokens) > 0: - hv = np.array( - [sha1_hash32(token) for token in tokens], - dtype=np.uint64 - ) - phv = ( - (hv[:, None] * self.perm_a[None, :] - + self.perm_b) % MERSENNE_PRIME - ).astype(np.uint32) + hv = np.array([sha1_hash32(token) for token in tokens], + dtype=np.uint64) + phv = ((hv[:, None] * self.perm_a[None, :] + self.perm_b) % + MERSENNE_PRIME).astype(np.uint32) hash_values = phv.min(axis=0) for i, (start, end) in enumerate(self.hash_ranges): hash_value = i.to_bytes(4, 'big') \ @@ -489,28 +473,24 @@ def calc_minhash(self, text_list: pa.Array, uid_list: List) -> pa.Table: if self.empty_hash_table_id not in pairs: pairs[self.empty_hash_table_id] = [] pairs[self.empty_hash_table_id].append( - (self.empty_hash_value, uid) - ) + (self.empty_hash_value, uid)) result_refs = [] for i, p in pairs.items(): if len(result_refs) > self.max_pending_filter_tasks: ready_refs, result_refs = ray.wait( - result_refs, - num_returns=self.num_filter_task_returns - ) + result_refs, num_returns=self.num_filter_task_returns) ray.get(ready_refs) result_refs.append( - self.union_find_list[i].add_key_value_pairs.remote(p) - ) + self.union_find_list[i].add_key_value_pairs.remote(p)) ray.get(result_refs) def merge_op_batch(self, object_refs): results = [] while object_refs: - ready_refs, object_refs = ray.wait( - object_refs, - num_returns=min(self.merge_batch_size, len(object_refs)) - ) + ready_refs, object_refs = ray.wait(object_refs, + num_returns=min( + self.merge_batch_size, + len(object_refs))) results.extend(ray.get(ready_refs)) return results @@ -520,18 +500,16 @@ def merge(self): for union_find in self.union_find_list ]) while any( - self.merge_op_batch([ - union_find.balanced_union_find.remote() - for union_find in self.union_find_list - ]) - ): + self.merge_op_batch([ + union_find.balanced_union_find.remote() + for union_find in self.union_find_list + ])): self.merge_op_batch([ union_find.communication.remote() for union_find in self.union_find_list ]) self.merge_op_batch([ - union_find.squeeze.remote() - for union_find in self.union_find_list + union_find.squeeze.remote() for union_find in self.union_find_list ]) def filter_with_union_find(self, samples: pa.Table) -> pa.Table: @@ -547,24 +525,19 @@ def filter_with_union_find(self, samples: pa.Table) -> pa.Table: for hash_id, query in query_dict.items(): if len(result_refs) > self.max_pending_filter_tasks: ready_refs, result_refs = ray.wait( - result_refs, - num_returns=self.num_filter_task_returns - ) + result_refs, num_returns=self.num_filter_task_returns) results = ray.get(ready_refs) for result in results: mask[result] = False del ready_refs result_refs.append( - self.union_find_list[hash_id].dup_idx.remote(query) - ) + self.union_find_list[hash_id].dup_idx.remote(query)) results = ray.get(result_refs) for result in results: mask[result] = False del query_dict, results columns_to_keep = [ - name - for name in samples.column_names - if name != HashKeys.uid + name for name in samples.column_names if name != HashKeys.uid ] return samples.select(columns_to_keep).filter(mask) @@ -574,25 +547,18 @@ def run(self, dataset): def minhash_with_uid(table: pa.Table) -> pa.Table: num_rows = len(table) - min_id, max_id = ray.get( - id_generator.get_next_id.remote(num_rows) - ) + min_id, max_id = ray.get(id_generator.get_next_id.remote(num_rows)) uid_list = range(min_id, max_id) self.calc_minhash(table[self.text_key], uid_list) - new_table = table.append_column( - HashKeys.uid, - pa.array(list(uid_list)) - ) + new_table = table.append_column(HashKeys.uid, + pa.array(list(uid_list))) return new_table dataset.map_batches( minhash_with_uid, batch_format='pyarrow', zero_copy_batch=True, - ).write_parquet( - self.tmp_file_name, - force_ascii=False - ) + ).write_parquet(self.tmp_file_name, force_ascii=False) dataset = ray.data.read_parquet(self.tmp_file_name) end_time = time.time() print(f'MinHash time = {end_time - start_time}') From 4e3b59852682d92f3f13984a330f57014d83fd71 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 24 Dec 2024 09:12:19 +0000 Subject: [PATCH 22/22] fix requires --- environments/dist_requires.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environments/dist_requires.txt b/environments/dist_requires.txt index b6ab28d06..90f28ef7a 100644 --- a/environments/dist_requires.txt +++ b/environments/dist_requires.txt @@ -1,2 +1,2 @@ -ray<=2.38.0 +ray>=2.31.0 redis>=5.0.0