From feb2836c47a222d09f9af139a242b15a1840c220 Mon Sep 17 00:00:00 2001 From: Ceceliachenen Date: Mon, 9 Sep 2024 20:33:21 +0800 Subject: [PATCH] add multi headings (#207) * add multi headings * add multi headings * add multi headings * add multi headings --- .../integrations/readers/pai_pdf_reader.py | 73 +++++++++++++++++++ .../test_post_process_multi_level_headings.py | 54 ++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 tests/utils/test_post_process_multi_level_headings.py diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py index 1671136d..9436bce9 100644 --- a/src/pai_rag/integrations/readers/pai_pdf_reader.py +++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py @@ -19,6 +19,7 @@ from PIL import Image from rapidocr_onnxruntime import RapidOCR from rapid_table import RapidTable +from operator import itemgetter import logging import os @@ -37,6 +38,7 @@ PAGE_TABLE_SUMMARY_MAX_TOKEN = 400 IMAGE_URL_PATTERN = r"(https?://[^\s]+?[\s\w.-]*\.(jpg|jpeg|png|gif|bmp))" IMAGE_COMBINED_PATTERN = r"!\[.*?\]\((https?://[^\s()]+|/[^()\s]+(?:\s[^()\s]*)?/\S*?\.(jpg|jpeg|png|gif|bmp))\)" +DEFAULT_HEADING_DIFF_THRESHOLD = 2 class PaiPDFReader(BaseReader): @@ -277,6 +279,74 @@ def process_table(self, markdown_content, json_data): print(f"警告:图片文件不存在 {img_path}") return markdown_content + def post_process_multi_level_headings(self, json_data, md_content): + logger.info( + "*****************************start process headings*****************************" + ) + pages_list = json_data["pdf_info"] + if not pages_list: + return md_content + text_height_min = float("inf") + text_height_max = 0 + title_list = [] + for page in pages_list: + page_infos = page["preproc_blocks"] + for item in page_infos: + if not item.get("lines", None) or len(item["lines"]) <= 0: + continue + x0, y0, x1, y1 = item["lines"][0]["bbox"] + content_height = y1 - y0 + if item["type"] == "title": + title_height = int(content_height) + title_text = "" + for line in item["lines"]: + for span in line["spans"]: + if span["type"] == "inline_equation": + span["content"] = " $" + span["content"] + "$ " + title_text += span["content"] + title_text = title_text.replace("\\", "\\\\") + title_list.append((title_text, title_height)) + elif item["type"] == "text": + if content_height < text_height_min: + text_height_min = content_height + if content_height > text_height_max: + text_height_max = content_height + + sorted_list = sorted(title_list, key=itemgetter(1), reverse=True) + diff_list = [ + (sorted_list[i][1] - sorted_list[i + 1][1], i) + for i in range(len(sorted_list) - 1) + ] + sorted_diff = sorted(diff_list, key=itemgetter(0), reverse=True) + slice_index = [] + for diff, index in sorted_diff: + # 标题差的绝对值超过2,则认为是下一级标题 + # markdown 中,# 表示一级标题,## 表示二级标题,以此类推,最多有6级标题,最多能有5次切分 + if diff >= DEFAULT_HEADING_DIFF_THRESHOLD and len(slice_index) <= 5: + slice_index.append(index) + slice_index.sort(reverse=True) + rank = 1 + cur_index = 0 + if len(slice_index) > 0: + cur_index = slice_index.pop() + for index, (title_text, title_height) in enumerate(sorted_list): + if index > cur_index: + rank += 1 + if len(slice_index) > 0: + cur_index = slice_index.pop() + else: + cur_index = len(sorted_list) - 1 + title_level = "#" * rank + " " + if text_height_min <= text_height_max and int( + text_height_min + ) <= title_height <= int(text_height_max): + title_level = "" + old_title = "# " + title_text + new_title = title_level + title_text + md_content = re.sub(re.escape(old_title), new_title, md_content) + + return md_content + def parse_pdf( self, pdf_path: str, @@ -334,6 +404,9 @@ def parse_pdf( pipe.pipe_parse() content_list = pipe.pipe_mk_uni_format(temp_file_path, drop_mode="none") md_content = pipe.pipe_mk_markdown(temp_file_path, drop_mode="none") + md_content = self.post_process_multi_level_headings( + pipe.pdf_mid_data, md_content + ) md_content = self.process_table(md_content, content_list) new_md_content = self.replace_image_paths(pdf_name, md_content) diff --git a/tests/utils/test_post_process_multi_level_headings.py b/tests/utils/test_post_process_multi_level_headings.py new file mode 100644 index 00000000..ccdccd80 --- /dev/null +++ b/tests/utils/test_post_process_multi_level_headings.py @@ -0,0 +1,54 @@ +import json +from pai_rag.integrations.readers.pai_pdf_reader import PaiPDFReader + +json_str = r"""{ + "pdf_info": [ + { + "preproc_blocks": [ + { + "type": "title", + "bbox": [ + 59, + 67, + 196, + 89 + ], + "lines": [ + { + "bbox": [ + 57, + 70, + 72, + 85 + ], + "spans": [ + { + "bbox": [ + 57, + 70, + 72, + 85 + ], + "score": 0.75, + "content": "\\leftarrow", + "type": "inline_equation" + } + ] + } + ] + } + ] + } + ] +}""" + +md_str = """# $\leftarrow$ """ + + +def test_post_process_multi_level_headings(): + pdf_process = PaiPDFReader() + json_content = json.loads(json_str) + md_content_escape = pdf_process.post_process_multi_level_headings( + json_content, md_str + ) + assert md_content_escape == "# $\leftarrow$ "